Coverage for /Users/davegaeddert/Development/dropseed/plain/plain-models/plain/models/constraints.py: 23%

233 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-10-16 22:04 -0500

1import warnings 

2from enum import Enum 

3from types import NoneType 

4 

5from plain.exceptions import FieldError, ValidationError 

6from plain.models.db import DEFAULT_DB_ALIAS, connections 

7from plain.models.expressions import Exists, ExpressionList, F, OrderBy 

8from plain.models.indexes import IndexExpression 

9from plain.models.lookups import Exact 

10from plain.models.query_utils import Q 

11from plain.models.sql.query import Query 

12from plain.utils.deprecation import RemovedInDjango60Warning 

13 

14__all__ = ["BaseConstraint", "CheckConstraint", "Deferrable", "UniqueConstraint"] 

15 

16 

17class BaseConstraint: 

18 default_violation_error_message = "Constraint “%(name)s” is violated." 

19 violation_error_code = None 

20 violation_error_message = None 

21 

22 # RemovedInDjango60Warning: When the deprecation ends, replace with: 

23 # def __init__( 

24 # self, *, name, violation_error_code=None, violation_error_message=None 

25 # ): 

26 def __init__( 

27 self, *args, name=None, violation_error_code=None, violation_error_message=None 

28 ): 

29 # RemovedInDjango60Warning. 

30 if name is None and not args: 

31 raise TypeError( 

32 f"{self.__class__.__name__}.__init__() missing 1 required keyword-only " 

33 f"argument: 'name'" 

34 ) 

35 self.name = name 

36 if violation_error_code is not None: 

37 self.violation_error_code = violation_error_code 

38 if violation_error_message is not None: 

39 self.violation_error_message = violation_error_message 

40 else: 

41 self.violation_error_message = self.default_violation_error_message 

42 # RemovedInDjango60Warning. 

43 if args: 

44 warnings.warn( 

45 f"Passing positional arguments to {self.__class__.__name__} is " 

46 f"deprecated.", 

47 RemovedInDjango60Warning, 

48 stacklevel=2, 

49 ) 

50 for arg, attr in zip(args, ["name", "violation_error_message"]): 

51 if arg: 

52 setattr(self, attr, arg) 

53 

54 @property 

55 def contains_expressions(self): 

56 return False 

57 

58 def constraint_sql(self, model, schema_editor): 

59 raise NotImplementedError("This method must be implemented by a subclass.") 

60 

61 def create_sql(self, model, schema_editor): 

62 raise NotImplementedError("This method must be implemented by a subclass.") 

63 

64 def remove_sql(self, model, schema_editor): 

65 raise NotImplementedError("This method must be implemented by a subclass.") 

66 

67 def validate(self, model, instance, exclude=None, using=DEFAULT_DB_ALIAS): 

68 raise NotImplementedError("This method must be implemented by a subclass.") 

69 

70 def get_violation_error_message(self): 

71 return self.violation_error_message % {"name": self.name} 

72 

73 def deconstruct(self): 

74 path = f"{self.__class__.__module__}.{self.__class__.__name__}" 

75 path = path.replace("plain.models.constraints", "plain.models") 

76 kwargs = {"name": self.name} 

77 if ( 

78 self.violation_error_message is not None 

79 and self.violation_error_message != self.default_violation_error_message 

80 ): 

81 kwargs["violation_error_message"] = self.violation_error_message 

82 if self.violation_error_code is not None: 

83 kwargs["violation_error_code"] = self.violation_error_code 

84 return (path, (), kwargs) 

85 

86 def clone(self): 

87 _, args, kwargs = self.deconstruct() 

88 return self.__class__(*args, **kwargs) 

89 

90 

91class CheckConstraint(BaseConstraint): 

92 def __init__( 

93 self, *, check, name, violation_error_code=None, violation_error_message=None 

94 ): 

95 self.check = check 

96 if not getattr(check, "conditional", False): 

97 raise TypeError( 

98 "CheckConstraint.check must be a Q instance or boolean expression." 

99 ) 

100 super().__init__( 

101 name=name, 

102 violation_error_code=violation_error_code, 

103 violation_error_message=violation_error_message, 

104 ) 

105 

106 def _get_check_sql(self, model, schema_editor): 

