Coverage for /Users/davegaeddert/Development/dropseed/plain/plain-models/plain/models/expressions.py: 33%

983 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-10-16 22:03 -0500

1import copy 

2import datetime 

3import functools 

4import inspect 

5from collections import defaultdict 

6from decimal import Decimal 

7from types import NoneType 

8from uuid import UUID 

9 

10from plain.exceptions import EmptyResultSet, FieldError, FullResultSet 

11from plain.models import fields 

12from plain.models.constants import LOOKUP_SEP 

13from plain.models.db import DatabaseError, NotSupportedError, connection 

14from plain.models.query_utils import Q 

15from plain.utils.deconstruct import deconstructible 

16from plain.utils.functional import cached_property 

17from plain.utils.hashable import make_hashable 

18 

19 

20class SQLiteNumericMixin: 

21 """ 

22 Some expressions with output_field=DecimalField() must be cast to 

23 numeric to be properly filtered. 

24 """ 

25 

26 def as_sqlite(self, compiler, connection, **extra_context): 

27 sql, params = self.as_sql(compiler, connection, **extra_context) 

28 try: 

29 if self.output_field.get_internal_type() == "DecimalField": 

30 sql = "CAST(%s AS NUMERIC)" % sql 

31 except FieldError: 

32 pass 

33 return sql, params 

34 

35 

36class Combinable: 

37 """ 

38 Provide the ability to combine one or two objects with 

39 some connector. For example F('foo') + F('bar'). 

40 """ 

41 

42 # Arithmetic connectors 

43 ADD = "+" 

44 SUB = "-" 

45 MUL = "*" 

46 DIV = "/" 

47 POW = "^" 

48 # The following is a quoted % operator - it is quoted because it can be 

49 # used in strings that also have parameter substitution. 

50 MOD = "%%" 

51 

52 # Bitwise operators - note that these are generated by .bitand() 

53 # and .bitor(), the '&' and '|' are reserved for boolean operator 

54 # usage. 

55 BITAND = "&" 

56 BITOR = "|" 

57 BITLEFTSHIFT = "<<" 

58 BITRIGHTSHIFT = ">>" 

59 BITXOR = "#" 

60 

61 def _combine(self, other, connector, reversed): 

62 if not hasattr(other, "resolve_expression"): 

63 # everything must be resolvable to an expression 

64 other = Value(other) 

65 

66 if reversed: 

67 return CombinedExpression(other, connector, self) 

68 return CombinedExpression(self, connector, other) 

69 

70 ############# 

71 # OPERATORS # 

72 ############# 

73 

74 def __neg__(self): 

75 return self._combine(-1, self.MUL, False) 

76 

77 def __add__(self, other): 

78 return self._combine(other, self.ADD, False) 

79 

80 def __sub__(self, other): 

81 return self._combine(other, self.SUB, False) 

82 

83 def __mul__(self, other): 

84 return self._combine(other, self.MUL, False) 

85 

86 def __truediv__(self, other): 

87 return self._combine(other, self.DIV, False) 

88 

89 def __mod__(self, other): 

90 return self._combine(other, self.MOD, False) 

91 

92 def __pow__(self, other): 

93 return self._combine(other, self.POW, False) 

94 

95 def __and__(self, other): 

96 if getattr(self, "conditional", False) and getattr(other, "conditional", False): 

97 return Q(self) & Q(other) 

98 raise NotImplementedError( 

99 "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations." 

100 ) 

101 

102 def bitand(self, other): 

103 return self._combine(other, self.BITAND, False) 

104 

105 def bitleftshift(self, other): 

106 return self._combine(other, self.BITLEFTSHIFT, False) 

107 

108 def bitrightshift(self, other): 

109 return self._combine(other, self.BITRIGHTSHIFT, False) 

110 

111 def __xor__(self, other): 

112 if getattr(self, "conditional", False) and getattr(other, "conditional", False): 

113 return Q(self) ^ Q(other) 

114 raise NotImplementedError( 

115 "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations." 

116 ) 

117 

118 def bitxor(self, other): 

119 return self._combine(other, self.BITXOR, False) 

120 

121 def __or__(self, other): 

122 if getattr(self, "conditional", False) and getattr(other, "conditional", False): 

123 return Q(self) | Q(other) 

124 raise NotImplementedError( 

125 "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations." 

126 ) 

127 

128 def bitor(self, other): 

129 return self._combine(other, self.BITOR, False) 

130 

131 def __radd__(self, other): 

132 return self._combine(other, self.ADD, True) 

133 

134 def __rsub__(self, other): 

135 return self._combine(other, self.SUB, True) 

136 

137 def __rmul__(self, other): 

138 return self._combine(other, self.MUL, True) 

139 

140 def __rtruediv__(self, other): 

141 return self._combine(other, self.DIV, True) 

142 

143 def __rmod__(self, other): 

144 return self._combine(other, self.MOD, True) 

145 

146 def __rpow__(self, other): 

147 return self._combine(other, self.POW, True) 

148 

149 def __rand__(self, other): 

150 raise NotImplementedError( 

151 "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations." 

152 ) 

153 

154 def __ror__(self, other): 

155 raise NotImplementedError( 

156 "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations." 

157 ) 

158 

159 def __rxor__(self, other): 

160 raise NotImplementedError( 

161 "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations." 

162 ) 

163 

164 def __invert__(self): 

165 return NegatedExpression(self) 

166 

167 

168class BaseExpression: 

169 """Base class for all query expressions.""" 

170 

171 empty_result_set_value = NotImplemented 

172 # aggregate specific fields 

173 is_summary = False 

174 _output_field_resolved_to_none = False 

175 # Can the expression be used in a WHERE clause? 

176 filterable = True 

177 # Can the expression can be used as a source expression in Window? 

178 window_compatible = False 

179 

180 def __init__(self, output_field=None): 

181 if output_field is not None: 

182 self.output_field = output_field 

183 

184 def __getstate__(self): 

185 state = self.__dict__.copy() 

186 state.pop("convert_value", None) 

187 return state 

188 

189 def get_db_converters(self, connection): 

190 return ( 

191 [] 

192 if self.convert_value is self._convert_value_noop 

193 else [self.convert_value] 

194 ) + self.output_field.get_db_converters(connection) 

195 

196 def get_source_expressions(self): 

197 return [] 

198 

199 def set_source_expressions(self, exprs): 

200 assert not exprs 

201 

202 def _parse_expressions(self, *expressions): 

203 return [ 

204 arg 

205 if hasattr(arg, "resolve_expression") 

206 else (F(arg) if isinstance(arg, str) else Value(arg)) 

207 for arg in expressions 

208 ] 

209 

210 def as_sql(self, compiler, connection): 

211 """ 

212 Responsible for returning a (sql, [params]) tuple to be included 

213 in the current query. 

214 

215 Different backends can provide their own implementation, by 

216 providing an `as_{vendor}` method and patching the Expression: 

217 

218 ``` 

219 def override_as_sql(self, compiler, connection): 

220 # custom logic 

221 return super().as_sql(compiler, connection) 

222 setattr(Expression, 'as_' + connection.vendor, override_as_sql) 

223 ``` 

224 

225 Arguments: 

226 * compiler: the query compiler responsible for generating the query. 

227 Must have a compile method, returning a (sql, [params]) tuple. 

228 Calling compiler(value) will return a quoted `value`. 

229 

230 * connection: the database connection used for the current query. 

231 

232 Return: (sql, params) 

233 Where `sql` is a string containing ordered sql parameters to be 

234 replaced with the elements of the list `params`. 

235 """ 

236 raise NotImplementedError("Subclasses must implement as_sql()") 

237 

238 @cached_property 

239 def contains_aggregate(self): 

240 return any( 

241 expr and expr.contains_aggregate for expr in self.get_source_expressions() 

242 ) 

243 

244 @cached_property 

245 def contains_over_clause(self): 

246 return any( 

247 expr and expr.contains_over_clause for expr in self.get_source_expressions() 

248 ) 

249 

250 @cached_property 

251 def contains_column_references(self): 

