Coverage for /Users/davegaeddert/Development/dropseed/plain/plain/plain/validators.py: 47%

294 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-10-16 22:03 -0500

1import ipaddress 

2import math 

3import re 

4from pathlib import Path 

5from urllib.parse import urlsplit, urlunsplit 

6 

7from plain.exceptions import ValidationError 

8from plain.utils.deconstruct import deconstructible 

9from plain.utils.encoding import punycode 

10from plain.utils.ipv6 import is_valid_ipv6_address 

11from plain.utils.regex_helper import _lazy_re_compile 

12from plain.utils.text import pluralize_lazy 

13 

14# These values, if given to validate(), will trigger the self.required check. 

15EMPTY_VALUES = (None, "", [], (), {}) 

16 

17 

18@deconstructible 

19class RegexValidator: 

20 regex = "" 

21 message = "Enter a valid value." 

22 code = "invalid" 

23 inverse_match = False 

24 flags = 0 

25 

26 def __init__( 

27 self, regex=None, message=None, code=None, inverse_match=None, flags=None 

28 ): 

29 if regex is not None: 

30 self.regex = regex 

31 if message is not None: 

32 self.message = message 

33 if code is not None: 

34 self.code = code 

35 if inverse_match is not None: 

36 self.inverse_match = inverse_match 

37 if flags is not None: 

38 self.flags = flags 

39 if self.flags and not isinstance(self.regex, str): 

40 raise TypeError( 

41 "If the flags are set, regex must be a regular expression string." 

42 ) 

43 

44 self.regex = _lazy_re_compile(self.regex, self.flags) 

45 

46 def __call__(self, value): 

47 """ 

48 Validate that the input contains (or does *not* contain, if 

49 inverse_match is True) a match for the regular expression. 

50 """ 

51 regex_matches = self.regex.search(str(value)) 

52 invalid_input = regex_matches if self.inverse_match else not regex_matches 

53 if invalid_input: 

54 raise ValidationError(self.message, code=self.code, params={"value": value}) 

55 

56 def __eq__(self, other): 

57 return ( 

58 isinstance(other, RegexValidator) 

59 and self.regex.pattern == other.regex.pattern 

60 and self.regex.flags == other.regex.flags 

61 and (self.message == other.message) 

62 and (self.code == other.code) 

63 and (self.inverse_match == other.inverse_match) 

64 ) 

65 

66 

67@deconstructible 

68class URLValidator(RegexValidator): 

69 ul = "\u00a1-\uffff" # Unicode letters range (must not be a raw string). 

70 

71 # IP patterns 

72 ipv4_re = ( 

73 r"(?:0|25[0-5]|2[0-4][0-9]|1[0-9]?[0-9]?|[1-9][0-9]?)" 

74 r"(?:\.(?:0|25[0-5]|2[0-4][0-9]|1[0-9]?[0-9]?|[1-9][0-9]?)){3}" 

75 ) 

76 ipv6_re = r"\[[0-9a-f:.]+\]" # (simple regex, validated later) 

77 

78 # Host patterns 

79 hostname_re = ( 

80 r"[a-z" + ul + r"0-9](?:[a-z" + ul + r"0-9-]{0,61}[a-z" + ul + r"0-9])?" 

81 ) 

82 # Max length for domain name labels is 63 characters per RFC 1034 sec. 3.1 

83 domain_re = r"(?:\.(?!-)[a-z" + ul + r"0-9-]{1,63}(?<!-))*" 

84 tld_re = ( 

85 r"\." # dot 

86 r"(?!-)" # can't start with a dash 

87 r"(?:[a-z" + ul + "-]{2,63}" # domain label 

88 r"|xn--[a-z0-9]{1,59})" # or punycode label 

89 r"(?<!-)" # can't end with a dash 

90 r"\.?" # may have a trailing dot 

91 ) 

92 host_re = "(" + hostname_re + domain_re + tld_re + "|localhost)" 

93 

94 regex = _lazy_re_compile( 

95 r"^(?:[a-z0-9.+-]*)://" # scheme is validated separately 

96 r"(?:[^\s:@/]+(?::[^\s:@/]*)?@)?" # user:pass authentication 

97 r"(?:" + ipv4_re + "|" + ipv6_re + "|" + host_re + ")" 

98 r"(?::[0-9]{1,5})?" # port 

99 r"(?:[/?#][^\s]*)?" # resource path 

100 r"\Z", 

101 re.IGNORECASE, 

102 ) 

