Coverage for /Users/davegaeddert/Developer/dropseed/plain/plain-models/plain/models/db.py: 70%
160 statements
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-23 11:16 -0600
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-23 11:16 -0600
1import pkgutil
2from importlib import import_module
4from plain import signals
5from plain.exceptions import ImproperlyConfigured
6from plain.runtime import settings
7from plain.utils.connection import BaseConnectionHandler, ConnectionProxy
8from plain.utils.functional import cached_property
9from plain.utils.module_loading import import_string
11DEFAULT_DB_ALIAS = "default"
12PLAIN_VERSION_PICKLE_KEY = "_plain_version"
15class Error(Exception):
16 pass
19class InterfaceError(Error):
20 pass
23class DatabaseError(Error):
24 pass
27class DataError(DatabaseError):
28 pass
31class OperationalError(DatabaseError):
32 pass
35class IntegrityError(DatabaseError):
36 pass
39class InternalError(DatabaseError):
40 pass
43class ProgrammingError(DatabaseError):
44 pass
47class NotSupportedError(DatabaseError):
48 pass
51class DatabaseErrorWrapper:
52 """
53 Context manager and decorator that reraises backend-specific database
54 exceptions using Plain's common wrappers.
55 """
57 def __init__(self, wrapper):
58 """
59 wrapper is a database wrapper.
61 It must have a Database attribute defining PEP-249 exceptions.
62 """
63 self.wrapper = wrapper
65 def __enter__(self):
66 pass
68 def __exit__(self, exc_type, exc_value, traceback):
69 if exc_type is None:
70 return
71 for plain_exc_type in (
72 DataError,
73 OperationalError,
74 IntegrityError,
75 InternalError,
76 ProgrammingError,
77 NotSupportedError,
78 DatabaseError,
79 InterfaceError,
80 Error,
81 ):
82 db_exc_type = getattr(self.wrapper.Database, plain_exc_type.__name__)
83 if issubclass(exc_type, db_exc_type):
84 plain_exc_value = plain_exc_type(*exc_value.args)
85 # Only set the 'errors_occurred' flag for errors that may make
86 # the connection unusable.
87 if plain_exc_type not in (DataError, IntegrityError):
88 self.wrapper.errors_occurred = True
89 raise plain_exc_value.with_traceback(traceback) from exc_value
91 def __call__(self, func):
92 # Note that we are intentionally not using @wraps here for performance
93 # reasons. Refs #21109.
94 def inner(*args, **kwargs):
95 with self:
96 return func(*args, **kwargs)
98 return inner
101def load_backend(backend_name):
102 """
103 Return a database backend's "base" module given a fully qualified database
104 backend name, or raise an error if it doesn't exist.
105 """
106 try:
107 return import_module(f"{backend_name}.base")
108 except ImportError as e_user:
109 # The database backend wasn't found. Display a helpful error message
110 # listing all built-in database backends.
111 import plain.models.backends
113 builtin_backends = [
114 name
115 for _, name, ispkg in pkgutil.iter_modules(plain.models.backends.__path__)
116 if ispkg and name not in {"base", "dummy"}
117 ]
118 if backend_name not in [f"plain.models.backends.{b}" for b in builtin_backends]:
119 backend_reprs = map(repr, sorted(builtin_backends))
120 raise ImproperlyConfigured(
121 "{!r} isn't an available database backend or couldn't be "
122 "imported. Check the above exception. To use one of the "
123 "built-in backends, use 'plain.models.backends.XXX', where XXX "
124 "is one of:\n"
125 " {}".format(backend_name, ", ".join(backend_reprs))
126 ) from e_user
127 else:
128 # If there's some other error, this must be an error in Plain
129 raise
132class ConnectionHandler(BaseConnectionHandler):
133 settings_name = "DATABASES"
135 def configure_settings(self, databases):
136 databases = super().configure_settings(databases)
137 if databases == {}:
138 databases[DEFAULT_DB_ALIAS] = {"ENGINE": "plain.models.backends.dummy"}
139 elif DEFAULT_DB_ALIAS not in databases:
140 raise ImproperlyConfigured(
141 f"You must define a '{DEFAULT_DB_ALIAS}' database."
142 )
143 elif databases[DEFAULT_DB_ALIAS] == {}:
144 databases[DEFAULT_DB_ALIAS]["ENGINE"] = "plain.models.backends.dummy"
146 # Configure default settings.
147 for conn in databases.values():
148 conn.setdefault("AUTOCOMMIT", True)
149 conn.setdefault("ENGINE", "plain.models.backends.dummy")
150 if conn["ENGINE"] == "plain.models.backends." or not conn["ENGINE"]:
151 conn["ENGINE"] = "plain.models.backends.dummy"
152 conn.setdefault("CONN_MAX_AGE", 0)
153 conn.setdefault("CONN_HEALTH_CHECKS", False)
154 conn.setdefault("OPTIONS", {})
155 conn.setdefault("TIME_ZONE", None)
156 for setting in ["NAME", "USER", "PASSWORD", "HOST", "PORT"]:
157 conn.setdefault(setting, "")
159 test_settings = conn.setdefault("TEST", {})
160 default_test_settings = [
161 ("CHARSET", None),
162 ("COLLATION", None),
163 ("MIGRATE", True),
164 ("MIRROR", None),
165 ("NAME", None),
166 ]
167 for key, value in default_test_settings:
168 test_settings.setdefault(key, value)
169 return databases
171 @property
172 def databases(self):
173 # Maintained for backward compatibility as some 3rd party packages have
174 # made use of this private API in the past. It is no longer used within
175 # Plain itself.
176 return self.settings
178 def create_connection(self, alias):
179 db = self.settings[alias]
180 backend = load_backend(db["ENGINE"])
181 return backend.DatabaseWrapper(db, alias)
184class ConnectionRouter:
185 def __init__(self, routers=None):
186 """
187 If routers is not specified, default to settings.DATABASE_ROUTERS.
188 """
189 self._routers = routers
191 @cached_property
192 def routers(self):
193 if self._routers is None:
194 self._routers = settings.DATABASE_ROUTERS
195 routers = []
196 for r in self._routers:
197 if isinstance(r, str):
198 router = import_string(r)()
199 else:
200 router = r
201 routers.append(router)
202 return routers
204 def _router_func(action):
205 def _route_db(self, model, **hints):
206 chosen_db = None
207 for router in self.routers:
208 try:
209 method = getattr(router, action)
210 except AttributeError:
211 # If the router doesn't have a method, skip to the next one.
212 pass
213 else:
214 chosen_db = method(model, **hints)
215 if chosen_db:
216 return chosen_db
217 instance = hints.get("instance")
218 if instance is not None and instance._state.db:
219 return instance._state.db
220 return DEFAULT_DB_ALIAS
222 return _route_db
224 db_for_read = _router_func("db_for_read")
225 db_for_write = _router_func("db_for_write")
227 def allow_relation(self, obj1, obj2, **hints):
228 for router in self.routers:
229 try:
230 method = router.allow_relation
231 except AttributeError:
232 # If the router doesn't have a method, skip to the next one.
233 pass
234 else:
235 allow = method(obj1, obj2, **hints)
236 if allow is not None:
237 return allow
238 return obj1._state.db == obj2._state.db
240 def allow_migrate(self, db, package_label, **hints):
241 for router in self.routers:
242 try:
243 method = router.allow_migrate
244 except AttributeError:
245 # If the router doesn't have a method, skip to the next one.
246 continue
248 allow = method(db, package_label, **hints)
250 if allow is not None:
251 return allow
252 return True
254 def allow_migrate_model(self, db, model):
255 return self.allow_migrate(
256 db,
257 model._meta.package_label,
258 model_name=model._meta.model_name,
259 model=model,
260 )
262 def get_migratable_models(self, package_config, db, include_auto_created=False):
263 """Return app models allowed to be migrated on provided db."""
264 models = package_config.get_models(include_auto_created=include_auto_created)
265 return [model for model in models if self.allow_migrate_model(db, model)]
268connections = ConnectionHandler()
270router = ConnectionRouter()
272# For backwards compatibility. Prefer connections['default'] instead.
273connection = ConnectionProxy(connections, DEFAULT_DB_ALIAS)
276# Register an event to reset saved queries when a Plain request is started.
277def reset_queries(**kwargs):
278 for conn in connections.all(initialized_only=True):
279 conn.queries_log.clear()
282signals.request_started.connect(reset_queries)
285# Register an event to reset transaction state and close connections past
286# their lifetime.
287def close_old_connections(**kwargs):
288 for conn in connections.all(initialized_only=True):
289 conn.close_if_unusable_or_obsolete()
292signals.request_started.connect(close_old_connections)
293signals.request_finished.connect(close_old_connections)