252 return any( 

253 expr and expr.contains_column_references 

254 for expr in self.get_source_expressions() 

255 ) 

256 

257 def resolve_expression( 

258 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False 

259 ): 

260 """ 

261 Provide the chance to do any preprocessing or validation before being 

262 added to the query. 

263 

264 Arguments: 

265 * query: the backend query implementation 

266 * allow_joins: boolean allowing or denying use of joins 

267 in this query 

268 * reuse: a set of reusable joins for multijoins 

269 * summarize: a terminal aggregate clause 

270 * for_save: whether this expression about to be used in a save or update 

271 

272 Return: an Expression to be added to the query. 

273 """ 

274 c = self.copy() 

275 c.is_summary = summarize 

276 c.set_source_expressions( 

277 [ 

278 expr.resolve_expression(query, allow_joins, reuse, summarize) 

279 if expr 

280 else None 

281 for expr in c.get_source_expressions() 

282 ] 

283 ) 

284 return c 

285 

286 @property 

287 def conditional(self): 

288 return isinstance(self.output_field, fields.BooleanField) 

289 

290 @property 

291 def field(self): 

292 return self.output_field 

293 

294 @cached_property 

295 def output_field(self): 

296 """Return the output type of this expressions.""" 

297 output_field = self._resolve_output_field() 

298 if output_field is None: 

299 self._output_field_resolved_to_none = True 

300 raise FieldError("Cannot resolve expression type, unknown output_field") 

301 return output_field 

302 

303 @cached_property 

304 def _output_field_or_none(self): 

305 """ 

306 Return the output field of this expression, or None if 

307 _resolve_output_field() didn't return an output type. 

308 """ 

309 try: 

310 return self.output_field 

311 except FieldError: 

312 if not self._output_field_resolved_to_none: 

313 raise 

314 

315 def _resolve_output_field(self): 

316 """ 

317 Attempt to infer the output type of the expression. 

318 

319 As a guess, if the output fields of all source fields match then simply 

320 infer the same type here. 

321 

322 If a source's output field resolves to None, exclude it from this check. 

323 If all sources are None, then an error is raised higher up the stack in 

324 the output_field property. 

325 """ 

326 # This guess is mostly a bad idea, but there is quite a lot of code 

327 # (especially 3rd party Func subclasses) that depend on it, we'd need a 

328 # deprecation path to fix it. 

329 sources_iter = ( 

330 source for source in self.get_source_fields() if source is not None 

331 ) 

332 for output_field in sources_iter: 

333 for source in sources_iter: 

334 if not isinstance(output_field, source.__class__): 

335 raise FieldError( 

336 "Expression contains mixed types: {}, {}. You must " 

337 "set output_field.".format( 

338 output_field.__class__.__name__, 

339 source.__class__.__name__, 

340 ) 

341 ) 

342 return output_field 

343 

344 @staticmethod 

345 def _convert_value_noop(value, expression, connection): 

346 return value 

347 

348 @cached_property 

349 def convert_value(self): 

350 """ 

351 Expressions provide their own converters because users have the option 

352 of manually specifying the output_field which may be a different type 

353 from the one the database returns. 

354 """ 

355 field = self.output_field 

356 internal_type = field.get_internal_type() 

357 if internal_type == "FloatField": 

358 return ( 

359 lambda value, expression, connection: None 

360 if value is None 

361 else float(value) 

362 ) 

363 elif internal_type.endswith("IntegerField"): 

364 return ( 

365 lambda value, expression, connection: None 

366 if value is None 

367 else int(value) 

368 ) 

369 elif internal_type == "DecimalField": 

370 return ( 

371 lambda value, expression, connection: None 

372 if value is None 

373 else Decimal(value) 

374 ) 

375 return self._convert_value_noop 

376 

377 def get_lookup(self, lookup): 

378 return self.output_field.get_lookup(lookup) 

379 

380 def get_transform(self, name): 

381 return self.output_field.get_transform(name) 

382 

383 def relabeled_clone(self, change_map): 

384 clone = self.copy() 

385 clone.set_source_expressions( 

386 [ 

387 e.relabeled_clone(change_map) if e is not None else None 

388 for e in self.get_source_expressions() 

389 ] 

390 ) 

391 return clone 

392 

393 def replace_expressions(self, replacements): 

394 if replacement := replacements.get(self): 

395 return replacement 

396 clone = self.copy() 

397 source_expressions = clone.get_source_expressions() 

398 clone.set_source_expressions( 

399 [ 

400 expr.replace_expressions(replacements) if expr else None 

401 for expr in source_expressions 

402 ] 

403 ) 

404 return clone 

405 

406 def get_refs(self): 

407 refs = set() 

408 for expr in self.get_source_expressions(): 

409 refs |= expr.get_refs() 

410 return refs 

411 

412 def copy(self): 

413 return copy.copy(self) 

414 

415 def prefix_references(self, prefix): 

416 clone = self.copy() 

417 clone.set_source_expressions( 

418 [ 

419 F(f"{prefix}{expr.name}") 

420 if isinstance(expr, F) 

421 else expr.prefix_references(prefix) 

422 for expr in self.get_source_expressions() 

423 ] 

424 ) 

425 return clone 

426 

427 def get_group_by_cols(self): 

428 if not self.contains_aggregate: 

429 return [self] 

430 cols = [] 

431 for source in self.get_source_expressions(): 

432 cols.extend(source.get_group_by_cols()) 

433 return cols 

434 

435 def get_source_fields(self): 

436 """Return the underlying field types used by this aggregate.""" 

437 return [e._output_field_or_none for e in self.get_source_expressions()] 

438 

439 def asc(self, **kwargs): 

440 return OrderBy(self, **kwargs) 

441 

442 def desc(self, **kwargs): 

443 return OrderBy(self, descending=True, **kwargs) 

444 

445 def reverse_ordering(self): 

446 return self 

447 

448 def flatten(self): 

449 """ 

450 Recursively yield this expression and all subexpressions, in 

451 depth-first order. 

452 """ 

453 yield self 

454 for expr in self.get_source_expressions(): 

455 if expr: 

456 if hasattr(expr, "flatten"): 

457 yield from expr.flatten() 

458 else: 

459 yield expr 

460 

461 def select_format(self, compiler, sql, params): 

462 """ 

463 Custom format for select clauses. For example, EXISTS expressions need 

464 to be wrapped in CASE WHEN on Oracle. 

465 """ 

466 if hasattr(self.output_field, "select_format"): 

467 return self.output_field.select_format(compiler, sql, params) 

468 return sql, params 

469 

470 

471@deconstructible 

472class Expression(BaseExpression, Combinable): 

473 """An expression that can be combined with other expressions.""" 

474 

475 @cached_property 

476 def identity(self): 

477 constructor_signature = inspect.signature(self.__init__) 

478 args, kwargs = self._constructor_args 

479 signature = constructor_signature.bind_partial(*args, **kwargs) 

480 signature.apply_defaults() 

481 arguments = signature.arguments.items() 

482 identity = [self.__class__] 

483 for arg, value in arguments: 

484 if isinstance(value, fields.Field): 

485 if value.name and value.model: 

486 value = (value.model._meta.label, value.name) 

487 else: 

488 value = type(value) 

489 else: 

490 value = make_hashable(value) 

491 identity.append((arg, value)) 

492 return tuple(identity) 

493 

494 def __eq__(self, other): 

495 if not isinstance(other, Expression): 

496 return NotImplemented 

497 return other.identity == self.identity 

498 

499 def __hash__(self): 

500 return hash(self.identity) 

501 

502 

503# Type inference for CombinedExpression.output_field. 

504# Missing items will result in FieldError, by design. 

505# 

506# The current approach for NULL is based on lowest common denominator behavior 

507# i.e. if one of the supported databases is raising an error (rather than 

508# return NULL) for `val <op> NULL`, then Plain raises FieldError. 

509 