103 message = "Enter a valid URL." 

104 schemes = ["http", "https", "ftp", "ftps"] 

105 unsafe_chars = frozenset("\t\r\n") 

106 

107 def __init__(self, schemes=None, **kwargs): 

108 super().__init__(**kwargs) 

109 if schemes is not None: 

110 self.schemes = schemes 

111 

112 def __call__(self, value): 

113 if not isinstance(value, str): 

114 raise ValidationError(self.message, code=self.code, params={"value": value}) 

115 if self.unsafe_chars.intersection(value): 

116 raise ValidationError(self.message, code=self.code, params={"value": value}) 

117 # Check if the scheme is valid. 

118 scheme = value.split("://")[0].lower() 

119 if scheme not in self.schemes: 

120 raise ValidationError(self.message, code=self.code, params={"value": value}) 

121 

122 # Then check full URL 

123 try: 

124 splitted_url = urlsplit(value) 

125 except ValueError: 

126 raise ValidationError(self.message, code=self.code, params={"value": value}) 

127 try: 

128 super().__call__(value) 

129 except ValidationError as e: 

130 # Trivial case failed. Try for possible IDN domain 

131 if value: 

132 scheme, netloc, path, query, fragment = splitted_url 

133 try: 

134 netloc = punycode(netloc) # IDN -> ACE 

135 except UnicodeError: # invalid domain part 

136 raise e 

137 url = urlunsplit((scheme, netloc, path, query, fragment)) 

138 super().__call__(url) 

139 else: 

140 raise 

141 else: 

142 # Now verify IPv6 in the netloc part 

143 host_match = re.search(r"^\[(.+)\](?::[0-9]{1,5})?$", splitted_url.netloc) 

144 if host_match: 

145 potential_ip = host_match[1] 

146 try: 

147 validate_ipv6_address(potential_ip) 

148 except ValidationError: 

149 raise ValidationError( 

150 self.message, code=self.code, params={"value": value} 

151 ) 

152 

153 # The maximum length of a full host name is 253 characters per RFC 1034 

154 # section 3.1. It's defined to be 255 bytes or less, but this includes 

155 # one byte for the length of the name and one byte for the trailing dot 

156 # that's used to indicate absolute names in DNS. 

157 if splitted_url.hostname is None or len(splitted_url.hostname) > 253: 

158 raise ValidationError(self.message, code=self.code, params={"value": value}) 

159 

160 

161integer_validator = RegexValidator( 

162 _lazy_re_compile(r"^-?\d+\Z"), 

163 message="Enter a valid integer.", 

164 code="invalid", 

165) 

166 

167 

168def validate_integer(value): 

169 return integer_validator(value) 

170 

171 

172@deconstructible 

173class EmailValidator: 

174 message = "Enter a valid email address." 

175 code = "invalid" 

176 user_regex = _lazy_re_compile( 

177 # dot-atom 

178 r"(^[-!#$%&'*+/=?^_`{}|~0-9A-Z]+(\.[-!#$%&'*+/=?^_`{}|~0-9A-Z]+)*\Z" 

179 # quoted-string 

180 r'|^"([\001-\010\013\014\016-\037!#-\[\]-\177]|\\[\001-\011\013\014\016-\177])' 

181 r'*"\Z)', 

182 re.IGNORECASE, 

183 ) 

184 domain_regex = _lazy_re_compile( 

185 # max length for domain name labels is 63 characters per RFC 1034 

186 r"((?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+)(?:[A-Z0-9-]{2,63}(?<!-))\Z", 

187 re.IGNORECASE, 

188 ) 

189 literal_regex = _lazy_re_compile( 

190 # literal form, ipv4 or ipv6 address (SMTP 4.1.3) 

191 r"\[([A-F0-9:.]+)\]\Z", 

192 re.IGNORECASE, 

193 ) 

194 domain_allowlist = ["localhost"] 

195 

196 def __init__(self, message=None, code=None, allowlist=None): 

197 if message is not None: 

198 self.message = message 

199 if code is not None: 

200 self.code = code 

201 if allowlist is not None: 

202 self.domain_allowlist = allowlist 

203 

204 def __call__(self, value): 

205 if not value or "@" not in value: 

206 raise ValidationError(self.message, code=self.code, params={"value": value}) 

207 

