v0.145.0
  1from __future__ import annotations
  2
  3from datetime import datetime
  4from typing import TYPE_CHECKING, Any
  5
  6from plain.postgres.dialect import (
  7    date_extract_sql,
  8    date_trunc_sql,
  9    datetime_cast_date_sql,
 10    datetime_cast_time_sql,
 11    datetime_extract_sql,
 12    datetime_trunc_sql,
 13    time_extract_sql,
 14    time_trunc_sql,
 15)
 16from plain.postgres.expressions import Func
 17from plain.postgres.fields import (
 18    DateField,
 19    DateTimeField,
 20    DurationField,
 21    Field,
 22    IntegerField,
 23    TimeField,
 24)
 25from plain.postgres.lookups import (
 26    Transform,
 27    YearExact,
 28    YearGt,
 29    YearGte,
 30    YearLt,
 31    YearLte,
 32)
 33from plain.utils import timezone
 34
 35if TYPE_CHECKING:
 36    from plain.postgres.connection import DatabaseConnection
 37    from plain.postgres.sql.compiler import SQLCompiler
 38
 39
 40class TimezoneMixin(Transform):
 41    tzinfo = None
 42
 43    def get_tzname(self) -> str | None:
 44        # Timezone conversions must happen to the input datetime *before*
 45        # applying a function. 2015-12-31 23:00:00 -02:00 is stored in the
 46        # database as 2016-01-01 01:00:00 +00:00. Any results should be
 47        # based on the input datetime not the stored datetime.
 48        if self.tzinfo is None:
 49            return timezone.get_current_timezone_name()
 50        else:
 51            return timezone._get_timezone_name(self.tzinfo)
 52
 53
 54class Extract(TimezoneMixin, Transform):
 55    lookup_name: str | None = None
 56    output_field = IntegerField()
 57
 58    def __init__(
 59        self,
 60        expression: Any,
 61        lookup_name: str | None = None,
 62        tzinfo: Any = None,
 63        **extra: Any,
 64    ) -> None:
 65        if self.lookup_name is None:
 66            self.lookup_name = lookup_name
 67        if self.lookup_name is None:
 68            raise ValueError("lookup_name must be provided")
 69        self.tzinfo = tzinfo
 70        super().__init__(expression, **extra)
 71
 72    def as_sql(
 73        self,
 74        compiler: SQLCompiler,
 75        connection: DatabaseConnection,
 76        function: str | None = None,
 77        template: str | None = None,
 78        arg_joiner: str | None = None,
 79        **extra_context: Any,
 80    ) -> tuple[str, list[Any]]:
 81        # lookup_name is guaranteed to be str after __init__ validation
 82        assert self.lookup_name is not None
 83        sql, params = compiler.compile(self.lhs)
 84        lhs_output_field = self.lhs.output_field
 85        if isinstance(lhs_output_field, DateTimeField):
 86            tzname = self.get_tzname()
 87            sql, params = datetime_extract_sql(
 88                self.lookup_name, sql, tuple(params), tzname
 89            )
 90        elif self.tzinfo is not None:
 91            raise ValueError("tzinfo can only be used with DateTimeField.")
 92        elif isinstance(lhs_output_field, DateField):
 93            sql, params = date_extract_sql(self.lookup_name, sql, tuple(params))
 94        elif isinstance(lhs_output_field, TimeField):
 95            sql, params = time_extract_sql(self.lookup_name, sql, tuple(params))
 96        elif isinstance(lhs_output_field, DurationField):
 97            # PostgreSQL has native duration (interval) type
 98            sql, params = time_extract_sql(self.lookup_name, sql, tuple(params))
 99        else:
