Coverage for /Users/davegaeddert/Developer/dropseed/plain/plain-models/plain/models/expressions.py: 40%
983 statements
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-23 11:16 -0600
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-23 11:16 -0600
1import copy
2import datetime
3import functools
4import inspect
5from collections import defaultdict
6from decimal import Decimal
7from types import NoneType
8from uuid import UUID
10from plain.exceptions import EmptyResultSet, FieldError, FullResultSet
11from plain.models import fields
12from plain.models.constants import LOOKUP_SEP
13from plain.models.db import DatabaseError, NotSupportedError, connection
14from plain.models.query_utils import Q
15from plain.utils.deconstruct import deconstructible
16from plain.utils.functional import cached_property
17from plain.utils.hashable import make_hashable
20class SQLiteNumericMixin:
21 """
22 Some expressions with output_field=DecimalField() must be cast to
23 numeric to be properly filtered.
24 """
26 def as_sqlite(self, compiler, connection, **extra_context):
27 sql, params = self.as_sql(compiler, connection, **extra_context)
28 try:
29 if self.output_field.get_internal_type() == "DecimalField":
30 sql = f"CAST({sql} AS NUMERIC)"
31 except FieldError:
32 pass
33 return sql, params
36class Combinable:
37 """
38 Provide the ability to combine one or two objects with
39 some connector. For example F('foo') + F('bar').
40 """
42 # Arithmetic connectors
43 ADD = "+"
44 SUB = "-"
45 MUL = "*"
46 DIV = "/"
47 POW = "^"
48 # The following is a quoted % operator - it is quoted because it can be
49 # used in strings that also have parameter substitution.
50 MOD = "%%"
52 # Bitwise operators - note that these are generated by .bitand()
53 # and .bitor(), the '&' and '|' are reserved for boolean operator
54 # usage.
55 BITAND = "&"
56 BITOR = "|"
57 BITLEFTSHIFT = "<<"
58 BITRIGHTSHIFT = ">>"
59 BITXOR = "#"
61 def _combine(self, other, connector, reversed):
62 if not hasattr(other, "resolve_expression"):
63 # everything must be resolvable to an expression
64 other = Value(other)
66 if reversed:
67 return CombinedExpression(other, connector, self)
68 return CombinedExpression(self, connector, other)
70 #############
71 # OPERATORS #
72 #############
74 def __neg__(self):
75 return self._combine(-1, self.MUL, False)
77 def __add__(self, other):
78 return self._combine(other, self.ADD, False)
80 def __sub__(self, other):
81 return self._combine(other, self.SUB, False)
83 def __mul__(self, other):
84 return self._combine(other, self.MUL, False)
86 def __truediv__(self, other):
87 return self._combine(other, self.DIV, False)
89 def __mod__(self, other):
90 return self._combine(other, self.MOD, False)
92 def __pow__(self, other):
93 return self._combine(other, self.POW, False)
95 def __and__(self, other):
96 if getattr(self, "conditional", False) and getattr(other, "conditional", False):
97 return Q(self) & Q(other)
98 raise NotImplementedError(
99 "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
100 )
102 def bitand(self, other):
103 return self._combine(other, self.BITAND, False)
105 def bitleftshift(self, other):
106 return self._combine(other, self.BITLEFTSHIFT, False)
108 def bitrightshift(self, other):
109 return self._combine(other, self.BITRIGHTSHIFT, False)
111 def __xor__(self, other):
112 if getattr(self, "conditional", False) and getattr(other, "conditional", False):
113 return Q(self) ^ Q(other)
114 raise NotImplementedError(
115 "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
116 )
118 def bitxor(self, other):
119 return self._combine(other, self.BITXOR, False)
121 def __or__(self, other):
122 if getattr(self, "conditional", False) and getattr(other, "conditional", False):
123 return Q(self) | Q(other)
124 raise NotImplementedError(
125 "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
126 )
128 def bitor(self, other):
129 return self._combine(other, self.BITOR, False)
131 def __radd__(self, other):
132 return self._combine(other, self.ADD, True)
134 def __rsub__(self, other):
135 return self._combine(other, self.SUB, True)
137 def __rmul__(self, other):
138 return self._combine(other, self.MUL, True)
140 def __rtruediv__(self, other):
141 return self._combine(other, self.DIV, True)
143 def __rmod__(self, other):
144 return self._combine(other, self.MOD, True)
146 def __rpow__(self, other):
147 return self._combine(other, self.POW, True)
149 def __rand__(self, other):
150 raise NotImplementedError(
151 "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
152 )
154 def __ror__(self, other):
155 raise NotImplementedError(
156 "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
157 )
159 def __rxor__(self, other):
160 raise NotImplementedError(
161 "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
162 )
164 def __invert__(self):
165 return NegatedExpression(self)
168class BaseExpression:
169 """Base class for all query expressions."""
171 empty_result_set_value = NotImplemented
172 # aggregate specific fields
173 is_summary = False
174 _output_field_resolved_to_none = False
175 # Can the expression be used in a WHERE clause?
176 filterable = True
177 # Can the expression can be used as a source expression in Window?
178 window_compatible = False
180 def __init__(self, output_field=None):
181 if output_field is not None:
182 self.output_field = output_field
184 def __getstate__(self):
185 state = self.__dict__.copy()
186 state.pop("convert_value", None)
187 return state
189 def get_db_converters(self, connection):
190 return (
191 []
192 if self.convert_value is self._convert_value_noop
193 else [self.convert_value]
194 ) + self.output_field.get_db_converters(connection)
196 def get_source_expressions(self):
197 return []
199 def set_source_expressions(self, exprs):
200 assert not exprs
202 def _parse_expressions(self, *expressions):
203 return [
204 arg
205 if hasattr(arg, "resolve_expression")
206 else (F(arg) if isinstance(arg, str) else Value(arg))
207 for arg in expressions
208 ]
210 def as_sql(self, compiler, connection):
211 """
212 Responsible for returning a (sql, [params]) tuple to be included
213 in the current query.
215 Different backends can provide their own implementation, by
216 providing an `as_{vendor}` method and patching the Expression:
218 ```
219 def override_as_sql(self, compiler, connection):
220 # custom logic
221 return super().as_sql(compiler, connection)
222 setattr(Expression, 'as_' + connection.vendor, override_as_sql)
223 ```
225 Arguments:
226 * compiler: the query compiler responsible for generating the query.
227 Must have a compile method, returning a (sql, [params]) tuple.
228 Calling compiler(value) will return a quoted `value`.
230 * connection: the database connection used for the current query.
232 Return: (sql, params)
233 Where `sql` is a string containing ordered sql parameters to be
234 replaced with the elements of the list `params`.
235 """
236 raise NotImplementedError("Subclasses must implement as_sql()")
238 @cached_property
239 def contains_aggregate(self):
240 return any(
241 expr and expr.contains_aggregate for expr in self.get_source_expressions()
242 )
244 @cached_property
245 def contains_over_clause(self):
246 return any(
247 expr and expr.contains_over_clause for expr in self.get_source_expressions()
248 )
250 @cached_property
251 def contains_column_references(self):
252 return any(
253 expr and expr.contains_column_references
254 for expr in self.get_source_expressions()
255 )
257 def resolve_expression(
258 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
259 ):
260 """
261 Provide the chance to do any preprocessing or validation before being
262 added to the query.
264 Arguments:
265 * query: the backend query implementation
266 * allow_joins: boolean allowing or denying use of joins
267 in this query
268 * reuse: a set of reusable joins for multijoins
269 * summarize: a terminal aggregate clause
270 * for_save: whether this expression about to be used in a save or update
272 Return: an Expression to be added to the query.
273 """
274 c = self.copy()
275 c.is_summary = summarize
276 c.set_source_expressions(
277 [
278 expr.resolve_expression(query, allow_joins, reuse, summarize)
279 if expr
280 else None
281 for expr in c.get_source_expressions()
282 ]
283 )
284 return c
286 @property
287 def conditional(self):
288 return isinstance(self.output_field, fields.BooleanField)
290 @property
291 def field(self):
292 return self.output_field
294 @cached_property
295 def output_field(self):
296 """Return the output type of this expressions."""
297 output_field = self._resolve_output_field()
298 if output_field is None:
299 self._output_field_resolved_to_none = True
300 raise FieldError("Cannot resolve expression type, unknown output_field")
301 return output_field
303 @cached_property
304 def _output_field_or_none(self):
305 """
306 Return the output field of this expression, or None if
307 _resolve_output_field() didn't return an output type.
308 """
309 try:
310 return self.output_field
311 except FieldError:
312 if not self._output_field_resolved_to_none:
313 raise
315 def _resolve_output_field(self):
316 """
317 Attempt to infer the output type of the expression.
319 As a guess, if the output fields of all source fields match then simply
320 infer the same type here.
322 If a source's output field resolves to None, exclude it from this check.
323 If all sources are None, then an error is raised higher up the stack in
324 the output_field property.
325 """
326 # This guess is mostly a bad idea, but there is quite a lot of code
327 # (especially 3rd party Func subclasses) that depend on it, we'd need a
328 # deprecation path to fix it.
329 sources_iter = (
330 source for source in self.get_source_fields() if source is not None
331 )
332 for output_field in sources_iter:
333 for source in sources_iter:
334 if not isinstance(output_field, source.__class__):
335 raise FieldError(
336 f"Expression contains mixed types: {output_field.__class__.__name__}, {source.__class__.__name__}. You must "
337 "set output_field."
338 )
339 return output_field
341 @staticmethod
342 def _convert_value_noop(value, expression, connection):
343 return value
345 @cached_property
346 def convert_value(self):
347 """
348 Expressions provide their own converters because users have the option
349 of manually specifying the output_field which may be a different type
350 from the one the database returns.
351 """
352 field = self.output_field
353 internal_type = field.get_internal_type()
354 if internal_type == "FloatField":
355 return (
356 lambda value, expression, connection: None
357 if value is None
358 else float(value)
359 )
360 elif internal_type.endswith("IntegerField"):
361 return (
362 lambda value, expression, connection: None
363 if value is None
364 else int(value)
365 )
366 elif internal_type == "DecimalField":
367 return (
368 lambda value, expression, connection: None
369 if value is None
370 else Decimal(value)
371 )
372 return self._convert_value_noop
374 def get_lookup(self, lookup):
375 return self.output_field.get_lookup(lookup)
377 def get_transform(self, name):
378 return self.output_field.get_transform(name)
380 def relabeled_clone(self, change_map):
381 clone = self.copy()
382 clone.set_source_expressions(
383 [
384 e.relabeled_clone(change_map) if e is not None else None
385 for e in self.get_source_expressions()
386 ]
387 )
388 return clone
390 def replace_expressions(self, replacements):
391 if replacement := replacements.get(self):
392 return replacement
393 clone = self.copy()
394 source_expressions = clone.get_source_expressions()
395 clone.set_source_expressions(
396 [
397 expr.replace_expressions(replacements) if expr else None
398 for expr in source_expressions
399 ]
400 )
401 return clone
403 def get_refs(self):
404 refs = set()
405 for expr in self.get_source_expressions():
406 refs |= expr.get_refs()
407 return refs
409 def copy(self):
410 return copy.copy(self)
412 def prefix_references(self, prefix):
413 clone = self.copy()
414 clone.set_source_expressions(
415 [
416 F(f"{prefix}{expr.name}")
417 if isinstance(expr, F)
418 else expr.prefix_references(prefix)
419 for expr in self.get_source_expressions()
420 ]
421 )
422 return clone
424 def get_group_by_cols(self):
425 if not self.contains_aggregate:
426 return [self]
427 cols = []
428 for source in self.get_source_expressions():
429 cols.extend(source.get_group_by_cols())
430 return cols
432 def get_source_fields(self):
433 """Return the underlying field types used by this aggregate."""
434 return [e._output_field_or_none for e in self.get_source_expressions()]
436 def asc(self, **kwargs):
437 return OrderBy(self, **kwargs)
439 def desc(self, **kwargs):
440 return OrderBy(self, descending=True, **kwargs)
442 def reverse_ordering(self):
443 return self
445 def flatten(self):
446 """
447 Recursively yield this expression and all subexpressions, in
448 depth-first order.
449 """
450 yield self
451 for expr in self.get_source_expressions():
452 if expr:
453 if hasattr(expr, "flatten"):
454 yield from expr.flatten()
455 else:
456 yield expr
458 def select_format(self, compiler, sql, params):
459 """
460 Custom format for select clauses. For example, EXISTS expressions need
461 to be wrapped in CASE WHEN on Oracle.
462 """
463 if hasattr(self.output_field, "select_format"):
464 return self.output_field.select_format(compiler, sql, params)
465 return sql, params
468@deconstructible
469class Expression(BaseExpression, Combinable):
470 """An expression that can be combined with other expressions."""
472 @cached_property
473 def identity(self):
474 constructor_signature = inspect.signature(self.__init__)
475 args, kwargs = self._constructor_args
476 signature = constructor_signature.bind_partial(*args, **kwargs)
477 signature.apply_defaults()
478 arguments = signature.arguments.items()
479 identity = [self.__class__]
480 for arg, value in arguments:
481 if isinstance(value, fields.Field):
482 if value.name and value.model:
483 value = (value.model._meta.label, value.name)
484 else:
485 value = type(value)
486 else:
487 value = make_hashable(value)
488 identity.append((arg, value))
489 return tuple(identity)
491 def __eq__(self, other):
492 if not isinstance(other, Expression):
493 return NotImplemented
494 return other.identity == self.identity
496 def __hash__(self):
497 return hash(self.identity)
500# Type inference for CombinedExpression.output_field.
501# Missing items will result in FieldError, by design.
502#
503# The current approach for NULL is based on lowest common denominator behavior
504# i.e. if one of the supported databases is raising an error (rather than
505# return NULL) for `val <op> NULL`, then Plain raises FieldError.
507_connector_combinations = [
508 # Numeric operations - operands of same type.
509 {
510 connector: [
511 (fields.IntegerField, fields.IntegerField, fields.IntegerField),
512 (fields.FloatField, fields.FloatField, fields.FloatField),
513 (fields.DecimalField, fields.DecimalField, fields.DecimalField),
514 ]
515 for connector in (
516 Combinable.ADD,
517 Combinable.SUB,
518 Combinable.MUL,
519 # Behavior for DIV with integer arguments follows Postgres/SQLite,
520 # not MySQL/Oracle.
521 Combinable.DIV,
522 Combinable.MOD,
523 Combinable.POW,
524 )
525 },
526 # Numeric operations - operands of different type.
527 {
528 connector: [
529 (fields.IntegerField, fields.DecimalField, fields.DecimalField),
530 (fields.DecimalField, fields.IntegerField, fields.DecimalField),
531 (fields.IntegerField, fields.FloatField, fields.FloatField),
532 (fields.FloatField, fields.IntegerField, fields.FloatField),
533 ]
534 for connector in (
535 Combinable.ADD,
536 Combinable.SUB,
537 Combinable.MUL,
538 Combinable.DIV,
539 Combinable.MOD,
540 )
541 },
542 # Bitwise operators.
543 {
544 connector: [
545 (fields.IntegerField, fields.IntegerField, fields.IntegerField),
546 ]
547 for connector in (
548 Combinable.BITAND,
549 Combinable.BITOR,
550 Combinable.BITLEFTSHIFT,
551 Combinable.BITRIGHTSHIFT,
552 Combinable.BITXOR,
553 )
554 },
555 # Numeric with NULL.
556 {
557 connector: [
558 (field_type, NoneType, field_type),
559 (NoneType, field_type, field_type),
560 ]
561 for connector in (
562 Combinable.ADD,
563 Combinable.SUB,
564 Combinable.MUL,
565 Combinable.DIV,
566 Combinable.MOD,
567 Combinable.POW,
568 )
569 for field_type in (fields.IntegerField, fields.DecimalField, fields.FloatField)
570 },
571 # Date/DateTimeField/DurationField/TimeField.
572 {
573 Combinable.ADD: [
574 # Date/DateTimeField.
575 (fields.DateField, fields.DurationField, fields.DateTimeField),
576 (fields.DateTimeField, fields.DurationField, fields.DateTimeField),
577 (fields.DurationField, fields.DateField, fields.DateTimeField),
578 (fields.DurationField, fields.DateTimeField, fields.DateTimeField),
579 # DurationField.
580 (fields.DurationField, fields.DurationField, fields.DurationField),
581 # TimeField.
582 (fields.TimeField, fields.DurationField, fields.TimeField),
583 (fields.DurationField, fields.TimeField, fields.TimeField),
584 ],
585 },
586 {
587 Combinable.SUB: [
588 # Date/DateTimeField.
589 (fields.DateField, fields.DurationField, fields.DateTimeField),
590 (fields.DateTimeField, fields.DurationField, fields.DateTimeField),
591 (fields.DateField, fields.DateField, fields.DurationField),
592 (fields.DateField, fields.DateTimeField, fields.DurationField),
593 (fields.DateTimeField, fields.DateField, fields.DurationField),
594 (fields.DateTimeField, fields.DateTimeField, fields.DurationField),
595 # DurationField.
596 (fields.DurationField, fields.DurationField, fields.DurationField),
597 # TimeField.
598 (fields.TimeField, fields.DurationField, fields.TimeField),
599 (fields.TimeField, fields.TimeField, fields.DurationField),
600 ],
601 },
602]
604_connector_combinators = defaultdict(list)
607def register_combinable_fields(lhs, connector, rhs, result):
608 """
609 Register combinable types:
610 lhs <connector> rhs -> result
611 e.g.
612 register_combinable_fields(
613 IntegerField, Combinable.ADD, FloatField, FloatField
614 )
615 """
616 _connector_combinators[connector].append((lhs, rhs, result))
619for d in _connector_combinations:
620 for connector, field_types in d.items():
621 for lhs, rhs, result in field_types:
622 register_combinable_fields(lhs, connector, rhs, result)
625@functools.lru_cache(maxsize=128)
626def _resolve_combined_type(connector, lhs_type, rhs_type):
627 combinators = _connector_combinators.get(connector, ())
628 for combinator_lhs_type, combinator_rhs_type, combined_type in combinators:
629 if issubclass(lhs_type, combinator_lhs_type) and issubclass(
630 rhs_type, combinator_rhs_type
631 ):
632 return combined_type
635class CombinedExpression(SQLiteNumericMixin, Expression):
636 def __init__(self, lhs, connector, rhs, output_field=None):
637 super().__init__(output_field=output_field)
638 self.connector = connector
639 self.lhs = lhs
640 self.rhs = rhs
642 def __repr__(self):
643 return f"<{self.__class__.__name__}: {self}>"
645 def __str__(self):
646 return f"{self.lhs} {self.connector} {self.rhs}"
648 def get_source_expressions(self):
649 return [self.lhs, self.rhs]
651 def set_source_expressions(self, exprs):
652 self.lhs, self.rhs = exprs
654 def _resolve_output_field(self):
655 # We avoid using super() here for reasons given in
656 # Expression._resolve_output_field()
657 combined_type = _resolve_combined_type(
658 self.connector,
659 type(self.lhs._output_field_or_none),
660 type(self.rhs._output_field_or_none),
661 )
662 if combined_type is None:
663 raise FieldError(
664 f"Cannot infer type of {self.connector!r} expression involving these "
665 f"types: {self.lhs.output_field.__class__.__name__}, "
666 f"{self.rhs.output_field.__class__.__name__}. You must set "
667 f"output_field."
668 )
669 return combined_type()
671 def as_sql(self, compiler, connection):
672 expressions = []
673 expression_params = []
674 sql, params = compiler.compile(self.lhs)
675 expressions.append(sql)
676 expression_params.extend(params)
677 sql, params = compiler.compile(self.rhs)
678 expressions.append(sql)
679 expression_params.extend(params)
680 # order of precedence
681 expression_wrapper = "(%s)"
682 sql = connection.ops.combine_expression(self.connector, expressions)
683 return expression_wrapper % sql, expression_params
685 def resolve_expression(
686 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
687 ):
688 lhs = self.lhs.resolve_expression(
689 query, allow_joins, reuse, summarize, for_save
690 )
691 rhs = self.rhs.resolve_expression(
692 query, allow_joins, reuse, summarize, for_save
693 )
694 if not isinstance(self, DurationExpression | TemporalSubtraction):
695 try:
696 lhs_type = lhs.output_field.get_internal_type()
697 except (AttributeError, FieldError):
698 lhs_type = None
699 try:
700 rhs_type = rhs.output_field.get_internal_type()
701 except (AttributeError, FieldError):
702 rhs_type = None
703 if "DurationField" in {lhs_type, rhs_type} and lhs_type != rhs_type:
704 return DurationExpression(
705 self.lhs, self.connector, self.rhs
706 ).resolve_expression(
707 query,
708 allow_joins,
709 reuse,
710 summarize,
711 for_save,
712 )
713 datetime_fields = {"DateField", "DateTimeField", "TimeField"}
714 if (
715 self.connector == self.SUB
716 and lhs_type in datetime_fields
717 and lhs_type == rhs_type
718 ):
719 return TemporalSubtraction(self.lhs, self.rhs).resolve_expression(
720 query,
721 allow_joins,
722 reuse,
723 summarize,
724 for_save,
725 )
726 c = self.copy()
727 c.is_summary = summarize
728 c.lhs = lhs
729 c.rhs = rhs
730 return c
733class DurationExpression(CombinedExpression):
734 def compile(self, side, compiler, connection):
735 try:
736 output = side.output_field
737 except FieldError:
738 pass
739 else:
740 if output.get_internal_type() == "DurationField":
741 sql, params = compiler.compile(side)
742 return connection.ops.format_for_duration_arithmetic(sql), params
743 return compiler.compile(side)
745 def as_sql(self, compiler, connection):
746 if connection.features.has_native_duration_field:
747 return super().as_sql(compiler, connection)
748 connection.ops.check_expression_support(self)
749 expressions = []
750 expression_params = []
751 sql, params = self.compile(self.lhs, compiler, connection)
752 expressions.append(sql)
753 expression_params.extend(params)
754 sql, params = self.compile(self.rhs, compiler, connection)
755 expressions.append(sql)
756 expression_params.extend(params)
757 # order of precedence
758 expression_wrapper = "(%s)"
759 sql = connection.ops.combine_duration_expression(self.connector, expressions)
760 return expression_wrapper % sql, expression_params
762 def as_sqlite(self, compiler, connection, **extra_context):
763 sql, params = self.as_sql(compiler, connection, **extra_context)
764 if self.connector in {Combinable.MUL, Combinable.DIV}:
765 try:
766 lhs_type = self.lhs.output_field.get_internal_type()
767 rhs_type = self.rhs.output_field.get_internal_type()
768 except (AttributeError, FieldError):
769 pass
770 else:
771 allowed_fields = {
772 "DecimalField",
773 "DurationField",
774 "FloatField",
775 "IntegerField",
776 }
777 if lhs_type not in allowed_fields or rhs_type not in allowed_fields:
778 raise DatabaseError(
779 f"Invalid arguments for operator {self.connector}."
780 )
781 return sql, params
784class TemporalSubtraction(CombinedExpression):
785 output_field = fields.DurationField()
787 def __init__(self, lhs, rhs):
788 super().__init__(lhs, self.SUB, rhs)
790 def as_sql(self, compiler, connection):
791 connection.ops.check_expression_support(self)
792 lhs = compiler.compile(self.lhs)
793 rhs = compiler.compile(self.rhs)
794 return connection.ops.subtract_temporals(
795 self.lhs.output_field.get_internal_type(), lhs, rhs
796 )
799@deconstructible(path="plain.models.F")
800class F(Combinable):
801 """An object capable of resolving references to existing query objects."""
803 def __init__(self, name):
804 """
805 Arguments:
806 * name: the name of the field this expression references
807 """
808 self.name = name
810 def __repr__(self):
811 return f"{self.__class__.__name__}({self.name})"
813 def resolve_expression(
814 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
815 ):
816 return query.resolve_ref(self.name, allow_joins, reuse, summarize)
818 def replace_expressions(self, replacements):
819 return replacements.get(self, self)
821 def asc(self, **kwargs):
822 return OrderBy(self, **kwargs)
824 def desc(self, **kwargs):
825 return OrderBy(self, descending=True, **kwargs)
827 def __eq__(self, other):
828 return self.__class__ == other.__class__ and self.name == other.name
830 def __hash__(self):
831 return hash(self.name)
833 def copy(self):
834 return copy.copy(self)
837class ResolvedOuterRef(F):
838 """
839 An object that contains a reference to an outer query.
841 In this case, the reference to the outer query has been resolved because
842 the inner query has been used as a subquery.
843 """
845 contains_aggregate = False
846 contains_over_clause = False
848 def as_sql(self, *args, **kwargs):
849 raise ValueError(
850 "This queryset contains a reference to an outer query and may "
851 "only be used in a subquery."
852 )
854 def resolve_expression(self, *args, **kwargs):
855 col = super().resolve_expression(*args, **kwargs)
856 if col.contains_over_clause:
857 raise NotSupportedError(
858 f"Referencing outer query window expression is not supported: "
859 f"{self.name}."
860 )
861 # FIXME: Rename possibly_multivalued to multivalued and fix detection
862 # for non-multivalued JOINs (e.g. foreign key fields). This should take
863 # into account only many-to-many and one-to-many relationships.
864 col.possibly_multivalued = LOOKUP_SEP in self.name
865 return col
867 def relabeled_clone(self, relabels):
868 return self
870 def get_group_by_cols(self):
871 return []
874class OuterRef(F):
875 contains_aggregate = False
877 def resolve_expression(self, *args, **kwargs):
878 if isinstance(self.name, self.__class__):
879 return self.name
880 return ResolvedOuterRef(self.name)
882 def relabeled_clone(self, relabels):
883 return self
886@deconstructible(path="plain.models.Func")
887class Func(SQLiteNumericMixin, Expression):
888 """An SQL function call."""
890 function = None
891 template = "%(function)s(%(expressions)s)"
892 arg_joiner = ", "
893 arity = None # The number of arguments the function accepts.
895 def __init__(self, *expressions, output_field=None, **extra):
896 if self.arity is not None and len(expressions) != self.arity:
897 raise TypeError(
898 "'{}' takes exactly {} {} ({} given)".format(
899 self.__class__.__name__,
900 self.arity,
901 "argument" if self.arity == 1 else "arguments",
902 len(expressions),
903 )
904 )
905 super().__init__(output_field=output_field)
906 self.source_expressions = self._parse_expressions(*expressions)
907 self.extra = extra
909 def __repr__(self):
910 args = self.arg_joiner.join(str(arg) for arg in self.source_expressions)
911 extra = {**self.extra, **self._get_repr_options()}
912 if extra:
913 extra = ", ".join(
914 str(key) + "=" + str(val) for key, val in sorted(extra.items())
915 )
916 return f"{self.__class__.__name__}({args}, {extra})"
917 return f"{self.__class__.__name__}({args})"
919 def _get_repr_options(self):
920 """Return a dict of extra __init__() options to include in the repr."""
921 return {}
923 def get_source_expressions(self):
924 return self.source_expressions
926 def set_source_expressions(self, exprs):
927 self.source_expressions = exprs
929 def resolve_expression(
930 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
931 ):
932 c = self.copy()
933 c.is_summary = summarize
934 for pos, arg in enumerate(c.source_expressions):
935 c.source_expressions[pos] = arg.resolve_expression(
936 query, allow_joins, reuse, summarize, for_save
937 )
938 return c
940 def as_sql(
941 self,
942 compiler,
943 connection,
944 function=None,
945 template=None,
946 arg_joiner=None,
947 **extra_context,
948 ):
949 connection.ops.check_expression_support(self)
950 sql_parts = []
951 params = []
952 for arg in self.source_expressions:
953 try:
954 arg_sql, arg_params = compiler.compile(arg)
955 except EmptyResultSet:
956 empty_result_set_value = getattr(
957 arg, "empty_result_set_value", NotImplemented
958 )
959 if empty_result_set_value is NotImplemented:
960 raise
961 arg_sql, arg_params = compiler.compile(Value(empty_result_set_value))
962 except FullResultSet:
963 arg_sql, arg_params = compiler.compile(Value(True))
964 sql_parts.append(arg_sql)
965 params.extend(arg_params)
966 data = {**self.extra, **extra_context}
967 # Use the first supplied value in this order: the parameter to this
968 # method, a value supplied in __init__()'s **extra (the value in
969 # `data`), or the value defined on the class.
970 if function is not None:
971 data["function"] = function
972 else:
973 data.setdefault("function", self.function)
974 template = template or data.get("template", self.template)
975 arg_joiner = arg_joiner or data.get("arg_joiner", self.arg_joiner)
976 data["expressions"] = data["field"] = arg_joiner.join(sql_parts)
977 return template % data, params
979 def copy(self):
980 copy = super().copy()
981 copy.source_expressions = self.source_expressions[:]
982 copy.extra = self.extra.copy()
983 return copy
986@deconstructible(path="plain.models.Value")
987class Value(SQLiteNumericMixin, Expression):
988 """Represent a wrapped value as a node within an expression."""
990 # Provide a default value for `for_save` in order to allow unresolved
991 # instances to be compiled until a decision is taken in #25425.
992 for_save = False
994 def __init__(self, value, output_field=None):
995 """
996 Arguments:
997 * value: the value this expression represents. The value will be
998 added into the sql parameter list and properly quoted.
1000 * output_field: an instance of the model field type that this
1001 expression will return, such as IntegerField() or CharField().
1002 """
1003 super().__init__(output_field=output_field)
1004 self.value = value
1006 def __repr__(self):
1007 return f"{self.__class__.__name__}({self.value!r})"
1009 def as_sql(self, compiler, connection):
1010 connection.ops.check_expression_support(self)
1011 val = self.value
1012 output_field = self._output_field_or_none
1013 if output_field is not None:
1014 if self.for_save:
1015 val = output_field.get_db_prep_save(val, connection=connection)
1016 else:
1017 val = output_field.get_db_prep_value(val, connection=connection)
1018 if hasattr(output_field, "get_placeholder"):
1019 return output_field.get_placeholder(val, compiler, connection), [val]
1020 if val is None:
1021 # cx_Oracle does not always convert None to the appropriate
1022 # NULL type (like in case expressions using numbers), so we
1023 # use a literal SQL NULL
1024 return "NULL", []
1025 return "%s", [val]
1027 def resolve_expression(
1028 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
1029 ):
1030 c = super().resolve_expression(query, allow_joins, reuse, summarize, for_save)
1031 c.for_save = for_save
1032 return c
1034 def get_group_by_cols(self):
1035 return []
1037 def _resolve_output_field(self):
1038 if isinstance(self.value, str):
1039 return fields.CharField()
1040 if isinstance(self.value, bool):
1041 return fields.BooleanField()
1042 if isinstance(self.value, int):
1043 return fields.IntegerField()
1044 if isinstance(self.value, float):
1045 return fields.FloatField()
1046 if isinstance(self.value, datetime.datetime):
1047 return fields.DateTimeField()
1048 if isinstance(self.value, datetime.date):
1049 return fields.DateField()
1050 if isinstance(self.value, datetime.time):
1051 return fields.TimeField()
1052 if isinstance(self.value, datetime.timedelta):
1053 return fields.DurationField()
1054 if isinstance(self.value, Decimal):
1055 return fields.DecimalField()
1056 if isinstance(self.value, bytes):
1057 return fields.BinaryField()
1058 if isinstance(self.value, UUID):
1059 return fields.UUIDField()
1061 @property
1062 def empty_result_set_value(self):
1063 return self.value
1066class RawSQL(Expression):
1067 def __init__(self, sql, params, output_field=None):
1068 if output_field is None:
1069 output_field = fields.Field()
1070 self.sql, self.params = sql, params
1071 super().__init__(output_field=output_field)
1073 def __repr__(self):
1074 return f"{self.__class__.__name__}({self.sql}, {self.params})"
1076 def as_sql(self, compiler, connection):
1077 return f"({self.sql})", self.params
1079 def get_group_by_cols(self):
1080 return [self]
1082 def resolve_expression(
1083 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
1084 ):
1085 # Resolve parents fields used in raw SQL.
1086 if query.model:
1087 for parent in query.model._meta.get_parent_list():
1088 for parent_field in parent._meta.local_fields:
1089 _, column_name = parent_field.get_attname_column()
1090 if column_name.lower() in self.sql.lower():
1091 query.resolve_ref(
1092 parent_field.name, allow_joins, reuse, summarize
1093 )
1094 break
1095 return super().resolve_expression(
1096 query, allow_joins, reuse, summarize, for_save
1097 )
1100class Star(Expression):
1101 def __repr__(self):
1102 return "'*'"
1104 def as_sql(self, compiler, connection):
1105 return "*", []
1108class Col(Expression):
1109 contains_column_references = True
1110 possibly_multivalued = False
1112 def __init__(self, alias, target, output_field=None):
1113 if output_field is None:
1114 output_field = target
1115 super().__init__(output_field=output_field)
1116 self.alias, self.target = alias, target
1118 def __repr__(self):
1119 alias, target = self.alias, self.target
1120 identifiers = (alias, str(target)) if alias else (str(target),)
1121 return "{}({})".format(self.__class__.__name__, ", ".join(identifiers))
1123 def as_sql(self, compiler, connection):
1124 alias, column = self.alias, self.target.column
1125 identifiers = (alias, column) if alias else (column,)
1126 sql = ".".join(map(compiler.quote_name_unless_alias, identifiers))
1127 return sql, []
1129 def relabeled_clone(self, relabels):
1130 if self.alias is None:
1131 return self
1132 return self.__class__(
1133 relabels.get(self.alias, self.alias), self.target, self.output_field
1134 )
1136 def get_group_by_cols(self):
1137 return [self]
1139 def get_db_converters(self, connection):
1140 if self.target == self.output_field:
1141 return self.output_field.get_db_converters(connection)
1142 return self.output_field.get_db_converters(
1143 connection
1144 ) + self.target.get_db_converters(connection)
1147class Ref(Expression):
1148 """
1149 Reference to column alias of the query. For example, Ref('sum_cost') in
1150 qs.annotate(sum_cost=Sum('cost')) query.
1151 """
1153 def __init__(self, refs, source):
1154 super().__init__()
1155 self.refs, self.source = refs, source
1157 def __repr__(self):
1158 return f"{self.__class__.__name__}({self.refs}, {self.source})"
1160 def get_source_expressions(self):
1161 return [self.source]
1163 def set_source_expressions(self, exprs):
1164 (self.source,) = exprs
1166 def resolve_expression(
1167 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
1168 ):
1169 # The sub-expression `source` has already been resolved, as this is
1170 # just a reference to the name of `source`.
1171 return self
1173 def get_refs(self):
1174 return {self.refs}
1176 def relabeled_clone(self, relabels):
1177 return self
1179 def as_sql(self, compiler, connection):
1180 return connection.ops.quote_name(self.refs), []
1182 def get_group_by_cols(self):
1183 return [self]
1186class ExpressionList(Func):
1187 """
1188 An expression containing multiple expressions. Can be used to provide a
1189 list of expressions as an argument to another expression, like a partition
1190 clause.
1191 """
1193 template = "%(expressions)s"
1195 def __init__(self, *expressions, **extra):
1196 if not expressions:
1197 raise ValueError(
1198 f"{self.__class__.__name__} requires at least one expression."
1199 )
1200 super().__init__(*expressions, **extra)
1202 def __str__(self):
1203 return self.arg_joiner.join(str(arg) for arg in self.source_expressions)
1205 def as_sqlite(self, compiler, connection, **extra_context):
1206 # Casting to numeric is unnecessary.
1207 return self.as_sql(compiler, connection, **extra_context)
1210class OrderByList(Func):
1211 template = "ORDER BY %(expressions)s"
1213 def __init__(self, *expressions, **extra):
1214 expressions = (
1215 (
1216 OrderBy(F(expr[1:]), descending=True)
1217 if isinstance(expr, str) and expr[0] == "-"
1218 else expr
1219 )
1220 for expr in expressions
1221 )
1222 super().__init__(*expressions, **extra)
1224 def as_sql(self, *args, **kwargs):
1225 if not self.source_expressions:
1226 return "", ()
1227 return super().as_sql(*args, **kwargs)
1229 def get_group_by_cols(self):
1230 group_by_cols = []
1231 for order_by in self.get_source_expressions():
1232 group_by_cols.extend(order_by.get_group_by_cols())
1233 return group_by_cols
1236@deconstructible(path="plain.models.ExpressionWrapper")
1237class ExpressionWrapper(SQLiteNumericMixin, Expression):
1238 """
1239 An expression that can wrap another expression so that it can provide
1240 extra context to the inner expression, such as the output_field.
1241 """
1243 def __init__(self, expression, output_field):
1244 super().__init__(output_field=output_field)
1245 self.expression = expression
1247 def set_source_expressions(self, exprs):
1248 self.expression = exprs[0]
1250 def get_source_expressions(self):
1251 return [self.expression]
1253 def get_group_by_cols(self):
1254 if isinstance(self.expression, Expression):
1255 expression = self.expression.copy()
1256 expression.output_field = self.output_field
1257 return expression.get_group_by_cols()
1258 # For non-expressions e.g. an SQL WHERE clause, the entire
1259 # `expression` must be included in the GROUP BY clause.
1260 return super().get_group_by_cols()
1262 def as_sql(self, compiler, connection):
1263 return compiler.compile(self.expression)
1265 def __repr__(self):
1266 return f"{self.__class__.__name__}({self.expression})"
1269class NegatedExpression(ExpressionWrapper):
1270 """The logical negation of a conditional expression."""
1272 def __init__(self, expression):
1273 super().__init__(expression, output_field=fields.BooleanField())
1275 def __invert__(self):
1276 return self.expression.copy()
1278 def as_sql(self, compiler, connection):
1279 try:
1280 sql, params = super().as_sql(compiler, connection)
1281 except EmptyResultSet:
1282 features = compiler.connection.features
1283 if not features.supports_boolean_expr_in_select_clause:
1284 return "1=1", ()
1285 return compiler.compile(Value(True))
1286 ops = compiler.connection.ops
1287 # Some database backends (e.g. Oracle) don't allow EXISTS() and filters
1288 # to be compared to another expression unless they're wrapped in a CASE
1289 # WHEN.
1290 if not ops.conditional_expression_supported_in_where_clause(self.expression):
1291 return f"CASE WHEN {sql} = 0 THEN 1 ELSE 0 END", params
1292 return f"NOT {sql}", params
1294 def resolve_expression(
1295 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
1296 ):
1297 resolved = super().resolve_expression(
1298 query, allow_joins, reuse, summarize, for_save
1299 )
1300 if not getattr(resolved.expression, "conditional", False):
1301 raise TypeError("Cannot negate non-conditional expressions.")
1302 return resolved
1304 def select_format(self, compiler, sql, params):
1305 # Wrap boolean expressions with a CASE WHEN expression if a database
1306 # backend (e.g. Oracle) doesn't support boolean expression in SELECT or
1307 # GROUP BY list.
1308 expression_supported_in_where_clause = (
1309 compiler.connection.ops.conditional_expression_supported_in_where_clause
1310 )
1311 if (
1312 not compiler.connection.features.supports_boolean_expr_in_select_clause
1313 # Avoid double wrapping.
1314 and expression_supported_in_where_clause(self.expression)
1315 ):
1316 sql = f"CASE WHEN {sql} THEN 1 ELSE 0 END"
1317 return sql, params
1320@deconstructible(path="plain.models.When")
1321class When(Expression):
1322 template = "WHEN %(condition)s THEN %(result)s"
1323 # This isn't a complete conditional expression, must be used in Case().
1324 conditional = False
1326 def __init__(self, condition=None, then=None, **lookups):
1327 if lookups:
1328 if condition is None:
1329 condition, lookups = Q(**lookups), None
1330 elif getattr(condition, "conditional", False):
1331 condition, lookups = Q(condition, **lookups), None
1332 if condition is None or not getattr(condition, "conditional", False) or lookups:
1333 raise TypeError(
1334 "When() supports a Q object, a boolean expression, or lookups "
1335 "as a condition."
1336 )
1337 if isinstance(condition, Q) and not condition:
1338 raise ValueError("An empty Q() can't be used as a When() condition.")
1339 super().__init__(output_field=None)
1340 self.condition = condition
1341 self.result = self._parse_expressions(then)[0]
1343 def __str__(self):
1344 return f"WHEN {self.condition!r} THEN {self.result!r}"
1346 def __repr__(self):
1347 return f"<{self.__class__.__name__}: {self}>"
1349 def get_source_expressions(self):
1350 return [self.condition, self.result]
1352 def set_source_expressions(self, exprs):
1353 self.condition, self.result = exprs
1355 def get_source_fields(self):
1356 # We're only interested in the fields of the result expressions.
1357 return [self.result._output_field_or_none]
1359 def resolve_expression(
1360 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
1361 ):
1362 c = self.copy()
1363 c.is_summary = summarize
1364 if hasattr(c.condition, "resolve_expression"):
1365 c.condition = c.condition.resolve_expression(
1366 query, allow_joins, reuse, summarize, False
1367 )
1368 c.result = c.result.resolve_expression(
1369 query, allow_joins, reuse, summarize, for_save
1370 )
1371 return c
1373 def as_sql(self, compiler, connection, template=None, **extra_context):
1374 connection.ops.check_expression_support(self)
1375 template_params = extra_context
1376 sql_params = []
1377 condition_sql, condition_params = compiler.compile(self.condition)
1378 template_params["condition"] = condition_sql
1379 result_sql, result_params = compiler.compile(self.result)
1380 template_params["result"] = result_sql
1381 template = template or self.template
1382 return template % template_params, (
1383 *sql_params,
1384 *condition_params,
1385 *result_params,
1386 )
1388 def get_group_by_cols(self):
1389 # This is not a complete expression and cannot be used in GROUP BY.
1390 cols = []
1391 for source in self.get_source_expressions():
1392 cols.extend(source.get_group_by_cols())
1393 return cols
1396@deconstructible(path="plain.models.Case")
1397class Case(SQLiteNumericMixin, Expression):
1398 """
1399 An SQL searched CASE expression:
1401 CASE
1402 WHEN n > 0
1403 THEN 'positive'
1404 WHEN n < 0
1405 THEN 'negative'
1406 ELSE 'zero'
1407 END
1408 """
1410 template = "CASE %(cases)s ELSE %(default)s END"
1411 case_joiner = " "
1413 def __init__(self, *cases, default=None, output_field=None, **extra):
1414 if not all(isinstance(case, When) for case in cases):
1415 raise TypeError("Positional arguments must all be When objects.")
1416 super().__init__(output_field)
1417 self.cases = list(cases)
1418 self.default = self._parse_expressions(default)[0]
1419 self.extra = extra
1421 def __str__(self):
1422 return "CASE {}, ELSE {!r}".format(
1423 ", ".join(str(c) for c in self.cases),
1424 self.default,
1425 )
1427 def __repr__(self):
1428 return f"<{self.__class__.__name__}: {self}>"
1430 def get_source_expressions(self):
1431 return self.cases + [self.default]
1433 def set_source_expressions(self, exprs):
1434 *self.cases, self.default = exprs
1436 def resolve_expression(
1437 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
1438 ):
1439 c = self.copy()
1440 c.is_summary = summarize
1441 for pos, case in enumerate(c.cases):
1442 c.cases[pos] = case.resolve_expression(
1443 query, allow_joins, reuse, summarize, for_save
1444 )
1445 c.default = c.default.resolve_expression(
1446 query, allow_joins, reuse, summarize, for_save
1447 )
1448 return c
1450 def copy(self):
1451 c = super().copy()
1452 c.cases = c.cases[:]
1453 return c
1455 def as_sql(
1456 self, compiler, connection, template=None, case_joiner=None, **extra_context
1457 ):
1458 connection.ops.check_expression_support(self)
1459 if not self.cases:
1460 return compiler.compile(self.default)
1461 template_params = {**self.extra, **extra_context}
1462 case_parts = []
1463 sql_params = []
1464 default_sql, default_params = compiler.compile(self.default)
1465 for case in self.cases:
1466 try:
1467 case_sql, case_params = compiler.compile(case)
1468 except EmptyResultSet:
1469 continue
1470 except FullResultSet:
1471 default_sql, default_params = compiler.compile(case.result)
1472 break
1473 case_parts.append(case_sql)
1474 sql_params.extend(case_params)
1475 if not case_parts:
1476 return default_sql, default_params
1477 case_joiner = case_joiner or self.case_joiner
1478 template_params["cases"] = case_joiner.join(case_parts)
1479 template_params["default"] = default_sql
1480 sql_params.extend(default_params)
1481 template = template or template_params.get("template", self.template)
1482 sql = template % template_params
1483 if self._output_field_or_none is not None:
1484 sql = connection.ops.unification_cast_sql(self.output_field) % sql
1485 return sql, sql_params
1487 def get_group_by_cols(self):
1488 if not self.cases:
1489 return self.default.get_group_by_cols()
1490 return super().get_group_by_cols()
1493class Subquery(BaseExpression, Combinable):
1494 """
1495 An explicit subquery. It may contain OuterRef() references to the outer
1496 query which will be resolved when it is applied to that query.
1497 """
1499 template = "(%(subquery)s)"
1500 contains_aggregate = False
1501 empty_result_set_value = None
1503 def __init__(self, queryset, output_field=None, **extra):
1504 # Allow the usage of both QuerySet and sql.Query objects.
1505 self.query = getattr(queryset, "query", queryset).clone()
1506 self.query.subquery = True
1507 self.extra = extra
1508 super().__init__(output_field)
1510 def get_source_expressions(self):
1511 return [self.query]
1513 def set_source_expressions(self, exprs):
1514 self.query = exprs[0]
1516 def _resolve_output_field(self):
1517 return self.query.output_field
1519 def copy(self):
1520 clone = super().copy()
1521 clone.query = clone.query.clone()
1522 return clone
1524 @property
1525 def external_aliases(self):
1526 return self.query.external_aliases
1528 def get_external_cols(self):
1529 return self.query.get_external_cols()
1531 def as_sql(self, compiler, connection, template=None, **extra_context):
1532 connection.ops.check_expression_support(self)
1533 template_params = {**self.extra, **extra_context}
1534 subquery_sql, sql_params = self.query.as_sql(compiler, connection)
1535 template_params["subquery"] = subquery_sql[1:-1]
1537 template = template or template_params.get("template", self.template)
1538 sql = template % template_params
1539 return sql, sql_params
1541 def get_group_by_cols(self):
1542 return self.query.get_group_by_cols(wrapper=self)
1545class Exists(Subquery):
1546 template = "EXISTS(%(subquery)s)"
1547 output_field = fields.BooleanField()
1548 empty_result_set_value = False
1550 def __init__(self, queryset, **kwargs):
1551 super().__init__(queryset, **kwargs)
1552 self.query = self.query.exists()
1554 def select_format(self, compiler, sql, params):
1555 # Wrap EXISTS() with a CASE WHEN expression if a database backend
1556 # (e.g. Oracle) doesn't support boolean expression in SELECT or GROUP
1557 # BY list.
1558 if not compiler.connection.features.supports_boolean_expr_in_select_clause:
1559 sql = f"CASE WHEN {sql} THEN 1 ELSE 0 END"
1560 return sql, params
1563@deconstructible(path="plain.models.OrderBy")
1564class OrderBy(Expression):
1565 template = "%(expression)s %(ordering)s"
1566 conditional = False
1568 def __init__(self, expression, descending=False, nulls_first=None, nulls_last=None):
1569 if nulls_first and nulls_last:
1570 raise ValueError("nulls_first and nulls_last are mutually exclusive")
1571 if nulls_first is False or nulls_last is False:
1572 raise ValueError("nulls_first and nulls_last values must be True or None.")
1573 self.nulls_first = nulls_first
1574 self.nulls_last = nulls_last
1575 self.descending = descending
1576 if not hasattr(expression, "resolve_expression"):
1577 raise ValueError("expression must be an expression type")
1578 self.expression = expression
1580 def __repr__(self):
1581 return f"{self.__class__.__name__}({self.expression}, descending={self.descending})"
1583 def set_source_expressions(self, exprs):
1584 self.expression = exprs[0]
1586 def get_source_expressions(self):
1587 return [self.expression]
1589 def as_sql(self, compiler, connection, template=None, **extra_context):
1590 template = template or self.template
1591 if connection.features.supports_order_by_nulls_modifier:
1592 if self.nulls_last:
1593 template = f"{template} NULLS LAST"
1594 elif self.nulls_first:
1595 template = f"{template} NULLS FIRST"
1596 else:
1597 if self.nulls_last and not (
1598 self.descending and connection.features.order_by_nulls_first
1599 ):
1600 template = f"%(expression)s IS NULL, {template}"
1601 elif self.nulls_first and not (
1602 not self.descending and connection.features.order_by_nulls_first
1603 ):
1604 template = f"%(expression)s IS NOT NULL, {template}"
1605 connection.ops.check_expression_support(self)
1606 expression_sql, params = compiler.compile(self.expression)
1607 placeholders = {
1608 "expression": expression_sql,
1609 "ordering": "DESC" if self.descending else "ASC",
1610 **extra_context,
1611 }
1612 params *= template.count("%(expression)s")
1613 return (template % placeholders).rstrip(), params
1615 def get_group_by_cols(self):
1616 cols = []
1617 for source in self.get_source_expressions():
1618 cols.extend(source.get_group_by_cols())
1619 return cols
1621 def reverse_ordering(self):
1622 self.descending = not self.descending
1623 if self.nulls_first:
1624 self.nulls_last = True
1625 self.nulls_first = None
1626 elif self.nulls_last:
1627 self.nulls_first = True
1628 self.nulls_last = None
1629 return self
1631 def asc(self):
1632 self.descending = False
1634 def desc(self):
1635 self.descending = True
1638class Window(SQLiteNumericMixin, Expression):
1639 template = "%(expression)s OVER (%(window)s)"
1640 # Although the main expression may either be an aggregate or an
1641 # expression with an aggregate function, the GROUP BY that will
1642 # be introduced in the query as a result is not desired.
1643 contains_aggregate = False
1644 contains_over_clause = True
1646 def __init__(
1647 self,
1648 expression,
1649 partition_by=None,
1650 order_by=None,
1651 frame=None,
1652 output_field=None,
1653 ):
1654 self.partition_by = partition_by
1655 self.order_by = order_by
1656 self.frame = frame
1658 if not getattr(expression, "window_compatible", False):
1659 raise ValueError(
1660 f"Expression '{expression.__class__.__name__}' isn't compatible with OVER clauses."
1661 )
1663 if self.partition_by is not None:
1664 if not isinstance(self.partition_by, tuple | list):
1665 self.partition_by = (self.partition_by,)
1666 self.partition_by = ExpressionList(*self.partition_by)
1668 if self.order_by is not None:
1669 if isinstance(self.order_by, list | tuple):
1670 self.order_by = OrderByList(*self.order_by)
1671 elif isinstance(self.order_by, BaseExpression | str):
1672 self.order_by = OrderByList(self.order_by)
1673 else:
1674 raise ValueError(
1675 "Window.order_by must be either a string reference to a "
1676 "field, an expression, or a list or tuple of them."
1677 )
1678 super().__init__(output_field=output_field)
1679 self.source_expression = self._parse_expressions(expression)[0]
1681 def _resolve_output_field(self):
1682 return self.source_expression.output_field
1684 def get_source_expressions(self):
1685 return [self.source_expression, self.partition_by, self.order_by, self.frame]
1687 def set_source_expressions(self, exprs):
1688 self.source_expression, self.partition_by, self.order_by, self.frame = exprs
1690 def as_sql(self, compiler, connection, template=None):
1691 connection.ops.check_expression_support(self)
1692 if not connection.features.supports_over_clause:
1693 raise NotSupportedError("This backend does not support window expressions.")
1694 expr_sql, params = compiler.compile(self.source_expression)
1695 window_sql, window_params = [], ()
1697 if self.partition_by is not None:
1698 sql_expr, sql_params = self.partition_by.as_sql(
1699 compiler=compiler,
1700 connection=connection,
1701 template="PARTITION BY %(expressions)s",
1702 )
1703 window_sql.append(sql_expr)
1704 window_params += tuple(sql_params)
1706 if self.order_by is not None:
1707 order_sql, order_params = compiler.compile(self.order_by)
1708 window_sql.append(order_sql)
1709 window_params += tuple(order_params)
1711 if self.frame:
1712 frame_sql, frame_params = compiler.compile(self.frame)
1713 window_sql.append(frame_sql)
1714 window_params += tuple(frame_params)
1716 template = template or self.template
1718 return (
1719 template % {"expression": expr_sql, "window": " ".join(window_sql).strip()},
1720 (*params, *window_params),
1721 )
1723 def as_sqlite(self, compiler, connection):
1724 if isinstance(self.output_field, fields.DecimalField):
1725 # Casting to numeric must be outside of the window expression.
1726 copy = self.copy()
1727 source_expressions = copy.get_source_expressions()
1728 source_expressions[0].output_field = fields.FloatField()
1729 copy.set_source_expressions(source_expressions)
1730 return super(Window, copy).as_sqlite(compiler, connection)
1731 return self.as_sql(compiler, connection)
1733 def __str__(self):
1734 return "{} OVER ({}{}{})".format(
1735 str(self.source_expression),
1736 "PARTITION BY " + str(self.partition_by) if self.partition_by else "",
1737 str(self.order_by or ""),
1738 str(self.frame or ""),
1739 )
1741 def __repr__(self):
1742 return f"<{self.__class__.__name__}: {self}>"
1744 def get_group_by_cols(self):
1745 group_by_cols = []
1746 if self.partition_by:
1747 group_by_cols.extend(self.partition_by.get_group_by_cols())
1748 if self.order_by is not None:
1749 group_by_cols.extend(self.order_by.get_group_by_cols())
1750 return group_by_cols
1753class WindowFrame(Expression):
1754 """
1755 Model the frame clause in window expressions. There are two types of frame
1756 clauses which are subclasses, however, all processing and validation (by no
1757 means intended to be complete) is done here. Thus, providing an end for a
1758 frame is optional (the default is UNBOUNDED FOLLOWING, which is the last
1759 row in the frame).
1760 """
1762 template = "%(frame_type)s BETWEEN %(start)s AND %(end)s"
1764 def __init__(self, start=None, end=None):
1765 self.start = Value(start)
1766 self.end = Value(end)
1768 def set_source_expressions(self, exprs):
1769 self.start, self.end = exprs
1771 def get_source_expressions(self):
1772 return [self.start, self.end]
1774 def as_sql(self, compiler, connection):
1775 connection.ops.check_expression_support(self)
1776 start, end = self.window_frame_start_end(
1777 connection, self.start.value, self.end.value
1778 )
1779 return (
1780 self.template
1781 % {
1782 "frame_type": self.frame_type,
1783 "start": start,
1784 "end": end,
1785 },
1786 [],
1787 )
1789 def __repr__(self):
1790 return f"<{self.__class__.__name__}: {self}>"
1792 def get_group_by_cols(self):
1793 return []
1795 def __str__(self):
1796 if self.start.value is not None and self.start.value < 0:
1797 start = "%d %s" % (abs(self.start.value), connection.ops.PRECEDING)
1798 elif self.start.value is not None and self.start.value == 0:
1799 start = connection.ops.CURRENT_ROW
1800 else:
1801 start = connection.ops.UNBOUNDED_PRECEDING
1803 if self.end.value is not None and self.end.value > 0:
1804 end = "%d %s" % (self.end.value, connection.ops.FOLLOWING)
1805 elif self.end.value is not None and self.end.value == 0:
1806 end = connection.ops.CURRENT_ROW
1807 else:
1808 end = connection.ops.UNBOUNDED_FOLLOWING
1809 return self.template % {
1810 "frame_type": self.frame_type,
1811 "start": start,
1812 "end": end,
1813 }
1815 def window_frame_start_end(self, connection, start, end):
1816 raise NotImplementedError("Subclasses must implement window_frame_start_end().")
1819class RowRange(WindowFrame):
1820 frame_type = "ROWS"
1822 def window_frame_start_end(self, connection, start, end):
1823 return connection.ops.window_frame_rows_start_end(start, end)
1826class ValueRange(WindowFrame):
1827 frame_type = "RANGE"
1829 def window_frame_start_end(self, connection, start, end):
1830 return connection.ops.window_frame_range_start_end(start, end)