Coverage for /Users/davegaeddert/Development/dropseed/plain/plain-models/plain/models/functions/comparison.py: 43%

95 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-10-16 22:03 -0500

1"""Database functions that do comparisons or type conversions.""" 

2from plain.models.db import NotSupportedError 

3from plain.models.expressions import Func, Value 

4from plain.models.fields import TextField 

5from plain.models.fields.json import JSONField 

6from plain.utils.regex_helper import _lazy_re_compile 

7 

8 

9class Cast(Func): 

10 """Coerce an expression to a new field type.""" 

11 

12 function = "CAST" 

13 template = "%(function)s(%(expressions)s AS %(db_type)s)" 

14 

15 def __init__(self, expression, output_field): 

16 super().__init__(expression, output_field=output_field) 

17 

18 def as_sql(self, compiler, connection, **extra_context): 

19 extra_context["db_type"] = self.output_field.cast_db_type(connection) 

20 return super().as_sql(compiler, connection, **extra_context) 

21 

22 def as_sqlite(self, compiler, connection, **extra_context): 

23 db_type = self.output_field.db_type(connection) 

24 if db_type in {"datetime", "time"}: 

25 # Use strftime as datetime/time don't keep fractional seconds. 

26 template = "strftime(%%s, %(expressions)s)" 

27 sql, params = super().as_sql( 

28 compiler, connection, template=template, **extra_context 

29 ) 

30 format_string = "%H:%M:%f" if db_type == "time" else "%Y-%m-%d %H:%M:%f" 

31 params.insert(0, format_string) 

32 return sql, params 

33 elif db_type == "date": 

34 template = "date(%(expressions)s)" 

35 return super().as_sql( 

36 compiler, connection, template=template, **extra_context 

37 ) 

38 return self.as_sql(compiler, connection, **extra_context) 

39 

40 def as_mysql(self, compiler, connection, **extra_context): 

41 template = None 

42 output_type = self.output_field.get_internal_type() 

43 # MySQL doesn't support explicit cast to float. 

44 if output_type == "FloatField": 

45 template = "(%(expressions)s + 0.0)" 

46 # MariaDB doesn't support explicit cast to JSON. 

47 elif output_type == "JSONField" and connection.mysql_is_mariadb: 

48 template = "JSON_EXTRACT(%(expressions)s, '$')" 

49 return self.as_sql(compiler, connection, template=template, **extra_context) 

50 

51 def as_postgresql(self, compiler, connection, **extra_context): 

52 # CAST would be valid too, but the :: shortcut syntax is more readable. 

53 # 'expressions' is wrapped in parentheses in case it's a complex 

54 # expression. 

55 return self.as_sql( 

56 compiler, 

57 connection, 

58 template="(%(expressions)s)::%(db_type)s", 

59 **extra_context, 

60 ) 

61 

62 

63class Coalesce(Func): 

64 """Return, from left to right, the first non-null expression.""" 

65 

66 function = "COALESCE" 

67 

68 def __init__(self, *expressions, **extra): 

69 if len(expressions) < 2: 

70 raise ValueError("Coalesce must take at least two expressions") 

71 super().__init__(*expressions, **extra) 

72 

73 @property 

74 def empty_result_set_value(self): 

75 for expression in self.get_source_expressions(): 

76 result = expression.empty_result_set_value 

77 if result is NotImplemented or result is not None: 

78 return result 

79 return None 

80 

81 

82class Collate(Func): 

83 function = "COLLATE" 

84 template = "%(expressions)s %(function)s %(collation)s" 

85 # Inspired from 

86 # https://www.postgresql.org/docs/current/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS 

87 collation_re = _lazy_re_compile(r"^[\w\-]+$") 

88 

89 def __init__(self, expression, collation): 

90 if not (collation and self.collation_re.match(collation)): 

91 raise ValueError("Invalid collation name: %r." % collation) 

92 self.collation = collation 

93 super().__init__(expression) 

94 

95 def as_sql(self, compiler, connection, **extra_context): 

96 extra_context.setdefault("collation", connection.ops.quote_name(self.collation)) 

97 return super().as_sql(compiler, connection, **extra_context) 

98 

99 

100class Greatest(Func): 

101 """ 

102 Return the maximum expression. 

103 

104 If any expression is null the return value is database-specific: 

105 On PostgreSQL, the maximum not-null expression is returned. 

106 On MySQL, Oracle, and SQLite, if any expression is null, null is returned. 

107 """ 

108 

109 function = "GREATEST" 

110 

111 def __init__(self, *expressions, **extra): 

112 if len(expressions) < 2: 

113 raise ValueError("Greatest must take at least two expressions") 

114 super().__init__(*expressions, **extra) 

115 

116 def as_sqlite(self, compiler, connection, **extra_context): 

117 """Use the MAX function on SQLite.""" 

118 return super().as_sqlite(compiler, connection, function="MAX", **extra_context) 

119 

120 

121class JSONObject(Func): 

122 function = "JSON_OBJECT" 

123 output_field = JSONField() 

124 

125 def __init__(self, **fields): 

126 expressions = [] 

127 for key, value in fields.items(): 

128 expressions.extend((Value(key), value)) 

129 super().__init__(*expressions) 

130 

131 def as_sql(self, compiler, connection, **extra_context): 

132 if not connection.features.has_json_object_function: 

133 raise NotSupportedError( 

134 "JSONObject() is not supported on this database backend." 

135 ) 

136 return super().as_sql(compiler, connection, **extra_context) 

137 

138 def as_postgresql(self, compiler, connection, **extra_context): 

139 copy = self.copy() 

140 copy.set_source_expressions( 

141 [ 

142 Cast(expression, TextField()) if index % 2 == 0 else expression 

143 for index, expression in enumerate(copy.get_source_expressions()) 

144 ] 

145 ) 

146 return super(JSONObject, copy).as_sql( 

147 compiler, 

148 connection, 

149 function="JSONB_BUILD_OBJECT", 

150 **extra_context, 

151 ) 

152 

153 

154class Least(Func): 

155 """ 

156 Return the minimum expression. 

157 

158 If any expression is null the return value is database-specific: 

159 On PostgreSQL, return the minimum not-null expression. 

160 On MySQL, Oracle, and SQLite, if any expression is null, return null. 

161 """ 

162 

163 function = "LEAST" 

164 

165 def __init__(self, *expressions, **extra): 

166 if len(expressions) < 2: 

167 raise ValueError("Least must take at least two expressions") 

168 super().__init__(*expressions, **extra) 

169 

170 def as_sqlite(self, compiler, connection, **extra_context): 

171 """Use the MIN function on SQLite.""" 

172 return super().as_sqlite(compiler, connection, function="MIN", **extra_context) 

173 

174 

175class NullIf(Func): 

176 function = "NULLIF" 

177 arity = 2