510_connector_combinations = [ 

511 # Numeric operations - operands of same type. 

512 { 

513 connector: [ 

514 (fields.IntegerField, fields.IntegerField, fields.IntegerField), 

515 (fields.FloatField, fields.FloatField, fields.FloatField), 

516 (fields.DecimalField, fields.DecimalField, fields.DecimalField), 

517 ] 

518 for connector in ( 

519 Combinable.ADD, 

520 Combinable.SUB, 

521 Combinable.MUL, 

522 # Behavior for DIV with integer arguments follows Postgres/SQLite, 

523 # not MySQL/Oracle. 

524 Combinable.DIV, 

525 Combinable.MOD, 

526 Combinable.POW, 

527 ) 

528 }, 

529 # Numeric operations - operands of different type. 

530 { 

531 connector: [ 

532 (fields.IntegerField, fields.DecimalField, fields.DecimalField), 

533 (fields.DecimalField, fields.IntegerField, fields.DecimalField), 

534 (fields.IntegerField, fields.FloatField, fields.FloatField), 

535 (fields.FloatField, fields.IntegerField, fields.FloatField), 

536 ] 

537 for connector in ( 

538 Combinable.ADD, 

539 Combinable.SUB, 

540 Combinable.MUL, 

541 Combinable.DIV, 

542 Combinable.MOD, 

543 ) 

544 }, 

545 # Bitwise operators. 

546 { 

547 connector: [ 

548 (fields.IntegerField, fields.IntegerField, fields.IntegerField), 

549 ] 

550 for connector in ( 

551 Combinable.BITAND, 

552 Combinable.BITOR, 

553 Combinable.BITLEFTSHIFT, 

554 Combinable.BITRIGHTSHIFT, 

555 Combinable.BITXOR, 

556 ) 

557 }, 

558 # Numeric with NULL. 

559 { 

560 connector: [ 

561 (field_type, NoneType, field_type), 

562 (NoneType, field_type, field_type), 

563 ] 

564 for connector in ( 

565 Combinable.ADD, 

566 Combinable.SUB, 

567 Combinable.MUL, 

568 Combinable.DIV, 

569 Combinable.MOD, 

570 Combinable.POW, 

571 ) 

572 for field_type in (fields.IntegerField, fields.DecimalField, fields.FloatField) 

573 }, 

574 # Date/DateTimeField/DurationField/TimeField. 

575 { 

576 Combinable.ADD: [ 

577 # Date/DateTimeField. 

578 (fields.DateField, fields.DurationField, fields.DateTimeField), 

579 (fields.DateTimeField, fields.DurationField, fields.DateTimeField), 

580 (fields.DurationField, fields.DateField, fields.DateTimeField), 

581 (fields.DurationField, fields.DateTimeField, fields.DateTimeField), 

582 # DurationField. 

583 (fields.DurationField, fields.DurationField, fields.DurationField), 

584 # TimeField. 

585 (fields.TimeField, fields.DurationField, fields.TimeField), 

586 (fields.DurationField, fields.TimeField, fields.TimeField), 

587 ], 

588 }, 

589 { 

590 Combinable.SUB: [ 

591 # Date/DateTimeField. 

592 (fields.DateField, fields.DurationField, fields.DateTimeField), 

593 (fields.DateTimeField, fields.DurationField, fields.DateTimeField), 

594 (fields.DateField, fields.DateField, fields.DurationField), 

595 (fields.DateField, fields.DateTimeField, fields.DurationField), 

596 (fields.DateTimeField, fields.DateField, fields.DurationField), 

597 (fields.DateTimeField, fields.DateTimeField, fields.DurationField), 

598 # DurationField. 

599 (fields.DurationField, fields.DurationField, fields.DurationField), 

600 # TimeField. 

601 (fields.TimeField, fields.DurationField, fields.TimeField), 

602 (fields.TimeField, fields.TimeField, fields.DurationField), 

603 ], 

604 }, 

605] 

606 

607_connector_combinators = defaultdict(list) 

608 

609 

610def register_combinable_fields(lhs, connector, rhs, result): 

611 """ 

612 Register combinable types: 

613 lhs <connector> rhs -> result 

614 e.g. 

615 register_combinable_fields( 

616 IntegerField, Combinable.ADD, FloatField, FloatField 

617 ) 

618 """ 

619 _connector_combinators[connector].append((lhs, rhs, result)) 

620 

621 

622for d in _connector_combinations: 

623 for connector, field_types in d.items(): 

624 for lhs, rhs, result in field_types: 

625 register_combinable_fields(lhs, connector, rhs, result) 

626 

627 

628@functools.lru_cache(maxsize=128) 

629def _resolve_combined_type(connector, lhs_type, rhs_type): 

630 combinators = _connector_combinators.get(connector, ()) 

631 for combinator_lhs_type, combinator_rhs_type, combined_type in combinators: 

632 if issubclass(lhs_type, combinator_lhs_type) and issubclass( 

633 rhs_type, combinator_rhs_type 

634 ): 

635 return combined_type 

636 

637 

638class CombinedExpression(SQLiteNumericMixin, Expression): 

639 def __init__(self, lhs, connector, rhs, output_field=None): 

640 super().__init__(output_field=output_field) 

641 self.connector = connector 

642 self.lhs = lhs 

643 self.rhs = rhs 

644 

645 def __repr__(self): 

646 return f"<{self.__class__.__name__}: {self}>" 

647 

648 def __str__(self): 

649 return f"{self.lhs} {self.connector} {self.rhs}" 

650 

651 def get_source_expressions(self): 

652 return [self.lhs, self.rhs] 

653 

654 def set_source_expressions(self, exprs): 

655 self.lhs, self.rhs = exprs 

656 

657 def _resolve_output_field(self): 

658 # We avoid using super() here for reasons given in 

659 # Expression._resolve_output_field() 

660 combined_type = _resolve_combined_type( 

661 self.connector, 

662 type(self.lhs._output_field_or_none), 

663 type(self.rhs._output_field_or_none), 

664 ) 

665 if combined_type is None: 

666 raise FieldError( 

667 f"Cannot infer type of {self.connector!r} expression involving these " 

668 f"types: {self.lhs.output_field.__class__.__name__}, " 

669 f"{self.rhs.output_field.__class__.__name__}. You must set " 

670 f"output_field." 

671 ) 

672 return combined_type() 

673 

674 def as_sql(self, compiler, connection): 

675 expressions = [] 

676 expression_params = [] 

677 sql, params = compiler.compile(self.lhs) 

678 expressions.append(sql) 

679 expression_params.extend(params) 

680 sql, params = compiler.compile(self.rhs) 

681 expressions.append(sql) 

682 expression_params.extend(params) 

683 # order of precedence 

684 expression_wrapper = "(%s)" 

685 sql = connection.ops.combine_expression(self.connector, expressions) 

686 return expression_wrapper % sql, expression_params 

687 

688 def resolve_expression( 

689 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False 

690 ): 

691 lhs = self.lhs.resolve_expression( 

692 query, allow_joins, reuse, summarize, for_save 

693 ) 

694 rhs = self.rhs.resolve_expression( 

695 query, allow_joins, reuse, summarize, for_save 

696 ) 

697 if not isinstance(self, DurationExpression | TemporalSubtraction): 

698 try: 

699 lhs_type = lhs.output_field.get_internal_type() 

700 except (AttributeError, FieldError): 

701 lhs_type = None 

702 try: 

703 rhs_type = rhs.output_field.get_internal_type() 

704 except (AttributeError, FieldError): 

705 rhs_type = None 

706 if "DurationField" in {lhs_type, rhs_type} and lhs_type != rhs_type: 

707 return DurationExpression( 

708 self.lhs, self.connector, self.rhs 

709 ).resolve_expression( 

710 query, 

711 allow_joins, 

712 reuse, 

713 summarize, 

714 for_save, 

715 ) 

716 datetime_fields = {"DateField", "DateTimeField", "TimeField"} 

717 if ( 

718 self.connector == self.SUB 

719 and lhs_type in datetime_fields 

720 and lhs_type == rhs_type 

721 ): 

722 return TemporalSubtraction(self.lhs, self.rhs).resolve_expression( 

723 query, 

724 allow_joins, 

725 reuse, 

726 summarize, 

727 for_save, 

728 ) 

729 c = self.copy() 

730 c.is_summary = summarize 

731 c.lhs = lhs 

