Coverage for /Users/davegaeddert/Development/dropseed/plain/plain-models/plain/models/functions/datetime.py: 53%

195 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-10-16 22:04 -0500

1from datetime import datetime 

2 

3from plain.models.expressions import Func 

4from plain.models.fields import ( 

5 DateField, 

6 DateTimeField, 

7 DurationField, 

8 Field, 

9 IntegerField, 

10 TimeField, 

11) 

12from plain.models.lookups import ( 

13 Transform, 

14 YearExact, 

15 YearGt, 

16 YearGte, 

17 YearLt, 

18 YearLte, 

19) 

20from plain.utils import timezone 

21 

22 

23class TimezoneMixin: 

24 tzinfo = None 

25 

26 def get_tzname(self): 

27 # Timezone conversions must happen to the input datetime *before* 

28 # applying a function. 2015-12-31 23:00:00 -02:00 is stored in the 

29 # database as 2016-01-01 01:00:00 +00:00. Any results should be 

30 # based on the input datetime not the stored datetime. 

31 if self.tzinfo is None: 

32 return timezone.get_current_timezone_name() 

33 else: 

34 return timezone._get_timezone_name(self.tzinfo) 

35 

36 

37class Extract(TimezoneMixin, Transform): 

38 lookup_name = None 

39 output_field = IntegerField() 

40 

41 def __init__(self, expression, lookup_name=None, tzinfo=None, **extra): 

42 if self.lookup_name is None: 

43 self.lookup_name = lookup_name 

44 if self.lookup_name is None: 

45 raise ValueError("lookup_name must be provided") 

46 self.tzinfo = tzinfo 

47 super().__init__(expression, **extra) 

48 

49 def as_sql(self, compiler, connection): 

50 sql, params = compiler.compile(self.lhs) 

51 lhs_output_field = self.lhs.output_field 

52 if isinstance(lhs_output_field, DateTimeField): 

53 tzname = self.get_tzname() 

54 sql, params = connection.ops.datetime_extract_sql( 

55 self.lookup_name, sql, tuple(params), tzname 

56 ) 

57 elif self.tzinfo is not None: 

58 raise ValueError("tzinfo can only be used with DateTimeField.") 

59 elif isinstance(lhs_output_field, DateField): 

60 sql, params = connection.ops.date_extract_sql( 

61 self.lookup_name, sql, tuple(params) 

62 ) 

63 elif isinstance(lhs_output_field, TimeField): 

64 sql, params = connection.ops.time_extract_sql( 

65 self.lookup_name, sql, tuple(params) 

66 ) 

67 elif isinstance(lhs_output_field, DurationField): 

68 if not connection.features.has_native_duration_field: 

69 raise ValueError( 

70 "Extract requires native DurationField database support." 

71 ) 

72 sql, params = connection.ops.time_extract_sql( 

73 self.lookup_name, sql, tuple(params) 

74 ) 

75 else: 

76 # resolve_expression has already validated the output_field so this 

77 # assert should never be hit. 

78 raise ValueError("Tried to Extract from an invalid type.") 

79 return sql, params 

80 

81 def resolve_expression( 

82 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False 

83 ): 

84 copy = super().resolve_expression( 

85 query, allow_joins, reuse, summarize, for_save 

86 ) 

87 field = getattr(copy.lhs, "output_field", None) 

88 if field is None: 

89 return copy 

90 if not isinstance(field, DateField | DateTimeField | TimeField | DurationField): 

91 raise ValueError( 

92 "Extract input expression must be DateField, DateTimeField, " 

93 "TimeField, or DurationField." 

94 ) 

95 # Passing dates to functions expecting datetimes is most likely a mistake. 

96 if type(field) == DateField and copy.lookup_name in ( 

97 "hour", 

98 "minute", 

99 "second", 

100 ): 

101 raise ValueError( 

102 "Cannot extract time component '{}' from DateField '{}'.".format( 

103 copy.lookup_name, field.name 

104 ) 

105 ) 

106 if isinstance(field, DurationField) and copy.lookup_name in ( 

107 "year", 

108 "iso_year", 

109 "month", 

110 "week", 

111 "week_day", 

112 "iso_week_day", 

113 "quarter", 

114 ): 

115 raise ValueError( 

116 "Cannot extract component '{}' from DurationField '{}'.".format( 

117 copy.lookup_name, field.name 

118 ) 

119 ) 

120 return copy 

121 

122 

123class ExtractYear(Extract): 

124 lookup_name = "year" 

125 

126 

127class ExtractIsoYear(Extract): 

128 """Return the ISO-8601 week-numbering year.""" 

129 

130 lookup_name = "iso_year" 

131 

132 

133class ExtractMonth(Extract): 

134 lookup_name = "month" 

135 

136 

137class ExtractDay(Extract): 

138 lookup_name = "day" 

139 

140 

141class ExtractWeek(Extract): 

