Coverage for /Users/davegaeddert/Development/dropseed/plain/plain-models/plain/models/backends/sqlite3/operations.py: 46%

211 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-10-16 22:04 -0500

1import datetime 

2import decimal 

3import uuid 

4from functools import lru_cache 

5 

6from plain import models 

7from plain.exceptions import FieldError 

8from plain.models.backends.base.operations import BaseDatabaseOperations 

9from plain.models.constants import OnConflict 

10from plain.models.db import DatabaseError, NotSupportedError 

11from plain.models.expressions import Col 

12from plain.utils import timezone 

13from plain.utils.dateparse import parse_date, parse_datetime, parse_time 

14from plain.utils.functional import cached_property 

15 

16 

17class DatabaseOperations(BaseDatabaseOperations): 

18 cast_char_field_without_max_length = "text" 

19 cast_data_types = { 

20 "DateField": "TEXT", 

21 "DateTimeField": "TEXT", 

22 } 

23 explain_prefix = "EXPLAIN QUERY PLAN" 

24 # List of datatypes to that cannot be extracted with JSON_EXTRACT() on 

25 # SQLite. Use JSON_TYPE() instead. 

26 jsonfield_datatype_values = frozenset(["null", "false", "true"]) 

27 

28 def bulk_batch_size(self, fields, objs): 

29 """ 

30 SQLite has a compile-time default (SQLITE_LIMIT_VARIABLE_NUMBER) of 

31 999 variables per query. 

32 

33 If there's only a single field to insert, the limit is 500 

34 (SQLITE_MAX_COMPOUND_SELECT). 

35 """ 

36 if len(fields) == 1: 

37 return 500 

38 elif len(fields) > 1: 

39 return self.connection.features.max_query_params // len(fields) 

40 else: 

41 return len(objs) 

42 

43 def check_expression_support(self, expression): 

44 bad_fields = (models.DateField, models.DateTimeField, models.TimeField) 

45 bad_aggregates = (models.Sum, models.Avg, models.Variance, models.StdDev) 

46 if isinstance(expression, bad_aggregates): 

47 for expr in expression.get_source_expressions(): 

48 try: 

49 output_field = expr.output_field 

50 except (AttributeError, FieldError): 

51 # Not every subexpression has an output_field which is fine 

52 # to ignore. 

53 pass 

54 else: 

55 if isinstance(output_field, bad_fields): 

56 raise NotSupportedError( 

57 "You cannot use Sum, Avg, StdDev, and Variance " 

58 "aggregations on date/time fields in sqlite3 " 

59 "since date/time is saved as text." 

60 ) 

61 if ( 

62 isinstance(expression, models.Aggregate) 

63 and expression.distinct 

64 and len(expression.source_expressions) > 1 

65 ): 

66 raise NotSupportedError( 

67 "SQLite doesn't support DISTINCT on aggregate functions " 

68 "accepting multiple arguments." 

69 ) 

70 

71 def date_extract_sql(self, lookup_type, sql, params): 

72 """ 

73 Support EXTRACT with a user-defined function plain_date_extract() 

74 that's registered in connect(). Use single quotes because this is a 

75 string and could otherwise cause a collision with a field name. 

76 """ 

77 return f"plain_date_extract(%s, {sql})", (lookup_type.lower(), *params) 

78 

79 def fetch_returned_insert_rows(self, cursor): 

80 """ 

81 Given a cursor object that has just performed an INSERT...RETURNING 

82 statement into a table, return the list of returned data. 

83 """ 

84 return cursor.fetchall() 

85 

86 def format_for_duration_arithmetic(self, sql): 

87 """Do nothing since formatting is handled in the custom function.""" 

88 return sql 

89 

90 def date_trunc_sql(self, lookup_type, sql, params, tzname=None): 

91 return f"plain_date_trunc(%s, {sql}, %s, %s)", ( 

92 lookup_type.lower(), 

93 *params, 

94 *self._convert_tznames_to_sql(tzname), 

95 ) 

96 

97 def time_trunc_sql(self, lookup_type, sql, params, tzname=None): 