732 c.rhs = rhs 

733 return c 

734 

735 

736class DurationExpression(CombinedExpression): 

737 def compile(self, side, compiler, connection): 

738 try: 

739 output = side.output_field 

740 except FieldError: 

741 pass 

742 else: 

743 if output.get_internal_type() == "DurationField": 

744 sql, params = compiler.compile(side) 

745 return connection.ops.format_for_duration_arithmetic(sql), params 

746 return compiler.compile(side) 

747 

748 def as_sql(self, compiler, connection): 

749 if connection.features.has_native_duration_field: 

750 return super().as_sql(compiler, connection) 

751 connection.ops.check_expression_support(self) 

752 expressions = [] 

753 expression_params = [] 

754 sql, params = self.compile(self.lhs, compiler, connection) 

755 expressions.append(sql) 

756 expression_params.extend(params) 

757 sql, params = self.compile(self.rhs, compiler, connection) 

758 expressions.append(sql) 

759 expression_params.extend(params) 

760 # order of precedence 

761 expression_wrapper = "(%s)" 

762 sql = connection.ops.combine_duration_expression(self.connector, expressions) 

763 return expression_wrapper % sql, expression_params 

764 

765 def as_sqlite(self, compiler, connection, **extra_context): 

766 sql, params = self.as_sql(compiler, connection, **extra_context) 

767 if self.connector in {Combinable.MUL, Combinable.DIV}: 

768 try: 

769 lhs_type = self.lhs.output_field.get_internal_type() 

770 rhs_type = self.rhs.output_field.get_internal_type() 

771 except (AttributeError, FieldError): 

772 pass 

773 else: 

774 allowed_fields = { 

775 "DecimalField", 

776 "DurationField", 

777 "FloatField", 

778 "IntegerField", 

779 } 

780 if lhs_type not in allowed_fields or rhs_type not in allowed_fields: 

781 raise DatabaseError( 

782 f"Invalid arguments for operator {self.connector}." 

783 ) 

784 return sql, params 

785 

786 

787class TemporalSubtraction(CombinedExpression): 

788 output_field = fields.DurationField() 

789 

790 def __init__(self, lhs, rhs): 

791 super().__init__(lhs, self.SUB, rhs) 

792 

793 def as_sql(self, compiler, connection): 

794 connection.ops.check_expression_support(self) 

795 lhs = compiler.compile(self.lhs) 

796 rhs = compiler.compile(self.rhs) 

797 return connection.ops.subtract_temporals( 

798 self.lhs.output_field.get_internal_type(), lhs, rhs 

799 ) 

800 

801 

802@deconstructible(path="plain.models.F") 

803class F(Combinable): 

804 """An object capable of resolving references to existing query objects.""" 

805 

806 def __init__(self, name): 

807 """ 

808 Arguments: 

809 * name: the name of the field this expression references 

810 """ 

811 self.name = name 

812 

813 def __repr__(self): 

814 return f"{self.__class__.__name__}({self.name})" 

815 

816 def resolve_expression( 

817 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False 

818 ): 

819 return query.resolve_ref(self.name, allow_joins, reuse, summarize) 

820 

821 def replace_expressions(self, replacements): 

822 return replacements.get(self, self) 

823 

824 def asc(self, **kwargs): 

825 return OrderBy(self, **kwargs) 

826 

827 def desc(self, **kwargs): 

828 return OrderBy(self, descending=True, **kwargs) 

829 

830 def __eq__(self, other): 

831 return self.__class__ == other.__class__ and self.name == other.name 

832 

833 def __hash__(self): 

834 return hash(self.name) 

835 

836 def copy(self): 

837 return copy.copy(self) 

838 

839 

840class ResolvedOuterRef(F): 

841 """ 

842 An object that contains a reference to an outer query. 

843 

844 In this case, the reference to the outer query has been resolved because 

845 the inner query has been used as a subquery. 

846 """ 

847 

848 contains_aggregate = False 

849 contains_over_clause = False 

850 

851 def as_sql(self, *args, **kwargs): 

852 raise ValueError( 

853 "This queryset contains a reference to an outer query and may " 

854 "only be used in a subquery." 

855 ) 

856 

857 def resolve_expression(self, *args, **kwargs): 

858 col = super().resolve_expression(*args, **kwargs) 

859 if col.contains_over_clause: 

860 raise NotSupportedError( 

861 f"Referencing outer query window expression is not supported: " 

862 f"{self.name}." 

863 ) 

864 # FIXME: Rename possibly_multivalued to multivalued and fix detection 

865 # for non-multivalued JOINs (e.g. foreign key fields). This should take 

866 # into account only many-to-many and one-to-many relationships. 

867 col.possibly_multivalued = LOOKUP_SEP in self.name 

868 return col 

869 

870 def relabeled_clone(self, relabels): 

871 return self 

872 

873 def get_group_by_cols(self): 

874 return [] 

875 

876 

877class OuterRef(F): 

878 contains_aggregate = False 

879 

880 def resolve_expression(self, *args, **kwargs): 

881 if isinstance(self.name, self.__class__): 

882 return self.name 

883 return ResolvedOuterRef(self.name) 

884 

885 def relabeled_clone(self, relabels): 

886 return self 

887 

888 

889@deconstructible(path="plain.models.Func") 

890class Func(SQLiteNumericMixin, Expression): 

891 """An SQL function call.""" 

892 

893 function = None 

894 template = "%(function)s(%(expressions)s)" 

895 arg_joiner = ", " 

896 arity = None # The number of arguments the function accepts. 

897 

898 def __init__(self, *expressions, output_field=None, **extra): 

899 if self.arity is not None and len(expressions) != self.arity: 

900 raise TypeError( 

901 "'{}' takes exactly {} {} ({} given)".format( 

902 self.__class__.__name__, 

903 self.arity, 

904 "argument" if self.arity == 1 else "arguments", 

905 len(expressions), 

906 ) 

907 ) 

908 super().__init__(output_field=output_field) 

909 self.source_expressions = self._parse_expressions(*expressions) 

910 self.extra = extra 

911 

912 def __repr__(self): 

913 args = self.arg_joiner.join(str(arg) for arg in self.source_expressions) 

914 extra = {**self.extra, **self._get_repr_options()} 

915 if extra: 

916 extra = ", ".join( 

917 str(key) + "=" + str(val) for key, val in sorted(extra.items()) 

918 ) 

919 return f"{self.__class__.__name__}({args}, {extra})" 

920 return f"{self.__class__.__name__}({args})" 

921 

922 def _get_repr_options(self): 

923 """Return a dict of extra __init__() options to include in the repr.""" 

924 return {} 

925 

926 def get_source_expressions(self): 

927 return self.source_expressions 

928 

929 def set_source_expressions(self, exprs): 

930 self.source_expressions = exprs 

931 

932 def resolve_expression( 

933 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False 

934 ): 

935 c = self.copy() 

936 c.is_summary = summarize 

937 for pos, arg in enumerate(c.source_expressions): 

938 c.source_expressions[pos] = arg.resolve_expression( 

939 query, allow_joins, reuse, summarize, for_save 

940 ) 

941 return c 

942 

943 def as_sql( 

944 self, 

945 compiler, 

946 connection, 

947 function=None, 

948 template=None, 

949 arg_joiner=None, 

950 **extra_context, 

951 ): 

952 connection.ops.check_expression_support(self) 

953 sql_parts = [] 

954 params = [] 

955 for arg in self.source_expressions: 

956 try: 

957 arg_sql, arg_params = compiler.compile(arg) 

958 except EmptyResultSet: 

959 empty_result_set_value = getattr( 

960 arg, "empty_result_set_value", NotImplemented 

961 ) 

962 if empty_result_set_value is NotImplemented: 

963 raise 

964 arg_sql, arg_params = compiler.compile(Value(empty_result_set_value)) 

965 except FullResultSet: 

966 arg_sql, arg_params = compiler.compile(Value(True)) 

967 sql_parts.append(arg_sql) 

968 params.extend(arg_params) 

969 data = {**self.extra, **extra_context} 

970 # Use the first supplied value in this order: the parameter to this 

