Coverage for /Users/davegaeddert/Developer/dropseed/plain/plain-models/plain/models/query_utils.py: 35%

233 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2024-12-23 11:16 -0600

1""" 

2Various data structures used in query construction. 

3 

4Factored out from plain.models.query to avoid making the main module very 

5large and/or so that they can be used by other modules without getting into 

6circular import difficulties. 

7""" 

8 

9import functools 

10import inspect 

11import logging 

12from collections import namedtuple 

13 

14from plain.exceptions import FieldError 

15from plain.models.constants import LOOKUP_SEP 

16from plain.models.db import DEFAULT_DB_ALIAS, DatabaseError, connections 

17from plain.utils import tree 

18 

19logger = logging.getLogger("plain.models") 

20 

21# PathInfo is used when converting lookups (fk__somecol). The contents 

22# describe the relation in Model terms (model Options and Fields for both 

23# sides of the relation. The join_field is the field backing the relation. 

24PathInfo = namedtuple( 

25 "PathInfo", 

26 "from_opts to_opts target_fields join_field m2m direct filtered_relation", 

27) 

28 

29 

30def subclasses(cls): 

31 yield cls 

32 for subclass in cls.__subclasses__(): 

33 yield from subclasses(subclass) 

34 

35 

36class Q(tree.Node): 

37 """ 

38 Encapsulate filters as objects that can then be combined logically (using 

39 `&` and `|`). 

40 """ 

41 

42 # Connection types 

43 AND = "AND" 

44 OR = "OR" 

45 XOR = "XOR" 

46 default = AND 

47 conditional = True 

48 

49 def __init__(self, *args, _connector=None, _negated=False, **kwargs): 

50 super().__init__( 

51 children=[*args, *sorted(kwargs.items())], 

52 connector=_connector, 

53 negated=_negated, 

54 ) 

55 

56 def _combine(self, other, conn): 

57 if getattr(other, "conditional", False) is False: 

58 raise TypeError(other) 

59 if not self: 

60 return other.copy() 

61 if not other and isinstance(other, Q): 

62 return self.copy() 

63 

64 obj = self.create(connector=conn) 

65 obj.add(self, conn) 

66 obj.add(other, conn) 

67 return obj 

68 

69 def __or__(self, other): 

70 return self._combine(other, self.OR) 

71 

72 def __and__(self, other): 

73 return self._combine(other, self.AND) 

74 

75 def __xor__(self, other): 

76 return self._combine(other, self.XOR) 

77 

78 def __invert__(self): 

79 obj = self.copy() 

80 obj.negate() 

81 return obj 

82 

83 def resolve_expression( 

84 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False 

85 ): 

86 # We must promote any new joins to left outer joins so that when Q is 

87 # used as an expression, rows aren't filtered due to joins. 

88 clause, joins = query._add_q( 

89 self, 

90 reuse, 

91 allow_joins=allow_joins, 

92 split_subq=False, 

93 check_filterable=False, 

94 summarize=summarize, 

95 ) 

96 query.promote_joins(joins) 

97 return clause 

98 

99 def flatten(self): 

100 """ 

101 Recursively yield this Q object and all subexpressions, in depth-first 

102 order. 

103 """ 

104 yield self 

105 for child in self.children: 

106 if isinstance(child, tuple): 

107 # Use the lookup. 

108 child = child[1] 

109 if hasattr(child, "flatten"): 

110 yield from child.flatten() 

111 else: 

112 yield child 

113 

114 def check(self, against, using=DEFAULT_DB_ALIAS): 

115 """ 

116 Do a database query to check if the expressions of the Q instance 

117 matches against the expressions. 

118 """ 

119 # Avoid circular imports. 

120 from plain.models.expressions import Value 

121 from plain.models.fields import BooleanField 

122 from plain.models.functions import Coalesce 

123 from plain.models.sql import Query 

124 from plain.models.sql.constants import SINGLE 