98 return f"plain_time_trunc(%s, {sql}, %s, %s)", ( 

99 lookup_type.lower(), 

100 *params, 

101 *self._convert_tznames_to_sql(tzname), 

102 ) 

103 

104 def _convert_tznames_to_sql(self, tzname): 

105 if tzname: 

106 return tzname, self.connection.timezone_name 

107 return None, None 

108 

109 def datetime_cast_date_sql(self, sql, params, tzname): 

110 return f"plain_datetime_cast_date({sql}, %s, %s)", ( 

111 *params, 

112 *self._convert_tznames_to_sql(tzname), 

113 ) 

114 

115 def datetime_cast_time_sql(self, sql, params, tzname): 

116 return f"plain_datetime_cast_time({sql}, %s, %s)", ( 

117 *params, 

118 *self._convert_tznames_to_sql(tzname), 

119 ) 

120 

121 def datetime_extract_sql(self, lookup_type, sql, params, tzname): 

122 return f"plain_datetime_extract(%s, {sql}, %s, %s)", ( 

123 lookup_type.lower(), 

124 *params, 

125 *self._convert_tznames_to_sql(tzname), 

126 ) 

127 

128 def datetime_trunc_sql(self, lookup_type, sql, params, tzname): 

129 return f"plain_datetime_trunc(%s, {sql}, %s, %s)", ( 

130 lookup_type.lower(), 

131 *params, 

132 *self._convert_tznames_to_sql(tzname), 

133 ) 

134 

135 def time_extract_sql(self, lookup_type, sql, params): 

136 return f"plain_time_extract(%s, {sql})", (lookup_type.lower(), *params) 

137 

138 def pk_default_value(self): 

139 return "NULL" 

140 

141 def _quote_params_for_last_executed_query(self, params): 

142 """ 

143 Only for last_executed_query! Don't use this to execute SQL queries! 

144 """ 

145 # This function is limited both by SQLITE_LIMIT_VARIABLE_NUMBER (the 

146 # number of parameters, default = 999) and SQLITE_MAX_COLUMN (the 

147 # number of return values, default = 2000). Since Python's sqlite3 

148 # module doesn't expose the get_limit() C API, assume the default 

149 # limits are in effect and split the work in batches if needed. 

150 BATCH_SIZE = 999 

151 if len(params) > BATCH_SIZE: 

152 results = () 

153 for index in range(0, len(params), BATCH_SIZE): 

154 chunk = params[index : index + BATCH_SIZE] 

155 results += self._quote_params_for_last_executed_query(chunk) 

156 return results 

157 

158 sql = "SELECT " + ", ".join(["QUOTE(?)"] * len(params)) 

159 # Bypass Plain's wrappers and use the underlying sqlite3 connection 

160 # to avoid logging this query - it would trigger infinite recursion. 

161 cursor = self.connection.connection.cursor() 

162 # Native sqlite3 cursors cannot be used as context managers. 

163 try: 

164 return cursor.execute(sql, params).fetchone() 

165 finally: 

166 cursor.close() 

167 

168 def last_executed_query(self, cursor, sql, params): 

169 # Python substitutes parameters in Modules/_sqlite/cursor.c with: 

170 # bind_parameters(state, self->statement, parameters); 

171 # Unfortunately there is no way to reach self->statement from Python, 

172 # so we quote and substitute parameters manually. 

173 if params: 

174 if isinstance(params, list | tuple): 

175 params = self._quote_params_for_last_executed_query(params) 

176 else: 

177 values = tuple(params.values()) 

178 values = self._quote_params_for_last_executed_query(values) 

179 params = dict(zip(params, values)) 

180 return sql % params 

181 # For consistency with SQLiteCursorWrapper.execute(), just return sql 

182 # when there are no parameters. See #13648 and #17158. 

183 else: 

184 return sql 

185 

186 def quote_name(self, name): 

187 if name.startswith('"') and name.endswith('"'): 

188 return name # Quoting once is enough. 

189 return '"%s"' % name 

190 

191 def no_limit_value(self): 

192 return -1 

193 

194 def __references_graph(self, table_name): 