100            # resolve_expression has already validated the output_field so this
101            # assert should never be hit.
102            raise ValueError("Tried to Extract from an invalid type.")
103        return sql, list(params)
104
105    def resolve_expression(
106        self,
107        query: Any = None,
108        allow_joins: bool = True,
109        reuse: Any = None,
110        summarize: bool = False,
111        for_save: bool = False,
112    ) -> Extract:
113        copy = super().resolve_expression(
114            query, allow_joins, reuse, summarize, for_save
115        )
116        field = getattr(copy.lhs, "output_field", None)
117        if field is None:
118            return copy
119        if not isinstance(field, DateField | DateTimeField | TimeField | DurationField):
120            raise ValueError(
121                "Extract input expression must be DateField, DateTimeField, "
122                "TimeField, or DurationField."
123            )
124        # Passing dates to functions expecting datetimes is most likely a mistake.
125        if type(field) == DateField and copy.lookup_name in (  # noqa: E721
126            "hour",
127            "minute",
128            "second",
129        ):
130            raise ValueError(
131                f"Cannot extract time component '{copy.lookup_name}' from DateField '{field.name}'."
132            )
133        if isinstance(field, DurationField) and copy.lookup_name in (
134            "year",
135            "iso_year",
136            "month",
137            "week",
138            "week_day",
139            "iso_week_day",
140            "quarter",
141        ):
142            raise ValueError(
143                f"Cannot extract component '{copy.lookup_name}' from DurationField '{field.name}'."
144            )
145        return copy
146
147
148class ExtractYear(Extract):
149    lookup_name = "year"
150
151
152class ExtractIsoYear(Extract):
153    """Return the ISO-8601 week-numbering year."""
154
155    lookup_name = "iso_year"
156
157
158class ExtractMonth(Extract):
159    lookup_name = "month"
160
161
162class ExtractDay(Extract):
163    lookup_name = "day"
164
165
166class ExtractWeek(Extract):
167    """
168    Return 1-52 or 53, based on ISO-8601, i.e., Monday is the first of the
169    week.
170    """
171
172    lookup_name = "week"
173
174
175class ExtractWeekDay(Extract):
176    """
177    Return Sunday=1 through Saturday=7.
178
179    To replicate this in Python: (mydatetime.isoweekday() % 7) + 1
180    """
181
182    lookup_name = "week_day"
183
184
185class ExtractIsoWeekDay(Extract):
186    """Return Monday=1 through Sunday=7, based on ISO-8601."""
187
188    lookup_name = "iso_week_day"
189
190
191class ExtractQuarter(Extract):
192    lookup_name = "quarter"
193
194
195class ExtractHour(Extract):
196    lookup_name = "hour"
197
198
199class ExtractMinute(Extract):
200    lookup_name = "minute"
201
202
203class ExtractSecond(Extract):
204    lookup_name = "second"
205
206
207DateField.register_lookup(ExtractYear)
208DateField.register_lookup(ExtractMonth)
209DateField.register_lookup(ExtractDay)
210DateField.register_lookup(ExtractWeekDay)
211DateField.register_lookup(ExtractIsoWeekDay)
212DateField.register_lookup(ExtractWeek)
213DateField.register_lookup(ExtractIsoYear)
214DateField.register_lookup(ExtractQuarter)
215
216DateTimeField.register_lookup(ExtractYear)
217DateTimeField.register_lookup(ExtractMonth)
218DateTimeField.register_lookup(ExtractDay)
219DateTimeField.register_lookup(ExtractWeekDay)
220DateTimeField.register_lookup(ExtractIsoWeekDay)
221DateTimeField.register_lookup(ExtractWeek)
222DateTimeField.register_lookup(ExtractIsoYear)
223DateTimeField.register_lookup(ExtractQuarter)
224DateTimeField.register_lookup(ExtractHour)
225DateTimeField.register_lookup(ExtractMinute)
226DateTimeField.register_lookup(ExtractSecond)
227
228TimeField.register_lookup(ExtractHour)
229TimeField.register_lookup(ExtractMinute)
230TimeField.register_lookup(ExtractSecond)
231
232ExtractYear.register_lookup(YearExact)
233ExtractYear.register_lookup(YearGt)
234ExtractYear.register_lookup(YearGte)
235ExtractYear.register_lookup(YearLt)
236ExtractYear.register_lookup(YearLte)
237
238ExtractIsoYear.register_lookup(YearExact)
239ExtractIsoYear.register_lookup(YearGt)
240ExtractIsoYear.register_lookup(YearGte)
241ExtractIsoYear.register_lookup(YearLt)
242ExtractIsoYear.register_lookup(YearLte)
243
244
245class Now(Func):
246    # STATEMENT_TIMESTAMP() returns the time at the start of the current statement,
247    # as opposed to CURRENT_TIMESTAMP which returns the time at the start of the
248    # transaction.
249    template = "STATEMENT_TIMESTAMP()"
250    output_field = DateTimeField()
251
252
253class TruncBase(TimezoneMixin, Transform):
254    kind: str | None = None
255
256    def __init__(
257        self,
258        expression: Any,
259        output_field: Field | None = None,
260        tzinfo: Any = None,
261        **extra: Any,
262    ) -> None:
263        self.tzinfo = tzinfo
264        super().__init__(expression, output_field=output_field, **extra)
265
266    def as_sql(
267        self,
268        compiler: SQLCompiler,
269        connection: DatabaseConnection,
270        function: str | None = None,
271        template: str | None = None,
272        arg_joiner: str | None = None,
273        **extra_context: Any,
274    ) -> tuple[str, list[Any]]:
275        # kind is guaranteed to be str in subclasses
276        assert self.kind is not None
277        sql, params = compiler.compile(self.lhs)
278        tzname = None
279        if isinstance(self.lhs.output_field, DateTimeField):
280            tzname = self.get_tzname()
281        elif self.tzinfo is not None:
282            raise ValueError("tzinfo can only be used with DateTimeField.")
283        if isinstance(self.output_field, DateTimeField):
284            sql, params = datetime_trunc_sql(self.kind, sql, tuple(params), tzname)
285        elif isinstance(self.output_field, DateField):
286            sql, params = date_trunc_sql(self.kind, sql, tuple(params), tzname)
287        elif isinstance(self.output_field, TimeField):
288            sql, params = time_trunc_sql(self.kind, sql, tuple(params), tzname)
289        else:
290            raise ValueError(
291                "Trunc only valid on DateField, TimeField, or DateTimeField."
292            )
293        return sql, list(params)
294
295    def resolve_expression(
296        self,
297        query: Any = None,
298        allow_joins: bool = True,
299        reuse: Any = None,
300        summarize: bool = False,
301        for_save: bool = False,
302    ) -> TruncBase:
303        copy = super().resolve_expression(
304            query, allow_joins, reuse, summarize, for_save
305        )
306        field = copy.lhs.output_field
307        if not isinstance(field, DateField | DateTimeField | TimeField):
308            raise TypeError(
309                f"{field.name!r} isn't a DateField, TimeField, or DateTimeField."
310            )
311        # If self.output_field was None, then accessing the field will trigger
312        # the resolver to assign it to self.lhs.output_field.
313        if not isinstance(copy.output_field, DateField | DateTimeField | TimeField):
314            raise ValueError(
315                "output_field must be either DateField, TimeField, or DateTimeField"
316            )
317        # Passing dates or times to functions expecting datetimes is most
318        # likely a mistake.
319        class_output_field = (
320            self.__class__.output_field
321            if isinstance(self.__class__.output_field, Field)
322            else None
323        )
324        output_field = class_output_field or copy.output_field
325        has_explicit_output_field = (
326            class_output_field or field.__class__ is not copy.output_field.__class__
327        )
328        if type(field) == DateField and (  # noqa: E721
329            isinstance(output_field, DateTimeField)
330            or copy.kind in ("hour", "minute", "second", "time")
331        ):
332            raise ValueError(
333                "Cannot truncate DateField '{}' to {}.".format(
334                    field.name,
335                    output_field.__class__.__name__
336                    if has_explicit_output_field
337                    else "DateTimeField",
338                )
339            )
340        elif isinstance(field, TimeField) and (
341            isinstance(output_field, DateTimeField)
342            or copy.kind in ("year", "quarter", "month", "week", "day", "date")
343        ):
344            raise ValueError(
345                "Cannot truncate TimeField '{}' to {}.".format(
346                    field.name,
347                    output_field.__class__.__name__
348                    if has_explicit_output_field
349                    else "DateTimeField",
350                )
351            )
352        return copy
353
354    def convert_value(
355        self, value: Any, expression: Any, connection: DatabaseConnection
356    ) -> Any:
357        if isinstance(self.output_field, DateTimeField):
358            if value is not None:
359                value = value.replace(tzinfo=None)
360                value = timezone.make_aware(value, self.tzinfo)
361        elif isinstance(value, datetime):
362            if value is None:
363                pass
364            elif isinstance(self.output_field, DateField):
365                value = value.date()
366            elif isinstance(self.output_field, TimeField):
367                value = value.time()
368        return value
369
370
371class Trunc(TruncBase):
372    def __init__(
373        self,
374        expression: Any,
375        kind: str,
376        output_field: Field | None = None,
377        tzinfo: Any = None,
378        **extra: Any,
379    ) -> None:
380        self.kind = kind
381        super().__init__(expression, output_field=output_field, tzinfo=tzinfo, **extra)
382
383
384class TruncYear(TruncBase):
385    kind = "year"
386
387
388class TruncQuarter(TruncBase):
389    kind = "quarter"
390
391
392class TruncMonth(TruncBase):
393    kind = "month"
394
395
396class TruncWeek(TruncBase):
397    """Truncate to midnight on the Monday of the week."""
398
399    kind = "week"
400
401
402class TruncDay(TruncBase):
403    kind = "day"
404
405
406class TruncDate(TruncBase):
407    kind = "date"
408    lookup_name = "date"
409    output_field = DateField()
410
411    def as_sql(
412        self,
413        compiler: SQLCompiler,
414        connection: DatabaseConnection,
415        function: str | None = None,
416        template: str | None = None,
417        arg_joiner: str | None = None,
418        **extra_context: Any,
419    ) -> tuple[str, list[Any]]:
420        # Cast to date rather than truncate to date.
421        sql, params = compiler.compile(self.lhs)
422        tzname = self.get_tzname()
423        sql, params = datetime_cast_date_sql(sql, tuple(params), tzname)
424        return sql, list(params)
425
426
427class TruncTime(TruncBase):
428    kind = "time"
429    lookup_name = "time"
430    output_field = TimeField()
431
432    def as_sql(
433        self,
434        compiler: SQLCompiler,
435        connection: DatabaseConnection,
436        function: str | None = None,
437        template: str | None = None,
438        arg_joiner: str | None = None,
439        **extra_context: Any,
440    ) -> tuple[str, list[Any]]:
441        # Cast to time rather than truncate to time.
442        sql, params = compiler.compile(self.lhs)
443        tzname = self.get_tzname()
444        sql, params = datetime_cast_time_sql(sql, tuple(params), tzname)
445        return sql, list(params)
446
447
448class TruncHour(TruncBase):
449    kind = "hour"
450
451
452class TruncMinute(TruncBase):
453    kind = "minute"
454
455
456class TruncSecond(TruncBase):
457    kind = "second"
458
459
460DateTimeField.register_lookup(TruncDate)
461DateTimeField.register_lookup(TruncTime)