142 """ 

143 Return 1-52 or 53, based on ISO-8601, i.e., Monday is the first of the 

144 week. 

145 """ 

146 

147 lookup_name = "week" 

148 

149 

150class ExtractWeekDay(Extract): 

151 """ 

152 Return Sunday=1 through Saturday=7. 

153 

154 To replicate this in Python: (mydatetime.isoweekday() % 7) + 1 

155 """ 

156 

157 lookup_name = "week_day" 

158 

159 

160class ExtractIsoWeekDay(Extract): 

161 """Return Monday=1 through Sunday=7, based on ISO-8601.""" 

162 

163 lookup_name = "iso_week_day" 

164 

165 

166class ExtractQuarter(Extract): 

167 lookup_name = "quarter" 

168 

169 

170class ExtractHour(Extract): 

171 lookup_name = "hour" 

172 

173 

174class ExtractMinute(Extract): 

175 lookup_name = "minute" 

176 

177 

178class ExtractSecond(Extract): 

179 lookup_name = "second" 

180 

181 

182DateField.register_lookup(ExtractYear) 

183DateField.register_lookup(ExtractMonth) 

184DateField.register_lookup(ExtractDay) 

185DateField.register_lookup(ExtractWeekDay) 

186DateField.register_lookup(ExtractIsoWeekDay) 

187DateField.register_lookup(ExtractWeek) 

188DateField.register_lookup(ExtractIsoYear) 

189DateField.register_lookup(ExtractQuarter) 

190 

191TimeField.register_lookup(ExtractHour) 

192TimeField.register_lookup(ExtractMinute) 

193TimeField.register_lookup(ExtractSecond) 

194 

195DateTimeField.register_lookup(ExtractHour) 

196DateTimeField.register_lookup(ExtractMinute) 

197DateTimeField.register_lookup(ExtractSecond) 

198 

199ExtractYear.register_lookup(YearExact) 

200ExtractYear.register_lookup(YearGt) 

201ExtractYear.register_lookup(YearGte) 

202ExtractYear.register_lookup(YearLt) 

203ExtractYear.register_lookup(YearLte) 

204 

205ExtractIsoYear.register_lookup(YearExact) 

206ExtractIsoYear.register_lookup(YearGt) 

207ExtractIsoYear.register_lookup(YearGte) 

208ExtractIsoYear.register_lookup(YearLt) 

209ExtractIsoYear.register_lookup(YearLte) 

210 

211 

212class Now(Func): 

213 template = "CURRENT_TIMESTAMP" 

214 output_field = DateTimeField() 

215 

216 def as_postgresql(self, compiler, connection, **extra_context): 

217 # PostgreSQL's CURRENT_TIMESTAMP means "the time at the start of the 

218 # transaction". Use STATEMENT_TIMESTAMP to be cross-compatible with 

219 # other databases. 

220 return self.as_sql( 

221 compiler, connection, template="STATEMENT_TIMESTAMP()", **extra_context 

222 ) 

223 

224 def as_mysql(self, compiler, connection, **extra_context): 

225 return self.as_sql( 

226 compiler, connection, template="CURRENT_TIMESTAMP(6)", **extra_context 

227 ) 

228 

229 def as_sqlite(self, compiler, connection, **extra_context): 

230 return self.as_sql( 

231 compiler, 

232 connection, 

233 template="STRFTIME('%%%%Y-%%%%m-%%%%d %%%%H:%%%%M:%%%%f', 'NOW')", 

234 **extra_context, 

235 ) 

236 

237 

238class TruncBase(TimezoneMixin, Transform): 

239 kind = None 

240 tzinfo = None 

241 

242 def __init__( 

243 self, 

244 expression, 

245 output_field=None, 

246 tzinfo=None, 

247 **extra, 

248 ): 

249 self.tzinfo = tzinfo 

250 super().__init__(expression, output_field=output_field, **extra) 

251 

252 def as_sql(self, compiler, connection): 

253 sql, params = compiler.compile(self.lhs) 

254 tzname = None 

255 if isinstance(self.lhs.output_field, DateTimeField): 

256 tzname = self.get_tzname() 

257 elif self.tzinfo is not None: 

258 raise ValueError("tzinfo can only be used with DateTimeField.") 

259 if isinstance(self.output_field, DateTimeField): 

260 sql, params = connection.ops.datetime_trunc_sql( 

261 self.kind, sql, tuple(params), tzname 

262 ) 

263 elif isinstance(self.output_field, DateField): 

264 sql, params = connection.ops.date_trunc_sql( 

265 self.kind, sql, tuple(params), tzname 

266 ) 

267 elif isinstance(self.output_field, TimeField): 

268 sql, params = connection.ops.time_trunc_sql( 

269 self.kind, sql, tuple(params), tzname 

270 ) 

271 else: 

272 raise ValueError( 

273 "Trunc only valid on DateField, TimeField, or DateTimeField." 

274 ) 

