Coverage for /Users/davegaeddert/Developer/dropseed/plain/plain-models/plain/models/functions/datetime.py: 53%

195 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2024-12-23 11:16 -0600

1from datetime import datetime 

2 

3from plain.models.expressions import Func 

4from plain.models.fields import ( 

5 DateField, 

6 DateTimeField, 

7 DurationField, 

8 Field, 

9 IntegerField, 

10 TimeField, 

11) 

12from plain.models.lookups import ( 

13 Transform, 

14 YearExact, 

15 YearGt, 

16 YearGte, 

17 YearLt, 

18 YearLte, 

19) 

20from plain.utils import timezone 

21 

22 

23class TimezoneMixin: 

24 tzinfo = None 

25 

26 def get_tzname(self): 

27 # Timezone conversions must happen to the input datetime *before* 

28 # applying a function. 2015-12-31 23:00:00 -02:00 is stored in the 

29 # database as 2016-01-01 01:00:00 +00:00. Any results should be 

30 # based on the input datetime not the stored datetime. 

31 if self.tzinfo is None: 

32 return timezone.get_current_timezone_name() 

33 else: 

34 return timezone._get_timezone_name(self.tzinfo) 

35 

36 

37class Extract(TimezoneMixin, Transform): 

38 lookup_name = None 

39 output_field = IntegerField() 

40 

41 def __init__(self, expression, lookup_name=None, tzinfo=None, **extra): 

42 if self.lookup_name is None: 

43 self.lookup_name = lookup_name 

44 if self.lookup_name is None: 

45 raise ValueError("lookup_name must be provided") 

46 self.tzinfo = tzinfo 

47 super().__init__(expression, **extra) 

48 

49 def as_sql(self, compiler, connection): 

50 sql, params = compiler.compile(self.lhs) 

51 lhs_output_field = self.lhs.output_field 

52 if isinstance(lhs_output_field, DateTimeField): 

53 tzname = self.get_tzname() 

54 sql, params = connection.ops.datetime_extract_sql( 

55 self.lookup_name, sql, tuple(params), tzname 

56 ) 

57 elif self.tzinfo is not None: 

58 raise ValueError("tzinfo can only be used with DateTimeField.") 

59 elif isinstance(lhs_output_field, DateField): 

60 sql, params = connection.ops.date_extract_sql( 

61 self.lookup_name, sql, tuple(params) 

62 ) 

63 elif isinstance(lhs_output_field, TimeField): 

64 sql, params = connection.ops.time_extract_sql( 

65 self.lookup_name, sql, tuple(params) 

66 ) 

67 elif isinstance(lhs_output_field, DurationField): 

68 if not connection.features.has_native_duration_field: 

69 raise ValueError( 

70 "Extract requires native DurationField database support." 

71 ) 

72 sql, params = connection.ops.time_extract_sql( 

73 self.lookup_name, sql, tuple(params) 

74 ) 

75 else: 

76 # resolve_expression has already validated the output_field so this 

77 # assert should never be hit. 

78 raise ValueError("Tried to Extract from an invalid type.") 

79 return sql, params 

80 

81 def resolve_expression( 

82 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False 

83 ): 

84 copy = super().resolve_expression( 

85 query, allow_joins, reuse, summarize, for_save 

86 ) 

87 field = getattr(copy.lhs, "output_field", None) 

88 if field is None: 

89 return copy 

90 if not isinstance(field, DateField | DateTimeField | TimeField | DurationField): 

91 raise ValueError( 

92 "Extract input expression must be DateField, DateTimeField, " 

93 "TimeField, or DurationField." 

94 ) 

95 # Passing dates to functions expecting datetimes is most likely a mistake. 

96 if type(field) == DateField and copy.lookup_name in ( 

97 "hour", 

98 "minute", 

99 "second", 

100 ): 

101 raise ValueError( 

102 f"Cannot extract time component '{copy.lookup_name}' from DateField '{field.name}'." 

103 ) 

104 if isinstance(field, DurationField) and copy.lookup_name in ( 

105 "year", 

106 "iso_year", 

107 "month", 

108 "week", 

109 "week_day", 

110 "iso_week_day", 

111 "quarter", 

112 ): 

113 raise ValueError( 

114 f"Cannot extract component '{copy.lookup_name}' from DurationField '{field.name}'." 

115 ) 

116 return copy 

117 

118 

119class ExtractYear(Extract): 

120 lookup_name = "year" 

121 

122 

123class ExtractIsoYear(Extract): 

124 """Return the ISO-8601 week-numbering year.""" 

125 

126 lookup_name = "iso_year" 

127 

128 

129class ExtractMonth(Extract): 

130 lookup_name = "month" 

131 

132 

133class ExtractDay(Extract): 

134 lookup_name = "day" 

135 

136 

137class ExtractWeek(Extract): 