107 query = Query(model=model, alias_cols=False) 

108 where = query.build_where(self.check) 

109 compiler = query.get_compiler(connection=schema_editor.connection) 

110 sql, params = where.as_sql(compiler, schema_editor.connection) 

111 return sql % tuple(schema_editor.quote_value(p) for p in params) 

112 

113 def constraint_sql(self, model, schema_editor): 

114 check = self._get_check_sql(model, schema_editor) 

115 return schema_editor._check_sql(self.name, check) 

116 

117 def create_sql(self, model, schema_editor): 

118 check = self._get_check_sql(model, schema_editor) 

119 return schema_editor._create_check_sql(model, self.name, check) 

120 

121 def remove_sql(self, model, schema_editor): 

122 return schema_editor._delete_check_sql(model, self.name) 

123 

124 def validate(self, model, instance, exclude=None, using=DEFAULT_DB_ALIAS): 

125 against = instance._get_field_value_map(meta=model._meta, exclude=exclude) 

126 try: 

127 if not Q(self.check).check(against, using=using): 

128 raise ValidationError( 

129 self.get_violation_error_message(), code=self.violation_error_code 

130 ) 

131 except FieldError: 

132 pass 

133 

134 def __repr__(self): 

135 return "<{}: check={} name={}{}{}>".format( 

136 self.__class__.__qualname__, 

137 self.check, 

138 repr(self.name), 

139 ( 

140 "" 

141 if self.violation_error_code is None 

142 else " violation_error_code=%r" % self.violation_error_code 

143 ), 

144 ( 

145 "" 

146 if self.violation_error_message is None 

147 or self.violation_error_message == self.default_violation_error_message 

148 else " violation_error_message=%r" % self.violation_error_message 

149 ), 

150 ) 

151 

152 def __eq__(self, other): 

153 if isinstance(other, CheckConstraint): 

154 return ( 

155 self.name == other.name 

156 and self.check == other.check 

157 and self.violation_error_code == other.violation_error_code 

158 and self.violation_error_message == other.violation_error_message 

159 ) 

160 return super().__eq__(other) 

161 

162 def deconstruct(self): 

163 path, args, kwargs = super().deconstruct() 

164 kwargs["check"] = self.check 

165 return path, args, kwargs 

166 

167 

168class Deferrable(Enum): 

169 DEFERRED = "deferred" 

170 IMMEDIATE = "immediate" 

171 

172 # A similar format was proposed for Python 3.10. 

173 def __repr__(self): 

174 return f"{self.__class__.__qualname__}.{self._name_}" 

175 

176 

177class UniqueConstraint(BaseConstraint): 

178 def __init__( 

179 self, 

180 *expressions, 

181 fields=(), 

182 name=None, 

183 condition=None, 

184 deferrable=None, 

185 include=None, 

186 opclasses=(), 

187 violation_error_code=None, 

188 violation_error_message=None, 

189 ): 

190 if not name: 

191 raise ValueError("A unique constraint must be named.") 

192 if not expressions and not fields: 

193 raise ValueError( 

194 "At least one field or expression is required to define a " 

195 "unique constraint." 

196 ) 

197 if expressions and fields: 

198 raise ValueError( 

199 "UniqueConstraint.fields and expressions are mutually exclusive." 

200 ) 

201 if not isinstance(condition, NoneType | Q): 

202 raise ValueError("UniqueConstraint.condition must be a Q instance.") 

203 if condition and deferrable: 

204 raise ValueError("UniqueConstraint with conditions cannot be deferred.") 

205 if include and deferrable: 

206 raise ValueError("UniqueConstraint with include fields cannot be deferred.") 

207 if opclasses and deferrable: 

208 raise ValueError("UniqueConstraint with opclasses cannot be deferred.") 

209 if expressions and deferrable: 

210 raise ValueError("UniqueConstraint with expressions cannot be deferred.") 

211 if expressions and opclasses: 

212 raise ValueError( 

213 "UniqueConstraint.opclasses cannot be used with expressions. " 

214 "Use a custom OpClass() instead." 

215 ) 

216 if not isinstance(deferrable, NoneType | Deferrable): 

217 raise ValueError( 

218 "UniqueConstraint.deferrable must be a Deferrable instance." 

219 ) 

220 if not isinstance(include, NoneType | list | tuple): 