275 return sql, params 

276 

277 def resolve_expression( 

278 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False 

279 ): 

280 copy = super().resolve_expression( 

281 query, allow_joins, reuse, summarize, for_save 

282 ) 

283 field = copy.lhs.output_field 

284 # DateTimeField is a subclass of DateField so this works for both. 

285 if not isinstance(field, DateField | TimeField): 

286 raise TypeError( 

287 "%r isn't a DateField, TimeField, or DateTimeField." % field.name 

288 ) 

289 # If self.output_field was None, then accessing the field will trigger 

290 # the resolver to assign it to self.lhs.output_field. 

291 if not isinstance(copy.output_field, DateField | DateTimeField | TimeField): 

292 raise ValueError( 

293 "output_field must be either DateField, TimeField, or DateTimeField" 

294 ) 

295 # Passing dates or times to functions expecting datetimes is most 

296 # likely a mistake. 

297 class_output_field = ( 

298 self.__class__.output_field 

299 if isinstance(self.__class__.output_field, Field) 

300 else None 

301 ) 

302 output_field = class_output_field or copy.output_field 

303 has_explicit_output_field = ( 

304 class_output_field or field.__class__ is not copy.output_field.__class__ 

305 ) 

306 if type(field) == DateField and ( 

307 isinstance(output_field, DateTimeField) 

308 or copy.kind in ("hour", "minute", "second", "time") 

309 ): 

310 raise ValueError( 

311 "Cannot truncate DateField '{}' to {}.".format( 

312 field.name, 

313 output_field.__class__.__name__ 

314 if has_explicit_output_field 

315 else "DateTimeField", 

316 ) 

317 ) 

318 elif isinstance(field, TimeField) and ( 

319 isinstance(output_field, DateTimeField) 

320 or copy.kind in ("year", "quarter", "month", "week", "day", "date") 

321 ): 

322 raise ValueError( 

323 "Cannot truncate TimeField '{}' to {}.".format( 

324 field.name, 

325 output_field.__class__.__name__ 

326 if has_explicit_output_field 

327 else "DateTimeField", 

328 ) 

329 ) 

330 return copy 

331 

332 def convert_value(self, value, expression, connection): 

333 if isinstance(self.output_field, DateTimeField): 

334 if value is not None: 

335 value = value.replace(tzinfo=None) 

336 value = timezone.make_aware(value, self.tzinfo) 

337 elif not connection.features.has_zoneinfo_database: 

338 raise ValueError( 

339 "Database returned an invalid datetime value. Are time " 

340 "zone definitions for your database installed?" 

341 ) 

342 elif isinstance(value, datetime): 

343 if value is None: 

344 pass 

345 elif isinstance(self.output_field, DateField): 

346 value = value.date() 

347 elif isinstance(self.output_field, TimeField): 

348 value = value.time() 

349 return value 

350 

351 

352class Trunc(TruncBase): 

353 def __init__( 

354 self, 

355 expression, 

356 kind, 

357 output_field=None, 

358 tzinfo=None, 

359 **extra, 

360 ): 

361 self.kind = kind 

362 super().__init__(expression, output_field=output_field, tzinfo=tzinfo, **extra) 

363 

364 

365class TruncYear(TruncBase): 

366 kind = "year" 

367 

368 

369class TruncQuarter(TruncBase): 

370 kind = "quarter" 

371 

372 

373class TruncMonth(TruncBase): 

374 kind = "month" 

375 

376 

377class TruncWeek(TruncBase): 

378 """Truncate to midnight on the Monday of the week.""" 

379 

380 kind = "week" 

381 

382 

383class TruncDay(TruncBase): 

384 kind = "day" 

385 

386 

387class TruncDate(TruncBase): 

388 kind = "date" 

389 lookup_name = "date" 

390 output_field = DateField() 

391 

392 def as_sql(self, compiler, connection): 

393 # Cast to date rather than truncate to date. 

394 sql, params = compiler.compile(self.lhs) 

395 tzname = self.get_tzname() 

396 return connection.ops.datetime_cast_date_sql(sql, tuple(params), tzname) 

397 

398 

399class TruncTime(TruncBase): 

400 kind = "time" 

401 lookup_name = "time" 

402 output_field = TimeField() 

403 

404 def as_sql(self, compiler, connection): 

405 # Cast to time rather than truncate to time. 

406 sql, params = compiler.compile(self.lhs) 

407 tzname = self.get_tzname() 

408 return connection.ops.datetime_cast_time_sql(sql, tuple(params), tzname) 

409 

410 

411class TruncHour(TruncBase): 

412 kind = "hour" 

413 

414 

415class TruncMinute(TruncBase): 

416 kind = "minute" 

417 

418 

419class TruncSecond(TruncBase): 

420 kind = "second" 

421 

422 

423DateTimeField.register_lookup(TruncDate) 

424DateTimeField.register_lookup(TruncTime)