Coverage for /Users/davegaeddert/Developer/dropseed/plain/plain-models/plain/models/sql/where.py: 29%

211 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2024-12-23 11:16 -0600

1""" 

2Code to manage the creation and SQL rendering of 'where' constraints. 

3""" 

4 

5import operator 

6from functools import reduce 

7 

8from plain.exceptions import EmptyResultSet, FullResultSet 

9from plain.models.expressions import Case, When 

10from plain.models.lookups import Exact 

11from plain.utils import tree 

12from plain.utils.functional import cached_property 

13 

14# Connection types 

15AND = "AND" 

16OR = "OR" 

17XOR = "XOR" 

18 

19 

20class WhereNode(tree.Node): 

21 """ 

22 An SQL WHERE clause. 

23 

24 The class is tied to the Query class that created it (in order to create 

25 the correct SQL). 

26 

27 A child is usually an expression producing boolean values. Most likely the 

28 expression is a Lookup instance. 

29 

30 However, a child could also be any class with as_sql() and either 

31 relabeled_clone() method or relabel_aliases() and clone() methods and 

32 contains_aggregate attribute. 

33 """ 

34 

35 default = AND 

36 resolved = False 

37 conditional = True 

38 

39 def split_having_qualify(self, negated=False, must_group_by=False): 

40 """ 

41 Return three possibly None nodes: one for those parts of self that 

42 should be included in the WHERE clause, one for those parts of self 

43 that must be included in the HAVING clause, and one for those parts 

44 that refer to window functions. 

45 """ 

46 if not self.contains_aggregate and not self.contains_over_clause: 

47 return self, None, None 

48 in_negated = negated ^ self.negated 

49 # Whether or not children must be connected in the same filtering 

50 # clause (WHERE > HAVING > QUALIFY) to maintain logical semantic. 

51 must_remain_connected = ( 

52 (in_negated and self.connector == AND) 

53 or (not in_negated and self.connector == OR) 

54 or self.connector == XOR 

55 ) 

56 if ( 

57 must_remain_connected 

58 and self.contains_aggregate 

59 and not self.contains_over_clause 

60 ): 

61 # It's must cheaper to short-circuit and stash everything in the 

62 # HAVING clause than split children if possible. 

63 return None, self, None 

64 where_parts = [] 

65 having_parts = [] 

66 qualify_parts = [] 

67 for c in self.children: 

68 if hasattr(c, "split_having_qualify"): 

69 where_part, having_part, qualify_part = c.split_having_qualify( 

70 in_negated, must_group_by 

71 ) 

72 if where_part is not None: 

73 where_parts.append(where_part) 

74 if having_part is not None: 

75 having_parts.append(having_part) 

76 if qualify_part is not None: 

77 qualify_parts.append(qualify_part) 

78 elif c.contains_over_clause: 

79 qualify_parts.append(c) 

80 elif c.contains_aggregate: 

81 having_parts.append(c) 

82 else: 

83 where_parts.append(c) 

84 if must_remain_connected and qualify_parts: 

85 # Disjunctive heterogeneous predicates can be pushed down to 

86 # qualify as long as no conditional aggregation is involved. 

87 if not where_parts or (where_parts and not must_group_by): 

88 return None, None, self 

89 elif where_parts: 

90 # In theory this should only be enforced when dealing with 

91 # where_parts containing predicates against multi-valued 

92 # relationships that could affect aggregation results but this 

93 # is complex to infer properly. 

94 raise NotImplementedError( 

95 "Heterogeneous disjunctive predicates against window functions are " 

96 "not implemented when performing conditional aggregation." 

97 ) 

98 where_node = ( 

99 self.create(where_parts, self.connector, self.negated) 

100 if where_parts 

101 else None 

102 ) 

103 having_node = ( 

104 self.create(having_parts, self.connector, self.negated) 

105 if having_parts 

106 else None 

107 ) 

108 qualify_node = ( 

109 self.create(qualify_parts, self.connector, self.negated) 

110 if qualify_parts 

111 else None 

112 ) 

113 return where_node, having_node, qualify_node 

114 

115 def as_sql(self, compiler, connection): 