208 user_part, domain_part = value.rsplit("@", 1) 

209 

210 if not self.user_regex.match(user_part): 

211 raise ValidationError(self.message, code=self.code, params={"value": value}) 

212 

213 if domain_part not in self.domain_allowlist and not self.validate_domain_part( 

214 domain_part 

215 ): 

216 # Try for possible IDN domain-part 

217 try: 

218 domain_part = punycode(domain_part) 

219 except UnicodeError: 

220 pass 

221 else: 

222 if self.validate_domain_part(domain_part): 

223 return 

224 raise ValidationError(self.message, code=self.code, params={"value": value}) 

225 

226 def validate_domain_part(self, domain_part): 

227 if self.domain_regex.match(domain_part): 

228 return True 

229 

230 literal_match = self.literal_regex.match(domain_part) 

231 if literal_match: 

232 ip_address = literal_match[1] 

233 try: 

234 validate_ipv46_address(ip_address) 

235 return True 

236 except ValidationError: 

237 pass 

238 return False 

239 

240 def __eq__(self, other): 

241 return ( 

242 isinstance(other, EmailValidator) 

243 and (self.domain_allowlist == other.domain_allowlist) 

244 and (self.message == other.message) 

245 and (self.code == other.code) 

246 ) 

247 

248 

249validate_email = EmailValidator() 

250 

251slug_re = _lazy_re_compile(r"^[-a-zA-Z0-9_]+\Z") 

252validate_slug = RegexValidator( 

253 slug_re, 

254 # Translators: "letters" means latin letters: a-z and A-Z. 

255 "Enter a valid “slug” consisting of letters, numbers, underscores or hyphens.", 

256 "invalid", 

257) 

258 

259slug_unicode_re = _lazy_re_compile(r"^[-\w]+\Z") 

260validate_unicode_slug = RegexValidator( 

261 slug_unicode_re, 

262 "Enter a valid “slug” consisting of Unicode letters, numbers, underscores, or hyphens." 

263 "invalid", 

264) 

265 

266 

267def validate_ipv4_address(value): 

268 try: 

269 ipaddress.IPv4Address(value) 

270 except ValueError: 

271 raise ValidationError( 

272 "Enter a valid IPv4 address.", code="invalid", params={"value": value} 

273 ) 

274 

275 

276def validate_ipv6_address(value): 

277 if not is_valid_ipv6_address(value): 

278 raise ValidationError( 

279 "Enter a valid IPv6 address.", code="invalid", params={"value": value} 

280 ) 

281 

282 

283def validate_ipv46_address(value): 

284 try: 

285 validate_ipv4_address(value) 

286 except ValidationError: 

287 try: 

288 validate_ipv6_address(value) 

289 except ValidationError: 

290 raise ValidationError( 

291 "Enter a valid IPv4 or IPv6 address.", 

292 code="invalid", 

293 params={"value": value}, 

294 ) 

295 

296 

297ip_address_validator_map = { 

298 "both": ([validate_ipv46_address], "Enter a valid IPv4 or IPv6 address."), 

299 "ipv4": ([validate_ipv4_address], "Enter a valid IPv4 address."), 

300 "ipv6": ([validate_ipv6_address], "Enter a valid IPv6 address."), 

301} 

302 

303 

304def ip_address_validators(protocol, unpack_ipv4): 

305 """ 

306 Depending on the given parameters, return the appropriate validators for 

307 the GenericIPAddressField. 

308 """ 

309 if protocol != "both" and unpack_ipv4: 

310 raise ValueError( 

311 "You can only use `unpack_ipv4` if `protocol` is set to 'both'" 

312 ) 

313 try: 

314 return ip_address_validator_map[protocol.lower()] 

315 except KeyError: 

316 raise ValueError( 

317 "The protocol '{}' is unknown. Supported: {}".format( 

318 protocol, list(ip_address_validator_map) 

319 ) 

320 ) 

321 

322 

323def int_list_validator(sep=",", message=None, code="invalid", allow_negative=False): 

324 regexp = _lazy_re_compile( 

325 r"^{neg}\d+(?:{sep}{neg}\d+)*\Z".format( 

326 neg="(-)?" if allow_negative else "", 

327 sep=re.escape(sep), 

328 ) 

329 ) 

330 return RegexValidator(regexp, message=message, code=code) 

331 

332 

