Coverage for /Users/davegaeddert/Development/dropseed/plain/plain-models/plain/models/aggregates.py: 42%
127 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-10-16 22:04 -0500
« prev ^ index » next coverage.py v7.6.1, created at 2024-10-16 22:04 -0500
1"""
2Classes to represent the definitions of aggregate functions.
3"""
4from plain.exceptions import FieldError, FullResultSet
5from plain.models.expressions import Case, Func, Star, Value, When
6from plain.models.fields import IntegerField
7from plain.models.functions.comparison import Coalesce
8from plain.models.functions.mixins import (
9 FixDurationInputMixin,
10 NumericOutputFieldMixin,
11)
13__all__ = [
14 "Aggregate",
15 "Avg",
16 "Count",
17 "Max",
18 "Min",
19 "StdDev",
20 "Sum",
21 "Variance",
22]
25class Aggregate(Func):
26 template = "%(function)s(%(distinct)s%(expressions)s)"
27 contains_aggregate = True
28 name = None
29 filter_template = "%s FILTER (WHERE %%(filter)s)"
30 window_compatible = True
31 allow_distinct = False
32 empty_result_set_value = None
34 def __init__(
35 self, *expressions, distinct=False, filter=None, default=None, **extra
36 ):
37 if distinct and not self.allow_distinct:
38 raise TypeError("%s does not allow distinct." % self.__class__.__name__)
39 if default is not None and self.empty_result_set_value is not None:
40 raise TypeError(f"{self.__class__.__name__} does not allow default.")
41 self.distinct = distinct
42 self.filter = filter
43 self.default = default
44 super().__init__(*expressions, **extra)
46 def get_source_fields(self):
47 # Don't return the filter expression since it's not a source field.
48 return [e._output_field_or_none for e in super().get_source_expressions()]
50 def get_source_expressions(self):
51 source_expressions = super().get_source_expressions()
52 if self.filter:
53 return source_expressions + [self.filter]
54 return source_expressions
56 def set_source_expressions(self, exprs):
57 self.filter = self.filter and exprs.pop()
58 return super().set_source_expressions(exprs)
60 def resolve_expression(
61 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
62 ):
63 # Aggregates are not allowed in UPDATE queries, so ignore for_save
64 c = super().resolve_expression(query, allow_joins, reuse, summarize)
65 c.filter = c.filter and c.filter.resolve_expression(
66 query, allow_joins, reuse, summarize
67 )
68 if not summarize:
69 # Call Aggregate.get_source_expressions() to avoid
70 # returning self.filter and including that in this loop.
71 expressions = super(Aggregate, c).get_source_expressions()
72 for index, expr in enumerate(expressions):
73 if expr.contains_aggregate:
74 before_resolved = self.get_source_expressions()[index]
75 name = (
76 before_resolved.name
77 if hasattr(before_resolved, "name")
78 else repr(before_resolved)
79 )
80 raise FieldError(
81 f"Cannot compute {c.name}('{name}'): '{name}' is an aggregate"
82 )
83 if (default := c.default) is None:
84 return c
85 if hasattr(default, "resolve_expression"):
86 default = default.resolve_expression(query, allow_joins, reuse, summarize)
87 if default._output_field_or_none is None:
88 default.output_field = c._output_field_or_none
89 else:
90 default = Value(default, c._output_field_or_none)
91 c.default = None # Reset the default argument before wrapping.
92 coalesce = Coalesce(c, default, output_field=c._output_field_or_none)
93 coalesce.is_summary = c.is_summary
94 return coalesce
96 @property
97 def default_alias(self):
98 expressions = self.get_source_expressions()
99 if len(expressions) == 1 and hasattr(expressions[0], "name"):
100 return f"{expressions[0].name}__{self.name.lower()}"
101 raise TypeError("Complex expressions require an alias")
103 def get_group_by_cols(self):
104 return []
106 def as_sql(self, compiler, connection, **extra_context):
107 extra_context["distinct"] = "DISTINCT " if self.distinct else ""
108 if self.filter:
109 if connection.features.supports_aggregate_filter_clause:
110 try:
111 filter_sql, filter_params = self.filter.as_sql(compiler, connection)
112 except FullResultSet:
113 pass
114 else:
115 template = self.filter_template % extra_context.get(
116 "template", self.template
117 )
118 sql, params = super().as_sql(
119 compiler,
120 connection,
121 template=template,
122 filter=filter_sql,
123 **extra_context,
124 )
125 return sql, (*params, *filter_params)
126 else:
127 copy = self.copy()
128 copy.filter = None
129 source_expressions = copy.get_source_expressions()
130 condition = When(self.filter, then=source_expressions[0])
131 copy.set_source_expressions([Case(condition)] + source_expressions[1:])
132 return super(Aggregate, copy).as_sql(
133 compiler, connection, **extra_context
134 )
135 return super().as_sql(compiler, connection, **extra_context)
137 def _get_repr_options(self):
138 options = super()._get_repr_options()
139 if self.distinct:
140 options["distinct"] = self.distinct
141 if self.filter:
142 options["filter"] = self.filter
143 return options
146class Avg(FixDurationInputMixin, NumericOutputFieldMixin, Aggregate):
147 function = "AVG"
148 name = "Avg"
149 allow_distinct = True
152class Count(Aggregate):
153 function = "COUNT"
154 name = "Count"
155 output_field = IntegerField()
156 allow_distinct = True
157 empty_result_set_value = 0
159 def __init__(self, expression, filter=None, **extra):
160 if expression == "*":
161 expression = Star()
162 if isinstance(expression, Star) and filter is not None:
163 raise ValueError("Star cannot be used with filter. Please specify a field.")
164 super().__init__(expression, filter=filter, **extra)
167class Max(Aggregate):
168 function = "MAX"
169 name = "Max"
172class Min(Aggregate):
173 function = "MIN"
174 name = "Min"
177class StdDev(NumericOutputFieldMixin, Aggregate):
178 name = "StdDev"
180 def __init__(self, expression, sample=False, **extra):
181 self.function = "STDDEV_SAMP" if sample else "STDDEV_POP"
182 super().__init__(expression, **extra)
184 def _get_repr_options(self):
185 return {**super()._get_repr_options(), "sample": self.function == "STDDEV_SAMP"}
188class Sum(FixDurationInputMixin, Aggregate):
189 function = "SUM"
190 name = "Sum"
191 allow_distinct = True
194class Variance(NumericOutputFieldMixin, Aggregate):
195 name = "Variance"
197 def __init__(self, expression, sample=False, **extra):
198 self.function = "VAR_SAMP" if sample else "VAR_POP"
199 super().__init__(expression, **extra)
201 def _get_repr_options(self):
202 return {**super()._get_repr_options(), "sample": self.function == "VAR_SAMP"}