221 raise ValueError("UniqueConstraint.include must be a list or tuple.") 

222 if not isinstance(opclasses, list | tuple): 

223 raise ValueError("UniqueConstraint.opclasses must be a list or tuple.") 

224 if opclasses and len(fields) != len(opclasses): 

225 raise ValueError( 

226 "UniqueConstraint.fields and UniqueConstraint.opclasses must " 

227 "have the same number of elements." 

228 ) 

229 self.fields = tuple(fields) 

230 self.condition = condition 

231 self.deferrable = deferrable 

232 self.include = tuple(include) if include else () 

233 self.opclasses = opclasses 

234 self.expressions = tuple( 

235 F(expression) if isinstance(expression, str) else expression 

236 for expression in expressions 

237 ) 

238 super().__init__( 

239 name=name, 

240 violation_error_code=violation_error_code, 

241 violation_error_message=violation_error_message, 

242 ) 

243 

244 @property 

245 def contains_expressions(self): 

246 return bool(self.expressions) 

247 

248 def _get_condition_sql(self, model, schema_editor): 

249 if self.condition is None: 

250 return None 

251 query = Query(model=model, alias_cols=False) 

252 where = query.build_where(self.condition) 

253 compiler = query.get_compiler(connection=schema_editor.connection) 

254 sql, params = where.as_sql(compiler, schema_editor.connection) 

255 return sql % tuple(schema_editor.quote_value(p) for p in params) 

256 

257 def _get_index_expressions(self, model, schema_editor): 

258 if not self.expressions: 

259 return None 

260 index_expressions = [] 

261 for expression in self.expressions: 

262 index_expression = IndexExpression(expression) 

263 index_expression.set_wrapper_classes(schema_editor.connection) 

264 index_expressions.append(index_expression) 

265 return ExpressionList(*index_expressions).resolve_expression( 

266 Query(model, alias_cols=False), 

267 ) 

268 

269 def constraint_sql(self, model, schema_editor): 

270 fields = [model._meta.get_field(field_name) for field_name in self.fields] 

271 include = [ 

272 model._meta.get_field(field_name).column for field_name in self.include 

273 ] 

274 condition = self._get_condition_sql(model, schema_editor) 

275 expressions = self._get_index_expressions(model, schema_editor) 

276 return schema_editor._unique_sql( 

277 model, 

278 fields, 

279 self.name, 

280 condition=condition, 

281 deferrable=self.deferrable, 

282 include=include, 

283 opclasses=self.opclasses, 

284 expressions=expressions, 

285 ) 

286 

287 def create_sql(self, model, schema_editor): 

288 fields = [model._meta.get_field(field_name) for field_name in self.fields] 

289 include = [ 

290 model._meta.get_field(field_name).column for field_name in self.include 

291 ] 

292 condition = self._get_condition_sql(model, schema_editor) 

293 expressions = self._get_index_expressions(model, schema_editor) 

294 return schema_editor._create_unique_sql( 

295 model, 

296 fields, 

297 self.name, 

298 condition=condition, 

299 deferrable=self.deferrable, 

300 include=include, 

301 opclasses=self.opclasses, 

302 expressions=expressions, 

303 ) 

304 

305 def remove_sql(self, model, schema_editor): 

306 condition = self._get_condition_sql(model, schema_editor) 

307 include = [ 

308 model._meta.get_field(field_name).column for field_name in self.include 

309 ] 

310 expressions = self._get_index_expressions(model, schema_editor) 

311 return schema_editor._delete_unique_sql( 

312 model, 

313 self.name, 

314 condition=condition, 

315 deferrable=self.deferrable, 

316 include=include, 

317 opclasses=self.opclasses, 

318 expressions=expressions, 

319 ) 

320 

321 def __repr__(self): 

322 return "<{}:{}{}{}{}{}{}{}{}{}>".format( 

323 self.__class__.__qualname__, 

324 "" if not self.fields else " fields=%s" % repr(self.fields), 

325 "" if not self.expressions else " expressions=%s" % repr(self.expressions), 

326 " name=%s" % repr(self.name), 

327 "" if self.condition is None else " condition=%s" % self.condition, 

328 "" if self.deferrable is None else " deferrable=%r" % self.deferrable, 

329 "" if not self.include else " include=%s" % repr(self.include), 

330 "" if not self.opclasses else " opclasses=%s" % repr(self.opclasses), 

331 ( 

332 "" 

333 if self.violation_error_code is None 

334 else " violation_error_code=%r" % self.violation_error_code 

335 ), 

336 ( 

337 "" 

338 if self.violation_error_message is None 

339 or self.violation_error_message == self.default_violation_error_message 

340 else " violation_error_message=%r" % self.violation_error_message 

341 ), 

342 ) 