195 query = """ 

196 WITH tables AS ( 

197 SELECT %s name 

198 UNION 

199 SELECT sqlite_master.name 

200 FROM sqlite_master 

201 JOIN tables ON (sql REGEXP %s || tables.name || %s) 

202 ) SELECT name FROM tables; 

203 """ 

204 params = ( 

205 table_name, 

206 r'(?i)\s+references\s+("|\')?', 

207 r'("|\')?\s*\(', 

208 ) 

209 with self.connection.cursor() as cursor: 

210 results = cursor.execute(query, params) 

211 return [row[0] for row in results.fetchall()] 

212 

213 @cached_property 

214 def _references_graph(self): 

215 # 512 is large enough to fit the ~330 tables (as of this writing) in 

216 # Plain's test suite. 

217 return lru_cache(maxsize=512)(self.__references_graph) 

218 

219 def sequence_reset_by_name_sql(self, style, sequences): 

220 if not sequences: 

221 return [] 

222 return [ 

223 "{} {} {} {} = 0 {} {} {} ({});".format( 

224 style.SQL_KEYWORD("UPDATE"), 

225 style.SQL_TABLE(self.quote_name("sqlite_sequence")), 

226 style.SQL_KEYWORD("SET"), 

227 style.SQL_FIELD(self.quote_name("seq")), 

228 style.SQL_KEYWORD("WHERE"), 

229 style.SQL_FIELD(self.quote_name("name")), 

230 style.SQL_KEYWORD("IN"), 

231 ", ".join( 

232 ["'%s'" % sequence_info["table"] for sequence_info in sequences] 

233 ), 

234 ), 

235 ] 

236 

237 def adapt_datetimefield_value(self, value): 

238 if value is None: 

239 return None 

240 

241 # Expression values are adapted by the database. 

242 if hasattr(value, "resolve_expression"): 

243 return value 

244 

245 # SQLite doesn't support tz-aware datetimes 

246 if timezone.is_aware(value): 

247 value = timezone.make_naive(value, self.connection.timezone) 

248 

249 return str(value) 

250 

251 def adapt_timefield_value(self, value): 

252 if value is None: 

253 return None 

254 

255 # Expression values are adapted by the database. 

256 if hasattr(value, "resolve_expression"): 

257 return value 

258 

259 # SQLite doesn't support tz-aware datetimes 

260 if timezone.is_aware(value): 

261 raise ValueError("SQLite backend does not support timezone-aware times.") 

262 

263 return str(value) 

264 

265 def get_db_converters(self, expression): 

266 converters = super().get_db_converters(expression) 

267 internal_type = expression.output_field.get_internal_type() 

268 if internal_type == "DateTimeField": 

269 converters.append(self.convert_datetimefield_value) 

270 elif internal_type == "DateField": 

271 converters.append(self.convert_datefield_value) 

272 elif internal_type == "TimeField": 

273 converters.append(self.convert_timefield_value) 

274 elif internal_type == "DecimalField": 

275 converters.append(self.get_decimalfield_converter(expression)) 

276 elif internal_type == "UUIDField": 

277 converters.append(self.convert_uuidfield_value) 

278 elif internal_type == "BooleanField": 

279 converters.append(self.convert_booleanfield_value) 

280 return converters 

281 

282 def convert_datetimefield_value(self, value, expression, connection): 

283 if value is not None: 

284 if not isinstance(value, datetime.datetime): 

285 value = parse_datetime(value) 

286 if not timezone.is_aware(value): 

287 value = timezone.make_aware(value, self.connection.timezone) 

288 return value 

289 

290 def convert_datefield_value(self, value, expression, connection): 

291 if value is not None: 

292 if not isinstance(value, datetime.date): 

293 value = parse_date(value) 

294 return value 

295 

296 def convert_timefield_value(self, value, expression, connection): 

297 if value is not None: 

298 if not isinstance(value, datetime.time): 

299 value = parse_time(value) 

300 return value 

301 

302 def get_decimalfield_converter(self, expression): 

303 # SQLite stores only 15 significant digits. Digits coming from 

304 # float inaccuracy must be removed. 

305 create_decimal = decimal.Context(prec=15).create_decimal_from_float 

306 if isinstance(expression, Col): 