971 # method, a value supplied in __init__()'s **extra (the value in 

972 # `data`), or the value defined on the class. 

973 if function is not None: 

974 data["function"] = function 

975 else: 

976 data.setdefault("function", self.function) 

977 template = template or data.get("template", self.template) 

978 arg_joiner = arg_joiner or data.get("arg_joiner", self.arg_joiner) 

979 data["expressions"] = data["field"] = arg_joiner.join(sql_parts) 

980 return template % data, params 

981 

982 def copy(self): 

983 copy = super().copy() 

984 copy.source_expressions = self.source_expressions[:] 

985 copy.extra = self.extra.copy() 

986 return copy 

987 

988 

989@deconstructible(path="plain.models.Value") 

990class Value(SQLiteNumericMixin, Expression): 

991 """Represent a wrapped value as a node within an expression.""" 

992 

993 # Provide a default value for `for_save` in order to allow unresolved 

994 # instances to be compiled until a decision is taken in #25425. 

995 for_save = False 

996 

997 def __init__(self, value, output_field=None): 

998 """ 

999 Arguments: 

1000 * value: the value this expression represents. The value will be 

1001 added into the sql parameter list and properly quoted. 

1002 

1003 * output_field: an instance of the model field type that this 

1004 expression will return, such as IntegerField() or CharField(). 

1005 """ 

1006 super().__init__(output_field=output_field) 

1007 self.value = value 

1008 

1009 def __repr__(self): 

1010 return f"{self.__class__.__name__}({self.value!r})" 

1011 

1012 def as_sql(self, compiler, connection): 

1013 connection.ops.check_expression_support(self) 

1014 val = self.value 

1015 output_field = self._output_field_or_none 

1016 if output_field is not None: 

1017 if self.for_save: 

1018 val = output_field.get_db_prep_save(val, connection=connection) 

1019 else: 

1020 val = output_field.get_db_prep_value(val, connection=connection) 

1021 if hasattr(output_field, "get_placeholder"): 

1022 return output_field.get_placeholder(val, compiler, connection), [val] 

1023 if val is None: 

1024 # cx_Oracle does not always convert None to the appropriate 

1025 # NULL type (like in case expressions using numbers), so we 

1026 # use a literal SQL NULL 

1027 return "NULL", [] 

1028 return "%s", [val] 

1029 

1030 def resolve_expression( 

1031 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False 

1032 ): 

1033 c = super().resolve_expression(query, allow_joins, reuse, summarize, for_save) 

1034 c.for_save = for_save 

1035 return c 

1036 

1037 def get_group_by_cols(self): 

1038 return [] 

1039 

1040 def _resolve_output_field(self): 

1041 if isinstance(self.value, str): 

1042 return fields.CharField() 

1043 if isinstance(self.value, bool): 

1044 return fields.BooleanField() 

1045 if isinstance(self.value, int): 

1046 return fields.IntegerField() 

1047 if isinstance(self.value, float): 

1048 return fields.FloatField() 

1049 if isinstance(self.value, datetime.datetime): 

1050 return fields.DateTimeField() 

1051 if isinstance(self.value, datetime.date): 

1052 return fields.DateField() 

1053 if isinstance(self.value, datetime.time): 

1054 return fields.TimeField() 

1055 if isinstance(self.value, datetime.timedelta): 

1056 return fields.DurationField() 

1057 if isinstance(self.value, Decimal): 

1058 return fields.DecimalField() 

1059 if isinstance(self.value, bytes): 

1060 return fields.BinaryField() 

1061 if isinstance(self.value, UUID): 

1062 return fields.UUIDField() 

1063 

1064 @property 

1065 def empty_result_set_value(self): 

1066 return self.value 

1067 

1068 

1069class RawSQL(Expression): 

1070 def __init__(self, sql, params, output_field=None): 

1071 if output_field is None: 

1072 output_field = fields.Field() 

1073 self.sql, self.params = sql, params 

1074 super().__init__(output_field=output_field) 

1075 

1076 def __repr__(self): 

1077 return f"{self.__class__.__name__}({self.sql}, {self.params})" 

1078 

1079 def as_sql(self, compiler, connection): 

1080 return "(%s)" % self.sql, self.params 

1081 

1082 def get_group_by_cols(self): 

1083 return [self] 

1084 

1085 def resolve_expression( 

1086 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False 

1087 ): 

1088 # Resolve parents fields used in raw SQL. 

1089 if query.model: 

1090 for parent in query.model._meta.get_parent_list(): 

1091 for parent_field in parent._meta.local_fields: 

1092 _, column_name = parent_field.get_attname_column() 

1093 if column_name.lower() in self.sql.lower(): 

1094 query.resolve_ref( 

1095 parent_field.name, allow_joins, reuse, summarize 

1096 ) 

1097 break 

1098 return super().resolve_expression( 

1099 query, allow_joins, reuse, summarize, for_save 

1100 ) 

1101 

1102 

1103class Star(Expression): 

1104 def __repr__(self): 

1105 return "'*'" 

1106 

1107 def as_sql(self, compiler, connection): 

1108 return "*", [] 

1109 

1110 

1111class Col(Expression): 

1112 contains_column_references = True 

1113 possibly_multivalued = False 

1114 

1115 def __init__(self, alias, target, output_field=None): 

1116 if output_field is None: 

1117 output_field = target 

1118 super().__init__(output_field=output_field) 

1119 self.alias, self.target = alias, target 

1120 

1121 def __repr__(self): 

1122 alias, target = self.alias, self.target 

1123 identifiers = (alias, str(target)) if alias else (str(target),) 

1124 return "{}({})".format(self.__class__.__name__, ", ".join(identifiers)) 

1125 

1126 def as_sql(self, compiler, connection): 

1127 alias, column = self.alias, self.target.column 

1128 identifiers = (alias, column) if alias else (column,) 

1129 sql = ".".join(map(compiler.quote_name_unless_alias, identifiers)) 

1130 return sql, [] 

1131 

1132 def relabeled_clone(self, relabels): 

1133 if self.alias is None: 

1134 return self 

1135 return self.__class__( 

1136 relabels.get(self.alias, self.alias), self.target, self.output_field 

1137 ) 

1138 

1139 def get_group_by_cols(self): 

1140 return [self] 

1141 

1142 def get_db_converters(self, connection): 

1143 if self.target == self.output_field: 

1144 return self.output_field.get_db_converters(connection) 

1145 return self.output_field.get_db_converters( 

1146 connection 

1147 ) + self.target.get_db_converters(connection) 

1148 

1149 

1150class Ref(Expression): 

1151 """ 

1152 Reference to column alias of the query. For example, Ref('sum_cost') in 

1153 qs.annotate(sum_cost=Sum('cost')) query. 

1154 """ 

1155 

1156 def __init__(self, refs, source): 

1157 super().__init__() 

1158 self.refs, self.source = refs, source 

1159 

1160 def __repr__(self): 

1161 return f"{self.__class__.__name__}({self.refs}, {self.source})" 

1162 

1163 def get_source_expressions(self): 

1164 return [self.source] 

1165 

1166 def set_source_expressions(self, exprs): 

1167 (self.source,) = exprs 

1168 

1169 def resolve_expression( 

1170 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False 

1171 ): 

1172 # The sub-expression `source` has already been resolved, as this is 

1173 # just a reference to the name of `source`. 

1174 return self 

1175 

1176 def get_refs(self): 

1177 return {self.refs} 

1178 

1179 def relabeled_clone(self, relabels): 

1180 return self 

1181 

1182 def as_sql(self, compiler, connection): 

1183 return connection.ops.quote_name(self.refs), [] 

1184 

1185 def get_group_by_cols(self): 

1186 return [self] 

1187 

1188 

1189class ExpressionList(Func): 

1190 """ 

1191 An expression containing multiple expressions. Can be used to provide a 

1192 list of expressions as an argument to another expression, like a partition 

1193 clause. 

1194 """ 

1195 

1196 template = "%(expressions)s" 

1197 

1198 def __init__(self, *expressions, **extra): 

1199 if not expressions: 