343 

344 def __eq__(self, other): 

345 if isinstance(other, UniqueConstraint): 

346 return ( 

347 self.name == other.name 

348 and self.fields == other.fields 

349 and self.condition == other.condition 

350 and self.deferrable == other.deferrable 

351 and self.include == other.include 

352 and self.opclasses == other.opclasses 

353 and self.expressions == other.expressions 

354 and self.violation_error_code == other.violation_error_code 

355 and self.violation_error_message == other.violation_error_message 

356 ) 

357 return super().__eq__(other) 

358 

359 def deconstruct(self): 

360 path, args, kwargs = super().deconstruct() 

361 if self.fields: 

362 kwargs["fields"] = self.fields 

363 if self.condition: 

364 kwargs["condition"] = self.condition 

365 if self.deferrable: 

366 kwargs["deferrable"] = self.deferrable 

367 if self.include: 

368 kwargs["include"] = self.include 

369 if self.opclasses: 

370 kwargs["opclasses"] = self.opclasses 

371 return path, self.expressions, kwargs 

372 

373 def validate(self, model, instance, exclude=None, using=DEFAULT_DB_ALIAS): 

374 queryset = model._default_manager.using(using) 

375 if self.fields: 

376 lookup_kwargs = {} 

377 for field_name in self.fields: 

378 if exclude and field_name in exclude: 

379 return 

380 field = model._meta.get_field(field_name) 

381 lookup_value = getattr(instance, field.attname) 

382 if lookup_value is None or ( 

383 lookup_value == "" 

384 and connections[using].features.interprets_empty_strings_as_nulls 

385 ): 

386 # A composite constraint containing NULL value cannot cause 

387 # a violation since NULL != NULL in SQL. 

388 return 

389 lookup_kwargs[field.name] = lookup_value 

390 queryset = queryset.filter(**lookup_kwargs) 

391 else: 

392 # Ignore constraints with excluded fields. 

393 if exclude: 

394 for expression in self.expressions: 

395 if hasattr(expression, "flatten"): 

396 for expr in expression.flatten(): 

397 if isinstance(expr, F) and expr.name in exclude: 

398 return 

399 elif isinstance(expression, F) and expression.name in exclude: 

400 return 

401 replacements = { 

402 F(field): value 

403 for field, value in instance._get_field_value_map( 

404 meta=model._meta, exclude=exclude 

405 ).items() 

406 } 

407 expressions = [] 

408 for expr in self.expressions: 

409 # Ignore ordering. 

410 if isinstance(expr, OrderBy): 

411 expr = expr.expression 

412 expressions.append(Exact(expr, expr.replace_expressions(replacements))) 

413 queryset = queryset.filter(*expressions) 

414 model_class_pk = instance._get_pk_val(model._meta) 

415 if not instance._state.adding and model_class_pk is not None: 

416 queryset = queryset.exclude(pk=model_class_pk) 

417 if not self.condition: 

418 if queryset.exists(): 

419 if self.expressions: 

420 raise ValidationError( 

421 self.get_violation_error_message(), 

422 code=self.violation_error_code, 

423 ) 

424 # When fields are defined, use the unique_error_message() for 

425 # backward compatibility. 

426 for model, constraints in instance.get_constraints(): 

427 for constraint in constraints: 

428 if constraint is self: 

429 raise ValidationError( 

430 instance.unique_error_message(model, self.fields), 

431 ) 

432 else: 

433 against = instance._get_field_value_map(meta=model._meta, exclude=exclude) 

434 try: 

435 if (self.condition & Exists(queryset.filter(self.condition))).check( 

436 against, using=using 

437 ): 

438 raise ValidationError( 

439 self.get_violation_error_message(), 

440 code=self.violation_error_code, 

441 ) 

442 except FieldError: 

443 pass