Coverage for /Users/davegaeddert/Development/dropseed/plain/plain-models/plain/models/functions/text.py: 69%

150 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-10-16 22:03 -0500

1from plain.models.expressions import Func, Value 

2from plain.models.fields import CharField, IntegerField, TextField 

3from plain.models.functions import Cast, Coalesce 

4from plain.models.lookups import Transform 

5 

6 

7class MySQLSHA2Mixin: 

8 def as_mysql(self, compiler, connection, **extra_context): 

9 return super().as_sql( 

10 compiler, 

11 connection, 

12 template="SHA2(%%(expressions)s, %s)" % self.function[3:], 

13 **extra_context, 

14 ) 

15 

16 

17class PostgreSQLSHAMixin: 

18 def as_postgresql(self, compiler, connection, **extra_context): 

19 return super().as_sql( 

20 compiler, 

21 connection, 

22 template="ENCODE(DIGEST(%(expressions)s, '%(function)s'), 'hex')", 

23 function=self.function.lower(), 

24 **extra_context, 

25 ) 

26 

27 

28class Chr(Transform): 

29 function = "CHR" 

30 lookup_name = "chr" 

31 

32 def as_mysql(self, compiler, connection, **extra_context): 

33 return super().as_sql( 

34 compiler, 

35 connection, 

36 function="CHAR", 

37 template="%(function)s(%(expressions)s USING utf16)", 

38 **extra_context, 

39 ) 

40 

41 def as_sqlite(self, compiler, connection, **extra_context): 

42 return super().as_sql(compiler, connection, function="CHAR", **extra_context) 

43 

44 

45class ConcatPair(Func): 

46 """ 

47 Concatenate two arguments together. This is used by `Concat` because not 

48 all backend databases support more than two arguments. 

49 """ 

50 

51 function = "CONCAT" 

52 

53 def as_sqlite(self, compiler, connection, **extra_context): 

54 coalesced = self.coalesce() 

55 return super(ConcatPair, coalesced).as_sql( 

56 compiler, 

57 connection, 

58 template="%(expressions)s", 

59 arg_joiner=" || ", 

60 **extra_context, 

61 ) 

62 

63 def as_postgresql(self, compiler, connection, **extra_context): 

64 copy = self.copy() 

65 copy.set_source_expressions( 

66 [ 

67 Cast(expression, TextField()) 

68 for expression in copy.get_source_expressions() 

69 ] 

70 ) 

71 return super(ConcatPair, copy).as_sql( 

72 compiler, 

73 connection, 

74 **extra_context, 

75 ) 

76 

77 def as_mysql(self, compiler, connection, **extra_context): 

78 # Use CONCAT_WS with an empty separator so that NULLs are ignored. 

79 return super().as_sql( 

80 compiler, 

81 connection, 

82 function="CONCAT_WS", 

83 template="%(function)s('', %(expressions)s)", 

84 **extra_context, 

85 ) 

86 

87 def coalesce(self): 

88 # null on either side results in null for expression, wrap with coalesce 

89 c = self.copy() 

90 c.set_source_expressions( 

91 [ 

92 Coalesce(expression, Value("")) 

93 for expression in c.get_source_expressions() 

94 ] 

95 ) 

96 return c 

97 

98 

99class Concat(Func): 

100 """ 

101 Concatenate text fields together. Backends that result in an entire 

102 null expression when any arguments are null will wrap each argument in 

103 coalesce functions to ensure a non-null result. 

104 """ 

105 

106 function = None 

107 template = "%(expressions)s" 

108 

109 def __init__(self, *expressions, **extra): 

110 if len(expressions) < 2: 

111 raise ValueError("Concat must take at least two expressions") 

112 paired = self._paired(expressions) 

113 super().__init__(paired, **extra) 

114 

115 def _paired(self, expressions): 

116 # wrap pairs of expressions in successive concat functions 

117 # exp = [a, b, c, d] 

118 # -> ConcatPair(a, ConcatPair(b, ConcatPair(c, d)))) 

119 if len(expressions) == 2: 

120 return ConcatPair(*expressions) 

121 return ConcatPair(expressions[0], self._paired(expressions[1:])) 

122 

123 

124class Left(Func): 

125 function = "LEFT" 

126 arity = 2 

127 output_field = CharField() 

128 

129 def __init__(self, expression, length, **extra): 

130 """ 

131 expression: the name of a field, or an expression returning a string 

132 length: the number of characters to return from the start of the string 

133 """ 

134 if not hasattr(length, "resolve_expression"): 

135 if length < 1: 

136 raise ValueError("'length' must be greater than 0.") 

137 super().__init__(expression, length, **extra) 

138 

139 def get_substr(self): 

140 return Substr(self.source_expressions[0], Value(1), self.source_expressions[1]) 

141 

142 def as_sqlite(self, compiler, connection, **extra_context): 

143 return self.get_substr().as_sqlite(compiler, connection, **extra_context) 

144 

145 

146class Length(Transform): 

147 """Return the number of characters in the expression.""" 

148 

149 function = "LENGTH" 