138 """ 

139 Return 1-52 or 53, based on ISO-8601, i.e., Monday is the first of the 

140 week. 

141 """ 

142 

143 lookup_name = "week" 

144 

145 

146class ExtractWeekDay(Extract): 

147 """ 

148 Return Sunday=1 through Saturday=7. 

149 

150 To replicate this in Python: (mydatetime.isoweekday() % 7) + 1 

151 """ 

152 

153 lookup_name = "week_day" 

154 

155 

156class ExtractIsoWeekDay(Extract): 

157 """Return Monday=1 through Sunday=7, based on ISO-8601.""" 

158 

159 lookup_name = "iso_week_day" 

160 

161 

162class ExtractQuarter(Extract): 

163 lookup_name = "quarter" 

164 

165 

166class ExtractHour(Extract): 

167 lookup_name = "hour" 

168 

169 

170class ExtractMinute(Extract): 

171 lookup_name = "minute" 

172 

173 

174class ExtractSecond(Extract): 

175 lookup_name = "second" 

176 

177 

178DateField.register_lookup(ExtractYear) 

179DateField.register_lookup(ExtractMonth) 

180DateField.register_lookup(ExtractDay) 

181DateField.register_lookup(ExtractWeekDay) 

182DateField.register_lookup(ExtractIsoWeekDay) 

183DateField.register_lookup(ExtractWeek) 

184DateField.register_lookup(ExtractIsoYear) 

185DateField.register_lookup(ExtractQuarter) 

186 

187TimeField.register_lookup(ExtractHour) 

188TimeField.register_lookup(ExtractMinute) 

189TimeField.register_lookup(ExtractSecond) 

190 

191DateTimeField.register_lookup(ExtractHour) 

192DateTimeField.register_lookup(ExtractMinute) 

193DateTimeField.register_lookup(ExtractSecond) 

194 

195ExtractYear.register_lookup(YearExact) 

196ExtractYear.register_lookup(YearGt) 

197ExtractYear.register_lookup(YearGte) 

198ExtractYear.register_lookup(YearLt) 

199ExtractYear.register_lookup(YearLte) 

200 

201ExtractIsoYear.register_lookup(YearExact) 

202ExtractIsoYear.register_lookup(YearGt) 

203ExtractIsoYear.register_lookup(YearGte) 

204ExtractIsoYear.register_lookup(YearLt) 

205ExtractIsoYear.register_lookup(YearLte) 

206 

207 

208class Now(Func): 

209 template = "CURRENT_TIMESTAMP" 

210 output_field = DateTimeField() 

211 

212 def as_postgresql(self, compiler, connection, **extra_context): 

213 # PostgreSQL's CURRENT_TIMESTAMP means "the time at the start of the 

214 # transaction". Use STATEMENT_TIMESTAMP to be cross-compatible with 

215 # other databases. 

216 return self.as_sql( 

217 compiler, connection, template="STATEMENT_TIMESTAMP()", **extra_context 

218 ) 

219 

220 def as_mysql(self, compiler, connection, **extra_context): 

221 return self.as_sql( 

222 compiler, connection, template="CURRENT_TIMESTAMP(6)", **extra_context 

223 ) 

224 

225 def as_sqlite(self, compiler, connection, **extra_context): 

226 return self.as_sql( 

227 compiler, 

228 connection, 

229 template="STRFTIME('%%%%Y-%%%%m-%%%%d %%%%H:%%%%M:%%%%f', 'NOW')", 

230 **extra_context, 

231 ) 

232 

233 

234class TruncBase(TimezoneMixin, Transform): 

235 kind = None 

236 tzinfo = None 

237 

238 def __init__( 

239 self, 

240 expression, 

241 output_field=None, 

242 tzinfo=None, 

243 **extra, 

244 ): 

245 self.tzinfo = tzinfo 

246 super().__init__(expression, output_field=output_field, **extra) 

247 

248 def as_sql(self, compiler, connection): 

249 sql, params = compiler.compile(self.lhs) 

250 tzname = None 

251 if isinstance(self.lhs.output_field, DateTimeField): 

252 tzname = self.get_tzname() 

253 elif self.tzinfo is not None: 

254 raise ValueError("tzinfo can only be used with DateTimeField.") 

255 if isinstance(self.output_field, DateTimeField): 

256 sql, params = connection.ops.datetime_trunc_sql( 

257 self.kind, sql, tuple(params), tzname 

258 ) 

259 elif isinstance(self.output_field, DateField): 

260 sql, params = connection.ops.date_trunc_sql( 

261 self.kind, sql, tuple(params), tzname 

262 ) 

263 elif isinstance(self.output_field, TimeField): 

264 sql, params = connection.ops.time_trunc_sql( 

265 self.kind, sql, tuple(params), tzname 

266 ) 

267 else: 

268 raise ValueError( 

269 "Trunc only valid on DateField, TimeField, or DateTimeField." 

270 ) 

271 return sql, params 