125 

126 query = Query(None) 

127 for name, value in against.items(): 

128 if not hasattr(value, "resolve_expression"): 

129 value = Value(value) 

130 query.add_annotation(value, name, select=False) 

131 query.add_annotation(Value(1), "_check") 

132 # This will raise a FieldError if a field is missing in "against". 

133 if connections[using].features.supports_comparing_boolean_expr: 

134 query.add_q(Q(Coalesce(self, True, output_field=BooleanField()))) 

135 else: 

136 query.add_q(self) 

137 compiler = query.get_compiler(using=using) 

138 try: 

139 return compiler.execute_sql(SINGLE) is not None 

140 except DatabaseError as e: 

141 logger.warning("Got a database error calling check() on %r: %s", self, e) 

142 return True 

143 

144 def deconstruct(self): 

145 path = f"{self.__class__.__module__}.{self.__class__.__name__}" 

146 if path.startswith("plain.models.query_utils"): 

147 path = path.replace("plain.models.query_utils", "plain.models") 

148 args = tuple(self.children) 

149 kwargs = {} 

150 if self.connector != self.default: 

151 kwargs["_connector"] = self.connector 

152 if self.negated: 

153 kwargs["_negated"] = True 

154 return path, args, kwargs 

155 

156 

157class DeferredAttribute: 

158 """ 

159 A wrapper for a deferred-loading field. When the value is read from this 

160 object the first time, the query is executed. 

161 """ 

162 

163 def __init__(self, field): 

164 self.field = field 

165 

166 def __get__(self, instance, cls=None): 

167 """ 

168 Retrieve and caches the value from the datastore on the first lookup. 

169 Return the cached value. 

170 """ 

171 if instance is None: 

172 return self 

173 data = instance.__dict__ 

174 field_name = self.field.attname 

175 if field_name not in data: 

176 # Let's see if the field is part of the parent chain. If so we 

177 # might be able to reuse the already loaded value. Refs #18343. 

178 val = self._check_parent_chain(instance) 

179 if val is None: 

180 instance.refresh_from_db(fields=[field_name]) 

181 else: 

182 data[field_name] = val 

183 return data[field_name] 

184 

185 def _check_parent_chain(self, instance): 

186 """ 

187 Check if the field value can be fetched from a parent field already 

188 loaded in the instance. This can be done if the to-be fetched 

189 field is a primary key field. 

190 """ 

191 opts = instance._meta 

192 link_field = opts.get_ancestor_link(self.field.model) 

193 if self.field.primary_key and self.field != link_field: 

194 return getattr(instance, link_field.attname) 

195 return None 

196 

197 

198class class_or_instance_method: 

199 """ 

200 Hook used in RegisterLookupMixin to return partial functions depending on 

201 the caller type (instance or class of models.Field). 

202 """ 

203 

204 def __init__(self, class_method, instance_method): 

205 self.class_method = class_method 

206 self.instance_method = instance_method 

207 

208 def __get__(self, instance, owner): 

209 if instance is None: 

210 return functools.partial(self.class_method, owner) 

211 return functools.partial(self.instance_method, instance) 

212 

213 

214class RegisterLookupMixin: 

215 def _get_lookup(self, lookup_name): 

216 return self.get_lookups().get(lookup_name, None) 

217 

218 @functools.cache 

219 def get_class_lookups(cls): 

220 class_lookups = [ 

221 parent.__dict__.get("class_lookups", {}) for parent in inspect.getmro(cls) 

222 ] 

223 return cls.merge_dicts(class_lookups) 

224 

225 def get_instance_lookups(self): 

226 class_lookups = self.get_class_lookups() 

227 if instance_lookups := getattr(self, "instance_lookups", None): 

228 return {**class_lookups, **instance_lookups} 

229 return class_lookups 

230 

231 get_lookups = class_or_instance_method(get_class_lookups, get_instance_lookups) 