333validate_comma_separated_integer_list = int_list_validator( 

334 message="Enter only digits separated by commas.", 

335) 

336 

337 

338@deconstructible 

339class BaseValidator: 

340 message = "Ensure this value is %(limit_value)s (it is %(show_value)s)." 

341 code = "limit_value" 

342 

343 def __init__(self, limit_value, message=None): 

344 self.limit_value = limit_value 

345 if message: 

346 self.message = message 

347 

348 def __call__(self, value): 

349 cleaned = self.clean(value) 

350 limit_value = ( 

351 self.limit_value() if callable(self.limit_value) else self.limit_value 

352 ) 

353 params = {"limit_value": limit_value, "show_value": cleaned, "value": value} 

354 if self.compare(cleaned, limit_value): 

355 raise ValidationError(self.message, code=self.code, params=params) 

356 

357 def __eq__(self, other): 

358 if not isinstance(other, self.__class__): 

359 return NotImplemented 

360 return ( 

361 self.limit_value == other.limit_value 

362 and self.message == other.message 

363 and self.code == other.code 

364 ) 

365 

366 def compare(self, a, b): 

367 return a is not b 

368 

369 def clean(self, x): 

370 return x 

371 

372 

373@deconstructible 

374class MaxValueValidator(BaseValidator): 

375 message = "Ensure this value is less than or equal to %(limit_value)s." 

376 code = "max_value" 

377 

378 def compare(self, a, b): 

379 return a > b 

380 

381 

382@deconstructible 

383class MinValueValidator(BaseValidator): 

384 message = "Ensure this value is greater than or equal to %(limit_value)s." 

385 code = "min_value" 

386 

387 def compare(self, a, b): 

388 return a < b 

389 

390 

391@deconstructible 

392class StepValueValidator(BaseValidator): 

393 message = "Ensure this value is a multiple of step size %(limit_value)s." 

394 code = "step_size" 

395 

396 def compare(self, a, b): 

397 return not math.isclose(math.remainder(a, b), 0, abs_tol=1e-9) 

398 

399 

400@deconstructible 

401class MinLengthValidator(BaseValidator): 

402 message = pluralize_lazy( 

403 "Ensure this value has at least %(limit_value)d character (it has " 

404 "%(show_value)d).", 

405 "Ensure this value has at least %(limit_value)d characters (it has " 

406 "%(show_value)d).", 

407 "limit_value", 

408 ) 

409 code = "min_length" 

410 

411 def compare(self, a, b): 

412 return a < b 

413 

414 def clean(self, x): 

415 return len(x) 

416 

417 

418@deconstructible 

419class MaxLengthValidator(BaseValidator): 

420 message = pluralize_lazy( 

421 "Ensure this value has at most %(limit_value)d character (it has " 

422 "%(show_value)d).", 

423 "Ensure this value has at most %(limit_value)d characters (it has " 

424 "%(show_value)d).", 

425 "limit_value", 

426 ) 

427 code = "max_length" 

428 

429 def compare(self, a, b): 

430 return a > b 

431 

432 def clean(self, x): 

433 return len(x) 

434 

435 

436@deconstructible 

437class DecimalValidator: 

438 """ 

439 Validate that the input does not exceed the maximum number of digits 

440 expected, otherwise raise ValidationError. 

441 """ 

442 

443 messages = { 

444 "invalid": "Enter a number.", 

445 "max_digits": pluralize_lazy( 

446 "Ensure that there are no more than %(max)s digit in total.", 

447 "Ensure that there are no more than %(max)s digits in total.", 

448 "max", 

449 ), 

450 "max_decimal_places": pluralize_lazy( 

451 "Ensure that there are no more than %(max)s decimal place.", 

452 "Ensure that there are no more than %(max)s decimal places.", 

453 "max", 

454 ), 

455 "max_whole_digits": pluralize_lazy( 

456 "Ensure that there are no more than %(max)s digit before the decimal " 

457 "point.", 

458 "Ensure that there are no more than %(max)s digits before the decimal " 

459 "point.", 

460 "max", 

461 ), 

462 } 

463 

464 def __init__(self, max_digits, decimal_places): 

465 self.max_digits = max_digits 

466 self.decimal_places = decimal_places 

467 

468 def __call__(self, value): 

469 digit_tuple, exponent = value.as_tuple()[1:] 

470 if exponent in {"F", "n", "N"}: 