116 """ 

117 Return the SQL version of the where clause and the value to be 

118 substituted in. Return '', [] if this node matches everything, 

119 None, [] if this node is empty, and raise EmptyResultSet if this 

120 node can't match anything. 

121 """ 

122 result = [] 

123 result_params = [] 

124 if self.connector == AND: 

125 full_needed, empty_needed = len(self.children), 1 

126 else: 

127 full_needed, empty_needed = 1, len(self.children) 

128 

129 if self.connector == XOR and not connection.features.supports_logical_xor: 

130 # Convert if the database doesn't support XOR: 

131 # a XOR b XOR c XOR ... 

132 # to: 

133 # (a OR b OR c OR ...) AND (a + b + c + ...) == 1 

134 lhs = self.__class__(self.children, OR) 

135 rhs_sum = reduce( 

136 operator.add, 

137 (Case(When(c, then=1), default=0) for c in self.children), 

138 ) 

139 rhs = Exact(1, rhs_sum) 

140 return self.__class__([lhs, rhs], AND, self.negated).as_sql( 

141 compiler, connection 

142 ) 

143 

144 for child in self.children: 

145 try: 

146 sql, params = compiler.compile(child) 

147 except EmptyResultSet: 

148 empty_needed -= 1 

149 except FullResultSet: 

150 full_needed -= 1 

151 else: 

152 if sql: 

153 result.append(sql) 

154 result_params.extend(params) 

155 else: 

156 full_needed -= 1 

157 # Check if this node matches nothing or everything. 

158 # First check the amount of full nodes and empty nodes 

159 # to make this node empty/full. 

160 # Now, check if this node is full/empty using the 

161 # counts. 

162 if empty_needed == 0: 

163 if self.negated: 

164 raise FullResultSet 

165 else: 

166 raise EmptyResultSet 

167 if full_needed == 0: 

168 if self.negated: 

169 raise EmptyResultSet 

170 else: 

171 raise FullResultSet 

172 conn = f" {self.connector} " 

173 sql_string = conn.join(result) 

174 if not sql_string: 

175 raise FullResultSet 

176 if self.negated: 

177 # Some backends (Oracle at least) need parentheses around the inner 

178 # SQL in the negated case, even if the inner SQL contains just a 

179 # single expression. 

180 sql_string = f"NOT ({sql_string})" 

181 elif len(result) > 1 or self.resolved: 

182 sql_string = f"({sql_string})" 

183 return sql_string, result_params 

184 

185 def get_group_by_cols(self): 

186 cols = [] 

187 for child in self.children: 

188 cols.extend(child.get_group_by_cols()) 

189 return cols 

190 

191 def get_source_expressions(self): 

192 return self.children[:] 

193 

194 def set_source_expressions(self, children): 

195 assert len(children) == len(self.children) 

196 self.children = children 

197 

198 def relabel_aliases(self, change_map): 

199 """ 

200 Relabel the alias values of any children. 'change_map' is a dictionary 

201 mapping old (current) alias values to the new values. 

202 """ 

203 for pos, child in enumerate(self.children): 

204 if hasattr(child, "relabel_aliases"): 

205 # For example another WhereNode 

206 child.relabel_aliases(change_map) 

207 elif hasattr(child, "relabeled_clone"): 

208 self.children[pos] = child.relabeled_clone(change_map) 

209 

210 def clone(self): 

211 clone = self.create(connector=self.connector, negated=self.negated) 

212 for child in self.children: 

213 if hasattr(child, "clone"): 

214 child = child.clone() 

215 clone.children.append(child) 

216 return clone 

217 

218 def relabeled_clone(self, change_map): 

219 clone = self.clone() 

220 clone.relabel_aliases(change_map) 

221 return clone 

222 

223 def replace_expressions(self, replacements): 

224 if replacement := replacements.get(self): 

225 return replacement 

226 clone = self.create(connector=self.connector, negated=self.negated) 

227 for child in self.children: 

228 clone.children.append(child.replace_expressions(replacements)) 

229 return clone 

230 

231 def get_refs(self): 

232 refs = set() 

233 for child in self.children: 

234 refs |= child.get_refs() 

235 return refs 

236 

237 @classmethod 

238 def _contains_aggregate(cls, obj): 