232 get_class_lookups = classmethod(get_class_lookups) 

233 

234 def get_lookup(self, lookup_name): 

235 from plain.models.lookups import Lookup 

236 

237 found = self._get_lookup(lookup_name) 

238 if found is None and hasattr(self, "output_field"): 

239 return self.output_field.get_lookup(lookup_name) 

240 if found is not None and not issubclass(found, Lookup): 

241 return None 

242 return found 

243 

244 def get_transform(self, lookup_name): 

245 from plain.models.lookups import Transform 

246 

247 found = self._get_lookup(lookup_name) 

248 if found is None and hasattr(self, "output_field"): 

249 return self.output_field.get_transform(lookup_name) 

250 if found is not None and not issubclass(found, Transform): 

251 return None 

252 return found 

253 

254 @staticmethod 

255 def merge_dicts(dicts): 

256 """ 

257 Merge dicts in reverse to preference the order of the original list. e.g., 

258 merge_dicts([a, b]) will preference the keys in 'a' over those in 'b'. 

259 """ 

260 merged = {} 

261 for d in reversed(dicts): 

262 merged.update(d) 

263 return merged 

264 

265 @classmethod 

266 def _clear_cached_class_lookups(cls): 

267 for subclass in subclasses(cls): 

268 subclass.get_class_lookups.cache_clear() 

269 

270 def register_class_lookup(cls, lookup, lookup_name=None): 

271 if lookup_name is None: 

272 lookup_name = lookup.lookup_name 

273 if "class_lookups" not in cls.__dict__: 

274 cls.class_lookups = {} 

275 cls.class_lookups[lookup_name] = lookup 

276 cls._clear_cached_class_lookups() 

277 return lookup 

278 

279 def register_instance_lookup(self, lookup, lookup_name=None): 

280 if lookup_name is None: 

281 lookup_name = lookup.lookup_name 

282 if "instance_lookups" not in self.__dict__: 

283 self.instance_lookups = {} 

284 self.instance_lookups[lookup_name] = lookup 

285 return lookup 

286 

287 register_lookup = class_or_instance_method( 

288 register_class_lookup, register_instance_lookup 

289 ) 

290 register_class_lookup = classmethod(register_class_lookup) 

291 

292 def _unregister_class_lookup(cls, lookup, lookup_name=None): 

293 """ 

294 Remove given lookup from cls lookups. For use in tests only as it's 

295 not thread-safe. 

296 """ 

297 if lookup_name is None: 

298 lookup_name = lookup.lookup_name 

299 del cls.class_lookups[lookup_name] 

300 cls._clear_cached_class_lookups() 

301 

302 def _unregister_instance_lookup(self, lookup, lookup_name=None): 

303 """ 

304 Remove given lookup from instance lookups. For use in tests only as 

305 it's not thread-safe. 

306 """ 

307 if lookup_name is None: 

308 lookup_name = lookup.lookup_name 

309 del self.instance_lookups[lookup_name] 

310 

311 _unregister_lookup = class_or_instance_method( 

312 _unregister_class_lookup, _unregister_instance_lookup 

313 ) 

314 _unregister_class_lookup = classmethod(_unregister_class_lookup) 

315 

316 

317def select_related_descend(field, restricted, requested, select_mask, reverse=False): 

318 """ 

319 Return True if this field should be used to descend deeper for 

320 select_related() purposes. Used by both the query construction code 

321 (compiler.get_related_selections()) and the model instance creation code 

322 (compiler.klass_info). 

323 

324 Arguments: 

325 * field - the field to be checked 

326 * restricted - a boolean field, indicating if the field list has been 

327 manually restricted using a requested clause) 

328 * requested - The select_related() dictionary. 

329 * select_mask - the dictionary of selected fields. 

330 * reverse - boolean, True if we are checking a reverse select related 

331 """ 

332 if not field.remote_field: 

333 return False 

334 if field.remote_field.parent_link and not reverse: 