150 lookup_name = "length" 

151 output_field = IntegerField() 

152 

153 def as_mysql(self, compiler, connection, **extra_context): 

154 return super().as_sql( 

155 compiler, connection, function="CHAR_LENGTH", **extra_context 

156 ) 

157 

158 

159class Lower(Transform): 

160 function = "LOWER" 

161 lookup_name = "lower" 

162 

163 

164class LPad(Func): 

165 function = "LPAD" 

166 output_field = CharField() 

167 

168 def __init__(self, expression, length, fill_text=Value(" "), **extra): 

169 if ( 

170 not hasattr(length, "resolve_expression") 

171 and length is not None 

172 and length < 0 

173 ): 

174 raise ValueError("'length' must be greater or equal to 0.") 

175 super().__init__(expression, length, fill_text, **extra) 

176 

177 

178class LTrim(Transform): 

179 function = "LTRIM" 

180 lookup_name = "ltrim" 

181 

182 

183class MD5(Transform): 

184 function = "MD5" 

185 lookup_name = "md5" 

186 

187 

188class Ord(Transform): 

189 function = "ASCII" 

190 lookup_name = "ord" 

191 output_field = IntegerField() 

192 

193 def as_mysql(self, compiler, connection, **extra_context): 

194 return super().as_sql(compiler, connection, function="ORD", **extra_context) 

195 

196 def as_sqlite(self, compiler, connection, **extra_context): 

197 return super().as_sql(compiler, connection, function="UNICODE", **extra_context) 

198 

199 

200class Repeat(Func): 

201 function = "REPEAT" 

202 output_field = CharField() 

203 

204 def __init__(self, expression, number, **extra): 

205 if ( 

206 not hasattr(number, "resolve_expression") 

207 and number is not None 

208 and number < 0 

209 ): 

210 raise ValueError("'number' must be greater or equal to 0.") 

211 super().__init__(expression, number, **extra) 

212 

213 

214class Replace(Func): 

215 function = "REPLACE" 

216 

217 def __init__(self, expression, text, replacement=Value(""), **extra): 

218 super().__init__(expression, text, replacement, **extra) 

219 

220 

221class Reverse(Transform): 

222 function = "REVERSE" 

223 lookup_name = "reverse" 

224 

225 

226class Right(Left): 

227 function = "RIGHT" 

228 

229 def get_substr(self): 

230 return Substr( 

231 self.source_expressions[0], self.source_expressions[1] * Value(-1) 

232 ) 

233 

234 

235class RPad(LPad): 

236 function = "RPAD" 

237 

238 

239class RTrim(Transform): 

240 function = "RTRIM" 

241 lookup_name = "rtrim" 

242 

243 

244class SHA1(PostgreSQLSHAMixin, Transform): 

245 function = "SHA1" 

246 lookup_name = "sha1" 

247 

248 

249class SHA224(MySQLSHA2Mixin, PostgreSQLSHAMixin, Transform): 

250 function = "SHA224" 

251 lookup_name = "sha224" 

252 

253 

254class SHA256(MySQLSHA2Mixin, PostgreSQLSHAMixin, Transform): 

255 function = "SHA256" 

256 lookup_name = "sha256" 

257 

258 

259class SHA384(MySQLSHA2Mixin, PostgreSQLSHAMixin, Transform): 

260 function = "SHA384" 

261 lookup_name = "sha384" 

262 

263 

264class SHA512(MySQLSHA2Mixin, PostgreSQLSHAMixin, Transform): 

265 function = "SHA512" 

266 lookup_name = "sha512" 

267 

268 

269class StrIndex(Func): 

270 """ 

271 Return a positive integer corresponding to the 1-indexed position of the 

272 first occurrence of a substring inside another string, or 0 if the 

273 substring is not found. 

274 """ 

275 

276 function = "INSTR" 

277 arity = 2 

278 output_field = IntegerField() 

279 

280 def as_postgresql(self, compiler, connection, **extra_context): 

281 return super().as_sql(compiler, connection, function="STRPOS", **extra_context) 

282 

283 

284class Substr(Func): 

285 function = "SUBSTRING" 

286 output_field = CharField() 

287 

288 def __init__(self, expression, pos, length=None, **extra): 

289 """ 

290 expression: the name of a field, or an expression returning a string 

291 pos: an integer > 0, or an expression returning an integer 

292 length: an optional number of characters to return 

293 """ 

294 if not hasattr(pos, "resolve_expression"): 

295 if pos < 1: 

296 raise ValueError("'pos' must be greater than 0") 

297 expressions = [expression, pos] 

298 if length is not None: 

299 expressions.append(length) 

300 super().__init__(*expressions, **extra) 

301 

302 def as_sqlite(self, compiler, connection, **extra_context): 

303 return super().as_sql(compiler, connection, function="SUBSTR", **extra_context) 

304 

305 

306class Trim(Transform): 

307 function = "TRIM" 

308 lookup_name = "trim" 

309 

310 

311class Upper(Transform): 

312 function = "UPPER" 

313 lookup_name = "upper"