1200 raise ValueError( 

1201 "%s requires at least one expression." % self.__class__.__name__ 

1202 ) 

1203 super().__init__(*expressions, **extra) 

1204 

1205 def __str__(self): 

1206 return self.arg_joiner.join(str(arg) for arg in self.source_expressions) 

1207 

1208 def as_sqlite(self, compiler, connection, **extra_context): 

1209 # Casting to numeric is unnecessary. 

1210 return self.as_sql(compiler, connection, **extra_context) 

1211 

1212 

1213class OrderByList(Func): 

1214 template = "ORDER BY %(expressions)s" 

1215 

1216 def __init__(self, *expressions, **extra): 

1217 expressions = ( 

1218 ( 

1219 OrderBy(F(expr[1:]), descending=True) 

1220 if isinstance(expr, str) and expr[0] == "-" 

1221 else expr 

1222 ) 

1223 for expr in expressions 

1224 ) 

1225 super().__init__(*expressions, **extra) 

1226 

1227 def as_sql(self, *args, **kwargs): 

1228 if not self.source_expressions: 

1229 return "", () 

1230 return super().as_sql(*args, **kwargs) 

1231 

1232 def get_group_by_cols(self): 

1233 group_by_cols = [] 

1234 for order_by in self.get_source_expressions(): 

1235 group_by_cols.extend(order_by.get_group_by_cols()) 

1236 return group_by_cols 

1237 

1238 

1239@deconstructible(path="plain.models.ExpressionWrapper") 

1240class ExpressionWrapper(SQLiteNumericMixin, Expression): 

1241 """ 

1242 An expression that can wrap another expression so that it can provide 

1243 extra context to the inner expression, such as the output_field. 

1244 """ 

1245 

1246 def __init__(self, expression, output_field): 

1247 super().__init__(output_field=output_field) 

1248 self.expression = expression 

1249 

1250 def set_source_expressions(self, exprs): 

1251 self.expression = exprs[0] 

1252 

1253 def get_source_expressions(self): 

1254 return [self.expression] 

1255 

1256 def get_group_by_cols(self): 

1257 if isinstance(self.expression, Expression): 

1258 expression = self.expression.copy() 

1259 expression.output_field = self.output_field 

1260 return expression.get_group_by_cols() 

1261 # For non-expressions e.g. an SQL WHERE clause, the entire 

1262 # `expression` must be included in the GROUP BY clause. 

1263 return super().get_group_by_cols() 

1264 

1265 def as_sql(self, compiler, connection): 

1266 return compiler.compile(self.expression) 

1267 

1268 def __repr__(self): 

1269 return f"{self.__class__.__name__}({self.expression})" 

1270 

1271 

1272class NegatedExpression(ExpressionWrapper): 

1273 """The logical negation of a conditional expression.""" 

1274 

1275 def __init__(self, expression): 

1276 super().__init__(expression, output_field=fields.BooleanField()) 

1277 

1278 def __invert__(self): 

1279 return self.expression.copy() 

1280 

1281 def as_sql(self, compiler, connection): 

1282 try: 

1283 sql, params = super().as_sql(compiler, connection) 

1284 except EmptyResultSet: 

1285 features = compiler.connection.features 

1286 if not features.supports_boolean_expr_in_select_clause: 

1287 return "1=1", () 

1288 return compiler.compile(Value(True)) 

1289 ops = compiler.connection.ops 

1290 # Some database backends (e.g. Oracle) don't allow EXISTS() and filters 

1291 # to be compared to another expression unless they're wrapped in a CASE 

1292 # WHEN. 

1293 if not ops.conditional_expression_supported_in_where_clause(self.expression): 

1294 return f"CASE WHEN {sql} = 0 THEN 1 ELSE 0 END", params 

1295 return f"NOT {sql}", params 

1296 

1297 def resolve_expression( 

1298 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False 

1299 ): 

1300 resolved = super().resolve_expression( 

1301 query, allow_joins, reuse, summarize, for_save 

1302 ) 

1303 if not getattr(resolved.expression, "conditional", False): 

1304 raise TypeError("Cannot negate non-conditional expressions.") 

1305 return resolved 

1306 

1307 def select_format(self, compiler, sql, params): 

1308 # Wrap boolean expressions with a CASE WHEN expression if a database 

1309 # backend (e.g. Oracle) doesn't support boolean expression in SELECT or 

1310 # GROUP BY list. 

1311 expression_supported_in_where_clause = ( 

1312 compiler.connection.ops.conditional_expression_supported_in_where_clause 

1313 ) 

1314 if ( 

1315 not compiler.connection.features.supports_boolean_expr_in_select_clause 

1316 # Avoid double wrapping. 

1317 and expression_supported_in_where_clause(self.expression) 

1318 ): 

1319 sql = f"CASE WHEN {sql} THEN 1 ELSE 0 END" 

1320 return sql, params 

1321 

1322 

1323@deconstructible(path="plain.models.When") 

1324class When(Expression): 

1325 template = "WHEN %(condition)s THEN %(result)s" 

1326 # This isn't a complete conditional expression, must be used in Case(). 

1327 conditional = False 

1328 

1329 def __init__(self, condition=None, then=None, **lookups): 

1330 if lookups: 

1331 if condition is None: 

1332 condition, lookups = Q(**lookups), None 

1333 elif getattr(condition, "conditional", False): 

1334 condition, lookups = Q(condition, **lookups), None 

1335 if condition is None or not getattr(condition, "conditional", False) or lookups: 

1336 raise TypeError( 

1337 "When() supports a Q object, a boolean expression, or lookups " 

1338 "as a condition." 

1339 ) 

1340 if isinstance(condition, Q) and not condition: 

1341 raise ValueError("An empty Q() can't be used as a When() condition.") 

1342 super().__init__(output_field=None) 

1343 self.condition = condition 

1344 self.result = self._parse_expressions(then)[0] 

1345 

1346 def __str__(self): 

1347 return f"WHEN {self.condition!r} THEN {self.result!r}" 

1348 

1349 def __repr__(self): 

1350 return f"<{self.__class__.__name__}: {self}>" 

1351 

1352 def get_source_expressions(self): 

1353 return [self.condition, self.result] 

1354 

1355 def set_source_expressions(self, exprs): 

1356 self.condition, self.result = exprs 

1357 

1358 def get_source_fields(self): 

1359 # We're only interested in the fields of the result expressions. 

1360 return [self.result._output_field_or_none] 

1361 

1362 def resolve_expression( 

1363 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False 

1364 ): 

1365 c = self.copy() 

1366 c.is_summary = summarize 

1367 if hasattr(c.condition, "resolve_expression"): 

1368 c.condition = c.condition.resolve_expression( 

1369 query, allow_joins, reuse, summarize, False 

1370 ) 

1371 c.result = c.result.resolve_expression( 

1372 query, allow_joins, reuse, summarize, for_save 

1373 ) 

1374 return c 

1375 

1376 def as_sql(self, compiler, connection, template=None, **extra_context): 

1377 connection.ops.check_expression_support(self) 

1378 template_params = extra_context 

1379 sql_params = [] 

1380 condition_sql, condition_params = compiler.compile(self.condition) 

1381 template_params["condition"] = condition_sql 

1382 result_sql, result_params = compiler.compile(self.result) 

1383 template_params["result"] = result_sql 

1384 template = template or self.template 

1385 return template % template_params, ( 

1386 *sql_params, 

1387 *condition_params, 

1388 *result_params, 

1389 ) 

1390 

1391 def get_group_by_cols(self): 

1392 # This is not a complete expression and cannot be used in GROUP BY. 

1393 cols = [] 

1394 for source in self.get_source_expressions(): 

1395 cols.extend(source.get_group_by_cols()) 

1396 return cols 

1397 

1398 

1399@deconstructible(path="plain.models.Case") 

1400class Case(SQLiteNumericMixin, Expression): 

1401 """ 

1402 An SQL searched CASE expression: 

1403 

1404 CASE 

1405 WHEN n > 0 

1406 THEN 'positive' 

1407 WHEN n < 0 

1408 THEN 'negative' 

1409 ELSE 'zero' 

1410 END 

1411 """ 