335 return False 

336 if restricted: 

337 if reverse and field.related_query_name() not in requested: 

338 return False 

339 if not reverse and field.name not in requested: 

340 return False 

341 if not restricted and field.null: 

342 return False 

343 if ( 

344 restricted 

345 and select_mask 

346 and field.name in requested 

347 and field not in select_mask 

348 ): 

349 raise FieldError( 

350 f"Field {field.model._meta.object_name}.{field.name} cannot be both " 

351 "deferred and traversed using select_related at the same time." 

352 ) 

353 return True 

354 

355 

356def refs_expression(lookup_parts, annotations): 

357 """ 

358 Check if the lookup_parts contains references to the given annotations set. 

359 Because the LOOKUP_SEP is contained in the default annotation names, check 

360 each prefix of the lookup_parts for a match. 

361 """ 

362 for n in range(1, len(lookup_parts) + 1): 

363 level_n_lookup = LOOKUP_SEP.join(lookup_parts[0:n]) 

364 if annotations.get(level_n_lookup): 

365 return level_n_lookup, lookup_parts[n:] 

366 return None, () 

367 

368 

369def check_rel_lookup_compatibility(model, target_opts, field): 

370 """ 

371 Check that self.model is compatible with target_opts. Compatibility 

372 is OK if: 

373 1) model and opts match (where proxy inheritance is removed) 

374 2) model is parent of opts' model or the other way around 

375 """ 

376 

377 def check(opts): 

378 return ( 

379 model._meta.concrete_model == opts.concrete_model 

380 or opts.concrete_model in model._meta.get_parent_list() 

381 or model in opts.get_parent_list() 

382 ) 

383 

384 # If the field is a primary key, then doing a query against the field's 

385 # model is ok, too. Consider the case: 

386 # class Restaurant(models.Model): 

387 # place = OneToOneField(Place, primary_key=True): 

388 # Restaurant.objects.filter(pk__in=Restaurant.objects.all()). 

389 # If we didn't have the primary key check, then pk__in (== place__in) would 

390 # give Place's opts as the target opts, but Restaurant isn't compatible 

391 # with that. This logic applies only to primary keys, as when doing __in=qs, 

392 # we are going to turn this into __in=qs.values('pk') later on. 

393 return check(target_opts) or ( 

394 getattr(field, "primary_key", False) and check(field.model._meta) 

395 ) 

396 

397 

398class FilteredRelation: 

399 """Specify custom filtering in the ON clause of SQL joins.""" 

400 

401 def __init__(self, relation_name, *, condition=Q()): 

402 if not relation_name: 

403 raise ValueError("relation_name cannot be empty.") 

404 self.relation_name = relation_name 

405 self.alias = None 

406 if not isinstance(condition, Q): 

407 raise ValueError("condition argument must be a Q() instance.") 

408 self.condition = condition 

409 self.path = [] 

410 

411 def __eq__(self, other): 

412 if not isinstance(other, self.__class__): 

413 return NotImplemented 

414 return ( 

415 self.relation_name == other.relation_name 

416 and self.alias == other.alias 

417 and self.condition == other.condition 

418 ) 

419 

420 def clone(self): 

421 clone = FilteredRelation(self.relation_name, condition=self.condition) 

422 clone.alias = self.alias 

423 clone.path = self.path[:] 

424 return clone 

425 

426 def resolve_expression(self, *args, **kwargs): 

427 """ 

428 QuerySet.annotate() only accepts expression-like arguments 

429 (with a resolve_expression() method). 

430 """ 

431 raise NotImplementedError("FilteredRelation.resolve_expression() is unused.") 

432 

433 def as_sql(self, compiler, connection): 

434 # Resolve the condition in Join.filtered_relation. 

435 query = compiler.query 

436 where = query.build_filtered_relation_q(self.condition, reuse=set(self.path)) 

437 return compiler.compile(where)