Coverage for /Users/davegaeddert/Developer/dropseed/plain/plain-models/plain/models/sql/where.py: 46%
211 statements
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-23 11:16 -0600
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-23 11:16 -0600
1"""
2Code to manage the creation and SQL rendering of 'where' constraints.
3"""
5import operator
6from functools import reduce
8from plain.exceptions import EmptyResultSet, FullResultSet
9from plain.models.expressions import Case, When
10from plain.models.lookups import Exact
11from plain.utils import tree
12from plain.utils.functional import cached_property
14# Connection types
15AND = "AND"
16OR = "OR"
17XOR = "XOR"
20class WhereNode(tree.Node):
21 """
22 An SQL WHERE clause.
24 The class is tied to the Query class that created it (in order to create
25 the correct SQL).
27 A child is usually an expression producing boolean values. Most likely the
28 expression is a Lookup instance.
30 However, a child could also be any class with as_sql() and either
31 relabeled_clone() method or relabel_aliases() and clone() methods and
32 contains_aggregate attribute.
33 """
35 default = AND
36 resolved = False
37 conditional = True
39 def split_having_qualify(self, negated=False, must_group_by=False):
40 """
41 Return three possibly None nodes: one for those parts of self that
42 should be included in the WHERE clause, one for those parts of self
43 that must be included in the HAVING clause, and one for those parts
44 that refer to window functions.
45 """
46 if not self.contains_aggregate and not self.contains_over_clause:
47 return self, None, None
48 in_negated = negated ^ self.negated
49 # Whether or not children must be connected in the same filtering
50 # clause (WHERE > HAVING > QUALIFY) to maintain logical semantic.
51 must_remain_connected = (
52 (in_negated and self.connector == AND)
53 or (not in_negated and self.connector == OR)
54 or self.connector == XOR
55 )
56 if (
57 must_remain_connected
58 and self.contains_aggregate
59 and not self.contains_over_clause
60 ):
61 # It's must cheaper to short-circuit and stash everything in the
62 # HAVING clause than split children if possible.
63 return None, self, None
64 where_parts = []
65 having_parts = []
66 qualify_parts = []
67 for c in self.children:
68 if hasattr(c, "split_having_qualify"):
69 where_part, having_part, qualify_part = c.split_having_qualify(
70 in_negated, must_group_by
71 )
72 if where_part is not None:
73 where_parts.append(where_part)
74 if having_part is not None:
75 having_parts.append(having_part)
76 if qualify_part is not None:
77 qualify_parts.append(qualify_part)
78 elif c.contains_over_clause:
79 qualify_parts.append(c)
80 elif c.contains_aggregate:
81 having_parts.append(c)
82 else:
83 where_parts.append(c)
84 if must_remain_connected and qualify_parts:
85 # Disjunctive heterogeneous predicates can be pushed down to
86 # qualify as long as no conditional aggregation is involved.
87 if not where_parts or (where_parts and not must_group_by):
88 return None, None, self
89 elif where_parts:
90 # In theory this should only be enforced when dealing with
91 # where_parts containing predicates against multi-valued
92 # relationships that could affect aggregation results but this
93 # is complex to infer properly.
94 raise NotImplementedError(
95 "Heterogeneous disjunctive predicates against window functions are "
96 "not implemented when performing conditional aggregation."
97 )
98 where_node = (
99 self.create(where_parts, self.connector, self.negated)
100 if where_parts
101 else None
102 )
103 having_node = (
104 self.create(having_parts, self.connector, self.negated)
105 if having_parts
106 else None
107 )
108 qualify_node = (
109 self.create(qualify_parts, self.connector, self.negated)
110 if qualify_parts
111 else None
112 )
113 return where_node, having_node, qualify_node
115 def as_sql(self, compiler, connection):
116 """
117 Return the SQL version of the where clause and the value to be
118 substituted in. Return '', [] if this node matches everything,
119 None, [] if this node is empty, and raise EmptyResultSet if this
120 node can't match anything.
121 """
122 result = []
123 result_params = []
124 if self.connector == AND:
125 full_needed, empty_needed = len(self.children), 1
126 else:
127 full_needed, empty_needed = 1, len(self.children)
129 if self.connector == XOR and not connection.features.supports_logical_xor:
130 # Convert if the database doesn't support XOR:
131 # a XOR b XOR c XOR ...
132 # to:
133 # (a OR b OR c OR ...) AND (a + b + c + ...) == 1
134 lhs = self.__class__(self.children, OR)
135 rhs_sum = reduce(
136 operator.add,
137 (Case(When(c, then=1), default=0) for c in self.children),
138 )
139 rhs = Exact(1, rhs_sum)
140 return self.__class__([lhs, rhs], AND, self.negated).as_sql(
141 compiler, connection
142 )
144 for child in self.children:
145 try:
146 sql, params = compiler.compile(child)
147 except EmptyResultSet:
148 empty_needed -= 1
149 except FullResultSet:
150 full_needed -= 1
151 else:
152 if sql:
153 result.append(sql)
154 result_params.extend(params)
155 else:
156 full_needed -= 1
157 # Check if this node matches nothing or everything.
158 # First check the amount of full nodes and empty nodes
159 # to make this node empty/full.
160 # Now, check if this node is full/empty using the
161 # counts.
162 if empty_needed == 0:
163 if self.negated:
164 raise FullResultSet
165 else:
166 raise EmptyResultSet
167 if full_needed == 0:
168 if self.negated:
169 raise EmptyResultSet
170 else:
171 raise FullResultSet
172 conn = f" {self.connector} "
173 sql_string = conn.join(result)
174 if not sql_string:
175 raise FullResultSet
176 if self.negated:
177 # Some backends (Oracle at least) need parentheses around the inner
178 # SQL in the negated case, even if the inner SQL contains just a
179 # single expression.
180 sql_string = f"NOT ({sql_string})"
181 elif len(result) > 1 or self.resolved:
182 sql_string = f"({sql_string})"
183 return sql_string, result_params
185 def get_group_by_cols(self):
186 cols = []
187 for child in self.children:
188 cols.extend(child.get_group_by_cols())
189 return cols
191 def get_source_expressions(self):
192 return self.children[:]
194 def set_source_expressions(self, children):
195 assert len(children) == len(self.children)
196 self.children = children
198 def relabel_aliases(self, change_map):
199 """
200 Relabel the alias values of any children. 'change_map' is a dictionary
201 mapping old (current) alias values to the new values.
202 """
203 for pos, child in enumerate(self.children):
204 if hasattr(child, "relabel_aliases"):
205 # For example another WhereNode
206 child.relabel_aliases(change_map)
207 elif hasattr(child, "relabeled_clone"):
208 self.children[pos] = child.relabeled_clone(change_map)
210 def clone(self):
211 clone = self.create(connector=self.connector, negated=self.negated)
212 for child in self.children:
213 if hasattr(child, "clone"):
214 child = child.clone()
215 clone.children.append(child)
216 return clone
218 def relabeled_clone(self, change_map):
219 clone = self.clone()
220 clone.relabel_aliases(change_map)
221 return clone
223 def replace_expressions(self, replacements):
224 if replacement := replacements.get(self):
225 return replacement
226 clone = self.create(connector=self.connector, negated=self.negated)
227 for child in self.children:
228 clone.children.append(child.replace_expressions(replacements))
229 return clone
231 def get_refs(self):
232 refs = set()
233 for child in self.children:
234 refs |= child.get_refs()
235 return refs
237 @classmethod
238 def _contains_aggregate(cls, obj):
239 if isinstance(obj, tree.Node):
240 return any(cls._contains_aggregate(c) for c in obj.children)
241 return obj.contains_aggregate
243 @cached_property
244 def contains_aggregate(self):
245 return self._contains_aggregate(self)
247 @classmethod
248 def _contains_over_clause(cls, obj):
249 if isinstance(obj, tree.Node):
250 return any(cls._contains_over_clause(c) for c in obj.children)
251 return obj.contains_over_clause
253 @cached_property
254 def contains_over_clause(self):
255 return self._contains_over_clause(self)
257 @property
258 def is_summary(self):
259 return any(child.is_summary for child in self.children)
261 @staticmethod
262 def _resolve_leaf(expr, query, *args, **kwargs):
263 if hasattr(expr, "resolve_expression"):
264 expr = expr.resolve_expression(query, *args, **kwargs)
265 return expr
267 @classmethod
268 def _resolve_node(cls, node, query, *args, **kwargs):
269 if hasattr(node, "children"):
270 for child in node.children:
271 cls._resolve_node(child, query, *args, **kwargs)
272 if hasattr(node, "lhs"):
273 node.lhs = cls._resolve_leaf(node.lhs, query, *args, **kwargs)
274 if hasattr(node, "rhs"):
275 node.rhs = cls._resolve_leaf(node.rhs, query, *args, **kwargs)
277 def resolve_expression(self, *args, **kwargs):
278 clone = self.clone()
279 clone._resolve_node(clone, *args, **kwargs)
280 clone.resolved = True
281 return clone
283 @cached_property
284 def output_field(self):
285 from plain.models.fields import BooleanField
287 return BooleanField()
289 @property
290 def _output_field_or_none(self):
291 return self.output_field
293 def select_format(self, compiler, sql, params):
294 # Wrap filters with a CASE WHEN expression if a database backend
295 # (e.g. Oracle) doesn't support boolean expression in SELECT or GROUP
296 # BY list.
297 if not compiler.connection.features.supports_boolean_expr_in_select_clause:
298 sql = f"CASE WHEN {sql} THEN 1 ELSE 0 END"
299 return sql, params
301 def get_db_converters(self, connection):
302 return self.output_field.get_db_converters(connection)
304 def get_lookup(self, lookup):
305 return self.output_field.get_lookup(lookup)
307 def leaves(self):
308 for child in self.children:
309 if isinstance(child, WhereNode):
310 yield from child.leaves()
311 else:
312 yield child
315class NothingNode:
316 """A node that matches nothing."""
318 contains_aggregate = False
319 contains_over_clause = False
321 def as_sql(self, compiler=None, connection=None):
322 raise EmptyResultSet
325class ExtraWhere:
326 # The contents are a black box - assume no aggregates or windows are used.
327 contains_aggregate = False
328 contains_over_clause = False
330 def __init__(self, sqls, params):
331 self.sqls = sqls
332 self.params = params
334 def as_sql(self, compiler=None, connection=None):
335 sqls = [f"({sql})" for sql in self.sqls]
336 return " AND ".join(sqls), list(self.params or ())
339class SubqueryConstraint:
340 # Even if aggregates or windows would be used in a subquery,
341 # the outer query isn't interested about those.
342 contains_aggregate = False
343 contains_over_clause = False
345 def __init__(self, alias, columns, targets, query_object):
346 self.alias = alias
347 self.columns = columns
348 self.targets = targets
349 query_object.clear_ordering(clear_default=True)
350 self.query_object = query_object
352 def as_sql(self, compiler, connection):
353 query = self.query_object
354 query.set_values(self.targets)
355 query_compiler = query.get_compiler(connection=connection)
356 return query_compiler.as_subquery_condition(self.alias, self.columns, compiler)