1412 

1413 template = "CASE %(cases)s ELSE %(default)s END" 

1414 case_joiner = " " 

1415 

1416 def __init__(self, *cases, default=None, output_field=None, **extra): 

1417 if not all(isinstance(case, When) for case in cases): 

1418 raise TypeError("Positional arguments must all be When objects.") 

1419 super().__init__(output_field) 

1420 self.cases = list(cases) 

1421 self.default = self._parse_expressions(default)[0] 

1422 self.extra = extra 

1423 

1424 def __str__(self): 

1425 return "CASE {}, ELSE {!r}".format( 

1426 ", ".join(str(c) for c in self.cases), 

1427 self.default, 

1428 ) 

1429 

1430 def __repr__(self): 

1431 return f"<{self.__class__.__name__}: {self}>" 

1432 

1433 def get_source_expressions(self): 

1434 return self.cases + [self.default] 

1435 

1436 def set_source_expressions(self, exprs): 

1437 *self.cases, self.default = exprs 

1438 

1439 def resolve_expression( 

1440 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False 

1441 ): 

1442 c = self.copy() 

1443 c.is_summary = summarize 

1444 for pos, case in enumerate(c.cases): 

1445 c.cases[pos] = case.resolve_expression( 

1446 query, allow_joins, reuse, summarize, for_save 

1447 ) 

1448 c.default = c.default.resolve_expression( 

1449 query, allow_joins, reuse, summarize, for_save 

1450 ) 

1451 return c 

1452 

1453 def copy(self): 

1454 c = super().copy() 

1455 c.cases = c.cases[:] 

1456 return c 

1457 

1458 def as_sql( 

1459 self, compiler, connection, template=None, case_joiner=None, **extra_context 

1460 ): 

1461 connection.ops.check_expression_support(self) 

1462 if not self.cases: 

1463 return compiler.compile(self.default) 

1464 template_params = {**self.extra, **extra_context} 

1465 case_parts = [] 

1466 sql_params = [] 

1467 default_sql, default_params = compiler.compile(self.default) 

1468 for case in self.cases: 

1469 try: 

1470 case_sql, case_params = compiler.compile(case) 

1471 except EmptyResultSet: 

1472 continue 

1473 except FullResultSet: 

1474 default_sql, default_params = compiler.compile(case.result) 

1475 break 

1476 case_parts.append(case_sql) 

1477 sql_params.extend(case_params) 

1478 if not case_parts: 

1479 return default_sql, default_params 

1480 case_joiner = case_joiner or self.case_joiner 

1481 template_params["cases"] = case_joiner.join(case_parts) 

1482 template_params["default"] = default_sql 

1483 sql_params.extend(default_params) 

1484 template = template or template_params.get("template", self.template) 

1485 sql = template % template_params 

1486 if self._output_field_or_none is not None: 

1487 sql = connection.ops.unification_cast_sql(self.output_field) % sql 

1488 return sql, sql_params 

1489 

1490 def get_group_by_cols(self): 

1491 if not self.cases: 

1492 return self.default.get_group_by_cols() 

1493 return super().get_group_by_cols() 

1494 

1495 

1496class Subquery(BaseExpression, Combinable): 

1497 """ 

1498 An explicit subquery. It may contain OuterRef() references to the outer 

1499 query which will be resolved when it is applied to that query. 

1500 """ 

1501 

1502 template = "(%(subquery)s)" 

1503 contains_aggregate = False 

1504 empty_result_set_value = None 

1505 

1506 def __init__(self, queryset, output_field=None, **extra): 

1507 # Allow the usage of both QuerySet and sql.Query objects. 

1508 self.query = getattr(queryset, "query", queryset).clone() 

1509 self.query.subquery = True 

1510 self.extra = extra 

1511 super().__init__(output_field) 

1512 

1513 def get_source_expressions(self): 

1514 return [self.query] 

1515 

1516 def set_source_expressions(self, exprs): 

1517 self.query = exprs[0] 

1518 

1519 def _resolve_output_field(self): 

1520 return self.query.output_field 

1521 

1522 def copy(self): 

1523 clone = super().copy() 

1524 clone.query = clone.query.clone() 

1525 return clone 

1526 

1527 @property 

1528 def external_aliases(self): 

1529 return self.query.external_aliases 

1530 

1531 def get_external_cols(self): 

1532 return self.query.get_external_cols() 

1533 

1534 def as_sql(self, compiler, connection, template=None, **extra_context): 

1535 connection.ops.check_expression_support(self) 

1536 template_params = {**self.extra, **extra_context} 

1537 subquery_sql, sql_params = self.query.as_sql(compiler, connection) 

1538 template_params["subquery"] = subquery_sql[1:-1] 

1539 

1540 template = template or template_params.get("template", self.template) 

1541 sql = template % template_params 

1542 return sql, sql_params 

1543 

1544 def get_group_by_cols(self): 

1545 return self.query.get_group_by_cols(wrapper=self) 

1546 

1547 

1548class Exists(Subquery): 

1549 template = "EXISTS(%(subquery)s)" 

1550 output_field = fields.BooleanField() 

1551 empty_result_set_value = False 

1552 

1553 def __init__(self, queryset, **kwargs): 

1554 super().__init__(queryset, **kwargs) 

1555 self.query = self.query.exists() 

1556 

1557 def select_format(self, compiler, sql, params): 

1558 # Wrap EXISTS() with a CASE WHEN expression if a database backend 

1559 # (e.g. Oracle) doesn't support boolean expression in SELECT or GROUP 

1560 # BY list. 

1561 if not compiler.connection.features.supports_boolean_expr_in_select_clause: 

1562 sql = f"CASE WHEN {sql} THEN 1 ELSE 0 END" 

1563 return sql, params 

1564 

1565 

1566@deconstructible(path="plain.models.OrderBy") 

1567class OrderBy(Expression): 

1568 template = "%(expression)s %(ordering)s" 

1569 conditional = False 

1570 

1571 def __init__(self, expression, descending=False, nulls_first=None, nulls_last=None): 

1572 if nulls_first and nulls_last: 

1573 raise ValueError("nulls_first and nulls_last are mutually exclusive") 

1574 if nulls_first is False or nulls_last is False: 

1575 raise ValueError("nulls_first and nulls_last values must be True or None.") 

1576 self.nulls_first = nulls_first 

1577 self.nulls_last = nulls_last 

1578 self.descending = descending 

1579 if not hasattr(expression, "resolve_expression"): 

1580 raise ValueError("expression must be an expression type") 

1581 self.expression = expression 

1582 

1583 def __repr__(self): 

1584 return "{}({}, descending={})".format( 

1585 self.__class__.__name__, self.expression, self.descending 

1586 ) 

1587 

1588 def set_source_expressions(self, exprs): 

1589 self.expression = exprs[0] 

1590 

1591 def get_source_expressions(self): 

1592 return [self.expression] 

1593 

1594 def as_sql(self, compiler, connection, template=None, **extra_context): 

1595 template = template or self.template 

1596 if connection.features.supports_order_by_nulls_modifier: 

1597 if self.nulls_last: 

1598 template = "%s NULLS LAST" % template 

1599 elif self.nulls_first: 

1600 template = "%s NULLS FIRST" % template 

1601 else: 

1602 if self.nulls_last and not ( 

1603 self.descending and connection.features.order_by_nulls_first 

1604 ): 

1605 template = "%%(expression)s IS NULL, %s" % template 

1606 elif self.nulls_first and not ( 

1607 not self.descending and connection.features.order_by_nulls_first 

1608 ): 

1609 template = "%%(expression)s IS NOT NULL, %s" % template 

1610 connection.ops.check_expression_support(self) 

1611 expression_sql, params = compiler.compile(self.expression) 

1612 placeholders = { 

1613 "expression": expression_sql, 

1614 "ordering": "DESC" if self.descending else "ASC", 

1615 **extra_context, 

1616 } 

1617 params *= template.count("%(expression)s") 

1618 return (template % placeholders).rstrip(), params 

1619 

1620 def get_group_by_cols(self): 

1621 cols = [] 

1622 for source in self.get_source_expressions(): 

