1from __future__ import annotations
2
3from datetime import datetime
4from typing import TYPE_CHECKING, Any
5
6from plain.postgres.dialect import (
7 date_extract_sql,
8 date_trunc_sql,
9 datetime_cast_date_sql,
10 datetime_cast_time_sql,
11 datetime_extract_sql,
12 datetime_trunc_sql,
13 time_extract_sql,
14 time_trunc_sql,
15)
16from plain.postgres.expressions import Func
17from plain.postgres.fields import (
18 DateField,
19 DateTimeField,
20 DurationField,
21 Field,
22 IntegerField,
23 TimeField,
24)
25from plain.postgres.lookups import (
26 Transform,
27 YearExact,
28 YearGt,
29 YearGte,
30 YearLt,
31 YearLte,
32)
33from plain.utils import timezone
34
35if TYPE_CHECKING:
36 from plain.postgres.connection import DatabaseConnection
37 from plain.postgres.sql.compiler import SQLCompiler
38
39
40class TimezoneMixin(Transform):
41 tzinfo = None
42
43 def get_tzname(self) -> str | None:
44 # Timezone conversions must happen to the input datetime *before*
45 # applying a function. 2015-12-31 23:00:00 -02:00 is stored in the
46 # database as 2016-01-01 01:00:00 +00:00. Any results should be
47 # based on the input datetime not the stored datetime.
48 if self.tzinfo is None:
49 return timezone.get_current_timezone_name()
50 else:
51 return timezone._get_timezone_name(self.tzinfo)
52
53
54class Extract(TimezoneMixin, Transform):
55 lookup_name: str | None = None
56 output_field = IntegerField()
57
58 def __init__(
59 self,
60 expression: Any,
61 lookup_name: str | None = None,
62 tzinfo: Any = None,
63 **extra: Any,
64 ) -> None:
65 if self.lookup_name is None:
66 self.lookup_name = lookup_name
67 if self.lookup_name is None:
68 raise ValueError("lookup_name must be provided")
69 self.tzinfo = tzinfo
70 super().__init__(expression, **extra)
71
72 def as_sql(
73 self,
74 compiler: SQLCompiler,
75 connection: DatabaseConnection,
76 function: str | None = None,
77 template: str | None = None,
78 arg_joiner: str | None = None,
79 **extra_context: Any,
80 ) -> tuple[str, list[Any]]:
81 # lookup_name is guaranteed to be str after __init__ validation
82 assert self.lookup_name is not None
83 sql, params = compiler.compile(self.lhs)
84 lhs_output_field = self.lhs.output_field
85 if isinstance(lhs_output_field, DateTimeField):
86 tzname = self.get_tzname()
87 sql, params = datetime_extract_sql(
88 self.lookup_name, sql, tuple(params), tzname
89 )
90 elif self.tzinfo is not None:
91 raise ValueError("tzinfo can only be used with DateTimeField.")
92 elif isinstance(lhs_output_field, DateField):
93 sql, params = date_extract_sql(self.lookup_name, sql, tuple(params))
94 elif isinstance(lhs_output_field, TimeField):
95 sql, params = time_extract_sql(self.lookup_name, sql, tuple(params))
96 elif isinstance(lhs_output_field, DurationField):
97 # PostgreSQL has native duration (interval) type
98 sql, params = time_extract_sql(self.lookup_name, sql, tuple(params))
99 else:
100 # resolve_expression has already validated the output_field so this
101 # assert should never be hit.
102 raise ValueError("Tried to Extract from an invalid type.")
103 return sql, list(params)
104
105 def resolve_expression(
106 self,
107 query: Any = None,
108 allow_joins: bool = True,
109 reuse: Any = None,
110 summarize: bool = False,
111 for_save: bool = False,
112 ) -> Extract:
113 copy = super().resolve_expression(
114 query, allow_joins, reuse, summarize, for_save
115 )
116 field = getattr(copy.lhs, "output_field", None)
117 if field is None:
118 return copy
119 if not isinstance(field, DateField | DateTimeField | TimeField | DurationField):
120 raise ValueError(
121 "Extract input expression must be DateField, DateTimeField, "
122 "TimeField, or DurationField."
123 )
124 # Passing dates to functions expecting datetimes is most likely a mistake.
125 if type(field) == DateField and copy.lookup_name in ( # noqa: E721
126 "hour",
127 "minute",
128 "second",
129 ):
130 raise ValueError(
131 f"Cannot extract time component '{copy.lookup_name}' from DateField '{field.name}'."
132 )
133 if isinstance(field, DurationField) and copy.lookup_name in (
134 "year",
135 "iso_year",
136 "month",
137 "week",
138 "week_day",
139 "iso_week_day",
140 "quarter",
141 ):
142 raise ValueError(
143 f"Cannot extract component '{copy.lookup_name}' from DurationField '{field.name}'."
144 )
145 return copy
146
147
148class ExtractYear(Extract):
149 lookup_name = "year"
150
151
152class ExtractIsoYear(Extract):
153 """Return the ISO-8601 week-numbering year."""
154
155 lookup_name = "iso_year"
156
157
158class ExtractMonth(Extract):
159 lookup_name = "month"
160
161
162class ExtractDay(Extract):
163 lookup_name = "day"
164
165
166class ExtractWeek(Extract):
167 """
168 Return 1-52 or 53, based on ISO-8601, i.e., Monday is the first of the
169 week.
170 """
171
172 lookup_name = "week"
173
174
175class ExtractWeekDay(Extract):
176 """
177 Return Sunday=1 through Saturday=7.
178
179 To replicate this in Python: (mydatetime.isoweekday() % 7) + 1
180 """
181
182 lookup_name = "week_day"
183
184
185class ExtractIsoWeekDay(Extract):
186 """Return Monday=1 through Sunday=7, based on ISO-8601."""
187
188 lookup_name = "iso_week_day"
189
190
191class ExtractQuarter(Extract):
192 lookup_name = "quarter"
193
194
195class ExtractHour(Extract):
196 lookup_name = "hour"
197
198
199class ExtractMinute(Extract):
200 lookup_name = "minute"
201
202
203class ExtractSecond(Extract):
204 lookup_name = "second"
205
206
207DateField.register_lookup(ExtractYear)
208DateField.register_lookup(ExtractMonth)
209DateField.register_lookup(ExtractDay)
210DateField.register_lookup(ExtractWeekDay)
211DateField.register_lookup(ExtractIsoWeekDay)
212DateField.register_lookup(ExtractWeek)
213DateField.register_lookup(ExtractIsoYear)
214DateField.register_lookup(ExtractQuarter)
215
216DateTimeField.register_lookup(ExtractYear)
217DateTimeField.register_lookup(ExtractMonth)
218DateTimeField.register_lookup(ExtractDay)
219DateTimeField.register_lookup(ExtractWeekDay)
220DateTimeField.register_lookup(ExtractIsoWeekDay)
221DateTimeField.register_lookup(ExtractWeek)
222DateTimeField.register_lookup(ExtractIsoYear)
223DateTimeField.register_lookup(ExtractQuarter)
224DateTimeField.register_lookup(ExtractHour)
225DateTimeField.register_lookup(ExtractMinute)
226DateTimeField.register_lookup(ExtractSecond)
227
228TimeField.register_lookup(ExtractHour)
229TimeField.register_lookup(ExtractMinute)
230TimeField.register_lookup(ExtractSecond)
231
232ExtractYear.register_lookup(YearExact)
233ExtractYear.register_lookup(YearGt)
234ExtractYear.register_lookup(YearGte)
235ExtractYear.register_lookup(YearLt)
236ExtractYear.register_lookup(YearLte)
237
238ExtractIsoYear.register_lookup(YearExact)
239ExtractIsoYear.register_lookup(YearGt)
240ExtractIsoYear.register_lookup(YearGte)
241ExtractIsoYear.register_lookup(YearLt)
242ExtractIsoYear.register_lookup(YearLte)
243
244
245class Now(Func):
246 # STATEMENT_TIMESTAMP() returns the time at the start of the current statement,
247 # as opposed to CURRENT_TIMESTAMP which returns the time at the start of the
248 # transaction.
249 template = "STATEMENT_TIMESTAMP()"
250 output_field = DateTimeField()
251
252
253class TruncBase(TimezoneMixin, Transform):
254 kind: str | None = None
255
256 def __init__(
257 self,
258 expression: Any,
259 output_field: Field | None = None,
260 tzinfo: Any = None,
261 **extra: Any,
262 ) -> None:
263 self.tzinfo = tzinfo
264 super().__init__(expression, output_field=output_field, **extra)
265
266 def as_sql(
267 self,
268 compiler: SQLCompiler,
269 connection: DatabaseConnection,
270 function: str | None = None,
271 template: str | None = None,
272 arg_joiner: str | None = None,
273 **extra_context: Any,
274 ) -> tuple[str, list[Any]]:
275 # kind is guaranteed to be str in subclasses
276 assert self.kind is not None
277 sql, params = compiler.compile(self.lhs)
278 tzname = None
279 if isinstance(self.lhs.output_field, DateTimeField):
280 tzname = self.get_tzname()
281 elif self.tzinfo is not None:
282 raise ValueError("tzinfo can only be used with DateTimeField.")
283 if isinstance(self.output_field, DateTimeField):
284 sql, params = datetime_trunc_sql(self.kind, sql, tuple(params), tzname)
285 elif isinstance(self.output_field, DateField):
286 sql, params = date_trunc_sql(self.kind, sql, tuple(params), tzname)
287 elif isinstance(self.output_field, TimeField):
288 sql, params = time_trunc_sql(self.kind, sql, tuple(params), tzname)
289 else:
290 raise ValueError(
291 "Trunc only valid on DateField, TimeField, or DateTimeField."
292 )
293 return sql, list(params)
294
295 def resolve_expression(
296 self,
297 query: Any = None,
298 allow_joins: bool = True,
299 reuse: Any = None,
300 summarize: bool = False,
301 for_save: bool = False,
302 ) -> TruncBase:
303 copy = super().resolve_expression(
304 query, allow_joins, reuse, summarize, for_save
305 )
306 field = copy.lhs.output_field
307 if not isinstance(field, DateField | DateTimeField | TimeField):
308 raise TypeError(
309 f"{field.name!r} isn't a DateField, TimeField, or DateTimeField."
310 )
311 # If self.output_field was None, then accessing the field will trigger
312 # the resolver to assign it to self.lhs.output_field.
313 if not isinstance(copy.output_field, DateField | DateTimeField | TimeField):
314 raise ValueError(
315 "output_field must be either DateField, TimeField, or DateTimeField"
316 )
317 # Passing dates or times to functions expecting datetimes is most
318 # likely a mistake.
319 class_output_field = (
320 self.__class__.output_field
321 if isinstance(self.__class__.output_field, Field)
322 else None
323 )
324 output_field = class_output_field or copy.output_field
325 has_explicit_output_field = (
326 class_output_field or field.__class__ is not copy.output_field.__class__
327 )
328 if type(field) == DateField and ( # noqa: E721
329 isinstance(output_field, DateTimeField)
330 or copy.kind in ("hour", "minute", "second", "time")
331 ):
332 raise ValueError(
333 "Cannot truncate DateField '{}' to {}.".format(
334 field.name,
335 output_field.__class__.__name__
336 if has_explicit_output_field
337 else "DateTimeField",
338 )
339 )
340 elif isinstance(field, TimeField) and (
341 isinstance(output_field, DateTimeField)
342 or copy.kind in ("year", "quarter", "month", "week", "day", "date")
343 ):
344 raise ValueError(
345 "Cannot truncate TimeField '{}' to {}.".format(
346 field.name,
347 output_field.__class__.__name__
348 if has_explicit_output_field
349 else "DateTimeField",
350 )
351 )
352 return copy
353
354 def convert_value(
355 self, value: Any, expression: Any, connection: DatabaseConnection
356 ) -> Any:
357 if isinstance(self.output_field, DateTimeField):
358 if value is not None:
359 value = value.replace(tzinfo=None)
360 value = timezone.make_aware(value, self.tzinfo)
361 elif isinstance(value, datetime):
362 if value is None:
363 pass
364 elif isinstance(self.output_field, DateField):
365 value = value.date()
366 elif isinstance(self.output_field, TimeField):
367 value = value.time()
368 return value
369
370
371class Trunc(TruncBase):
372 def __init__(
373 self,
374 expression: Any,
375 kind: str,
376 output_field: Field | None = None,
377 tzinfo: Any = None,
378 **extra: Any,
379 ) -> None:
380 self.kind = kind
381 super().__init__(expression, output_field=output_field, tzinfo=tzinfo, **extra)
382
383
384class TruncYear(TruncBase):
385 kind = "year"
386
387
388class TruncQuarter(TruncBase):
389 kind = "quarter"
390
391
392class TruncMonth(TruncBase):
393 kind = "month"
394
395
396class TruncWeek(TruncBase):
397 """Truncate to midnight on the Monday of the week."""
398
399 kind = "week"
400
401
402class TruncDay(TruncBase):
403 kind = "day"
404
405
406class TruncDate(TruncBase):
407 kind = "date"
408 lookup_name = "date"
409 output_field = DateField()
410
411 def as_sql(
412 self,
413 compiler: SQLCompiler,
414 connection: DatabaseConnection,
415 function: str | None = None,
416 template: str | None = None,
417 arg_joiner: str | None = None,
418 **extra_context: Any,
419 ) -> tuple[str, list[Any]]:
420 # Cast to date rather than truncate to date.
421 sql, params = compiler.compile(self.lhs)
422 tzname = self.get_tzname()
423 sql, params = datetime_cast_date_sql(sql, tuple(params), tzname)
424 return sql, list(params)
425
426
427class TruncTime(TruncBase):
428 kind = "time"
429 lookup_name = "time"
430 output_field = TimeField()
431
432 def as_sql(
433 self,
434 compiler: SQLCompiler,
435 connection: DatabaseConnection,
436 function: str | None = None,
437 template: str | None = None,
438 arg_joiner: str | None = None,
439 **extra_context: Any,
440 ) -> tuple[str, list[Any]]:
441 # Cast to time rather than truncate to time.
442 sql, params = compiler.compile(self.lhs)
443 tzname = self.get_tzname()
444 sql, params = datetime_cast_time_sql(sql, tuple(params), tzname)
445 return sql, list(params)
446
447
448class TruncHour(TruncBase):
449 kind = "hour"
450
451
452class TruncMinute(TruncBase):
453 kind = "minute"
454
455
456class TruncSecond(TruncBase):
457 kind = "second"
458
459
460DateTimeField.register_lookup(TruncDate)
461DateTimeField.register_lookup(TruncTime)