Coverage for /Users/davegaeddert/Developer/dropseed/plain/plain-models/plain/models/query.py: 15%
1237 statements
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-23 11:16 -0600
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-23 11:16 -0600
1"""
2The main QuerySet implementation. This provides the public API for the ORM.
3"""
5import copy
6import operator
7import warnings
8from itertools import chain, islice
10import plain.runtime
11from plain import exceptions
12from plain.exceptions import ValidationError
13from plain.models import (
14 sql,
15 transaction,
16)
17from plain.models.constants import LOOKUP_SEP, OnConflict
18from plain.models.db import (
19 PLAIN_VERSION_PICKLE_KEY,
20 IntegrityError,
21 NotSupportedError,
22 connections,
23 router,
24)
25from plain.models.expressions import Case, F, Value, When
26from plain.models.fields import (
27 AutoField,
28 DateField,
29 DateTimeField,
30 Field,
31)
32from plain.models.functions import Cast, Trunc
33from plain.models.query_utils import FilteredRelation, Q
34from plain.models.sql.constants import CURSOR, GET_ITERATOR_CHUNK_SIZE
35from plain.models.utils import (
36 AltersData,
37 create_namedtuple_class,
38 resolve_callables,
39)
40from plain.utils import timezone
41from plain.utils.functional import cached_property, partition
43# The maximum number of results to fetch in a get() query.
44MAX_GET_RESULTS = 21
46# The maximum number of items to display in a QuerySet.__repr__
47REPR_OUTPUT_SIZE = 20
50class BaseIterable:
51 def __init__(
52 self, queryset, chunked_fetch=False, chunk_size=GET_ITERATOR_CHUNK_SIZE
53 ):
54 self.queryset = queryset
55 self.chunked_fetch = chunked_fetch
56 self.chunk_size = chunk_size
59class ModelIterable(BaseIterable):
60 """Iterable that yields a model instance for each row."""
62 def __iter__(self):
63 queryset = self.queryset
64 db = queryset.db
65 compiler = queryset.query.get_compiler(using=db)
66 # Execute the query. This will also fill compiler.select, klass_info,
67 # and annotations.
68 results = compiler.execute_sql(
69 chunked_fetch=self.chunked_fetch, chunk_size=self.chunk_size
70 )
71 select, klass_info, annotation_col_map = (
72 compiler.select,
73 compiler.klass_info,
74 compiler.annotation_col_map,
75 )
76 model_cls = klass_info["model"]
77 select_fields = klass_info["select_fields"]
78 model_fields_start, model_fields_end = select_fields[0], select_fields[-1] + 1
79 init_list = [
80 f[0].target.attname for f in select[model_fields_start:model_fields_end]
81 ]
82 related_populators = get_related_populators(klass_info, select, db)
83 known_related_objects = [
84 (
85 field,
86 related_objs,
87 operator.attrgetter(
88 *[
89 field.attname
90 if from_field == "self"
91 else queryset.model._meta.get_field(from_field).attname
92 for from_field in field.from_fields
93 ]
94 ),
95 )
96 for field, related_objs in queryset._known_related_objects.items()
97 ]
98 for row in compiler.results_iter(results):
99 obj = model_cls.from_db(
100 db, init_list, row[model_fields_start:model_fields_end]
101 )
102 for rel_populator in related_populators:
103 rel_populator.populate(row, obj)
104 if annotation_col_map:
105 for attr_name, col_pos in annotation_col_map.items():
106 setattr(obj, attr_name, row[col_pos])
108 # Add the known related objects to the model.
109 for field, rel_objs, rel_getter in known_related_objects:
110 # Avoid overwriting objects loaded by, e.g., select_related().
111 if field.is_cached(obj):
112 continue
113 rel_obj_id = rel_getter(obj)
114 try:
115 rel_obj = rel_objs[rel_obj_id]
116 except KeyError:
117 pass # May happen in qs1 | qs2 scenarios.
118 else:
119 setattr(obj, field.name, rel_obj)
121 yield obj
124class RawModelIterable(BaseIterable):
125 """
126 Iterable that yields a model instance for each row from a raw queryset.
127 """
129 def __iter__(self):
130 # Cache some things for performance reasons outside the loop.
131 db = self.queryset.db
132 query = self.queryset.query
133 connection = connections[db]
134 compiler = connection.ops.compiler("SQLCompiler")(query, connection, db)
135 query_iterator = iter(query)
137 try:
138 (
139 model_init_names,
140 model_init_pos,
141 annotation_fields,
142 ) = self.queryset.resolve_model_init_order()
143 model_cls = self.queryset.model
144 if model_cls._meta.pk.attname not in model_init_names:
145 raise exceptions.FieldDoesNotExist(
146 "Raw query must include the primary key"
147 )
148 fields = [self.queryset.model_fields.get(c) for c in self.queryset.columns]
149 converters = compiler.get_converters(
150 [f.get_col(f.model._meta.db_table) if f else None for f in fields]
151 )
152 if converters:
153 query_iterator = compiler.apply_converters(query_iterator, converters)
154 for values in query_iterator:
155 # Associate fields to values
156 model_init_values = [values[pos] for pos in model_init_pos]
157 instance = model_cls.from_db(db, model_init_names, model_init_values)
158 if annotation_fields:
159 for column, pos in annotation_fields:
160 setattr(instance, column, values[pos])
161 yield instance
162 finally:
163 # Done iterating the Query. If it has its own cursor, close it.
164 if hasattr(query, "cursor") and query.cursor:
165 query.cursor.close()
168class ValuesIterable(BaseIterable):
169 """
170 Iterable returned by QuerySet.values() that yields a dict for each row.
171 """
173 def __iter__(self):
174 queryset = self.queryset
175 query = queryset.query
176 compiler = query.get_compiler(queryset.db)
178 # extra(select=...) cols are always at the start of the row.
179 names = [
180 *query.extra_select,
181 *query.values_select,
182 *query.annotation_select,
183 ]
184 indexes = range(len(names))
185 for row in compiler.results_iter(
186 chunked_fetch=self.chunked_fetch, chunk_size=self.chunk_size
187 ):
188 yield {names[i]: row[i] for i in indexes}
191class ValuesListIterable(BaseIterable):
192 """
193 Iterable returned by QuerySet.values_list(flat=False) that yields a tuple
194 for each row.
195 """
197 def __iter__(self):
198 queryset = self.queryset
199 query = queryset.query
200 compiler = query.get_compiler(queryset.db)
202 if queryset._fields:
203 # extra(select=...) cols are always at the start of the row.
204 names = [
205 *query.extra_select,
206 *query.values_select,
207 *query.annotation_select,
208 ]
209 fields = [
210 *queryset._fields,
211 *(f for f in query.annotation_select if f not in queryset._fields),
212 ]
213 if fields != names:
214 # Reorder according to fields.
215 index_map = {name: idx for idx, name in enumerate(names)}
216 rowfactory = operator.itemgetter(*[index_map[f] for f in fields])
217 return map(
218 rowfactory,
219 compiler.results_iter(
220 chunked_fetch=self.chunked_fetch, chunk_size=self.chunk_size
221 ),
222 )
223 return compiler.results_iter(
224 tuple_expected=True,
225 chunked_fetch=self.chunked_fetch,
226 chunk_size=self.chunk_size,
227 )
230class NamedValuesListIterable(ValuesListIterable):
231 """
232 Iterable returned by QuerySet.values_list(named=True) that yields a
233 namedtuple for each row.
234 """
236 def __iter__(self):
237 queryset = self.queryset
238 if queryset._fields:
239 names = queryset._fields
240 else:
241 query = queryset.query
242 names = [
243 *query.extra_select,
244 *query.values_select,
245 *query.annotation_select,
246 ]
247 tuple_class = create_namedtuple_class(*names)
248 new = tuple.__new__
249 for row in super().__iter__():
250 yield new(tuple_class, row)
253class FlatValuesListIterable(BaseIterable):
254 """
255 Iterable returned by QuerySet.values_list(flat=True) that yields single
256 values.
257 """
259 def __iter__(self):
260 queryset = self.queryset
261 compiler = queryset.query.get_compiler(queryset.db)
262 for row in compiler.results_iter(
263 chunked_fetch=self.chunked_fetch, chunk_size=self.chunk_size
264 ):
265 yield row[0]
268class QuerySet(AltersData):
269 """Represent a lazy database lookup for a set of objects."""
271 def __init__(self, model=None, query=None, using=None, hints=None):
272 self.model = model
273 self._db = using
274 self._hints = hints or {}
275 self._query = query or sql.Query(self.model)
276 self._result_cache = None
277 self._sticky_filter = False
278 self._for_write = False
279 self._prefetch_related_lookups = ()
280 self._prefetch_done = False
281 self._known_related_objects = {} # {rel_field: {pk: rel_obj}}
282 self._iterable_class = ModelIterable
283 self._fields = None
284 self._defer_next_filter = False
285 self._deferred_filter = None
287 @property
288 def query(self):
289 if self._deferred_filter:
290 negate, args, kwargs = self._deferred_filter
291 self._filter_or_exclude_inplace(negate, args, kwargs)
292 self._deferred_filter = None
293 return self._query
295 @query.setter
296 def query(self, value):
297 if value.values_select:
298 self._iterable_class = ValuesIterable
299 self._query = value
301 def as_manager(cls):
302 # Address the circular dependency between `Queryset` and `Manager`.
303 from plain.models.manager import Manager
305 manager = Manager.from_queryset(cls)()
306 manager._built_with_as_manager = True
307 return manager
309 as_manager.queryset_only = True
310 as_manager = classmethod(as_manager)
312 ########################
313 # PYTHON MAGIC METHODS #
314 ########################
316 def __deepcopy__(self, memo):
317 """Don't populate the QuerySet's cache."""
318 obj = self.__class__()
319 for k, v in self.__dict__.items():
320 if k == "_result_cache":
321 obj.__dict__[k] = None
322 else:
323 obj.__dict__[k] = copy.deepcopy(v, memo)
324 return obj
326 def __getstate__(self):
327 # Force the cache to be fully populated.
328 self._fetch_all()
329 return {**self.__dict__, PLAIN_VERSION_PICKLE_KEY: plain.runtime.__version__}
331 def __setstate__(self, state):
332 pickled_version = state.get(PLAIN_VERSION_PICKLE_KEY)
333 if pickled_version:
334 if pickled_version != plain.runtime.__version__:
335 warnings.warn(
336 f"Pickled queryset instance's Plain version {pickled_version} does not "
337 f"match the current version {plain.runtime.__version__}.",
338 RuntimeWarning,
339 stacklevel=2,
340 )
341 else:
342 warnings.warn(
343 "Pickled queryset instance's Plain version is not specified.",
344 RuntimeWarning,
345 stacklevel=2,
346 )
347 self.__dict__.update(state)
349 def __repr__(self):
350 data = list(self[: REPR_OUTPUT_SIZE + 1])
351 if len(data) > REPR_OUTPUT_SIZE:
352 data[-1] = "...(remaining elements truncated)..."
353 return f"<{self.__class__.__name__} {data!r}>"
355 def __len__(self):
356 self._fetch_all()
357 return len(self._result_cache)
359 def __iter__(self):
360 """
361 The queryset iterator protocol uses three nested iterators in the
362 default case:
363 1. sql.compiler.execute_sql()
364 - Returns 100 rows at time (constants.GET_ITERATOR_CHUNK_SIZE)
365 using cursor.fetchmany(). This part is responsible for
366 doing some column masking, and returning the rows in chunks.
367 2. sql.compiler.results_iter()
368 - Returns one row at time. At this point the rows are still just
369 tuples. In some cases the return values are converted to
370 Python values at this location.
371 3. self.iterator()
372 - Responsible for turning the rows into model objects.
373 """
374 self._fetch_all()
375 return iter(self._result_cache)
377 def __bool__(self):
378 self._fetch_all()
379 return bool(self._result_cache)
381 def __getitem__(self, k):
382 """Retrieve an item or slice from the set of results."""
383 if not isinstance(k, int | slice):
384 raise TypeError(
385 f"QuerySet indices must be integers or slices, not {type(k).__name__}."
386 )
387 if (isinstance(k, int) and k < 0) or (
388 isinstance(k, slice)
389 and (
390 (k.start is not None and k.start < 0)
391 or (k.stop is not None and k.stop < 0)
392 )
393 ):
394 raise ValueError("Negative indexing is not supported.")
396 if self._result_cache is not None:
397 return self._result_cache[k]
399 if isinstance(k, slice):
400 qs = self._chain()
401 if k.start is not None:
402 start = int(k.start)
403 else:
404 start = None
405 if k.stop is not None:
406 stop = int(k.stop)
407 else:
408 stop = None
409 qs.query.set_limits(start, stop)
410 return list(qs)[:: k.step] if k.step else qs
412 qs = self._chain()
413 qs.query.set_limits(k, k + 1)
414 qs._fetch_all()
415 return qs._result_cache[0]
417 def __class_getitem__(cls, *args, **kwargs):
418 return cls
420 def __and__(self, other):
421 self._check_operator_queryset(other, "&")
422 self._merge_sanity_check(other)
423 if isinstance(other, EmptyQuerySet):
424 return other
425 if isinstance(self, EmptyQuerySet):
426 return self
427 combined = self._chain()
428 combined._merge_known_related_objects(other)
429 combined.query.combine(other.query, sql.AND)
430 return combined
432 def __or__(self, other):
433 self._check_operator_queryset(other, "|")
434 self._merge_sanity_check(other)
435 if isinstance(self, EmptyQuerySet):
436 return other
437 if isinstance(other, EmptyQuerySet):
438 return self
439 query = (
440 self
441 if self.query.can_filter()
442 else self.model._base_manager.filter(pk__in=self.values("pk"))
443 )
444 combined = query._chain()
445 combined._merge_known_related_objects(other)
446 if not other.query.can_filter():
447 other = other.model._base_manager.filter(pk__in=other.values("pk"))
448 combined.query.combine(other.query, sql.OR)
449 return combined
451 def __xor__(self, other):
452 self._check_operator_queryset(other, "^")
453 self._merge_sanity_check(other)
454 if isinstance(self, EmptyQuerySet):
455 return other
456 if isinstance(other, EmptyQuerySet):
457 return self
458 query = (
459 self
460 if self.query.can_filter()
461 else self.model._base_manager.filter(pk__in=self.values("pk"))
462 )
463 combined = query._chain()
464 combined._merge_known_related_objects(other)
465 if not other.query.can_filter():
466 other = other.model._base_manager.filter(pk__in=other.values("pk"))
467 combined.query.combine(other.query, sql.XOR)
468 return combined
470 ####################################
471 # METHODS THAT DO DATABASE QUERIES #
472 ####################################
474 def _iterator(self, use_chunked_fetch, chunk_size):
475 iterable = self._iterable_class(
476 self,
477 chunked_fetch=use_chunked_fetch,
478 chunk_size=chunk_size or 2000,
479 )
480 if not self._prefetch_related_lookups or chunk_size is None:
481 yield from iterable
482 return
484 iterator = iter(iterable)
485 while results := list(islice(iterator, chunk_size)):
486 prefetch_related_objects(results, *self._prefetch_related_lookups)
487 yield from results
489 def iterator(self, chunk_size=None):
490 """
491 An iterator over the results from applying this QuerySet to the
492 database. chunk_size must be provided for QuerySets that prefetch
493 related objects. Otherwise, a default chunk_size of 2000 is supplied.
494 """
495 if chunk_size is None:
496 if self._prefetch_related_lookups:
497 raise ValueError(
498 "chunk_size must be provided when using QuerySet.iterator() after "
499 "prefetch_related()."
500 )
501 elif chunk_size <= 0:
502 raise ValueError("Chunk size must be strictly positive.")
503 use_chunked_fetch = not connections[self.db].settings_dict.get(
504 "DISABLE_SERVER_SIDE_CURSORS"
505 )
506 return self._iterator(use_chunked_fetch, chunk_size)
508 def aggregate(self, *args, **kwargs):
509 """
510 Return a dictionary containing the calculations (aggregation)
511 over the current queryset.
513 If args is present the expression is passed as a kwarg using
514 the Aggregate object's default alias.
515 """
516 if self.query.distinct_fields:
517 raise NotImplementedError("aggregate() + distinct(fields) not implemented.")
518 self._validate_values_are_expressions(
519 (*args, *kwargs.values()), method_name="aggregate"
520 )
521 for arg in args:
522 # The default_alias property raises TypeError if default_alias
523 # can't be set automatically or AttributeError if it isn't an
524 # attribute.
525 try:
526 arg.default_alias
527 except (AttributeError, TypeError):
528 raise TypeError("Complex aggregates require an alias")
529 kwargs[arg.default_alias] = arg
531 return self.query.chain().get_aggregation(self.db, kwargs)
533 def count(self):
534 """
535 Perform a SELECT COUNT() and return the number of records as an
536 integer.
538 If the QuerySet is already fully cached, return the length of the
539 cached results set to avoid multiple SELECT COUNT(*) calls.
540 """
541 if self._result_cache is not None:
542 return len(self._result_cache)
544 return self.query.get_count(using=self.db)
546 def get(self, *args, **kwargs):
547 """
548 Perform the query and return a single object matching the given
549 keyword arguments.
550 """
551 if self.query.combinator and (args or kwargs):
552 raise NotSupportedError(
553 f"Calling QuerySet.get(...) with filters after {self.query.combinator}() is not "
554 "supported."
555 )
556 clone = self._chain() if self.query.combinator else self.filter(*args, **kwargs)
557 if self.query.can_filter() and not self.query.distinct_fields:
558 clone = clone.order_by()
559 limit = None
560 if (
561 not clone.query.select_for_update
562 or connections[clone.db].features.supports_select_for_update_with_limit
563 ):
564 limit = MAX_GET_RESULTS
565 clone.query.set_limits(high=limit)
566 num = len(clone)
567 if num == 1:
568 return clone._result_cache[0]
569 if not num:
570 raise self.model.DoesNotExist(
571 f"{self.model._meta.object_name} matching query does not exist."
572 )
573 raise self.model.MultipleObjectsReturned(
574 "get() returned more than one {} -- it returned {}!".format(
575 self.model._meta.object_name,
576 num if not limit or num < limit else "more than %s" % (limit - 1),
577 )
578 )
580 def create(self, **kwargs):
581 """
582 Create a new object with the given kwargs, saving it to the database
583 and returning the created object.
584 """
585 obj = self.model(**kwargs)
586 self._for_write = True
587 obj.save(force_insert=True, using=self.db)
588 return obj
590 def _prepare_for_bulk_create(self, objs):
591 for obj in objs:
592 if obj.pk is None:
593 # Populate new PK values.
594 obj.pk = obj._meta.pk.get_pk_value_on_save(obj)
595 obj._prepare_related_fields_for_save(operation_name="bulk_create")
597 def _check_bulk_create_options(
598 self, ignore_conflicts, update_conflicts, update_fields, unique_fields
599 ):
600 if ignore_conflicts and update_conflicts:
601 raise ValueError(
602 "ignore_conflicts and update_conflicts are mutually exclusive."
603 )
604 db_features = connections[self.db].features
605 if ignore_conflicts:
606 if not db_features.supports_ignore_conflicts:
607 raise NotSupportedError(
608 "This database backend does not support ignoring conflicts."
609 )
610 return OnConflict.IGNORE
611 elif update_conflicts:
612 if not db_features.supports_update_conflicts:
613 raise NotSupportedError(
614 "This database backend does not support updating conflicts."
615 )
616 if not update_fields:
617 raise ValueError(
618 "Fields that will be updated when a row insertion fails "
619 "on conflicts must be provided."
620 )
621 if unique_fields and not db_features.supports_update_conflicts_with_target:
622 raise NotSupportedError(
623 "This database backend does not support updating "
624 "conflicts with specifying unique fields that can trigger "
625 "the upsert."
626 )
627 if not unique_fields and db_features.supports_update_conflicts_with_target:
628 raise ValueError(
629 "Unique fields that can trigger the upsert must be provided."
630 )
631 # Updating primary keys and non-concrete fields is forbidden.
632 if any(not f.concrete or f.many_to_many for f in update_fields):
633 raise ValueError(
634 "bulk_create() can only be used with concrete fields in "
635 "update_fields."
636 )
637 if any(f.primary_key for f in update_fields):
638 raise ValueError(
639 "bulk_create() cannot be used with primary keys in "
640 "update_fields."
641 )
642 if unique_fields:
643 if any(not f.concrete or f.many_to_many for f in unique_fields):
644 raise ValueError(
645 "bulk_create() can only be used with concrete fields "
646 "in unique_fields."
647 )
648 return OnConflict.UPDATE
649 return None
651 def bulk_create(
652 self,
653 objs,
654 batch_size=None,
655 ignore_conflicts=False,
656 update_conflicts=False,
657 update_fields=None,
658 unique_fields=None,
659 ):
660 """
661 Insert each of the instances into the database. Do *not* call
662 save() on each of the instances, do not send any pre/post_save
663 signals, and do not set the primary key attribute if it is an
664 autoincrement field (except if features.can_return_rows_from_bulk_insert=True).
665 Multi-table models are not supported.
666 """
667 # When you bulk insert you don't get the primary keys back (if it's an
668 # autoincrement, except if can_return_rows_from_bulk_insert=True), so
669 # you can't insert into the child tables which references this. There
670 # are two workarounds:
671 # 1) This could be implemented if you didn't have an autoincrement pk
672 # 2) You could do it by doing O(n) normal inserts into the parent
673 # tables to get the primary keys back and then doing a single bulk
674 # insert into the childmost table.
675 # We currently set the primary keys on the objects when using
676 # PostgreSQL via the RETURNING ID clause. It should be possible for
677 # Oracle as well, but the semantics for extracting the primary keys is
678 # trickier so it's not done yet.
679 if batch_size is not None and batch_size <= 0:
680 raise ValueError("Batch size must be a positive integer.")
681 # Check that the parents share the same concrete model with the our
682 # model to detect the inheritance pattern ConcreteGrandParent ->
683 # MultiTableParent -> ProxyChild. Simply checking self.model._meta.proxy
684 # would not identify that case as involving multiple tables.
685 for parent in self.model._meta.get_parent_list():
686 if parent._meta.concrete_model is not self.model._meta.concrete_model:
687 raise ValueError("Can't bulk create a multi-table inherited model")
688 if not objs:
689 return objs
690 opts = self.model._meta
691 if unique_fields:
692 # Primary key is allowed in unique_fields.
693 unique_fields = [
694 self.model._meta.get_field(opts.pk.name if name == "pk" else name)
695 for name in unique_fields
696 ]
697 if update_fields:
698 update_fields = [self.model._meta.get_field(name) for name in update_fields]
699 on_conflict = self._check_bulk_create_options(
700 ignore_conflicts,
701 update_conflicts,
702 update_fields,
703 unique_fields,
704 )
705 self._for_write = True
706 fields = opts.concrete_fields
707 objs = list(objs)
708 self._prepare_for_bulk_create(objs)
709 with transaction.atomic(using=self.db, savepoint=False):
710 objs_with_pk, objs_without_pk = partition(lambda o: o.pk is None, objs)
711 if objs_with_pk:
712 returned_columns = self._batched_insert(
713 objs_with_pk,
714 fields,
715 batch_size,
716 on_conflict=on_conflict,
717 update_fields=update_fields,
718 unique_fields=unique_fields,
719 )
720 for obj_with_pk, results in zip(objs_with_pk, returned_columns):
721 for result, field in zip(results, opts.db_returning_fields):
722 if field != opts.pk:
723 setattr(obj_with_pk, field.attname, result)
724 for obj_with_pk in objs_with_pk:
725 obj_with_pk._state.adding = False
726 obj_with_pk._state.db = self.db
727 if objs_without_pk:
728 fields = [f for f in fields if not isinstance(f, AutoField)]
729 returned_columns = self._batched_insert(
730 objs_without_pk,
731 fields,
732 batch_size,
733 on_conflict=on_conflict,
734 update_fields=update_fields,
735 unique_fields=unique_fields,
736 )
737 connection = connections[self.db]
738 if (
739 connection.features.can_return_rows_from_bulk_insert
740 and on_conflict is None
741 ):
742 assert len(returned_columns) == len(objs_without_pk)
743 for obj_without_pk, results in zip(objs_without_pk, returned_columns):
744 for result, field in zip(results, opts.db_returning_fields):
745 setattr(obj_without_pk, field.attname, result)
746 obj_without_pk._state.adding = False
747 obj_without_pk._state.db = self.db
749 return objs
751 def bulk_update(self, objs, fields, batch_size=None):
752 """
753 Update the given fields in each of the given objects in the database.
754 """
755 if batch_size is not None and batch_size <= 0:
756 raise ValueError("Batch size must be a positive integer.")
757 if not fields:
758 raise ValueError("Field names must be given to bulk_update().")
759 objs = tuple(objs)
760 if any(obj.pk is None for obj in objs):
761 raise ValueError("All bulk_update() objects must have a primary key set.")
762 fields = [self.model._meta.get_field(name) for name in fields]
763 if any(not f.concrete or f.many_to_many for f in fields):
764 raise ValueError("bulk_update() can only be used with concrete fields.")
765 if any(f.primary_key for f in fields):
766 raise ValueError("bulk_update() cannot be used with primary key fields.")
767 if not objs:
768 return 0
769 for obj in objs:
770 obj._prepare_related_fields_for_save(
771 operation_name="bulk_update", fields=fields
772 )
773 # PK is used twice in the resulting update query, once in the filter
774 # and once in the WHEN. Each field will also have one CAST.
775 self._for_write = True
776 connection = connections[self.db]
777 max_batch_size = connection.ops.bulk_batch_size(["pk", "pk"] + fields, objs)
778 batch_size = min(batch_size, max_batch_size) if batch_size else max_batch_size
779 requires_casting = connection.features.requires_casted_case_in_updates
780 batches = (objs[i : i + batch_size] for i in range(0, len(objs), batch_size))
781 updates = []
782 for batch_objs in batches:
783 update_kwargs = {}
784 for field in fields:
785 when_statements = []
786 for obj in batch_objs:
787 attr = getattr(obj, field.attname)
788 if not hasattr(attr, "resolve_expression"):
789 attr = Value(attr, output_field=field)
790 when_statements.append(When(pk=obj.pk, then=attr))
791 case_statement = Case(*when_statements, output_field=field)
792 if requires_casting:
793 case_statement = Cast(case_statement, output_field=field)
794 update_kwargs[field.attname] = case_statement
795 updates.append(([obj.pk for obj in batch_objs], update_kwargs))
796 rows_updated = 0
797 queryset = self.using(self.db)
798 with transaction.atomic(using=self.db, savepoint=False):
799 for pks, update_kwargs in updates:
800 rows_updated += queryset.filter(pk__in=pks).update(**update_kwargs)
801 return rows_updated
803 bulk_update.alters_data = True
805 def get_or_create(self, defaults=None, **kwargs):
806 """
807 Look up an object with the given kwargs, creating one if necessary.
808 Return a tuple of (object, created), where created is a boolean
809 specifying whether an object was created.
810 """
811 # The get() needs to be targeted at the write database in order
812 # to avoid potential transaction consistency problems.
813 self._for_write = True
814 try:
815 return self.get(**kwargs), False
816 except self.model.DoesNotExist:
817 params = self._extract_model_params(defaults, **kwargs)
818 # Try to create an object using passed params.
819 try:
820 with transaction.atomic(using=self.db):
821 params = dict(resolve_callables(params))
822 return self.create(**params), True
823 except (IntegrityError, ValidationError):
824 # Since create() also validates by default,
825 # we can get any kind of ValidationError here,
826 # or it can flow through and get an IntegrityError from the database.
827 # The main thing we're concerned about is uniqueness failures,
828 # but ValidationError could include other things too.
829 # In all cases though it should be fine to try the get() again
830 # and return an existing object.
831 try:
832 return self.get(**kwargs), False
833 except self.model.DoesNotExist:
834 pass
835 raise
837 def update_or_create(self, defaults=None, create_defaults=None, **kwargs):
838 """
839 Look up an object with the given kwargs, updating one with defaults
840 if it exists, otherwise create a new one. Optionally, an object can
841 be created with different values than defaults by using
842 create_defaults.
843 Return a tuple (object, created), where created is a boolean
844 specifying whether an object was created.
845 """
846 if create_defaults is None:
847 update_defaults = create_defaults = defaults or {}
848 else:
849 update_defaults = defaults or {}
850 self._for_write = True
851 with transaction.atomic(using=self.db):
852 # Lock the row so that a concurrent update is blocked until
853 # update_or_create() has performed its save.
854 obj, created = self.select_for_update().get_or_create(
855 create_defaults, **kwargs
856 )
857 if created:
858 return obj, created
859 for k, v in resolve_callables(update_defaults):
860 setattr(obj, k, v)
862 update_fields = set(update_defaults)
863 concrete_field_names = self.model._meta._non_pk_concrete_field_names
864 # update_fields does not support non-concrete fields.
865 if concrete_field_names.issuperset(update_fields):
866 # Add fields which are set on pre_save(), e.g. auto_now fields.
867 # This is to maintain backward compatibility as these fields
868 # are not updated unless explicitly specified in the
869 # update_fields list.
870 for field in self.model._meta.local_concrete_fields:
871 if not (
872 field.primary_key or field.__class__.pre_save is Field.pre_save
873 ):
874 update_fields.add(field.name)
875 if field.name != field.attname:
876 update_fields.add(field.attname)
877 obj.save(using=self.db, update_fields=update_fields)
878 else:
879 obj.save(using=self.db)
880 return obj, False
882 def _extract_model_params(self, defaults, **kwargs):
883 """
884 Prepare `params` for creating a model instance based on the given
885 kwargs; for use by get_or_create().
886 """
887 defaults = defaults or {}
888 params = {k: v for k, v in kwargs.items() if LOOKUP_SEP not in k}
889 params.update(defaults)
890 property_names = self.model._meta._property_names
891 invalid_params = []
892 for param in params:
893 try:
894 self.model._meta.get_field(param)
895 except exceptions.FieldDoesNotExist:
896 # It's okay to use a model's property if it has a setter.
897 if not (param in property_names and getattr(self.model, param).fset):
898 invalid_params.append(param)
899 if invalid_params:
900 raise exceptions.FieldError(
901 "Invalid field name(s) for model {}: '{}'.".format(
902 self.model._meta.object_name,
903 "', '".join(sorted(invalid_params)),
904 )
905 )
906 return params
908 def _earliest(self, *fields):
909 """
910 Return the earliest object according to fields (if given) or by the
911 model's Meta.get_latest_by.
912 """
913 if fields:
914 order_by = fields
915 else:
916 order_by = getattr(self.model._meta, "get_latest_by")
917 if order_by and not isinstance(order_by, tuple | list):
918 order_by = (order_by,)
919 if order_by is None:
920 raise ValueError(
921 "earliest() and latest() require either fields as positional "
922 "arguments or 'get_latest_by' in the model's Meta."
923 )
924 obj = self._chain()
925 obj.query.set_limits(high=1)
926 obj.query.clear_ordering(force=True)
927 obj.query.add_ordering(*order_by)
928 return obj.get()
930 def earliest(self, *fields):
931 if self.query.is_sliced:
932 raise TypeError("Cannot change a query once a slice has been taken.")
933 return self._earliest(*fields)
935 def latest(self, *fields):
936 """
937 Return the latest object according to fields (if given) or by the
938 model's Meta.get_latest_by.
939 """
940 if self.query.is_sliced:
941 raise TypeError("Cannot change a query once a slice has been taken.")
942 return self.reverse()._earliest(*fields)
944 def first(self):
945 """Return the first object of a query or None if no match is found."""
946 if self.ordered:
947 queryset = self
948 else:
949 self._check_ordering_first_last_queryset_aggregation(method="first")
950 queryset = self.order_by("pk")
951 for obj in queryset[:1]:
952 return obj
954 def last(self):
955 """Return the last object of a query or None if no match is found."""
956 if self.ordered:
957 queryset = self.reverse()
958 else:
959 self._check_ordering_first_last_queryset_aggregation(method="last")
960 queryset = self.order_by("-pk")
961 for obj in queryset[:1]:
962 return obj
964 def in_bulk(self, id_list=None, *, field_name="pk"):
965 """
966 Return a dictionary mapping each of the given IDs to the object with
967 that ID. If `id_list` isn't provided, evaluate the entire QuerySet.
968 """
969 if self.query.is_sliced:
970 raise TypeError("Cannot use 'limit' or 'offset' with in_bulk().")
971 opts = self.model._meta
972 unique_fields = [
973 constraint.fields[0]
974 for constraint in opts.total_unique_constraints
975 if len(constraint.fields) == 1
976 ]
977 if (
978 field_name != "pk"
979 and not opts.get_field(field_name).unique
980 and field_name not in unique_fields
981 and self.query.distinct_fields != (field_name,)
982 ):
983 raise ValueError(
984 f"in_bulk()'s field_name must be a unique field but {field_name!r} isn't."
985 )
986 if id_list is not None:
987 if not id_list:
988 return {}
989 filter_key = f"{field_name}__in"
990 batch_size = connections[self.db].features.max_query_params
991 id_list = tuple(id_list)
992 # If the database has a limit on the number of query parameters
993 # (e.g. SQLite), retrieve objects in batches if necessary.
994 if batch_size and batch_size < len(id_list):
995 qs = ()
996 for offset in range(0, len(id_list), batch_size):
997 batch = id_list[offset : offset + batch_size]
998 qs += tuple(self.filter(**{filter_key: batch}))
999 else:
1000 qs = self.filter(**{filter_key: id_list})
1001 else:
1002 qs = self._chain()
1003 return {getattr(obj, field_name): obj for obj in qs}
1005 def delete(self):
1006 """Delete the records in the current QuerySet."""
1007 self._not_support_combined_queries("delete")
1008 if self.query.is_sliced:
1009 raise TypeError("Cannot use 'limit' or 'offset' with delete().")
1010 if self.query.distinct or self.query.distinct_fields:
1011 raise TypeError("Cannot call delete() after .distinct().")
1012 if self._fields is not None:
1013 raise TypeError("Cannot call delete() after .values() or .values_list()")
1015 del_query = self._chain()
1017 # The delete is actually 2 queries - one to find related objects,
1018 # and one to delete. Make sure that the discovery of related
1019 # objects is performed on the same database as the deletion.
1020 del_query._for_write = True
1022 # Disable non-supported fields.
1023 del_query.query.select_for_update = False
1024 del_query.query.select_related = False
1025 del_query.query.clear_ordering(force=True)
1027 from plain.models.deletion import Collector
1029 collector = Collector(using=del_query.db, origin=self)
1030 collector.collect(del_query)
1031 deleted, _rows_count = collector.delete()
1033 # Clear the result cache, in case this QuerySet gets reused.
1034 self._result_cache = None
1035 return deleted, _rows_count
1037 delete.alters_data = True
1038 delete.queryset_only = True
1040 def _raw_delete(self, using):
1041 """
1042 Delete objects found from the given queryset in single direct SQL
1043 query. No signals are sent and there is no protection for cascades.
1044 """
1045 query = self.query.clone()
1046 query.__class__ = sql.DeleteQuery
1047 cursor = query.get_compiler(using).execute_sql(CURSOR)
1048 if cursor:
1049 with cursor:
1050 return cursor.rowcount
1051 return 0
1053 _raw_delete.alters_data = True
1055 def update(self, **kwargs):
1056 """
1057 Update all elements in the current QuerySet, setting all the given
1058 fields to the appropriate values.
1059 """
1060 self._not_support_combined_queries("update")
1061 if self.query.is_sliced:
1062 raise TypeError("Cannot update a query once a slice has been taken.")
1063 self._for_write = True
1064 query = self.query.chain(sql.UpdateQuery)
1065 query.add_update_values(kwargs)
1067 # Inline annotations in order_by(), if possible.
1068 new_order_by = []
1069 for col in query.order_by:
1070 alias = col
1071 descending = False
1072 if isinstance(alias, str) and alias.startswith("-"):
1073 alias = alias.removeprefix("-")
1074 descending = True
1075 if annotation := query.annotations.get(alias):
1076 if getattr(annotation, "contains_aggregate", False):
1077 raise exceptions.FieldError(
1078 f"Cannot update when ordering by an aggregate: {annotation}"
1079 )
1080 if descending:
1081 annotation = annotation.desc()
1082 new_order_by.append(annotation)
1083 else:
1084 new_order_by.append(col)
1085 query.order_by = tuple(new_order_by)
1087 # Clear any annotations so that they won't be present in subqueries.
1088 query.annotations = {}
1089 with transaction.mark_for_rollback_on_error(using=self.db):
1090 rows = query.get_compiler(self.db).execute_sql(CURSOR)
1091 self._result_cache = None
1092 return rows
1094 update.alters_data = True
1096 def _update(self, values):
1097 """
1098 A version of update() that accepts field objects instead of field names.
1099 Used primarily for model saving and not intended for use by general
1100 code (it requires too much poking around at model internals to be
1101 useful at that level).
1102 """
1103 if self.query.is_sliced:
1104 raise TypeError("Cannot update a query once a slice has been taken.")
1105 query = self.query.chain(sql.UpdateQuery)
1106 query.add_update_fields(values)
1107 # Clear any annotations so that they won't be present in subqueries.
1108 query.annotations = {}
1109 self._result_cache = None
1110 return query.get_compiler(self.db).execute_sql(CURSOR)
1112 _update.alters_data = True
1113 _update.queryset_only = False
1115 def exists(self):
1116 """
1117 Return True if the QuerySet would have any results, False otherwise.
1118 """
1119 if self._result_cache is None:
1120 return self.query.has_results(using=self.db)
1121 return bool(self._result_cache)
1123 def contains(self, obj):
1124 """
1125 Return True if the QuerySet contains the provided obj,
1126 False otherwise.
1127 """
1128 self._not_support_combined_queries("contains")
1129 if self._fields is not None:
1130 raise TypeError(
1131 "Cannot call QuerySet.contains() after .values() or .values_list()."
1132 )
1133 try:
1134 if obj._meta.concrete_model != self.model._meta.concrete_model:
1135 return False
1136 except AttributeError:
1137 raise TypeError("'obj' must be a model instance.")
1138 if obj.pk is None:
1139 raise ValueError("QuerySet.contains() cannot be used on unsaved objects.")
1140 if self._result_cache is not None:
1141 return obj in self._result_cache
1142 return self.filter(pk=obj.pk).exists()
1144 def _prefetch_related_objects(self):
1145 # This method can only be called once the result cache has been filled.
1146 prefetch_related_objects(self._result_cache, *self._prefetch_related_lookups)
1147 self._prefetch_done = True
1149 def explain(self, *, format=None, **options):
1150 """
1151 Runs an EXPLAIN on the SQL query this QuerySet would perform, and
1152 returns the results.
1153 """
1154 return self.query.explain(using=self.db, format=format, **options)
1156 ##################################################
1157 # PUBLIC METHODS THAT RETURN A QUERYSET SUBCLASS #
1158 ##################################################
1160 def raw(self, raw_query, params=(), translations=None, using=None):
1161 if using is None:
1162 using = self.db
1163 qs = RawQuerySet(
1164 raw_query,
1165 model=self.model,
1166 params=params,
1167 translations=translations,
1168 using=using,
1169 )
1170 qs._prefetch_related_lookups = self._prefetch_related_lookups[:]
1171 return qs
1173 def _values(self, *fields, **expressions):
1174 clone = self._chain()
1175 if expressions:
1176 clone = clone.annotate(**expressions)
1177 clone._fields = fields
1178 clone.query.set_values(fields)
1179 return clone
1181 def values(self, *fields, **expressions):
1182 fields += tuple(expressions)
1183 clone = self._values(*fields, **expressions)
1184 clone._iterable_class = ValuesIterable
1185 return clone
1187 def values_list(self, *fields, flat=False, named=False):
1188 if flat and named:
1189 raise TypeError("'flat' and 'named' can't be used together.")
1190 if flat and len(fields) > 1:
1191 raise TypeError(
1192 "'flat' is not valid when values_list is called with more than one "
1193 "field."
1194 )
1196 field_names = {f for f in fields if not hasattr(f, "resolve_expression")}
1197 _fields = []
1198 expressions = {}
1199 counter = 1
1200 for field in fields:
1201 if hasattr(field, "resolve_expression"):
1202 field_id_prefix = getattr(
1203 field, "default_alias", field.__class__.__name__.lower()
1204 )
1205 while True:
1206 field_id = field_id_prefix + str(counter)
1207 counter += 1
1208 if field_id not in field_names:
1209 break
1210 expressions[field_id] = field
1211 _fields.append(field_id)
1212 else:
1213 _fields.append(field)
1215 clone = self._values(*_fields, **expressions)
1216 clone._iterable_class = (
1217 NamedValuesListIterable
1218 if named
1219 else FlatValuesListIterable
1220 if flat
1221 else ValuesListIterable
1222 )
1223 return clone
1225 def dates(self, field_name, kind, order="ASC"):
1226 """
1227 Return a list of date objects representing all available dates for
1228 the given field_name, scoped to 'kind'.
1229 """
1230 if kind not in ("year", "month", "week", "day"):
1231 raise ValueError("'kind' must be one of 'year', 'month', 'week', or 'day'.")
1232 if order not in ("ASC", "DESC"):
1233 raise ValueError("'order' must be either 'ASC' or 'DESC'.")
1234 return (
1235 self.annotate(
1236 datefield=Trunc(field_name, kind, output_field=DateField()),
1237 plain_field=F(field_name),
1238 )
1239 .values_list("datefield", flat=True)
1240 .distinct()
1241 .filter(plain_field__isnull=False)
1242 .order_by(("-" if order == "DESC" else "") + "datefield")
1243 )
1245 def datetimes(self, field_name, kind, order="ASC", tzinfo=None):
1246 """
1247 Return a list of datetime objects representing all available
1248 datetimes for the given field_name, scoped to 'kind'.
1249 """
1250 if kind not in ("year", "month", "week", "day", "hour", "minute", "second"):
1251 raise ValueError(
1252 "'kind' must be one of 'year', 'month', 'week', 'day', "
1253 "'hour', 'minute', or 'second'."
1254 )
1255 if order not in ("ASC", "DESC"):
1256 raise ValueError("'order' must be either 'ASC' or 'DESC'.")
1258 if tzinfo is None:
1259 tzinfo = timezone.get_current_timezone()
1261 return (
1262 self.annotate(
1263 datetimefield=Trunc(
1264 field_name,
1265 kind,
1266 output_field=DateTimeField(),
1267 tzinfo=tzinfo,
1268 ),
1269 plain_field=F(field_name),
1270 )
1271 .values_list("datetimefield", flat=True)
1272 .distinct()
1273 .filter(plain_field__isnull=False)
1274 .order_by(("-" if order == "DESC" else "") + "datetimefield")
1275 )
1277 def none(self):
1278 """Return an empty QuerySet."""
1279 clone = self._chain()
1280 clone.query.set_empty()
1281 return clone
1283 ##################################################################
1284 # PUBLIC METHODS THAT ALTER ATTRIBUTES AND RETURN A NEW QUERYSET #
1285 ##################################################################
1287 def all(self):
1288 """
1289 Return a new QuerySet that is a copy of the current one. This allows a
1290 QuerySet to proxy for a model manager in some cases.
1291 """
1292 return self._chain()
1294 def filter(self, *args, **kwargs):
1295 """
1296 Return a new QuerySet instance with the args ANDed to the existing
1297 set.
1298 """
1299 self._not_support_combined_queries("filter")
1300 return self._filter_or_exclude(False, args, kwargs)
1302 def exclude(self, *args, **kwargs):
1303 """
1304 Return a new QuerySet instance with NOT (args) ANDed to the existing
1305 set.
1306 """
1307 self._not_support_combined_queries("exclude")
1308 return self._filter_or_exclude(True, args, kwargs)
1310 def _filter_or_exclude(self, negate, args, kwargs):
1311 if (args or kwargs) and self.query.is_sliced:
1312 raise TypeError("Cannot filter a query once a slice has been taken.")
1313 clone = self._chain()
1314 if self._defer_next_filter:
1315 self._defer_next_filter = False
1316 clone._deferred_filter = negate, args, kwargs
1317 else:
1318 clone._filter_or_exclude_inplace(negate, args, kwargs)
1319 return clone
1321 def _filter_or_exclude_inplace(self, negate, args, kwargs):
1322 if negate:
1323 self._query.add_q(~Q(*args, **kwargs))
1324 else:
1325 self._query.add_q(Q(*args, **kwargs))
1327 def complex_filter(self, filter_obj):
1328 """
1329 Return a new QuerySet instance with filter_obj added to the filters.
1331 filter_obj can be a Q object or a dictionary of keyword lookup
1332 arguments.
1334 This exists to support framework features such as 'limit_choices_to',
1335 and usually it will be more natural to use other methods.
1336 """
1337 if isinstance(filter_obj, Q):
1338 clone = self._chain()
1339 clone.query.add_q(filter_obj)
1340 return clone
1341 else:
1342 return self._filter_or_exclude(False, args=(), kwargs=filter_obj)
1344 def _combinator_query(self, combinator, *other_qs, all=False):
1345 # Clone the query to inherit the select list and everything
1346 clone = self._chain()
1347 # Clear limits and ordering so they can be reapplied
1348 clone.query.clear_ordering(force=True)
1349 clone.query.clear_limits()
1350 clone.query.combined_queries = (self.query,) + tuple(
1351 qs.query for qs in other_qs
1352 )
1353 clone.query.combinator = combinator
1354 clone.query.combinator_all = all
1355 return clone
1357 def union(self, *other_qs, all=False):
1358 # If the query is an EmptyQuerySet, combine all nonempty querysets.
1359 if isinstance(self, EmptyQuerySet):
1360 qs = [q for q in other_qs if not isinstance(q, EmptyQuerySet)]
1361 if not qs:
1362 return self
1363 if len(qs) == 1:
1364 return qs[0]
1365 return qs[0]._combinator_query("union", *qs[1:], all=all)
1366 return self._combinator_query("union", *other_qs, all=all)
1368 def intersection(self, *other_qs):
1369 # If any query is an EmptyQuerySet, return it.
1370 if isinstance(self, EmptyQuerySet):
1371 return self
1372 for other in other_qs:
1373 if isinstance(other, EmptyQuerySet):
1374 return other
1375 return self._combinator_query("intersection", *other_qs)
1377 def difference(self, *other_qs):
1378 # If the query is an EmptyQuerySet, return it.
1379 if isinstance(self, EmptyQuerySet):
1380 return self
1381 return self._combinator_query("difference", *other_qs)
1383 def select_for_update(self, nowait=False, skip_locked=False, of=(), no_key=False):
1384 """
1385 Return a new QuerySet instance that will select objects with a
1386 FOR UPDATE lock.
1387 """
1388 if nowait and skip_locked:
1389 raise ValueError("The nowait option cannot be used with skip_locked.")
1390 obj = self._chain()
1391 obj._for_write = True
1392 obj.query.select_for_update = True
1393 obj.query.select_for_update_nowait = nowait
1394 obj.query.select_for_update_skip_locked = skip_locked
1395 obj.query.select_for_update_of = of
1396 obj.query.select_for_no_key_update = no_key
1397 return obj
1399 def select_related(self, *fields):
1400 """
1401 Return a new QuerySet instance that will select related objects.
1403 If fields are specified, they must be ForeignKey fields and only those
1404 related objects are included in the selection.
1406 If select_related(None) is called, clear the list.
1407 """
1408 self._not_support_combined_queries("select_related")
1409 if self._fields is not None:
1410 raise TypeError(
1411 "Cannot call select_related() after .values() or .values_list()"
1412 )
1414 obj = self._chain()
1415 if fields == (None,):
1416 obj.query.select_related = False
1417 elif fields:
1418 obj.query.add_select_related(fields)
1419 else:
1420 obj.query.select_related = True
1421 return obj
1423 def prefetch_related(self, *lookups):
1424 """
1425 Return a new QuerySet instance that will prefetch the specified
1426 Many-To-One and Many-To-Many related objects when the QuerySet is
1427 evaluated.
1429 When prefetch_related() is called more than once, append to the list of
1430 prefetch lookups. If prefetch_related(None) is called, clear the list.
1431 """
1432 self._not_support_combined_queries("prefetch_related")
1433 clone = self._chain()
1434 if lookups == (None,):
1435 clone._prefetch_related_lookups = ()
1436 else:
1437 for lookup in lookups:
1438 if isinstance(lookup, Prefetch):
1439 lookup = lookup.prefetch_to
1440 lookup = lookup.split(LOOKUP_SEP, 1)[0]
1441 if lookup in self.query._filtered_relations:
1442 raise ValueError(
1443 "prefetch_related() is not supported with FilteredRelation."
1444 )
1445 clone._prefetch_related_lookups = clone._prefetch_related_lookups + lookups
1446 return clone
1448 def annotate(self, *args, **kwargs):
1449 """
1450 Return a query set in which the returned objects have been annotated
1451 with extra data or aggregations.
1452 """
1453 self._not_support_combined_queries("annotate")
1454 return self._annotate(args, kwargs, select=True)
1456 def alias(self, *args, **kwargs):
1457 """
1458 Return a query set with added aliases for extra data or aggregations.
1459 """
1460 self._not_support_combined_queries("alias")
1461 return self._annotate(args, kwargs, select=False)
1463 def _annotate(self, args, kwargs, select=True):
1464 self._validate_values_are_expressions(
1465 args + tuple(kwargs.values()), method_name="annotate"
1466 )
1467 annotations = {}
1468 for arg in args:
1469 # The default_alias property may raise a TypeError.
1470 try:
1471 if arg.default_alias in kwargs:
1472 raise ValueError(
1473 f"The named annotation '{arg.default_alias}' conflicts with the "
1474 "default name for another annotation."
1475 )
1476 except TypeError:
1477 raise TypeError("Complex annotations require an alias")
1478 annotations[arg.default_alias] = arg
1479 annotations.update(kwargs)
1481 clone = self._chain()
1482 names = self._fields
1483 if names is None:
1484 names = set(
1485 chain.from_iterable(
1486 (field.name, field.attname)
1487 if hasattr(field, "attname")
1488 else (field.name,)
1489 for field in self.model._meta.get_fields()
1490 )
1491 )
1493 for alias, annotation in annotations.items():
1494 if alias in names:
1495 raise ValueError(
1496 f"The annotation '{alias}' conflicts with a field on " "the model."
1497 )
1498 if isinstance(annotation, FilteredRelation):
1499 clone.query.add_filtered_relation(annotation, alias)
1500 else:
1501 clone.query.add_annotation(
1502 annotation,
1503 alias,
1504 select=select,
1505 )
1506 for alias, annotation in clone.query.annotations.items():
1507 if alias in annotations and annotation.contains_aggregate:
1508 if clone._fields is None:
1509 clone.query.group_by = True
1510 else:
1511 clone.query.set_group_by()
1512 break
1514 return clone
1516 def order_by(self, *field_names):
1517 """Return a new QuerySet instance with the ordering changed."""
1518 if self.query.is_sliced:
1519 raise TypeError("Cannot reorder a query once a slice has been taken.")
1520 obj = self._chain()
1521 obj.query.clear_ordering(force=True, clear_default=False)
1522 obj.query.add_ordering(*field_names)
1523 return obj
1525 def distinct(self, *field_names):
1526 """
1527 Return a new QuerySet instance that will select only distinct results.
1528 """
1529 self._not_support_combined_queries("distinct")
1530 if self.query.is_sliced:
1531 raise TypeError(
1532 "Cannot create distinct fields once a slice has been taken."
1533 )
1534 obj = self._chain()
1535 obj.query.add_distinct_fields(*field_names)
1536 return obj
1538 def extra(
1539 self,
1540 select=None,
1541 where=None,
1542 params=None,
1543 tables=None,
1544 order_by=None,
1545 select_params=None,
1546 ):
1547 """Add extra SQL fragments to the query."""
1548 self._not_support_combined_queries("extra")
1549 if self.query.is_sliced:
1550 raise TypeError("Cannot change a query once a slice has been taken.")
1551 clone = self._chain()
1552 clone.query.add_extra(select, select_params, where, params, tables, order_by)
1553 return clone
1555 def reverse(self):
1556 """Reverse the ordering of the QuerySet."""
1557 if self.query.is_sliced:
1558 raise TypeError("Cannot reverse a query once a slice has been taken.")
1559 clone = self._chain()
1560 clone.query.standard_ordering = not clone.query.standard_ordering
1561 return clone
1563 def defer(self, *fields):
1564 """
1565 Defer the loading of data for certain fields until they are accessed.
1566 Add the set of deferred fields to any existing set of deferred fields.
1567 The only exception to this is if None is passed in as the only
1568 parameter, in which case removal all deferrals.
1569 """
1570 self._not_support_combined_queries("defer")
1571 if self._fields is not None:
1572 raise TypeError("Cannot call defer() after .values() or .values_list()")
1573 clone = self._chain()
1574 if fields == (None,):
1575 clone.query.clear_deferred_loading()
1576 else:
1577 clone.query.add_deferred_loading(fields)
1578 return clone
1580 def only(self, *fields):
1581 """
1582 Essentially, the opposite of defer(). Only the fields passed into this
1583 method and that are not already specified as deferred are loaded
1584 immediately when the queryset is evaluated.
1585 """
1586 self._not_support_combined_queries("only")
1587 if self._fields is not None:
1588 raise TypeError("Cannot call only() after .values() or .values_list()")
1589 if fields == (None,):
1590 # Can only pass None to defer(), not only(), as the rest option.
1591 # That won't stop people trying to do this, so let's be explicit.
1592 raise TypeError("Cannot pass None as an argument to only().")
1593 for field in fields:
1594 field = field.split(LOOKUP_SEP, 1)[0]
1595 if field in self.query._filtered_relations:
1596 raise ValueError("only() is not supported with FilteredRelation.")
1597 clone = self._chain()
1598 clone.query.add_immediate_loading(fields)
1599 return clone
1601 def using(self, alias):
1602 """Select which database this QuerySet should execute against."""
1603 clone = self._chain()
1604 clone._db = alias
1605 return clone
1607 ###################################
1608 # PUBLIC INTROSPECTION ATTRIBUTES #
1609 ###################################
1611 @property
1612 def ordered(self):
1613 """
1614 Return True if the QuerySet is ordered -- i.e. has an order_by()
1615 clause or a default ordering on the model (or is empty).
1616 """
1617 if isinstance(self, EmptyQuerySet):
1618 return True
1619 if self.query.extra_order_by or self.query.order_by:
1620 return True
1621 elif (
1622 self.query.default_ordering
1623 and self.query.get_meta().ordering
1624 and
1625 # A default ordering doesn't affect GROUP BY queries.
1626 not self.query.group_by
1627 ):
1628 return True
1629 else:
1630 return False
1632 @property
1633 def db(self):
1634 """Return the database used if this query is executed now."""
1635 if self._for_write:
1636 return self._db or router.db_for_write(self.model, **self._hints)
1637 return self._db or router.db_for_read(self.model, **self._hints)
1639 ###################
1640 # PRIVATE METHODS #
1641 ###################
1643 def _insert(
1644 self,
1645 objs,
1646 fields,
1647 returning_fields=None,
1648 raw=False,
1649 using=None,
1650 on_conflict=None,
1651 update_fields=None,
1652 unique_fields=None,
1653 ):
1654 """
1655 Insert a new record for the given model. This provides an interface to
1656 the InsertQuery class and is how Model.save() is implemented.
1657 """
1658 self._for_write = True
1659 if using is None:
1660 using = self.db
1661 query = sql.InsertQuery(
1662 self.model,
1663 on_conflict=on_conflict,
1664 update_fields=update_fields,
1665 unique_fields=unique_fields,
1666 )
1667 query.insert_values(fields, objs, raw=raw)
1668 return query.get_compiler(using=using).execute_sql(returning_fields)
1670 _insert.alters_data = True
1671 _insert.queryset_only = False
1673 def _batched_insert(
1674 self,
1675 objs,
1676 fields,
1677 batch_size,
1678 on_conflict=None,
1679 update_fields=None,
1680 unique_fields=None,
1681 ):
1682 """
1683 Helper method for bulk_create() to insert objs one batch at a time.
1684 """
1685 connection = connections[self.db]
1686 ops = connection.ops
1687 max_batch_size = max(ops.bulk_batch_size(fields, objs), 1)
1688 batch_size = min(batch_size, max_batch_size) if batch_size else max_batch_size
1689 inserted_rows = []
1690 bulk_return = connection.features.can_return_rows_from_bulk_insert
1691 for item in [objs[i : i + batch_size] for i in range(0, len(objs), batch_size)]:
1692 if bulk_return and on_conflict is None:
1693 inserted_rows.extend(
1694 self._insert(
1695 item,
1696 fields=fields,
1697 using=self.db,
1698 returning_fields=self.model._meta.db_returning_fields,
1699 )
1700 )
1701 else:
1702 self._insert(
1703 item,
1704 fields=fields,
1705 using=self.db,
1706 on_conflict=on_conflict,
1707 update_fields=update_fields,
1708 unique_fields=unique_fields,
1709 )
1710 return inserted_rows
1712 def _chain(self):
1713 """
1714 Return a copy of the current QuerySet that's ready for another
1715 operation.
1716 """
1717 obj = self._clone()
1718 if obj._sticky_filter:
1719 obj.query.filter_is_sticky = True
1720 obj._sticky_filter = False
1721 return obj
1723 def _clone(self):
1724 """
1725 Return a copy of the current QuerySet. A lightweight alternative
1726 to deepcopy().
1727 """
1728 c = self.__class__(
1729 model=self.model,
1730 query=self.query.chain(),
1731 using=self._db,
1732 hints=self._hints,
1733 )
1734 c._sticky_filter = self._sticky_filter
1735 c._for_write = self._for_write
1736 c._prefetch_related_lookups = self._prefetch_related_lookups[:]
1737 c._known_related_objects = self._known_related_objects
1738 c._iterable_class = self._iterable_class
1739 c._fields = self._fields
1740 return c
1742 def _fetch_all(self):
1743 if self._result_cache is None:
1744 self._result_cache = list(self._iterable_class(self))
1745 if self._prefetch_related_lookups and not self._prefetch_done:
1746 self._prefetch_related_objects()
1748 def _next_is_sticky(self):
1749 """
1750 Indicate that the next filter call and the one following that should
1751 be treated as a single filter. This is only important when it comes to
1752 determining when to reuse tables for many-to-many filters. Required so
1753 that we can filter naturally on the results of related managers.
1755 This doesn't return a clone of the current QuerySet (it returns
1756 "self"). The method is only used internally and should be immediately
1757 followed by a filter() that does create a clone.
1758 """
1759 self._sticky_filter = True
1760 return self
1762 def _merge_sanity_check(self, other):
1763 """Check that two QuerySet classes may be merged."""
1764 if self._fields is not None and (
1765 set(self.query.values_select) != set(other.query.values_select)
1766 or set(self.query.extra_select) != set(other.query.extra_select)
1767 or set(self.query.annotation_select) != set(other.query.annotation_select)
1768 ):
1769 raise TypeError(
1770 f"Merging '{self.__class__.__name__}' classes must involve the same values in each case."
1771 )
1773 def _merge_known_related_objects(self, other):
1774 """
1775 Keep track of all known related objects from either QuerySet instance.
1776 """
1777 for field, objects in other._known_related_objects.items():
1778 self._known_related_objects.setdefault(field, {}).update(objects)
1780 def resolve_expression(self, *args, **kwargs):
1781 if self._fields and len(self._fields) > 1:
1782 # values() queryset can only be used as nested queries
1783 # if they are set up to select only a single field.
1784 raise TypeError("Cannot use multi-field values as a filter value.")
1785 query = self.query.resolve_expression(*args, **kwargs)
1786 query._db = self._db
1787 return query
1789 resolve_expression.queryset_only = True
1791 def _add_hints(self, **hints):
1792 """
1793 Update hinting information for use by routers. Add new key/values or
1794 overwrite existing key/values.
1795 """
1796 self._hints.update(hints)
1798 def _has_filters(self):
1799 """
1800 Check if this QuerySet has any filtering going on. This isn't
1801 equivalent with checking if all objects are present in results, for
1802 example, qs[1:]._has_filters() -> False.
1803 """
1804 return self.query.has_filters()
1806 @staticmethod
1807 def _validate_values_are_expressions(values, method_name):
1808 invalid_args = sorted(
1809 str(arg) for arg in values if not hasattr(arg, "resolve_expression")
1810 )
1811 if invalid_args:
1812 raise TypeError(
1813 "QuerySet.{}() received non-expression(s): {}.".format(
1814 method_name,
1815 ", ".join(invalid_args),
1816 )
1817 )
1819 def _not_support_combined_queries(self, operation_name):
1820 if self.query.combinator:
1821 raise NotSupportedError(
1822 f"Calling QuerySet.{operation_name}() after {self.query.combinator}() is not supported."
1823 )
1825 def _check_operator_queryset(self, other, operator_):
1826 if self.query.combinator or other.query.combinator:
1827 raise TypeError(f"Cannot use {operator_} operator with combined queryset.")
1829 def _check_ordering_first_last_queryset_aggregation(self, method):
1830 if isinstance(self.query.group_by, tuple) and not any(
1831 col.output_field is self.model._meta.pk for col in self.query.group_by
1832 ):
1833 raise TypeError(
1834 f"Cannot use QuerySet.{method}() on an unordered queryset performing "
1835 f"aggregation. Add an ordering with order_by()."
1836 )
1839class InstanceCheckMeta(type):
1840 def __instancecheck__(self, instance):
1841 return isinstance(instance, QuerySet) and instance.query.is_empty()
1844class EmptyQuerySet(metaclass=InstanceCheckMeta):
1845 """
1846 Marker class to checking if a queryset is empty by .none():
1847 isinstance(qs.none(), EmptyQuerySet) -> True
1848 """
1850 def __init__(self, *args, **kwargs):
1851 raise TypeError("EmptyQuerySet can't be instantiated")
1854class RawQuerySet:
1855 """
1856 Provide an iterator which converts the results of raw SQL queries into
1857 annotated model instances.
1858 """
1860 def __init__(
1861 self,
1862 raw_query,
1863 model=None,
1864 query=None,
1865 params=(),
1866 translations=None,
1867 using=None,
1868 hints=None,
1869 ):
1870 self.raw_query = raw_query
1871 self.model = model
1872 self._db = using
1873 self._hints = hints or {}
1874 self.query = query or sql.RawQuery(sql=raw_query, using=self.db, params=params)
1875 self.params = params
1876 self.translations = translations or {}
1877 self._result_cache = None
1878 self._prefetch_related_lookups = ()
1879 self._prefetch_done = False
1881 def resolve_model_init_order(self):
1882 """Resolve the init field names and value positions."""
1883 converter = connections[self.db].introspection.identifier_converter
1884 model_init_fields = [
1885 f for f in self.model._meta.fields if converter(f.column) in self.columns
1886 ]
1887 annotation_fields = [
1888 (column, pos)
1889 for pos, column in enumerate(self.columns)
1890 if column not in self.model_fields
1891 ]
1892 model_init_order = [
1893 self.columns.index(converter(f.column)) for f in model_init_fields
1894 ]
1895 model_init_names = [f.attname for f in model_init_fields]
1896 return model_init_names, model_init_order, annotation_fields
1898 def prefetch_related(self, *lookups):
1899 """Same as QuerySet.prefetch_related()"""
1900 clone = self._clone()
1901 if lookups == (None,):
1902 clone._prefetch_related_lookups = ()
1903 else:
1904 clone._prefetch_related_lookups = clone._prefetch_related_lookups + lookups
1905 return clone
1907 def _prefetch_related_objects(self):
1908 prefetch_related_objects(self._result_cache, *self._prefetch_related_lookups)
1909 self._prefetch_done = True
1911 def _clone(self):
1912 """Same as QuerySet._clone()"""
1913 c = self.__class__(
1914 self.raw_query,
1915 model=self.model,
1916 query=self.query,
1917 params=self.params,
1918 translations=self.translations,
1919 using=self._db,
1920 hints=self._hints,
1921 )
1922 c._prefetch_related_lookups = self._prefetch_related_lookups[:]
1923 return c
1925 def _fetch_all(self):
1926 if self._result_cache is None:
1927 self._result_cache = list(self.iterator())
1928 if self._prefetch_related_lookups and not self._prefetch_done:
1929 self._prefetch_related_objects()
1931 def __len__(self):
1932 self._fetch_all()
1933 return len(self._result_cache)
1935 def __bool__(self):
1936 self._fetch_all()
1937 return bool(self._result_cache)
1939 def __iter__(self):
1940 self._fetch_all()
1941 return iter(self._result_cache)
1943 def iterator(self):
1944 yield from RawModelIterable(self)
1946 def __repr__(self):
1947 return f"<{self.__class__.__name__}: {self.query}>"
1949 def __getitem__(self, k):
1950 return list(self)[k]
1952 @property
1953 def db(self):
1954 """Return the database used if this query is executed now."""
1955 return self._db or router.db_for_read(self.model, **self._hints)
1957 def using(self, alias):
1958 """Select the database this RawQuerySet should execute against."""
1959 return RawQuerySet(
1960 self.raw_query,
1961 model=self.model,
1962 query=self.query.chain(using=alias),
1963 params=self.params,
1964 translations=self.translations,
1965 using=alias,
1966 )
1968 @cached_property
1969 def columns(self):
1970 """
1971 A list of model field names in the order they'll appear in the
1972 query results.
1973 """
1974 columns = self.query.get_columns()
1975 # Adjust any column names which don't match field names
1976 for query_name, model_name in self.translations.items():
1977 # Ignore translations for nonexistent column names
1978 try:
1979 index = columns.index(query_name)
1980 except ValueError:
1981 pass
1982 else:
1983 columns[index] = model_name
1984 return columns
1986 @cached_property
1987 def model_fields(self):
1988 """A dict mapping column names to model field names."""
1989 converter = connections[self.db].introspection.identifier_converter
1990 model_fields = {}
1991 for field in self.model._meta.fields:
1992 name, column = field.get_attname_column()
1993 model_fields[converter(column)] = field
1994 return model_fields
1997class Prefetch:
1998 def __init__(self, lookup, queryset=None, to_attr=None):
1999 # `prefetch_through` is the path we traverse to perform the prefetch.
2000 self.prefetch_through = lookup
2001 # `prefetch_to` is the path to the attribute that stores the result.
2002 self.prefetch_to = lookup
2003 if queryset is not None and (
2004 isinstance(queryset, RawQuerySet)
2005 or (
2006 hasattr(queryset, "_iterable_class")
2007 and not issubclass(queryset._iterable_class, ModelIterable)
2008 )
2009 ):
2010 raise ValueError(
2011 "Prefetch querysets cannot use raw(), values(), and values_list()."
2012 )
2013 if to_attr:
2014 self.prefetch_to = LOOKUP_SEP.join(
2015 lookup.split(LOOKUP_SEP)[:-1] + [to_attr]
2016 )
2018 self.queryset = queryset
2019 self.to_attr = to_attr
2021 def __getstate__(self):
2022 obj_dict = self.__dict__.copy()
2023 if self.queryset is not None:
2024 queryset = self.queryset._chain()
2025 # Prevent the QuerySet from being evaluated
2026 queryset._result_cache = []
2027 queryset._prefetch_done = True
2028 obj_dict["queryset"] = queryset
2029 return obj_dict
2031 def add_prefix(self, prefix):
2032 self.prefetch_through = prefix + LOOKUP_SEP + self.prefetch_through
2033 self.prefetch_to = prefix + LOOKUP_SEP + self.prefetch_to
2035 def get_current_prefetch_to(self, level):
2036 return LOOKUP_SEP.join(self.prefetch_to.split(LOOKUP_SEP)[: level + 1])
2038 def get_current_to_attr(self, level):
2039 parts = self.prefetch_to.split(LOOKUP_SEP)
2040 to_attr = parts[level]
2041 as_attr = self.to_attr and level == len(parts) - 1
2042 return to_attr, as_attr
2044 def get_current_queryset(self, level):
2045 if self.get_current_prefetch_to(level) == self.prefetch_to:
2046 return self.queryset
2047 return None
2049 def __eq__(self, other):
2050 if not isinstance(other, Prefetch):
2051 return NotImplemented
2052 return self.prefetch_to == other.prefetch_to
2054 def __hash__(self):
2055 return hash((self.__class__, self.prefetch_to))
2058def normalize_prefetch_lookups(lookups, prefix=None):
2059 """Normalize lookups into Prefetch objects."""
2060 ret = []
2061 for lookup in lookups:
2062 if not isinstance(lookup, Prefetch):
2063 lookup = Prefetch(lookup)
2064 if prefix:
2065 lookup.add_prefix(prefix)
2066 ret.append(lookup)
2067 return ret
2070def prefetch_related_objects(model_instances, *related_lookups):
2071 """
2072 Populate prefetched object caches for a list of model instances based on
2073 the lookups/Prefetch instances given.
2074 """
2075 if not model_instances:
2076 return # nothing to do
2078 # We need to be able to dynamically add to the list of prefetch_related
2079 # lookups that we look up (see below). So we need some book keeping to
2080 # ensure we don't do duplicate work.
2081 done_queries = {} # dictionary of things like 'foo__bar': [results]
2083 auto_lookups = set() # we add to this as we go through.
2084 followed_descriptors = set() # recursion protection
2086 all_lookups = normalize_prefetch_lookups(reversed(related_lookups))
2087 while all_lookups:
2088 lookup = all_lookups.pop()
2089 if lookup.prefetch_to in done_queries:
2090 if lookup.queryset is not None:
2091 raise ValueError(
2092 f"'{lookup.prefetch_to}' lookup was already seen with a different queryset. "
2093 "You may need to adjust the ordering of your lookups."
2094 )
2096 continue
2098 # Top level, the list of objects to decorate is the result cache
2099 # from the primary QuerySet. It won't be for deeper levels.
2100 obj_list = model_instances
2102 through_attrs = lookup.prefetch_through.split(LOOKUP_SEP)
2103 for level, through_attr in enumerate(through_attrs):
2104 # Prepare main instances
2105 if not obj_list:
2106 break
2108 prefetch_to = lookup.get_current_prefetch_to(level)
2109 if prefetch_to in done_queries:
2110 # Skip any prefetching, and any object preparation
2111 obj_list = done_queries[prefetch_to]
2112 continue
2114 # Prepare objects:
2115 good_objects = True
2116 for obj in obj_list:
2117 # Since prefetching can re-use instances, it is possible to have
2118 # the same instance multiple times in obj_list, so obj might
2119 # already be prepared.
2120 if not hasattr(obj, "_prefetched_objects_cache"):
2121 try:
2122 obj._prefetched_objects_cache = {}
2123 except (AttributeError, TypeError):
2124 # Must be an immutable object from
2125 # values_list(flat=True), for example (TypeError) or
2126 # a QuerySet subclass that isn't returning Model
2127 # instances (AttributeError), either in Plain or a 3rd
2128 # party. prefetch_related() doesn't make sense, so quit.
2129 good_objects = False
2130 break
2131 if not good_objects:
2132 break
2134 # Descend down tree
2136 # We assume that objects retrieved are homogeneous (which is the premise
2137 # of prefetch_related), so what applies to first object applies to all.
2138 first_obj = obj_list[0]
2139 to_attr = lookup.get_current_to_attr(level)[0]
2140 prefetcher, descriptor, attr_found, is_fetched = get_prefetcher(
2141 first_obj, through_attr, to_attr
2142 )
2144 if not attr_found:
2145 raise AttributeError(
2146 f"Cannot find '{through_attr}' on {first_obj.__class__.__name__} object, '{lookup.prefetch_through}' is an invalid "
2147 "parameter to prefetch_related()"
2148 )
2150 if level == len(through_attrs) - 1 and prefetcher is None:
2151 # Last one, this *must* resolve to something that supports
2152 # prefetching, otherwise there is no point adding it and the
2153 # developer asking for it has made a mistake.
2154 raise ValueError(
2155 f"'{lookup.prefetch_through}' does not resolve to an item that supports "
2156 "prefetching - this is an invalid parameter to "
2157 "prefetch_related()."
2158 )
2160 obj_to_fetch = None
2161 if prefetcher is not None:
2162 obj_to_fetch = [obj for obj in obj_list if not is_fetched(obj)]
2164 if obj_to_fetch:
2165 obj_list, additional_lookups = prefetch_one_level(
2166 obj_to_fetch,
2167 prefetcher,
2168 lookup,
2169 level,
2170 )
2171 # We need to ensure we don't keep adding lookups from the
2172 # same relationships to stop infinite recursion. So, if we
2173 # are already on an automatically added lookup, don't add
2174 # the new lookups from relationships we've seen already.
2175 if not (
2176 prefetch_to in done_queries
2177 and lookup in auto_lookups
2178 and descriptor in followed_descriptors
2179 ):
2180 done_queries[prefetch_to] = obj_list
2181 new_lookups = normalize_prefetch_lookups(
2182 reversed(additional_lookups), prefetch_to
2183 )
2184 auto_lookups.update(new_lookups)
2185 all_lookups.extend(new_lookups)
2186 followed_descriptors.add(descriptor)
2187 else:
2188 # Either a singly related object that has already been fetched
2189 # (e.g. via select_related), or hopefully some other property
2190 # that doesn't support prefetching but needs to be traversed.
2192 # We replace the current list of parent objects with the list
2193 # of related objects, filtering out empty or missing values so
2194 # that we can continue with nullable or reverse relations.
2195 new_obj_list = []
2196 for obj in obj_list:
2197 if through_attr in getattr(obj, "_prefetched_objects_cache", ()):
2198 # If related objects have been prefetched, use the
2199 # cache rather than the object's through_attr.
2200 new_obj = list(obj._prefetched_objects_cache.get(through_attr))
2201 else:
2202 try:
2203 new_obj = getattr(obj, through_attr)
2204 except exceptions.ObjectDoesNotExist:
2205 continue
2206 if new_obj is None:
2207 continue
2208 # We special-case `list` rather than something more generic
2209 # like `Iterable` because we don't want to accidentally match
2210 # user models that define __iter__.
2211 if isinstance(new_obj, list):
2212 new_obj_list.extend(new_obj)
2213 else:
2214 new_obj_list.append(new_obj)
2215 obj_list = new_obj_list
2218def get_prefetcher(instance, through_attr, to_attr):
2219 """
2220 For the attribute 'through_attr' on the given instance, find
2221 an object that has a get_prefetch_queryset().
2222 Return a 4 tuple containing:
2223 (the object with get_prefetch_queryset (or None),
2224 the descriptor object representing this relationship (or None),
2225 a boolean that is False if the attribute was not found at all,
2226 a function that takes an instance and returns a boolean that is True if
2227 the attribute has already been fetched for that instance)
2228 """
2230 def has_to_attr_attribute(instance):
2231 return hasattr(instance, to_attr)
2233 prefetcher = None
2234 is_fetched = has_to_attr_attribute
2236 # For singly related objects, we have to avoid getting the attribute
2237 # from the object, as this will trigger the query. So we first try
2238 # on the class, in order to get the descriptor object.
2239 rel_obj_descriptor = getattr(instance.__class__, through_attr, None)
2240 if rel_obj_descriptor is None:
2241 attr_found = hasattr(instance, through_attr)
2242 else:
2243 attr_found = True
2244 if rel_obj_descriptor:
2245 # singly related object, descriptor object has the
2246 # get_prefetch_queryset() method.
2247 if hasattr(rel_obj_descriptor, "get_prefetch_queryset"):
2248 prefetcher = rel_obj_descriptor
2249 is_fetched = rel_obj_descriptor.is_cached
2250 else:
2251 # descriptor doesn't support prefetching, so we go ahead and get
2252 # the attribute on the instance rather than the class to
2253 # support many related managers
2254 rel_obj = getattr(instance, through_attr)
2255 if hasattr(rel_obj, "get_prefetch_queryset"):
2256 prefetcher = rel_obj
2257 if through_attr != to_attr:
2258 # Special case cached_property instances because hasattr
2259 # triggers attribute computation and assignment.
2260 if isinstance(
2261 getattr(instance.__class__, to_attr, None), cached_property
2262 ):
2264 def has_cached_property(instance):
2265 return to_attr in instance.__dict__
2267 is_fetched = has_cached_property
2268 else:
2270 def in_prefetched_cache(instance):
2271 return through_attr in instance._prefetched_objects_cache
2273 is_fetched = in_prefetched_cache
2274 return prefetcher, rel_obj_descriptor, attr_found, is_fetched
2277def prefetch_one_level(instances, prefetcher, lookup, level):
2278 """
2279 Helper function for prefetch_related_objects().
2281 Run prefetches on all instances using the prefetcher object,
2282 assigning results to relevant caches in instance.
2284 Return the prefetched objects along with any additional prefetches that
2285 must be done due to prefetch_related lookups found from default managers.
2286 """
2287 # prefetcher must have a method get_prefetch_queryset() which takes a list
2288 # of instances, and returns a tuple:
2290 # (queryset of instances of self.model that are related to passed in instances,
2291 # callable that gets value to be matched for returned instances,
2292 # callable that gets value to be matched for passed in instances,
2293 # boolean that is True for singly related objects,
2294 # cache or field name to assign to,
2295 # boolean that is True when the previous argument is a cache name vs a field name).
2297 # The 'values to be matched' must be hashable as they will be used
2298 # in a dictionary.
2300 (
2301 rel_qs,
2302 rel_obj_attr,
2303 instance_attr,
2304 single,
2305 cache_name,
2306 is_descriptor,
2307 ) = prefetcher.get_prefetch_queryset(instances, lookup.get_current_queryset(level))
2308 # We have to handle the possibility that the QuerySet we just got back
2309 # contains some prefetch_related lookups. We don't want to trigger the
2310 # prefetch_related functionality by evaluating the query. Rather, we need
2311 # to merge in the prefetch_related lookups.
2312 # Copy the lookups in case it is a Prefetch object which could be reused
2313 # later (happens in nested prefetch_related).
2314 additional_lookups = [
2315 copy.copy(additional_lookup)
2316 for additional_lookup in getattr(rel_qs, "_prefetch_related_lookups", ())
2317 ]
2318 if additional_lookups:
2319 # Don't need to clone because the manager should have given us a fresh
2320 # instance, so we access an internal instead of using public interface
2321 # for performance reasons.
2322 rel_qs._prefetch_related_lookups = ()
2324 all_related_objects = list(rel_qs)
2326 rel_obj_cache = {}
2327 for rel_obj in all_related_objects:
2328 rel_attr_val = rel_obj_attr(rel_obj)
2329 rel_obj_cache.setdefault(rel_attr_val, []).append(rel_obj)
2331 to_attr, as_attr = lookup.get_current_to_attr(level)
2332 # Make sure `to_attr` does not conflict with a field.
2333 if as_attr and instances:
2334 # We assume that objects retrieved are homogeneous (which is the premise
2335 # of prefetch_related), so what applies to first object applies to all.
2336 model = instances[0].__class__
2337 try:
2338 model._meta.get_field(to_attr)
2339 except exceptions.FieldDoesNotExist:
2340 pass
2341 else:
2342 msg = "to_attr={} conflicts with a field on the {} model."
2343 raise ValueError(msg.format(to_attr, model.__name__))
2345 # Whether or not we're prefetching the last part of the lookup.
2346 leaf = len(lookup.prefetch_through.split(LOOKUP_SEP)) - 1 == level
2348 for obj in instances:
2349 instance_attr_val = instance_attr(obj)
2350 vals = rel_obj_cache.get(instance_attr_val, [])
2352 if single:
2353 val = vals[0] if vals else None
2354 if as_attr:
2355 # A to_attr has been given for the prefetch.
2356 setattr(obj, to_attr, val)
2357 elif is_descriptor:
2358 # cache_name points to a field name in obj.
2359 # This field is a descriptor for a related object.
2360 setattr(obj, cache_name, val)
2361 else:
2362 # No to_attr has been given for this prefetch operation and the
2363 # cache_name does not point to a descriptor. Store the value of
2364 # the field in the object's field cache.
2365 obj._state.fields_cache[cache_name] = val
2366 else:
2367 if as_attr:
2368 setattr(obj, to_attr, vals)
2369 else:
2370 manager = getattr(obj, to_attr)
2371 if leaf and lookup.queryset is not None:
2372 qs = manager._apply_rel_filters(lookup.queryset)
2373 else:
2374 qs = manager.get_queryset()
2375 qs._result_cache = vals
2376 # We don't want the individual qs doing prefetch_related now,
2377 # since we have merged this into the current work.
2378 qs._prefetch_done = True
2379 obj._prefetched_objects_cache[cache_name] = qs
2380 return all_related_objects, additional_lookups
2383class RelatedPopulator:
2384 """
2385 RelatedPopulator is used for select_related() object instantiation.
2387 The idea is that each select_related() model will be populated by a
2388 different RelatedPopulator instance. The RelatedPopulator instances get
2389 klass_info and select (computed in SQLCompiler) plus the used db as
2390 input for initialization. That data is used to compute which columns
2391 to use, how to instantiate the model, and how to populate the links
2392 between the objects.
2394 The actual creation of the objects is done in populate() method. This
2395 method gets row and from_obj as input and populates the select_related()
2396 model instance.
2397 """
2399 def __init__(self, klass_info, select, db):
2400 self.db = db
2401 # Pre-compute needed attributes. The attributes are:
2402 # - model_cls: the possibly deferred model class to instantiate
2403 # - either:
2404 # - cols_start, cols_end: usually the columns in the row are
2405 # in the same order model_cls.__init__ expects them, so we
2406 # can instantiate by model_cls(*row[cols_start:cols_end])
2407 # - reorder_for_init: When select_related descends to a child
2408 # class, then we want to reuse the already selected parent
2409 # data. However, in this case the parent data isn't necessarily
2410 # in the same order that Model.__init__ expects it to be, so
2411 # we have to reorder the parent data. The reorder_for_init
2412 # attribute contains a function used to reorder the field data
2413 # in the order __init__ expects it.
2414 # - pk_idx: the index of the primary key field in the reordered
2415 # model data. Used to check if a related object exists at all.
2416 # - init_list: the field attnames fetched from the database. For
2417 # deferred models this isn't the same as all attnames of the
2418 # model's fields.
2419 # - related_populators: a list of RelatedPopulator instances if
2420 # select_related() descends to related models from this model.
2421 # - local_setter, remote_setter: Methods to set cached values on
2422 # the object being populated and on the remote object. Usually
2423 # these are Field.set_cached_value() methods.
2424 select_fields = klass_info["select_fields"]
2425 from_parent = klass_info["from_parent"]
2426 if not from_parent:
2427 self.cols_start = select_fields[0]
2428 self.cols_end = select_fields[-1] + 1
2429 self.init_list = [
2430 f[0].target.attname for f in select[self.cols_start : self.cols_end]
2431 ]
2432 self.reorder_for_init = None
2433 else:
2434 attname_indexes = {
2435 select[idx][0].target.attname: idx for idx in select_fields
2436 }
2437 model_init_attnames = (
2438 f.attname for f in klass_info["model"]._meta.concrete_fields
2439 )
2440 self.init_list = [
2441 attname for attname in model_init_attnames if attname in attname_indexes
2442 ]
2443 self.reorder_for_init = operator.itemgetter(
2444 *[attname_indexes[attname] for attname in self.init_list]
2445 )
2447 self.model_cls = klass_info["model"]
2448 self.pk_idx = self.init_list.index(self.model_cls._meta.pk.attname)
2449 self.related_populators = get_related_populators(klass_info, select, self.db)
2450 self.local_setter = klass_info["local_setter"]
2451 self.remote_setter = klass_info["remote_setter"]
2453 def populate(self, row, from_obj):
2454 if self.reorder_for_init:
2455 obj_data = self.reorder_for_init(row)
2456 else:
2457 obj_data = row[self.cols_start : self.cols_end]
2458 if obj_data[self.pk_idx] is None:
2459 obj = None
2460 else:
2461 obj = self.model_cls.from_db(self.db, self.init_list, obj_data)
2462 for rel_iter in self.related_populators:
2463 rel_iter.populate(row, obj)
2464 self.local_setter(from_obj, obj)
2465 if obj is not None:
2466 self.remote_setter(obj, from_obj)
2469def get_related_populators(klass_info, select, db):
2470 iterators = []
2471 related_klass_infos = klass_info.get("related_klass_infos", [])
2472 for rel_klass_info in related_klass_infos:
2473 rel_cls = RelatedPopulator(rel_klass_info, select, db)
2474 iterators.append(rel_cls)
2475 return iterators