239 if isinstance(obj, tree.Node): 

240 return any(cls._contains_aggregate(c) for c in obj.children) 

241 return obj.contains_aggregate 

242 

243 @cached_property 

244 def contains_aggregate(self): 

245 return self._contains_aggregate(self) 

246 

247 @classmethod 

248 def _contains_over_clause(cls, obj): 

249 if isinstance(obj, tree.Node): 

250 return any(cls._contains_over_clause(c) for c in obj.children) 

251 return obj.contains_over_clause 

252 

253 @cached_property 

254 def contains_over_clause(self): 

255 return self._contains_over_clause(self) 

256 

257 @property 

258 def is_summary(self): 

259 return any(child.is_summary for child in self.children) 

260 

261 @staticmethod 

262 def _resolve_leaf(expr, query, *args, **kwargs): 

263 if hasattr(expr, "resolve_expression"): 

264 expr = expr.resolve_expression(query, *args, **kwargs) 

265 return expr 

266 

267 @classmethod 

268 def _resolve_node(cls, node, query, *args, **kwargs): 

269 if hasattr(node, "children"): 

270 for child in node.children: 

271 cls._resolve_node(child, query, *args, **kwargs) 

272 if hasattr(node, "lhs"): 

273 node.lhs = cls._resolve_leaf(node.lhs, query, *args, **kwargs) 

274 if hasattr(node, "rhs"): 

275 node.rhs = cls._resolve_leaf(node.rhs, query, *args, **kwargs) 

276 

277 def resolve_expression(self, *args, **kwargs): 

278 clone = self.clone() 

279 clone._resolve_node(clone, *args, **kwargs) 

280 clone.resolved = True 

281 return clone 

282 

283 @cached_property 

284 def output_field(self): 

285 from plain.models.fields import BooleanField 

286 

287 return BooleanField() 

288 

289 @property 

290 def _output_field_or_none(self): 

291 return self.output_field 

292 

293 def select_format(self, compiler, sql, params): 

294 # Wrap filters with a CASE WHEN expression if a database backend 

295 # (e.g. Oracle) doesn't support boolean expression in SELECT or GROUP 

296 # BY list. 

297 if not compiler.connection.features.supports_boolean_expr_in_select_clause: 

298 sql = f"CASE WHEN {sql} THEN 1 ELSE 0 END" 

299 return sql, params 

300 

301 def get_db_converters(self, connection): 

302 return self.output_field.get_db_converters(connection) 

303 

304 def get_lookup(self, lookup): 

305 return self.output_field.get_lookup(lookup) 

306 

307 def leaves(self): 

308 for child in self.children: 

309 if isinstance(child, WhereNode): 

310 yield from child.leaves() 

311 else: 

312 yield child 

313 

314 

315class NothingNode: 

316 """A node that matches nothing.""" 

317 

318 contains_aggregate = False 

319 contains_over_clause = False 

320 

321 def as_sql(self, compiler=None, connection=None): 

322 raise EmptyResultSet 

323 

324 

325class ExtraWhere: 

326 # The contents are a black box - assume no aggregates or windows are used. 

327 contains_aggregate = False 

328 contains_over_clause = False 

329 

330 def __init__(self, sqls, params): 

331 self.sqls = sqls 

332 self.params = params 

333 

334 def as_sql(self, compiler=None, connection=None): 

335 sqls = [f"({sql})" for sql in self.sqls] 

336 return " AND ".join(sqls), list(self.params or ()) 

337 

338 

339class SubqueryConstraint: 

340 # Even if aggregates or windows would be used in a subquery, 

341 # the outer query isn't interested about those. 

342 contains_aggregate = False 

343 contains_over_clause = False 

344 

345 def __init__(self, alias, columns, targets, query_object): 

346 self.alias = alias 

347 self.columns = columns 

348 self.targets = targets 

349 query_object.clear_ordering(clear_default=True) 

350 self.query_object = query_object 

351 

352 def as_sql(self, compiler, connection): 

353 query = self.query_object 

354 query.set_values(self.targets) 

355 query_compiler = query.get_compiler(connection=connection) 

356 return query_compiler.as_subquery_condition(self.alias, self.columns, compiler)