1623 cols.extend(source.get_group_by_cols()) 

1624 return cols 

1625 

1626 def reverse_ordering(self): 

1627 self.descending = not self.descending 

1628 if self.nulls_first: 

1629 self.nulls_last = True 

1630 self.nulls_first = None 

1631 elif self.nulls_last: 

1632 self.nulls_first = True 

1633 self.nulls_last = None 

1634 return self 

1635 

1636 def asc(self): 

1637 self.descending = False 

1638 

1639 def desc(self): 

1640 self.descending = True 

1641 

1642 

1643class Window(SQLiteNumericMixin, Expression): 

1644 template = "%(expression)s OVER (%(window)s)" 

1645 # Although the main expression may either be an aggregate or an 

1646 # expression with an aggregate function, the GROUP BY that will 

1647 # be introduced in the query as a result is not desired. 

1648 contains_aggregate = False 

1649 contains_over_clause = True 

1650 

1651 def __init__( 

1652 self, 

1653 expression, 

1654 partition_by=None, 

1655 order_by=None, 

1656 frame=None, 

1657 output_field=None, 

1658 ): 

1659 self.partition_by = partition_by 

1660 self.order_by = order_by 

1661 self.frame = frame 

1662 

1663 if not getattr(expression, "window_compatible", False): 

1664 raise ValueError( 

1665 "Expression '%s' isn't compatible with OVER clauses." 

1666 % expression.__class__.__name__ 

1667 ) 

1668 

1669 if self.partition_by is not None: 

1670 if not isinstance(self.partition_by, tuple | list): 

1671 self.partition_by = (self.partition_by,) 

1672 self.partition_by = ExpressionList(*self.partition_by) 

1673 

1674 if self.order_by is not None: 

1675 if isinstance(self.order_by, list | tuple): 

1676 self.order_by = OrderByList(*self.order_by) 

1677 elif isinstance(self.order_by, BaseExpression | str): 

1678 self.order_by = OrderByList(self.order_by) 

1679 else: 

1680 raise ValueError( 

1681 "Window.order_by must be either a string reference to a " 

1682 "field, an expression, or a list or tuple of them." 

1683 ) 

1684 super().__init__(output_field=output_field) 

1685 self.source_expression = self._parse_expressions(expression)[0] 

1686 

1687 def _resolve_output_field(self): 

1688 return self.source_expression.output_field 

1689 

1690 def get_source_expressions(self): 

1691 return [self.source_expression, self.partition_by, self.order_by, self.frame] 

1692 

1693 def set_source_expressions(self, exprs): 

1694 self.source_expression, self.partition_by, self.order_by, self.frame = exprs 

1695 

1696 def as_sql(self, compiler, connection, template=None): 

1697 connection.ops.check_expression_support(self) 

1698 if not connection.features.supports_over_clause: 

1699 raise NotSupportedError("This backend does not support window expressions.") 

1700 expr_sql, params = compiler.compile(self.source_expression) 

1701 window_sql, window_params = [], () 

1702 

1703 if self.partition_by is not None: 

1704 sql_expr, sql_params = self.partition_by.as_sql( 

1705 compiler=compiler, 

1706 connection=connection, 

1707 template="PARTITION BY %(expressions)s", 

1708 ) 

1709 window_sql.append(sql_expr) 

1710 window_params += tuple(sql_params) 

1711 

1712 if self.order_by is not None: 

1713 order_sql, order_params = compiler.compile(self.order_by) 

1714 window_sql.append(order_sql) 

1715 window_params += tuple(order_params) 

1716 

1717 if self.frame: 

1718 frame_sql, frame_params = compiler.compile(self.frame) 

1719 window_sql.append(frame_sql) 

1720 window_params += tuple(frame_params) 

1721 

1722 template = template or self.template 

1723 

1724 return ( 

1725 template % {"expression": expr_sql, "window": " ".join(window_sql).strip()}, 

1726 (*params, *window_params), 

1727 ) 

1728 

1729 def as_sqlite(self, compiler, connection): 

1730 if isinstance(self.output_field, fields.DecimalField): 

1731 # Casting to numeric must be outside of the window expression. 

1732 copy = self.copy() 

1733 source_expressions = copy.get_source_expressions() 

1734 source_expressions[0].output_field = fields.FloatField() 

1735 copy.set_source_expressions(source_expressions) 

1736 return super(Window, copy).as_sqlite(compiler, connection) 

1737 return self.as_sql(compiler, connection) 

1738 

1739 def __str__(self): 

1740 return "{} OVER ({}{}{})".format( 

1741 str(self.source_expression), 

1742 "PARTITION BY " + str(self.partition_by) if self.partition_by else "", 

1743 str(self.order_by or ""), 

1744 str(self.frame or ""), 

1745 ) 

1746 

1747 def __repr__(self): 

1748 return f"<{self.__class__.__name__}: {self}>" 

1749 

1750 def get_group_by_cols(self): 

1751 group_by_cols = [] 

1752 if self.partition_by: 

1753 group_by_cols.extend(self.partition_by.get_group_by_cols()) 

1754 if self.order_by is not None: 

1755 group_by_cols.extend(self.order_by.get_group_by_cols()) 

1756 return group_by_cols 

1757 

1758 

1759class WindowFrame(Expression): 

1760 """ 

1761 Model the frame clause in window expressions. There are two types of frame 

1762 clauses which are subclasses, however, all processing and validation (by no 

1763 means intended to be complete) is done here. Thus, providing an end for a 

1764 frame is optional (the default is UNBOUNDED FOLLOWING, which is the last 

1765 row in the frame). 

1766 """ 

1767 

1768 template = "%(frame_type)s BETWEEN %(start)s AND %(end)s" 

1769 

1770 def __init__(self, start=None, end=None): 

1771 self.start = Value(start) 

1772 self.end = Value(end) 

1773 

1774 def set_source_expressions(self, exprs): 

1775 self.start, self.end = exprs 

1776 

1777 def get_source_expressions(self): 

1778 return [self.start, self.end] 

1779 

1780 def as_sql(self, compiler, connection): 

1781 connection.ops.check_expression_support(self) 

1782 start, end = self.window_frame_start_end( 

1783 connection, self.start.value, self.end.value 

1784 ) 

1785 return ( 

1786 self.template 

1787 % { 

1788 "frame_type": self.frame_type, 

1789 "start": start, 

1790 "end": end, 

1791 }, 

1792 [], 

1793 ) 

1794 

1795 def __repr__(self): 

1796 return f"<{self.__class__.__name__}: {self}>" 

1797 

1798 def get_group_by_cols(self): 

1799 return [] 

1800 

1801 def __str__(self): 

1802 if self.start.value is not None and self.start.value < 0: 

1803 start = "%d %s" % (abs(self.start.value), connection.ops.PRECEDING) 

1804 elif self.start.value is not None and self.start.value == 0: 

1805 start = connection.ops.CURRENT_ROW 

1806 else: 

1807 start = connection.ops.UNBOUNDED_PRECEDING 

1808 

1809 if self.end.value is not None and self.end.value > 0: 

1810 end = "%d %s" % (self.end.value, connection.ops.FOLLOWING) 

1811 elif self.end.value is not None and self.end.value == 0: 

1812 end = connection.ops.CURRENT_ROW 

1813 else: 

1814 end = connection.ops.UNBOUNDED_FOLLOWING 

1815 return self.template % { 

1816 "frame_type": self.frame_type, 

1817 "start": start, 

1818 "end": end, 

1819 } 

1820 

1821 def window_frame_start_end(self, connection, start, end): 

1822 raise NotImplementedError("Subclasses must implement window_frame_start_end().") 

1823 

1824 

1825class RowRange(WindowFrame): 

1826 frame_type = "ROWS" 

1827 

1828 def window_frame_start_end(self, connection, start, end): 

1829 return connection.ops.window_frame_rows_start_end(start, end) 

1830 

1831 

1832class ValueRange(WindowFrame): 

1833 frame_type = "RANGE" 

1834 

1835 def window_frame_start_end(self, connection, start, end): 

1836 return connection.ops.window_frame_range_start_end(start, end)