Coverage for /Users/davegaeddert/Developer/dropseed/plain/plain-models/plain/models/backends/utils.py: 25%

157 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2024-12-23 11:16 -0600

1import datetime 

2import decimal 

3import functools 

4import logging 

5import time 

6from contextlib import contextmanager 

7from hashlib import md5 

8 

9from plain.models.db import NotSupportedError 

10from plain.utils.dateparse import parse_time 

11 

12logger = logging.getLogger("plain.models.backends") 

13 

14 

15class CursorWrapper: 

16 def __init__(self, cursor, db): 

17 self.cursor = cursor 

18 self.db = db 

19 

20 WRAP_ERROR_ATTRS = frozenset(["fetchone", "fetchmany", "fetchall", "nextset"]) 

21 

22 def __getattr__(self, attr): 

23 cursor_attr = getattr(self.cursor, attr) 

24 if attr in CursorWrapper.WRAP_ERROR_ATTRS: 

25 return self.db.wrap_database_errors(cursor_attr) 

26 else: 

27 return cursor_attr 

28 

29 def __iter__(self): 

30 with self.db.wrap_database_errors: 

31 yield from self.cursor 

32 

33 def __enter__(self): 

34 return self 

35 

36 def __exit__(self, type, value, traceback): 

37 # Close instead of passing through to avoid backend-specific behavior 

38 # (#17671). Catch errors liberally because errors in cleanup code 

39 # aren't useful. 

40 try: 

41 self.close() 

42 except self.db.Database.Error: 

43 pass 

44 

45 # The following methods cannot be implemented in __getattr__, because the 

46 # code must run when the method is invoked, not just when it is accessed. 

47 

48 def callproc(self, procname, params=None, kparams=None): 

49 # Keyword parameters for callproc aren't supported in PEP 249, but the 

50 # database driver may support them (e.g. cx_Oracle). 

51 if kparams is not None and not self.db.features.supports_callproc_kwargs: 

52 raise NotSupportedError( 

53 "Keyword parameters for callproc are not supported on this " 

54 "database backend." 

55 ) 

56 self.db.validate_no_broken_transaction() 

57 with self.db.wrap_database_errors: 

58 if params is None and kparams is None: 

59 return self.cursor.callproc(procname) 

60 elif kparams is None: 

61 return self.cursor.callproc(procname, params) 

62 else: 

63 params = params or () 

64 return self.cursor.callproc(procname, params, kparams) 

65 

66 def execute(self, sql, params=None): 

67 return self._execute_with_wrappers( 

68 sql, params, many=False, executor=self._execute 

69 ) 

70 

71 def executemany(self, sql, param_list): 

72 return self._execute_with_wrappers( 

73 sql, param_list, many=True, executor=self._executemany 

74 ) 

75 

76 def _execute_with_wrappers(self, sql, params, many, executor): 

77 context = {"connection": self.db, "cursor": self} 

78 for wrapper in reversed(self.db.execute_wrappers): 

79 executor = functools.partial(wrapper, executor) 

80 return executor(sql, params, many, context) 

81 

82 def _execute(self, sql, params, *ignored_wrapper_args): 

83 self.db.validate_no_broken_transaction() 

84 with self.db.wrap_database_errors: 

85 if params is None: 

86 # params default might be backend specific. 

87 return self.cursor.execute(sql) 

88 else: 

89 return self.cursor.execute(sql, params) 

90 

91 def _executemany(self, sql, param_list, *ignored_wrapper_args): 

92 self.db.validate_no_broken_transaction() 

93 with self.db.wrap_database_errors: 

94 return self.cursor.executemany(sql, param_list) 

95 

96 

97class CursorDebugWrapper(CursorWrapper): 

98 # XXX callproc isn't instrumented at this time. 

99 

100 def execute(self, sql, params=None): 

101 with self.debug_sql(sql, params, use_last_executed_query=True): 

102 return super().execute(sql, params) 

103 

104 def executemany(self, sql, param_list): 

105 with self.debug_sql(sql, param_list, many=True): 

106 return super().executemany(sql, param_list) 

107 

108 @contextmanager 

109 def debug_sql( 

110 self, sql=None, params=None, use_last_executed_query=False, many=False 

111 ): 

112 start = time.monotonic() 

113 try: 

114 yield 

115 finally: 

116 stop = time.monotonic() 

117 duration = stop - start 

118 if use_last_executed_query: 

119 sql = self.db.ops.last_executed_query(self.cursor, sql, params) 

120 try: 

121 times = len(params) if many else "" 

122 except TypeError: 

123 # params could be an iterator. 

124 times = "?" 

125 self.db.queries_log.append( 

126 { 

127 "sql": f"{times} times: {sql}" if many else sql, 

128 "time": f"{duration:.3f}", 

129 } 

130 ) 

131 logger.debug( 

132 "(%.3f) %s; args=%s; alias=%s", 

133 duration, 

134 sql, 

135 params, 

136 self.db.alias, 

137 extra={ 

138 "duration": duration, 

139 "sql": sql, 

140 "params": params, 

141 "alias": self.db.alias, 

142 }, 

143 ) 

144 

145 

146@contextmanager 

147def debug_transaction(connection, sql): 

148 start = time.monotonic() 

149 try: 

150 yield 

