Coverage for /Users/davegaeddert/Development/dropseed/plain/plain-models/plain/models/backends/base/base.py: 30%
361 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-10-16 22:03 -0500
« prev ^ index » next coverage.py v7.6.1, created at 2024-10-16 22:03 -0500
1import _thread
2import copy
3import datetime
4import logging
5import threading
6import time
7import warnings
8import zoneinfo
9from collections import deque
10from contextlib import contextmanager
12from plain.models.backends import utils
13from plain.models.backends.base.validation import BaseDatabaseValidation
14from plain.models.backends.signals import connection_created
15from plain.models.backends.utils import debug_transaction
16from plain.models.db import (
17 DEFAULT_DB_ALIAS,
18 DatabaseError,
19 DatabaseErrorWrapper,
20 NotSupportedError,
21)
22from plain.models.transaction import TransactionManagementError
23from plain.runtime import settings
24from plain.utils.functional import cached_property
26NO_DB_ALIAS = "__no_db__"
27RAN_DB_VERSION_CHECK = set()
29logger = logging.getLogger("plain.models.backends.base")
32class BaseDatabaseWrapper:
33 """Represent a database connection."""
35 # Mapping of Field objects to their column types.
36 data_types = {}
37 # Mapping of Field objects to their SQL suffix such as AUTOINCREMENT.
38 data_types_suffix = {}
39 # Mapping of Field objects to their SQL for CHECK constraints.
40 data_type_check_constraints = {}
41 ops = None
42 vendor = "unknown"
43 display_name = "unknown"
44 SchemaEditorClass = None
45 # Classes instantiated in __init__().
46 client_class = None
47 creation_class = None
48 features_class = None
49 introspection_class = None
50 ops_class = None
51 validation_class = BaseDatabaseValidation
53 queries_limit = 9000
55 def __init__(self, settings_dict, alias=DEFAULT_DB_ALIAS):
56 # Connection related attributes.
57 # The underlying database connection.
58 self.connection = None
59 # `settings_dict` should be a dictionary containing keys such as
60 # NAME, USER, etc. It's called `settings_dict` instead of `settings`
61 # to disambiguate it from Plain settings modules.
62 self.settings_dict = settings_dict
63 self.alias = alias
64 # Query logging in debug mode or when explicitly enabled.
65 self.queries_log = deque(maxlen=self.queries_limit)
66 self.force_debug_cursor = False
68 # Transaction related attributes.
69 # Tracks if the connection is in autocommit mode. Per PEP 249, by
70 # default, it isn't.
71 self.autocommit = False
72 # Tracks if the connection is in a transaction managed by 'atomic'.
73 self.in_atomic_block = False
74 # Increment to generate unique savepoint ids.
75 self.savepoint_state = 0
76 # List of savepoints created by 'atomic'.
77 self.savepoint_ids = []
78 # Stack of active 'atomic' blocks.
79 self.atomic_blocks = []
80 # Tracks if the outermost 'atomic' block should commit on exit,
81 # ie. if autocommit was active on entry.
82 self.commit_on_exit = True
83 # Tracks if the transaction should be rolled back to the next
84 # available savepoint because of an exception in an inner block.
85 self.needs_rollback = False
86 self.rollback_exc = None
88 # Connection termination related attributes.
89 self.close_at = None
90 self.closed_in_transaction = False
91 self.errors_occurred = False
92 self.health_check_enabled = False
93 self.health_check_done = False
95 # Thread-safety related attributes.
96 self._thread_sharing_lock = threading.Lock()
97 self._thread_sharing_count = 0
98 self._thread_ident = _thread.get_ident()
100 # A list of no-argument functions to run when the transaction commits.
101 # Each entry is an (sids, func, robust) tuple, where sids is a set of
102 # the active savepoint IDs when this function was registered and robust
103 # specifies whether it's allowed for the function to fail.
104 self.run_on_commit = []
106 # Should we run the on-commit hooks the next time set_autocommit(True)
107 # is called?
108 self.run_commit_hooks_on_set_autocommit_on = False
110 # A stack of wrappers to be invoked around execute()/executemany()
111 # calls. Each entry is a function taking five arguments: execute, sql,
112 # params, many, and context. It's the function's responsibility to
113 # call execute(sql, params, many, context).
114 self.execute_wrappers = []
116 self.client = self.client_class(self)
117 self.creation = self.creation_class(self)
118 self.features = self.features_class(self)
119 self.introspection = self.introspection_class(self)
120 self.ops = self.ops_class(self)
121 self.validation = self.validation_class(self)
123 def __repr__(self):
124 return (
125 f"<{self.__class__.__qualname__} "
126 f"vendor={self.vendor!r} alias={self.alias!r}>"
127 )
129 def ensure_timezone(self):
130 """
131 Ensure the connection's timezone is set to `self.timezone_name` and
132 return whether it changed or not.
133 """
134 return False
136 @cached_property
137 def timezone(self):
138 """
139 Return a tzinfo of the database connection time zone.
141 This is only used when time zone support is enabled. When a datetime is
142 read from the database, it is always returned in this time zone.
144 When the database backend supports time zones, it doesn't matter which
145 time zone Plain uses, as long as aware datetimes are used everywhere.
146 Other users connecting to the database can choose their own time zone.
148 When the database backend doesn't support time zones, the time zone
149 Plain uses may be constrained by the requirements of other users of
150 the database.
151 """
152 if self.settings_dict["TIME_ZONE"] is None:
153 return datetime.timezone.utc
154 else:
155 return zoneinfo.ZoneInfo(self.settings_dict["TIME_ZONE"])
157 @cached_property
158 def timezone_name(self):
159 """
160 Name of the time zone of the database connection.
161 """
162 if self.settings_dict["TIME_ZONE"] is None:
163 return "UTC"
164 else:
165 return self.settings_dict["TIME_ZONE"]
167 @property
168 def queries_logged(self):
169 return self.force_debug_cursor or settings.DEBUG
171 @property
172 def queries(self):
173 if len(self.queries_log) == self.queries_log.maxlen:
174 warnings.warn(
175 "Limit for query logging exceeded, only the last {} queries "
176 "will be returned.".format(self.queries_log.maxlen)
177 )
178 return list(self.queries_log)
180 def get_database_version(self):
181 """Return a tuple of the database's version."""
182 raise NotImplementedError(
183 "subclasses of BaseDatabaseWrapper may require a get_database_version() "
184 "method."
185 )
187 def check_database_version_supported(self):
188 """
189 Raise an error if the database version isn't supported by this
190 version of Plain.
191 """
192 if (
193 self.features.minimum_database_version is not None
194 and self.get_database_version() < self.features.minimum_database_version
195 ):
196 db_version = ".".join(map(str, self.get_database_version()))
197 min_db_version = ".".join(map(str, self.features.minimum_database_version))
198 raise NotSupportedError(
199 f"{self.display_name} {min_db_version} or later is required "
200 f"(found {db_version})."
201 )
203 # ##### Backend-specific methods for creating connections and cursors #####
205 def get_connection_params(self):
206 """Return a dict of parameters suitable for get_new_connection."""
207 raise NotImplementedError(
208 "subclasses of BaseDatabaseWrapper may require a get_connection_params() "
209 "method"
210 )
212 def get_new_connection(self, conn_params):
213 """Open a connection to the database."""
214 raise NotImplementedError(
215 "subclasses of BaseDatabaseWrapper may require a get_new_connection() "
216 "method"
217 )
219 def init_connection_state(self):
220 """Initialize the database connection settings."""
221 global RAN_DB_VERSION_CHECK
222 if self.alias not in RAN_DB_VERSION_CHECK:
223 self.check_database_version_supported()
224 RAN_DB_VERSION_CHECK.add(self.alias)
226 def create_cursor(self, name=None):
227 """Create a cursor. Assume that a connection is established."""
228 raise NotImplementedError(
229 "subclasses of BaseDatabaseWrapper may require a create_cursor() method"
230 )
232 # ##### Backend-specific methods for creating connections #####
234 def connect(self):
235 """Connect to the database. Assume that the connection is closed."""
236 # In case the previous connection was closed while in an atomic block
237 self.in_atomic_block = False
238 self.savepoint_ids = []
239 self.atomic_blocks = []
240 self.needs_rollback = False
241 # Reset parameters defining when to close/health-check the connection.
242 self.health_check_enabled = self.settings_dict["CONN_HEALTH_CHECKS"]
243 max_age = self.settings_dict["CONN_MAX_AGE"]
244 self.close_at = None if max_age is None else time.monotonic() + max_age
245 self.closed_in_transaction = False
246 self.errors_occurred = False
247 # New connections are healthy.
248 self.health_check_done = True
249 # Establish the connection
250 conn_params = self.get_connection_params()
251 self.connection = self.get_new_connection(conn_params)
252 self.set_autocommit(self.settings_dict["AUTOCOMMIT"])
253 self.init_connection_state()
254 connection_created.send(sender=self.__class__, connection=self)
256 self.run_on_commit = []
258 def ensure_connection(self):
259 """Guarantee that a connection to the database is established."""
260 if self.connection is None:
261 with self.wrap_database_errors:
262 self.connect()
264 # ##### Backend-specific wrappers for PEP-249 connection methods #####
266 def _prepare_cursor(self, cursor):
267 """
268 Validate the connection is usable and perform database cursor wrapping.
269 """
270 self.validate_thread_sharing()
271 if self.queries_logged:
272 wrapped_cursor = self.make_debug_cursor(cursor)
273 else:
274 wrapped_cursor = self.make_cursor(cursor)
275 return wrapped_cursor
277 def _cursor(self, name=None):
278 self.close_if_health_check_failed()
279 self.ensure_connection()
280 with self.wrap_database_errors:
281 return self._prepare_cursor(self.create_cursor(name))
283 def _commit(self):
284 if self.connection is not None:
285 with debug_transaction(self, "COMMIT"), self.wrap_database_errors:
286 return self.connection.commit()
288 def _rollback(self):
289 if self.connection is not None:
290 with debug_transaction(self, "ROLLBACK"), self.wrap_database_errors:
291 return self.connection.rollback()
293 def _close(self):
294 if self.connection is not None:
295 with self.wrap_database_errors:
296 return self.connection.close()
298 # ##### Generic wrappers for PEP-249 connection methods #####
300 def cursor(self):
301 """Create a cursor, opening a connection if necessary."""
302 return self._cursor()
304 def commit(self):
305 """Commit a transaction and reset the dirty flag."""
306 self.validate_thread_sharing()
307 self.validate_no_atomic_block()
308 self._commit()
309 # A successful commit means that the database connection works.
310 self.errors_occurred = False
311 self.run_commit_hooks_on_set_autocommit_on = True
313 def rollback(self):
314 """Roll back a transaction and reset the dirty flag."""
315 self.validate_thread_sharing()
316 self.validate_no_atomic_block()
317 self._rollback()
318 # A successful rollback means that the database connection works.
319 self.errors_occurred = False
320 self.needs_rollback = False
321 self.run_on_commit = []
323 def close(self):
324 """Close the connection to the database."""
325 self.validate_thread_sharing()
326 self.run_on_commit = []
328 # Don't call validate_no_atomic_block() to avoid making it difficult
329 # to get rid of a connection in an invalid state. The next connect()
330 # will reset the transaction state anyway.
331 if self.closed_in_transaction or self.connection is None:
332 return
333 try:
334 self._close()
335 finally:
336 if self.in_atomic_block:
337 self.closed_in_transaction = True
338 self.needs_rollback = True
339 else:
340 self.connection = None
342 # ##### Backend-specific savepoint management methods #####
344 def _savepoint(self, sid):
345 with self.cursor() as cursor:
346 cursor.execute(self.ops.savepoint_create_sql(sid))
348 def _savepoint_rollback(self, sid):
349 with self.cursor() as cursor:
350 cursor.execute(self.ops.savepoint_rollback_sql(sid))
352 def _savepoint_commit(self, sid):
353 with self.cursor() as cursor:
354 cursor.execute(self.ops.savepoint_commit_sql(sid))
356 def _savepoint_allowed(self):
357 # Savepoints cannot be created outside a transaction
358 return self.features.uses_savepoints and not self.get_autocommit()
360 # ##### Generic savepoint management methods #####
362 def savepoint(self):
363 """
364 Create a savepoint inside the current transaction. Return an
365 identifier for the savepoint that will be used for the subsequent
366 rollback or commit. Do nothing if savepoints are not supported.
367 """
368 if not self._savepoint_allowed():
369 return
371 thread_ident = _thread.get_ident()
372 tid = str(thread_ident).replace("-", "")
374 self.savepoint_state += 1
375 sid = "s%s_x%d" % (tid, self.savepoint_state)
377 self.validate_thread_sharing()
378 self._savepoint(sid)
380 return sid
382 def savepoint_rollback(self, sid):
383 """
384 Roll back to a savepoint. Do nothing if savepoints are not supported.
385 """
386 if not self._savepoint_allowed():
387 return
389 self.validate_thread_sharing()
390 self._savepoint_rollback(sid)
392 # Remove any callbacks registered while this savepoint was active.
393 self.run_on_commit = [
394 (sids, func, robust)
395 for (sids, func, robust) in self.run_on_commit
396 if sid not in sids
397 ]
399 def savepoint_commit(self, sid):
400 """
401 Release a savepoint. Do nothing if savepoints are not supported.
402 """
403 if not self._savepoint_allowed():
404 return
406 self.validate_thread_sharing()
407 self._savepoint_commit(sid)
409 def clean_savepoints(self):
410 """
411 Reset the counter used to generate unique savepoint ids in this thread.
412 """
413 self.savepoint_state = 0
415 # ##### Backend-specific transaction management methods #####
417 def _set_autocommit(self, autocommit):
418 """
419 Backend-specific implementation to enable or disable autocommit.
420 """
421 raise NotImplementedError(
422 "subclasses of BaseDatabaseWrapper may require a _set_autocommit() method"
423 )
425 # ##### Generic transaction management methods #####
427 def get_autocommit(self):
428 """Get the autocommit state."""
429 self.ensure_connection()
430 return self.autocommit
432 def set_autocommit(
433 self, autocommit, force_begin_transaction_with_broken_autocommit=False
434 ):
435 """
436 Enable or disable autocommit.
438 The usual way to start a transaction is to turn autocommit off.
439 SQLite does not properly start a transaction when disabling
440 autocommit. To avoid this buggy behavior and to actually enter a new
441 transaction, an explicit BEGIN is required. Using
442 force_begin_transaction_with_broken_autocommit=True will issue an
443 explicit BEGIN with SQLite. This option will be ignored for other
444 backends.
445 """
446 self.validate_no_atomic_block()
447 self.close_if_health_check_failed()
448 self.ensure_connection()
450 start_transaction_under_autocommit = (
451 force_begin_transaction_with_broken_autocommit
452 and not autocommit
453 and hasattr(self, "_start_transaction_under_autocommit")
454 )
456 if start_transaction_under_autocommit:
457 self._start_transaction_under_autocommit()
458 elif autocommit:
459 self._set_autocommit(autocommit)
460 else:
461 with debug_transaction(self, "BEGIN"):
462 self._set_autocommit(autocommit)
463 self.autocommit = autocommit
465 if autocommit and self.run_commit_hooks_on_set_autocommit_on:
466 self.run_and_clear_commit_hooks()
467 self.run_commit_hooks_on_set_autocommit_on = False
469 def get_rollback(self):
470 """Get the "needs rollback" flag -- for *advanced use* only."""
471 if not self.in_atomic_block:
472 raise TransactionManagementError(
473 "The rollback flag doesn't work outside of an 'atomic' block."
474 )
475 return self.needs_rollback
477 def set_rollback(self, rollback):
478 """
479 Set or unset the "needs rollback" flag -- for *advanced use* only.
480 """
481 if not self.in_atomic_block:
482 raise TransactionManagementError(
483 "The rollback flag doesn't work outside of an 'atomic' block."
484 )
485 self.needs_rollback = rollback
487 def validate_no_atomic_block(self):
488 """Raise an error if an atomic block is active."""
489 if self.in_atomic_block:
490 raise TransactionManagementError(
491 "This is forbidden when an 'atomic' block is active."
492 )
494 def validate_no_broken_transaction(self):
495 if self.needs_rollback:
496 raise TransactionManagementError(
497 "An error occurred in the current transaction. You can't "
498 "execute queries until the end of the 'atomic' block."
499 ) from self.rollback_exc
501 # ##### Foreign key constraints checks handling #####
503 @contextmanager
504 def constraint_checks_disabled(self):
505 """
506 Disable foreign key constraint checking.
507 """
508 disabled = self.disable_constraint_checking()
509 try:
510 yield
511 finally:
512 if disabled:
513 self.enable_constraint_checking()
515 def disable_constraint_checking(self):
516 """
517 Backends can implement as needed to temporarily disable foreign key
518 constraint checking. Should return True if the constraints were
519 disabled and will need to be reenabled.
520 """
521 return False
523 def enable_constraint_checking(self):
524 """
525 Backends can implement as needed to re-enable foreign key constraint
526 checking.
527 """
528 pass
530 def check_constraints(self, table_names=None):
531 """
532 Backends can override this method if they can apply constraint
533 checking (e.g. via "SET CONSTRAINTS ALL IMMEDIATE"). Should raise an
534 IntegrityError if any invalid foreign key references are encountered.
535 """
536 pass
538 # ##### Connection termination handling #####
540 def is_usable(self):
541 """
542 Test if the database connection is usable.
544 This method may assume that self.connection is not None.
546 Actual implementations should take care not to raise exceptions
547 as that may prevent Plain from recycling unusable connections.
548 """
549 raise NotImplementedError(
550 "subclasses of BaseDatabaseWrapper may require an is_usable() method"
551 )
553 def close_if_health_check_failed(self):
554 """Close existing connection if it fails a health check."""
555 if (
556 self.connection is None
557 or not self.health_check_enabled
558 or self.health_check_done
559 ):
560 return
562 if not self.is_usable():
563 self.close()
564 self.health_check_done = True
566 def close_if_unusable_or_obsolete(self):
567 """
568 Close the current connection if unrecoverable errors have occurred
569 or if it outlived its maximum age.
570 """
571 if self.connection is not None:
572 self.health_check_done = False
573 # If the application didn't restore the original autocommit setting,
574 # don't take chances, drop the connection.
575 if self.get_autocommit() != self.settings_dict["AUTOCOMMIT"]:
576 self.close()
577 return
579 # If an exception other than DataError or IntegrityError occurred
580 # since the last commit / rollback, check if the connection works.
581 if self.errors_occurred:
582 if self.is_usable():
583 self.errors_occurred = False
584 self.health_check_done = True
585 else:
586 self.close()
587 return
589 if self.close_at is not None and time.monotonic() >= self.close_at:
590 self.close()
591 return
593 # ##### Thread safety handling #####
595 @property
596 def allow_thread_sharing(self):
597 with self._thread_sharing_lock:
598 return self._thread_sharing_count > 0
600 def inc_thread_sharing(self):
601 with self._thread_sharing_lock:
602 self._thread_sharing_count += 1
604 def dec_thread_sharing(self):
605 with self._thread_sharing_lock:
606 if self._thread_sharing_count <= 0:
607 raise RuntimeError(
608 "Cannot decrement the thread sharing count below zero."
609 )
610 self._thread_sharing_count -= 1
612 def validate_thread_sharing(self):
613 """
614 Validate that the connection isn't accessed by another thread than the
615 one which originally created it, unless the connection was explicitly
616 authorized to be shared between threads (via the `inc_thread_sharing()`
617 method). Raise an exception if the validation fails.
618 """
619 if not (self.allow_thread_sharing or self._thread_ident == _thread.get_ident()):
620 raise DatabaseError(
621 "DatabaseWrapper objects created in a "
622 "thread can only be used in that same thread. The object "
623 "with alias '{}' was created in thread id {} and this is "
624 "thread id {}.".format(
625 self.alias, self._thread_ident, _thread.get_ident()
626 )
627 )
629 # ##### Miscellaneous #####
631 def prepare_database(self):
632 """
633 Hook to do any database check or preparation, generally called before
634 migrating a project or an app.
635 """
636 pass
638 @cached_property
639 def wrap_database_errors(self):
640 """
641 Context manager and decorator that re-throws backend-specific database
642 exceptions using Plain's common wrappers.
643 """
644 return DatabaseErrorWrapper(self)
646 def chunked_cursor(self):
647 """
648 Return a cursor that tries to avoid caching in the database (if
649 supported by the database), otherwise return a regular cursor.
650 """
651 return self.cursor()
653 def make_debug_cursor(self, cursor):
654 """Create a cursor that logs all queries in self.queries_log."""
655 return utils.CursorDebugWrapper(cursor, self)
657 def make_cursor(self, cursor):
658 """Create a cursor without debug logging."""
659 return utils.CursorWrapper(cursor, self)
661 @contextmanager
662 def temporary_connection(self):
663 """
664 Context manager that ensures that a connection is established, and
665 if it opened one, closes it to avoid leaving a dangling connection.
666 This is useful for operations outside of the request-response cycle.
668 Provide a cursor: with self.temporary_connection() as cursor: ...
669 """
670 must_close = self.connection is None
671 try:
672 with self.cursor() as cursor:
673 yield cursor
674 finally:
675 if must_close:
676 self.close()
678 @contextmanager
679 def _nodb_cursor(self):
680 """
681 Return a cursor from an alternative connection to be used when there is
682 no need to access the main database, specifically for test db
683 creation/deletion. This also prevents the production database from
684 being exposed to potential child threads while (or after) the test
685 database is destroyed. Refs #10868, #17786, #16969.
686 """
687 conn = self.__class__({**self.settings_dict, "NAME": None}, alias=NO_DB_ALIAS)
688 try:
689 with conn.cursor() as cursor:
690 yield cursor
691 finally:
692 conn.close()
694 def schema_editor(self, *args, **kwargs):
695 """
696 Return a new instance of this backend's SchemaEditor.
697 """
698 if self.SchemaEditorClass is None:
699 raise NotImplementedError(
700 "The SchemaEditorClass attribute of this database wrapper is still None"
701 )
702 return self.SchemaEditorClass(self, *args, **kwargs)
704 def on_commit(self, func, robust=False):
705 if not callable(func):
706 raise TypeError("on_commit()'s callback must be a callable.")
707 if self.in_atomic_block:
708 # Transaction in progress; save for execution on commit.
709 self.run_on_commit.append((set(self.savepoint_ids), func, robust))
710 elif not self.get_autocommit():
711 raise TransactionManagementError(
712 "on_commit() cannot be used in manual transaction management"
713 )
714 else:
715 # No transaction in progress and in autocommit mode; execute
716 # immediately.
717 if robust:
718 try:
719 func()
720 except Exception as e:
721 logger.error(
722 f"Error calling {func.__qualname__} in on_commit() (%s).",
723 e,
724 exc_info=True,
725 )
726 else:
727 func()
729 def run_and_clear_commit_hooks(self):
730 self.validate_no_atomic_block()
731 current_run_on_commit = self.run_on_commit
732 self.run_on_commit = []
733 while current_run_on_commit:
734 _, func, robust = current_run_on_commit.pop(0)
735 if robust:
736 try:
737 func()
738 except Exception as e:
739 logger.error(
740 f"Error calling {func.__qualname__} in on_commit() during "
741 f"transaction (%s).",
742 e,
743 exc_info=True,
744 )
745 else:
746 func()
748 @contextmanager
749 def execute_wrapper(self, wrapper):
750 """
751 Return a context manager under which the wrapper is applied to suitable
752 database query executions.
753 """
754 self.execute_wrappers.append(wrapper)
755 try:
756 yield
757 finally:
758 self.execute_wrappers.pop()
760 def copy(self, alias=None):
761 """
762 Return a copy of this connection.
764 For tests that require two connections to the same database.
765 """
766 settings_dict = copy.deepcopy(self.settings_dict)
767 if alias is None:
768 alias = self.alias
769 return type(self)(settings_dict, alias)