471 raise ValidationError( 

472 self.messages["invalid"], code="invalid", params={"value": value} 

473 ) 

474 if exponent >= 0: 

475 digits = len(digit_tuple) 

476 if digit_tuple != (0,): 

477 # A positive exponent adds that many trailing zeros. 

478 digits += exponent 

479 decimals = 0 

480 else: 

481 # If the absolute value of the negative exponent is larger than the 

482 # number of digits, then it's the same as the number of digits, 

483 # because it'll consume all of the digits in digit_tuple and then 

484 # add abs(exponent) - len(digit_tuple) leading zeros after the 

485 # decimal point. 

486 if abs(exponent) > len(digit_tuple): 

487 digits = decimals = abs(exponent) 

488 else: 

489 digits = len(digit_tuple) 

490 decimals = abs(exponent) 

491 whole_digits = digits - decimals 

492 

493 if self.max_digits is not None and digits > self.max_digits: 

494 raise ValidationError( 

495 self.messages["max_digits"], 

496 code="max_digits", 

497 params={"max": self.max_digits, "value": value}, 

498 ) 

499 if self.decimal_places is not None and decimals > self.decimal_places: 

500 raise ValidationError( 

501 self.messages["max_decimal_places"], 

502 code="max_decimal_places", 

503 params={"max": self.decimal_places, "value": value}, 

504 ) 

505 if ( 

506 self.max_digits is not None 

507 and self.decimal_places is not None 

508 and whole_digits > (self.max_digits - self.decimal_places) 

509 ): 

510 raise ValidationError( 

511 self.messages["max_whole_digits"], 

512 code="max_whole_digits", 

513 params={"max": (self.max_digits - self.decimal_places), "value": value}, 

514 ) 

515 

516 def __eq__(self, other): 

517 return ( 

518 isinstance(other, self.__class__) 

519 and self.max_digits == other.max_digits 

520 and self.decimal_places == other.decimal_places 

521 ) 

522 

523 

524@deconstructible 

525class FileExtensionValidator: 

526 message = "File extension “%(extension)s” is not allowed. Allowed extensions are: %(allowed_extensions)s." 

527 code = "invalid_extension" 

528 

529 def __init__(self, allowed_extensions=None, message=None, code=None): 

530 if allowed_extensions is not None: 

531 allowed_extensions = [ 

532 allowed_extension.lower() for allowed_extension in allowed_extensions 

533 ] 

534 self.allowed_extensions = allowed_extensions 

535 if message is not None: 

536 self.message = message 

537 if code is not None: 

538 self.code = code 

539 

540 def __call__(self, value): 

541 extension = Path(value.name).suffix[1:].lower() 

542 if ( 

543 self.allowed_extensions is not None 

544 and extension not in self.allowed_extensions 

545 ): 

546 raise ValidationError( 

547 self.message, 

548 code=self.code, 

549 params={ 

550 "extension": extension, 

551 "allowed_extensions": ", ".join(self.allowed_extensions), 

552 "value": value, 

553 }, 

554 ) 

555 

556 def __eq__(self, other): 

557 return ( 

558 isinstance(other, self.__class__) 

559 and self.allowed_extensions == other.allowed_extensions 

560 and self.message == other.message 

561 and self.code == other.code 

562 ) 

563 

564 

565def get_available_image_extensions(): 

566 try: 

567 from PIL import Image 

568 except ImportError: 

569 return [] 

570 else: 

571 Image.init() 

572 return [ext.lower()[1:] for ext in Image.EXTENSION] 

573 

574 

575def validate_image_file_extension(value): 

576 return FileExtensionValidator(allowed_extensions=get_available_image_extensions())( 

577 value 

578 ) 

579 

580 

581@deconstructible 

582class ProhibitNullCharactersValidator: 

583 """Validate that the string doesn't contain the null character.""" 

584 

585 message = "Null characters are not allowed." 

586 code = "null_characters_not_allowed" 

587 

588 def __init__(self, message=None, code=None): 

589 if message is not None: 

590 self.message = message 

591 if code is not None: 

592 self.code = code 

593 

594 def __call__(self, value): 

595 if "\x00" in str(value): 

596 raise ValidationError(self.message, code=self.code, params={"value": value}) 

597 

598 def __eq__(self, other): 

599 return ( 

600 isinstance(other, self.__class__) 

601 and self.message == other.message 

602 and self.code == other.code 

603 )