272 

273 def resolve_expression( 

274 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False 

275 ): 

276 copy = super().resolve_expression( 

277 query, allow_joins, reuse, summarize, for_save 

278 ) 

279 field = copy.lhs.output_field 

280 # DateTimeField is a subclass of DateField so this works for both. 

281 if not isinstance(field, DateField | TimeField): 

282 raise TypeError( 

283 f"{field.name!r} isn't a DateField, TimeField, or DateTimeField." 

284 ) 

285 # If self.output_field was None, then accessing the field will trigger 

286 # the resolver to assign it to self.lhs.output_field. 

287 if not isinstance(copy.output_field, DateField | DateTimeField | TimeField): 

288 raise ValueError( 

289 "output_field must be either DateField, TimeField, or DateTimeField" 

290 ) 

291 # Passing dates or times to functions expecting datetimes is most 

292 # likely a mistake. 

293 class_output_field = ( 

294 self.__class__.output_field 

295 if isinstance(self.__class__.output_field, Field) 

296 else None 

297 ) 

298 output_field = class_output_field or copy.output_field 

299 has_explicit_output_field = ( 

300 class_output_field or field.__class__ is not copy.output_field.__class__ 

301 ) 

302 if type(field) == DateField and ( 

303 isinstance(output_field, DateTimeField) 

304 or copy.kind in ("hour", "minute", "second", "time") 

305 ): 

306 raise ValueError( 

307 "Cannot truncate DateField '{}' to {}.".format( 

308 field.name, 

309 output_field.__class__.__name__ 

310 if has_explicit_output_field 

311 else "DateTimeField", 

312 ) 

313 ) 

314 elif isinstance(field, TimeField) and ( 

315 isinstance(output_field, DateTimeField) 

316 or copy.kind in ("year", "quarter", "month", "week", "day", "date") 

317 ): 

318 raise ValueError( 

319 "Cannot truncate TimeField '{}' to {}.".format( 

320 field.name, 

321 output_field.__class__.__name__ 

322 if has_explicit_output_field 

323 else "DateTimeField", 

324 ) 

325 ) 

326 return copy 

327 

328 def convert_value(self, value, expression, connection): 

329 if isinstance(self.output_field, DateTimeField): 

330 if value is not None: 

331 value = value.replace(tzinfo=None) 

332 value = timezone.make_aware(value, self.tzinfo) 

333 elif not connection.features.has_zoneinfo_database: 

334 raise ValueError( 

335 "Database returned an invalid datetime value. Are time " 

336 "zone definitions for your database installed?" 

337 ) 

338 elif isinstance(value, datetime): 

339 if value is None: 

340 pass 

341 elif isinstance(self.output_field, DateField): 

342 value = value.date() 

343 elif isinstance(self.output_field, TimeField): 

344 value = value.time() 

345 return value 

346 

347 

348class Trunc(TruncBase): 

349 def __init__( 

350 self, 

351 expression, 

352 kind, 

353 output_field=None, 

354 tzinfo=None, 

355 **extra, 

356 ): 

357 self.kind = kind 

358 super().__init__(expression, output_field=output_field, tzinfo=tzinfo, **extra) 

359 

360 

361class TruncYear(TruncBase): 

362 kind = "year" 

363 

364 

365class TruncQuarter(TruncBase): 

366 kind = "quarter" 

367 

368 

369class TruncMonth(TruncBase): 

370 kind = "month" 

371 

372 

373class TruncWeek(TruncBase): 

374 """Truncate to midnight on the Monday of the week.""" 

375 

376 kind = "week" 

377 

378 

379class TruncDay(TruncBase): 

380 kind = "day" 

381 

382 

383class TruncDate(TruncBase): 

384 kind = "date" 

385 lookup_name = "date" 

386 output_field = DateField() 

387 

388 def as_sql(self, compiler, connection): 

389 # Cast to date rather than truncate to date. 

390 sql, params = compiler.compile(self.lhs) 

391 tzname = self.get_tzname() 

392 return connection.ops.datetime_cast_date_sql(sql, tuple(params), tzname) 

393 

394 

395class TruncTime(TruncBase): 

396 kind = "time" 

397 lookup_name = "time" 

398 output_field = TimeField() 

399 

400 def as_sql(self, compiler, connection): 

401 # Cast to time rather than truncate to time. 

402 sql, params = compiler.compile(self.lhs) 

403 tzname = self.get_tzname() 

404 return connection.ops.datetime_cast_time_sql(sql, tuple(params), tzname) 

405 

406 

407class TruncHour(TruncBase): 

408 kind = "hour" 

409 

410 

411class TruncMinute(TruncBase): 

412 kind = "minute" 

413 

414 

415class TruncSecond(TruncBase): 

416 kind = "second" 

417 

418 

419DateTimeField.register_lookup(TruncDate) 

420DateTimeField.register_lookup(TruncTime)