151 finally: 

152 if connection.queries_logged: 

153 stop = time.monotonic() 

154 duration = stop - start 

155 connection.queries_log.append( 

156 { 

157 "sql": f"{sql}", 

158 "time": f"{duration:.3f}", 

159 } 

160 ) 

161 logger.debug( 

162 "(%.3f) %s; args=%s; alias=%s", 

163 duration, 

164 sql, 

165 None, 

166 connection.alias, 

167 extra={ 

168 "duration": duration, 

169 "sql": sql, 

170 "alias": connection.alias, 

171 }, 

172 ) 

173 

174 

175def split_tzname_delta(tzname): 

176 """ 

177 Split a time zone name into a 3-tuple of (name, sign, offset). 

178 """ 

179 for sign in ["+", "-"]: 

180 if sign in tzname: 

181 name, offset = tzname.rsplit(sign, 1) 

182 if offset and parse_time(offset): 

183 return name, sign, offset 

184 return tzname, None, None 

185 

186 

187############################################### 

188# Converters from database (string) to Python # 

189############################################### 

190 

191 

192def typecast_date(s): 

193 return ( 

194 datetime.date(*map(int, s.split("-"))) if s else None 

195 ) # return None if s is null 

196 

197 

198def typecast_time(s): # does NOT store time zone information 

199 if not s: 

200 return None 

201 hour, minutes, seconds = s.split(":") 

202 if "." in seconds: # check whether seconds have a fractional part 

203 seconds, microseconds = seconds.split(".") 

204 else: 

205 microseconds = "0" 

206 return datetime.time( 

207 int(hour), int(minutes), int(seconds), int((microseconds + "000000")[:6]) 

208 ) 

209 

210 

211def typecast_timestamp(s): # does NOT store time zone information 

212 # "2005-07-29 15:48:00.590358-05" 

213 # "2005-07-29 09:56:00-05" 

214 if not s: 

215 return None 

216 if " " not in s: 

217 return typecast_date(s) 

218 d, t = s.split() 

219 # Remove timezone information. 

220 if "-" in t: 

221 t, _ = t.split("-", 1) 

222 elif "+" in t: 

223 t, _ = t.split("+", 1) 

224 dates = d.split("-") 

225 times = t.split(":") 

226 seconds = times[2] 

227 if "." in seconds: # check whether seconds have a fractional part 

228 seconds, microseconds = seconds.split(".") 

229 else: 

230 microseconds = "0" 

231 return datetime.datetime( 

232 int(dates[0]), 

233 int(dates[1]), 

234 int(dates[2]), 

235 int(times[0]), 

236 int(times[1]), 

237 int(seconds), 

238 int((microseconds + "000000")[:6]), 

239 ) 

240 

241 

242############################################### 

243# Converters from Python to database (string) # 

244############################################### 

245 

246 

247def split_identifier(identifier): 

248 """ 

249 Split an SQL identifier into a two element tuple of (namespace, name). 

250 

251 The identifier could be a table, column, or sequence name might be prefixed 

252 by a namespace. 

253 """ 

254 try: 

255 namespace, name = identifier.split('"."') 

256 except ValueError: 

257 namespace, name = "", identifier 

258 return namespace.strip('"'), name.strip('"') 

259 

260 

261def truncate_name(identifier, length=None, hash_len=4): 

262 """ 

263 Shorten an SQL identifier to a repeatable mangled version with the given 

264 length. 

265 

266 If a quote stripped name contains a namespace, e.g. USERNAME"."TABLE, 

267 truncate the table portion only. 

268 """ 

269 namespace, name = split_identifier(identifier) 

270 

271 if length is None or len(name) <= length: 

272 return identifier 

273 

274 digest = names_digest(name, length=hash_len) 

275 return "{}{}{}".format( 

276 f'{namespace}"."' if namespace else "", 

277 name[: length - hash_len], 

278 digest, 

279 ) 

280 

281 

282def names_digest(*args, length): 

283 """ 

284 Generate a 32-bit digest of a set of arguments that can be used to shorten 

285 identifying names. 

286 """ 

287 h = md5(usedforsecurity=False) 

288 for arg in args: 

289 h.update(arg.encode()) 

290 return h.hexdigest()[:length] 

291 

292 

293def format_number(value, max_digits, decimal_places): 

294 """ 

295 Format a number into a string with the requisite number of digits and 

296 decimal places. 

297 """ 

298 if value is None: 

299 return None 

300 context = decimal.getcontext().copy() 

301 if max_digits is not None: 

302 context.prec = max_digits 

303 if decimal_places is not None: 

304 value = value.quantize( 

305 decimal.Decimal(1).scaleb(-decimal_places), context=context 

306 ) 

307 else: 

308 context.traps[decimal.Rounded] = 1 

309 value = context.create_decimal(value) 

310 return f"{value:f}" 

311 

312 

313def strip_quotes(table_name): 

314 """ 

315 Strip quotes off of quoted table names to make them safe for use in index 

316 names, sequence names, etc. For example '"USER"."TABLE"' (an Oracle naming 

317 scheme) becomes 'USER"."TABLE'. 

318 """ 

319 has_quotes = table_name.startswith('"') and table_name.endswith('"') 

320 return table_name[1:-1] if has_quotes else table_name