307 quantize_value = decimal.Decimal(1).scaleb( 

308 -expression.output_field.decimal_places 

309 ) 

310 

311 def converter(value, expression, connection): 

312 if value is not None: 

313 return create_decimal(value).quantize( 

314 quantize_value, context=expression.output_field.context 

315 ) 

316 

317 else: 

318 

319 def converter(value, expression, connection): 

320 if value is not None: 

321 return create_decimal(value) 

322 

323 return converter 

324 

325 def convert_uuidfield_value(self, value, expression, connection): 

326 if value is not None: 

327 value = uuid.UUID(value) 

328 return value 

329 

330 def convert_booleanfield_value(self, value, expression, connection): 

331 return bool(value) if value in (1, 0) else value 

332 

333 def bulk_insert_sql(self, fields, placeholder_rows): 

334 placeholder_rows_sql = (", ".join(row) for row in placeholder_rows) 

335 values_sql = ", ".join(f"({sql})" for sql in placeholder_rows_sql) 

336 return f"VALUES {values_sql}" 

337 

338 def combine_expression(self, connector, sub_expressions): 

339 # SQLite doesn't have a ^ operator, so use the user-defined POWER 

340 # function that's registered in connect(). 

341 if connector == "^": 

342 return "POWER(%s)" % ",".join(sub_expressions) 

343 elif connector == "#": 

344 return "BITXOR(%s)" % ",".join(sub_expressions) 

345 return super().combine_expression(connector, sub_expressions) 

346 

347 def combine_duration_expression(self, connector, sub_expressions): 

348 if connector not in ["+", "-", "*", "/"]: 

349 raise DatabaseError("Invalid connector for timedelta: %s." % connector) 

350 fn_params = ["'%s'" % connector] + sub_expressions 

351 if len(fn_params) > 3: 

352 raise ValueError("Too many params for timedelta operations.") 

353 return "plain_format_dtdelta(%s)" % ", ".join(fn_params) 

354 

355 def integer_field_range(self, internal_type): 

356 # SQLite doesn't enforce any integer constraints, but sqlite3 supports 

357 # integers up to 64 bits. 

358 if internal_type in [ 

359 "PositiveBigIntegerField", 

360 "PositiveIntegerField", 

361 "PositiveSmallIntegerField", 

362 ]: 

363 return (0, 9223372036854775807) 

364 return (-9223372036854775808, 9223372036854775807) 

365 

366 def subtract_temporals(self, internal_type, lhs, rhs): 

367 lhs_sql, lhs_params = lhs 

368 rhs_sql, rhs_params = rhs 

369 params = (*lhs_params, *rhs_params) 

370 if internal_type == "TimeField": 

371 return f"plain_time_diff({lhs_sql}, {rhs_sql})", params 

372 return f"plain_timestamp_diff({lhs_sql}, {rhs_sql})", params 

373 

374 def insert_statement(self, on_conflict=None): 

375 if on_conflict == OnConflict.IGNORE: 

376 return "INSERT OR IGNORE INTO" 

377 return super().insert_statement(on_conflict=on_conflict) 

378 

379 def return_insert_columns(self, fields): 

380 # SQLite < 3.35 doesn't support an INSERT...RETURNING statement. 

381 if not fields: 

382 return "", () 

383 columns = [ 

384 "{}.{}".format( 

385 self.quote_name(field.model._meta.db_table), 

386 self.quote_name(field.column), 

387 ) 

388 for field in fields 

389 ] 

390 return "RETURNING %s" % ", ".join(columns), () 

391 

392 def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fields): 

393 if ( 

394 on_conflict == OnConflict.UPDATE 

395 and self.connection.features.supports_update_conflicts_with_target 

396 ): 

397 return "ON CONFLICT({}) DO UPDATE SET {}".format( 

398 ", ".join(map(self.quote_name, unique_fields)), 

399 ", ".join( 

400 [ 

401 f"{field} = EXCLUDED.{field}" 

402 for field in map(self.quote_name, update_fields) 

403 ] 

404 ), 

405 ) 

406 return super().on_conflict_suffix_sql( 

407 fields, 

408 on_conflict, 

409 update_fields, 

410 unique_fields, 

411 )