1from __future__ import annotations
2
3import copy
4import datetime
5import functools
6import inspect
7from collections import defaultdict
8from decimal import Decimal
9from functools import cached_property
10from types import NoneType
11from typing import TYPE_CHECKING, Any, Protocol, Self, runtime_checkable
12from uuid import UUID
13
14import psycopg
15
16from plain.postgres import fields
17from plain.postgres.constants import LOOKUP_SEP
18from plain.postgres.dialect import (
19 CURRENT_ROW,
20 FOLLOWING,
21 PRECEDING,
22 UNBOUNDED_FOLLOWING,
23 UNBOUNDED_PRECEDING,
24 combine_expression,
25 quote_name,
26 subtract_temporals,
27 window_frame_range_start_end,
28 window_frame_rows_start_end,
29)
30from plain.postgres.exceptions import EmptyResultSet, FieldError, FullResultSet
31from plain.postgres.query_utils import Q
32from plain.utils.deconstruct import deconstructible
33from plain.utils.hashable import make_hashable
34
35if TYPE_CHECKING:
36 from collections.abc import Callable, Iterable, Sequence
37
38 from plain.postgres.connection import DatabaseConnection
39 from plain.postgres.fields import Field
40 from plain.postgres.lookups import Lookup, Transform
41 from plain.postgres.query import QuerySet
42 from plain.postgres.sql.compiler import SQLCompilable, SQLCompiler
43 from plain.postgres.sql.query import Query
44
45__all__ = [
46 # Core expression classes
47 "F",
48 "Value",
49 "Case",
50 "When",
51 "Subquery",
52 "Exists",
53 "OuterRef",
54 "Window",
55 "ExpressionWrapper",
56 "RawSQL",
57 "OrderBy",
58 # Base classes (for extension)
59 "Func",
60 "Expression",
61 "Combinable",
62 # Window frame specs
63 "RowRange",
64 "ValueRange",
65]
66
67
68@runtime_checkable
69class ResolvableExpression(Protocol):
70 """Protocol for expressions that can be resolved in query context."""
71
72 def resolve_expression(
73 self,
74 query: Any = None,
75 allow_joins: bool = True,
76 reuse: Any = None,
77 summarize: bool = False,
78 for_save: bool = False,
79 ) -> Any: ...
80
81
82@runtime_checkable
83class ReplaceableExpression(Protocol):
84 """Protocol for expressions that support expression replacement."""
85
86 def replace_expressions(self, replacements: dict[Any, Any]) -> Self: ...
87
88
89class Combinable:
90 """
91 Provide the ability to combine one or two objects with
92 some connector. For example F('foo') + F('bar').
93 """
94
95 # Arithmetic connectors
96 ADD = "+"
97 SUB = "-"
98 MUL = "*"
99 DIV = "/"
100 POW = "^"
101 # The following is a quoted % operator - it is quoted because it can be
102 # used in strings that also have parameter substitution.
103 MOD = "%%"
104
105 # Bitwise operators - note that these are generated by .bitand()
106 # and .bitor(), the '&' and '|' are reserved for boolean operator
107 # usage.
108 BITAND = "&"
109 BITOR = "|"
110 BITLEFTSHIFT = "<<"
111 BITRIGHTSHIFT = ">>"
112 BITXOR = "#"
113
114 def _combine(
115 self, other: Any, connector: str, reversed: bool
116 ) -> CombinedExpression:
117 if not isinstance(other, ResolvableExpression):
118 # everything must be resolvable to an expression
119 other = Value(other)
120
121 if reversed:
122 return CombinedExpression(other, connector, self)
123 return CombinedExpression(self, connector, other)
124
125 #############
126 # OPERATORS #
127 #############
128
129 def __neg__(self) -> CombinedExpression:
130 return self._combine(-1, self.MUL, False)
131
132 def __add__(self, other: Any) -> CombinedExpression:
133 return self._combine(other, self.ADD, False)
134
135 def __sub__(self, other: Any) -> CombinedExpression:
136 return self._combine(other, self.SUB, False)
137
138 def __mul__(self, other: Any) -> CombinedExpression:
139 return self._combine(other, self.MUL, False)
140
141 def __truediv__(self, other: Any) -> CombinedExpression:
142 return self._combine(other, self.DIV, False)
143
144 def __mod__(self, other: Any) -> CombinedExpression:
145 return self._combine(other, self.MOD, False)
146
147 def __pow__(self, other: Any) -> CombinedExpression:
148 return self._combine(other, self.POW, False)
149
150 def __and__(self, other: Any) -> Q:
151 if getattr(self, "conditional", False) and getattr(other, "conditional", False):
152 return Q(self) & Q(other)
153 raise NotImplementedError(
154 "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
155 )
156
157 def bitand(self, other: Any) -> CombinedExpression:
158 return self._combine(other, self.BITAND, False)
159
160 def bitleftshift(self, other: Any) -> CombinedExpression:
161 return self._combine(other, self.BITLEFTSHIFT, False)
162
163 def bitrightshift(self, other: Any) -> CombinedExpression:
164 return self._combine(other, self.BITRIGHTSHIFT, False)
165
166 def __xor__(self, other: Any) -> Q:
167 if getattr(self, "conditional", False) and getattr(other, "conditional", False):
168 return Q(self) ^ Q(other)
169 raise NotImplementedError(
170 "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
171 )
172
173 def bitxor(self, other: Any) -> CombinedExpression:
174 return self._combine(other, self.BITXOR, False)
175
176 def __or__(self, other: Any) -> Q:
177 if getattr(self, "conditional", False) and getattr(other, "conditional", False):
178 return Q(self) | Q(other)
179 raise NotImplementedError(
180 "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
181 )
182
183 def bitor(self, other: Any) -> CombinedExpression:
184 return self._combine(other, self.BITOR, False)
185
186 def __radd__(self, other: Any) -> CombinedExpression:
187 return self._combine(other, self.ADD, True)
188
189 def __rsub__(self, other: Any) -> CombinedExpression:
190 return self._combine(other, self.SUB, True)
191
192 def __rmul__(self, other: Any) -> CombinedExpression:
193 return self._combine(other, self.MUL, True)
194
195 def __rtruediv__(self, other: Any) -> CombinedExpression:
196 return self._combine(other, self.DIV, True)
197
198 def __rmod__(self, other: Any) -> CombinedExpression:
199 return self._combine(other, self.MOD, True)
200
201 def __rpow__(self, other: Any) -> CombinedExpression:
202 return self._combine(other, self.POW, True)
203
204 def __rand__(self, other: Any) -> None:
205 raise NotImplementedError(
206 "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
207 )
208
209 def __ror__(self, other: Any) -> None:
210 raise NotImplementedError(
211 "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
212 )
213
214 def __rxor__(self, other: Any) -> None:
215 raise NotImplementedError(
216 "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
217 )
218
219 def __invert__(self) -> NegatedExpression:
220 return NegatedExpression(self)
221
222
223class BaseExpression:
224 """Base class for all query expressions."""
225
226 empty_result_set_value = NotImplemented
227 # aggregate specific fields
228 is_summary = False
229 _output_field_resolved_to_none = False
230 # Can the expression be used in a WHERE clause?
231 filterable = True
232 # Can the expression can be used as a source expression in Window?
233 window_compatible = False
234
235 def __init__(self, output_field: Field | None = None):
236 if output_field is not None:
237 self.output_field = output_field
238
239 def __getstate__(self) -> dict[str, Any]:
240 state = self.__dict__.copy()
241 state.pop("convert_value", None)
242 return state
243
244 def get_db_converters(
245 self, connection: DatabaseConnection
246 ) -> list[Callable[..., Any]]:
247 converters = []
248 if self.convert_value is not self._convert_value_noop:
249 converters.append(self.convert_value)
250 converters.extend(self.output_field.get_db_converters(connection))
251 return converters
252
253 def get_source_expressions(self) -> list[Any]:
254 return []
255
256 def set_source_expressions(self, exprs: Sequence[Any]) -> None:
257 assert not exprs
258
259 def _parse_expressions(self, *expressions: Any) -> list[Any]:
260 return [
261 arg
262 if isinstance(arg, ResolvableExpression)
263 else (F(arg) if isinstance(arg, str) else Value(arg))
264 for arg in expressions
265 ]
266
267 def as_sql(
268 self, compiler: SQLCompiler, connection: DatabaseConnection
269 ) -> tuple[str, Sequence[Any]]:
270 """
271 Return a (sql, params) tuple to be included in the current query.
272
273 Arguments:
274 * compiler: the query compiler responsible for generating the query.
275 Must have a compile method, returning a (sql, [params]) tuple.
276 Calling compiler(value) will return a quoted `value`.
277
278 * connection: the database connection used for the current query.
279
280 Return: (sql, params)
281 Where `sql` is a string containing ordered sql parameters to be
282 replaced with the elements of the list `params`.
283 """
284 raise NotImplementedError("Subclasses must implement as_sql()")
285
286 @cached_property
287 def contains_aggregate(self) -> bool:
288 return any(
289 expr and expr.contains_aggregate for expr in self.get_source_expressions()
290 )
291
292 @cached_property
293 def contains_over_clause(self) -> bool:
294 return any(
295 expr and expr.contains_over_clause for expr in self.get_source_expressions()
296 )
297
298 @cached_property
299 def contains_column_references(self) -> bool:
300 return any(
301 expr and expr.contains_column_references
302 for expr in self.get_source_expressions()
303 )
304
305 def resolve_expression(
306 self,
307 query: Any = None,
308 allow_joins: bool = True,
309 reuse: Any = None,
310 summarize: bool = False,
311 for_save: bool = False,
312 ) -> Self:
313 """
314 Provide the chance to do any preprocessing or validation before being
315 added to the query.
316
317 Arguments:
318 * query: the backend query implementation
319 * allow_joins: boolean allowing or denying use of joins
320 in this query
321 * reuse: a set of reusable joins for multijoins
322 * summarize: a terminal aggregate clause
323 * for_save: whether this expression about to be used in a save or update
324
325 Return: an Expression to be added to the query.
326 """
327 c = self.copy()
328 c.is_summary = summarize
329 c.set_source_expressions(
330 [
331 expr.resolve_expression(query, allow_joins, reuse, summarize)
332 if expr
333 else None
334 for expr in c.get_source_expressions()
335 ]
336 )
337 return c
338
339 @property
340 def conditional(self) -> bool:
341 output_field = getattr(self, "output_field", None)
342 return isinstance(output_field, fields.BooleanField)
343
344 @property
345 def field(self) -> Field:
346 return self.output_field
347
348 @cached_property
349 def output_field(self) -> Field:
350 """Return the output type of this expressions."""
351 output_field = self._resolve_output_field()
352 if output_field is None:
353 self._output_field_resolved_to_none = True
354 raise FieldError("Cannot resolve expression type, unknown output_field")
355 return output_field
356
357 @cached_property
358 def _output_field_or_none(self) -> Field | None:
359 """
360 Return the output field of this expression, or None if
361 _resolve_output_field() didn't return an output type.
362 """
363 try:
364 return self.output_field
365 except FieldError:
366 if not self._output_field_resolved_to_none:
367 raise
368 return None
369
370 def _resolve_output_field(self) -> Field | None:
371 """
372 Attempt to infer the output type of the expression.
373
374 As a guess, if the output fields of all source fields match then simply
375 infer the same type here.
376
377 If a source's output field resolves to None, exclude it from this check.
378 If all sources are None, then an error is raised higher up the stack in
379 the output_field property.
380 """
381 # This guess is mostly a bad idea, but there is quite a lot of code
382 # (especially 3rd party Func subclasses) that depend on it, we'd need a
383 # deprecation path to fix it.
384 sources_iter = (
385 source for source in self.get_source_fields() if source is not None
386 )
387 for output_field in sources_iter:
388 for source in sources_iter:
389 if not isinstance(output_field, source.__class__):
390 raise FieldError(
391 f"Expression contains mixed types: {output_field.__class__.__name__}, {source.__class__.__name__}. You must "
392 "set output_field."
393 )
394 return output_field
395 return None
396
397 @staticmethod
398 def _convert_value_noop(
399 value: Any, expression: Any, connection: DatabaseConnection
400 ) -> Any:
401 return value
402
403 @cached_property
404 def convert_value(self) -> Callable[[Any, Any, Any], Any]:
405 """
406 Expressions provide their own converters because users have the option
407 of manually specifying the output_field which may be a different type
408 from the one the database returns.
409 """
410 field = self.output_field
411 if isinstance(field, fields.FloatField):
412 return (
413 lambda value, expression, connection: None
414 if value is None
415 else float(value)
416 )
417 elif isinstance(field, fields.IntegerField | fields.PrimaryKeyField):
418 return (
419 lambda value, expression, connection: None
420 if value is None
421 else int(value)
422 )
423 elif isinstance(field, fields.DecimalField):
424 return (
425 lambda value, expression, connection: None
426 if value is None
427 else Decimal(value)
428 )
429 return self._convert_value_noop
430
431 def get_lookup(self, lookup: str) -> type[Lookup] | None:
432 return self.output_field.get_lookup(lookup)
433
434 def get_transform(self, name: str) -> type[Transform] | None:
435 return self.output_field.get_transform(name) # ty: ignore[invalid-return-type]
436
437 def relabeled_clone(self, change_map: dict[str, str]) -> Self:
438 clone = self.copy()
439 clone.set_source_expressions(
440 [
441 e.relabeled_clone(change_map) if e is not None else None
442 for e in self.get_source_expressions()
443 ]
444 )
445 return clone
446
447 def replace_expressions(self, replacements: dict[BaseExpression, Any]) -> Self:
448 if replacement := replacements.get(self):
449 return replacement
450 clone = self.copy()
451 source_expressions = clone.get_source_expressions()
452 clone.set_source_expressions(
453 [
454 expr.replace_expressions(replacements) if expr else None
455 for expr in source_expressions
456 ]
457 )
458 return clone
459
460 def get_refs(self) -> set[str]:
461 refs = set()
462 for expr in self.get_source_expressions():
463 refs |= expr.get_refs()
464 return refs
465
466 def copy(self) -> Self:
467 return copy.copy(self)
468
469 def prefix_references(self, prefix: str) -> Self:
470 clone = self.copy()
471 clone.set_source_expressions(
472 [
473 F(f"{prefix}{expr.name}")
474 if isinstance(expr, F)
475 else expr.prefix_references(prefix)
476 for expr in self.get_source_expressions()
477 ]
478 )
479 return clone
480
481 def get_group_by_cols(self) -> list[BaseExpression]:
482 if not self.contains_aggregate:
483 return [self]
484 cols: list[BaseExpression] = []
485 for source in self.get_source_expressions():
486 cols.extend(source.get_group_by_cols())
487 return cols
488
489 def get_source_fields(self) -> list[Field | None]:
490 """Return the underlying field types used by this aggregate."""
491 return [e._output_field_or_none for e in self.get_source_expressions()]
492
493 def asc(self, **kwargs: Any) -> OrderBy:
494 return OrderBy(self, **kwargs)
495
496 def desc(self, **kwargs: Any) -> OrderBy:
497 return OrderBy(self, descending=True, **kwargs)
498
499 def reverse_ordering(self) -> Self:
500 return self
501
502 def flatten(self) -> Iterable[Any]:
503 """
504 Recursively yield this expression and all subexpressions, in
505 depth-first order.
506 """
507 yield self
508 for expr in self.get_source_expressions():
509 if expr:
510 if hasattr(expr, "flatten"):
511 yield from expr.flatten()
512 else:
513 yield expr
514
515 def select_format(
516 self, compiler: SQLCompiler, sql: str, params: Sequence[Any]
517 ) -> tuple[str, Sequence[Any]]:
518 """Custom format for select clauses."""
519 if output_field := getattr(self, "output_field", None):
520 if select_format := getattr(output_field, "select_format", None):
521 return select_format(compiler, sql, params)
522 return sql, params
523
524
525@deconstructible
526class Expression(BaseExpression, Combinable):
527 """An expression that can be combined with other expressions."""
528
529 # Set by @deconstructible decorator in __new__
530 _constructor_args: tuple[tuple[Any, ...], dict[str, Any]]
531
532 @cached_property
533 def identity(self) -> tuple[Any, ...]:
534 constructor_signature = inspect.signature(self.__init__)
535 args, kwargs = self._constructor_args
536 signature = constructor_signature.bind_partial(*args, **kwargs)
537 signature.apply_defaults()
538 arguments = signature.arguments.items()
539 identity: list[Any] = [self.__class__]
540 for arg, value in arguments:
541 if isinstance(value, fields.Field):
542 if value.name and value.model:
543 value = (value.model.model_options.label, value.name)
544 else:
545 value = type(value)
546 else:
547 value = make_hashable(value)
548 identity.append((arg, value))
549 return tuple(identity)
550
551 def __eq__(self, other: object) -> bool:
552 if not isinstance(other, Expression):
553 return NotImplemented
554 return other.identity == self.identity
555
556 def __hash__(self) -> int:
557 return hash(self.identity)
558
559
560# Type inference for CombinedExpression.output_field.
561# Missing items will result in FieldError, by design.
562#
563# The current approach for NULL is based on lowest common denominator behavior
564# i.e. if one of the supported databases is raising an error (rather than
565# return NULL) for `val <op> NULL`, then Plain raises FieldError.
566
567_connector_combinations = [
568 # Numeric operations - operands of same type.
569 {
570 connector: [
571 (fields.IntegerField, fields.IntegerField, fields.IntegerField),
572 (fields.FloatField, fields.FloatField, fields.FloatField),
573 (fields.DecimalField, fields.DecimalField, fields.DecimalField),
574 ]
575 for connector in (
576 Combinable.ADD,
577 Combinable.SUB,
578 Combinable.MUL,
579 Combinable.DIV,
580 Combinable.MOD,
581 Combinable.POW,
582 )
583 },
584 # Numeric operations - operands of different type.
585 {
586 connector: [
587 (fields.IntegerField, fields.DecimalField, fields.DecimalField),
588 (fields.DecimalField, fields.IntegerField, fields.DecimalField),
589 (fields.IntegerField, fields.FloatField, fields.FloatField),
590 (fields.FloatField, fields.IntegerField, fields.FloatField),
591 ]
592 for connector in (
593 Combinable.ADD,
594 Combinable.SUB,
595 Combinable.MUL,
596 Combinable.DIV,
597 Combinable.MOD,
598 )
599 },
600 # Bitwise operators.
601 {
602 connector: [
603 (fields.IntegerField, fields.IntegerField, fields.IntegerField),
604 ]
605 for connector in (
606 Combinable.BITAND,
607 Combinable.BITOR,
608 Combinable.BITLEFTSHIFT,
609 Combinable.BITRIGHTSHIFT,
610 Combinable.BITXOR,
611 )
612 },
613 # Numeric with NULL.
614 {
615 connector: [
616 (field_type, NoneType, field_type),
617 (NoneType, field_type, field_type),
618 ]
619 for connector in (
620 Combinable.ADD,
621 Combinable.SUB,
622 Combinable.MUL,
623 Combinable.DIV,
624 Combinable.MOD,
625 Combinable.POW,
626 )
627 for field_type in (fields.IntegerField, fields.DecimalField, fields.FloatField)
628 },
629 # Date/DateTimeField/DurationField/TimeField.
630 {
631 Combinable.ADD: [
632 # Date/DateTimeField.
633 (fields.DateField, fields.DurationField, fields.DateTimeField),
634 (fields.DateTimeField, fields.DurationField, fields.DateTimeField),
635 (fields.DurationField, fields.DateField, fields.DateTimeField),
636 (fields.DurationField, fields.DateTimeField, fields.DateTimeField),
637 # DurationField.
638 (fields.DurationField, fields.DurationField, fields.DurationField),
639 # TimeField.
640 (fields.TimeField, fields.DurationField, fields.TimeField),
641 (fields.DurationField, fields.TimeField, fields.TimeField),
642 ],
643 },
644 {
645 Combinable.SUB: [
646 # Date/DateTimeField.
647 (fields.DateField, fields.DurationField, fields.DateTimeField),
648 (fields.DateTimeField, fields.DurationField, fields.DateTimeField),
649 (fields.DateField, fields.DateField, fields.DurationField),
650 (fields.DateField, fields.DateTimeField, fields.DurationField),
651 (fields.DateTimeField, fields.DateField, fields.DurationField),
652 (fields.DateTimeField, fields.DateTimeField, fields.DurationField),
653 # DurationField.
654 (fields.DurationField, fields.DurationField, fields.DurationField),
655 # TimeField.
656 (fields.TimeField, fields.DurationField, fields.TimeField),
657 (fields.TimeField, fields.TimeField, fields.DurationField),
658 ],
659 },
660]
661
662_connector_combinators = defaultdict(list)
663
664
665def register_combinable_fields(
666 lhs: type[Field] | type[None],
667 connector: str,
668 rhs: type[Field] | type[None],
669 result: type[Field],
670) -> None:
671 """
672 Register combinable types:
673 lhs <connector> rhs -> result
674 e.g.
675 register_combinable_fields(
676 IntegerField, Combinable.ADD, FloatField, FloatField
677 )
678 """
679 _connector_combinators[connector].append((lhs, rhs, result))
680
681
682for d in _connector_combinations:
683 for connector, field_types in d.items():
684 for lhs, rhs, result in field_types:
685 register_combinable_fields(lhs, connector, rhs, result)
686
687
688@functools.lru_cache(maxsize=128)
689def _resolve_combined_type(
690 connector: str, lhs_type: type[Field], rhs_type: type[Field]
691) -> type[Field] | None:
692 combinators = _connector_combinators.get(connector, ())
693 for combinator_lhs_type, combinator_rhs_type, combined_type in combinators:
694 if issubclass(lhs_type, combinator_lhs_type) and issubclass(
695 rhs_type, combinator_rhs_type
696 ):
697 return combined_type
698 return None
699
700
701class CombinedExpression(Expression):
702 def __init__(
703 self, lhs: Any, connector: str, rhs: Any, output_field: Field | None = None
704 ):
705 super().__init__(output_field=output_field)
706 self.connector = connector
707 self.lhs = lhs
708 self.rhs = rhs
709
710 def __repr__(self) -> str:
711 return f"<{self.__class__.__name__}: {self}>"
712
713 def __str__(self) -> str:
714 return f"{self.lhs} {self.connector} {self.rhs}"
715
716 def get_source_expressions(self) -> list[Any]:
717 return [self.lhs, self.rhs]
718
719 def set_source_expressions(self, exprs: Sequence[Any]) -> None:
720 self.lhs, self.rhs = exprs
721
722 def _resolve_output_field(self) -> Field | None:
723 # We avoid using super() here for reasons given in
724 # Expression._resolve_output_field()
725 combined_type = _resolve_combined_type(
726 self.connector,
727 type(self.lhs._output_field_or_none),
728 type(self.rhs._output_field_or_none),
729 )
730 if combined_type is None:
731 raise FieldError(
732 f"Cannot infer type of {self.connector!r} expression involving these "
733 f"types: {self.lhs.output_field.__class__.__name__}, "
734 f"{self.rhs.output_field.__class__.__name__}. You must set "
735 f"output_field."
736 )
737 return combined_type()
738
739 def as_sql(
740 self, compiler: SQLCompiler, connection: DatabaseConnection
741 ) -> tuple[str, list[Any]]:
742 expressions = []
743 expression_params = []
744 sql, params = compiler.compile(self.lhs)
745 expressions.append(sql)
746 expression_params.extend(params)
747 sql, params = compiler.compile(self.rhs)
748 expressions.append(sql)
749 expression_params.extend(params)
750 # order of precedence
751 expression_wrapper = "(%s)"
752 sql = combine_expression(self.connector, expressions)
753 return expression_wrapper % sql, expression_params
754
755 def resolve_expression(
756 self,
757 query: Any = None,
758 allow_joins: bool = True,
759 reuse: Any = None,
760 summarize: bool = False,
761 for_save: bool = False,
762 ) -> CombinedExpression | TemporalSubtraction:
763 lhs = self.lhs.resolve_expression(
764 query, allow_joins, reuse, summarize, for_save
765 )
766 rhs = self.rhs.resolve_expression(
767 query, allow_joins, reuse, summarize, for_save
768 )
769 if not isinstance(self, TemporalSubtraction):
770 try:
771 lhs_field = lhs.output_field
772 except (AttributeError, FieldError):
773 lhs_field = None
774 try:
775 rhs_field = rhs.output_field
776 except (AttributeError, FieldError):
777 rhs_field = None
778 is_temporal = isinstance(
779 lhs_field, fields.DateField | fields.DateTimeField | fields.TimeField
780 )
781 same_type = (
782 lhs_field is not None
783 and rhs_field is not None
784 and type(lhs_field) is type(rhs_field)
785 )
786 if self.connector == self.SUB and is_temporal and same_type:
787 return TemporalSubtraction(self.lhs, self.rhs).resolve_expression(
788 query,
789 allow_joins,
790 reuse,
791 summarize,
792 for_save,
793 )
794 c = self.copy()
795 c.is_summary = summarize
796 c.lhs = lhs
797 c.rhs = rhs
798 return c
799
800
801class TemporalSubtraction(CombinedExpression):
802 output_field = fields.DurationField()
803
804 def __init__(self, lhs: Any, rhs: Any):
805 super().__init__(lhs, self.SUB, rhs)
806
807 def as_sql(
808 self, compiler: SQLCompiler, connection: DatabaseConnection
809 ) -> tuple[str, list[Any]]:
810 lhs = compiler.compile(self.lhs)
811 rhs = compiler.compile(self.rhs)
812 sql, params = subtract_temporals(self.lhs.output_field, lhs, rhs)
813 return sql, list(params)
814
815
816@deconstructible(path="plain.postgres.F")
817class F(Combinable):
818 """An object capable of resolving references to existing query objects."""
819
820 def __init__(self, name: str):
821 """
822 Arguments:
823 * name: the name of the field this expression references
824 """
825 self.name = name
826
827 def __repr__(self) -> str:
828 return f"{self.__class__.__name__}({self.name})"
829
830 def resolve_expression(
831 self,
832 query: Any = None,
833 allow_joins: bool = True,
834 reuse: Any = None,
835 summarize: bool = False,
836 for_save: bool = False,
837 ) -> Any:
838 return query.resolve_ref(self.name, allow_joins, reuse, summarize)
839
840 def replace_expressions(self, replacements: dict[Any, Any]) -> F:
841 return replacements.get(self, self)
842
843 def asc(self, **kwargs: Any) -> OrderBy:
844 return OrderBy(self, **kwargs)
845
846 def desc(self, **kwargs: Any) -> OrderBy:
847 return OrderBy(self, descending=True, **kwargs)
848
849 def __eq__(self, other: object) -> bool:
850 if not isinstance(other, F):
851 return NotImplemented
852 return self.__class__ == other.__class__ and self.name == other.name
853
854 def __hash__(self) -> int:
855 return hash(self.name)
856
857 def copy(self) -> Self:
858 return copy.copy(self)
859
860
861class ResolvedOuterRef(F):
862 """
863 An object that contains a reference to an outer query.
864
865 In this case, the reference to the outer query has been resolved because
866 the inner query has been used as a subquery.
867 """
868
869 contains_aggregate = False
870 contains_over_clause = False
871
872 def as_sql(self, *args: Any, **kwargs: Any) -> None:
873 raise ValueError(
874 "This queryset contains a reference to an outer query and may "
875 "only be used in a subquery."
876 )
877
878 def resolve_expression(self, *args: Any, **kwargs: Any) -> Any:
879 col = super().resolve_expression(*args, **kwargs)
880 if col.contains_over_clause:
881 raise psycopg.NotSupportedError(
882 f"Referencing outer query window expression is not supported: "
883 f"{self.name}."
884 )
885 # FIXME: Rename possibly_multivalued to multivalued and fix detection
886 # for non-multivalued JOINs (e.g. foreign key fields). This should take
887 # into account only many-to-many and one-to-many relationships.
888 col.possibly_multivalued = LOOKUP_SEP in self.name
889 return col
890
891 def relabeled_clone(self, relabels: dict[str, str]) -> ResolvedOuterRef:
892 return self
893
894 def get_group_by_cols(self) -> list[Any]:
895 return []
896
897
898class OuterRef(F):
899 contains_aggregate = False
900
901 def resolve_expression(self, *args: Any, **kwargs: Any) -> ResolvedOuterRef | F:
902 if isinstance(self.name, self.__class__):
903 return self.name
904 return ResolvedOuterRef(self.name)
905
906 def relabeled_clone(self, relabels: dict[str, str]) -> OuterRef:
907 return self
908
909
910@deconstructible(path="plain.postgres.expressions.Func")
911class Func(Expression):
912 """An SQL function call."""
913
914 function = None
915 template = "%(function)s(%(expressions)s)"
916 arg_joiner = ", "
917 arity = None # The number of arguments the function accepts.
918
919 def __init__(
920 self, *expressions: Any, output_field: Field | None = None, **extra: Any
921 ):
922 if self.arity is not None and len(expressions) != self.arity:
923 raise TypeError(
924 "'{}' takes exactly {} {} ({} given)".format(
925 self.__class__.__name__,
926 self.arity,
927 "argument" if self.arity == 1 else "arguments",
928 len(expressions),
929 )
930 )
931 super().__init__(output_field=output_field)
932 self.source_expressions: list[Any] = self._parse_expressions(*expressions)
933 self.extra = extra
934
935 def __repr__(self) -> str:
936 args = self.arg_joiner.join(str(arg) for arg in self.source_expressions)
937 extra = {**self.extra, **self._get_repr_options()}
938 if extra:
939 extra = ", ".join(
940 str(key) + "=" + str(val) for key, val in sorted(extra.items())
941 )
942 return f"{self.__class__.__name__}({args}, {extra})"
943 return f"{self.__class__.__name__}({args})"
944
945 def _get_repr_options(self) -> dict[str, Any]:
946 """Return a dict of extra __init__() options to include in the repr."""
947 return {}
948
949 def get_source_expressions(self) -> list[Any]:
950 return self.source_expressions
951
952 def set_source_expressions(self, exprs: Sequence[Any]) -> None:
953 self.source_expressions = list(exprs)
954
955 def resolve_expression(
956 self,
957 query: Any = None,
958 allow_joins: bool = True,
959 reuse: Any = None,
960 summarize: bool = False,
961 for_save: bool = False,
962 ) -> Self:
963 c = self.copy()
964 c.is_summary = summarize
965 for pos, arg in enumerate(c.source_expressions):
966 c.source_expressions[pos] = arg.resolve_expression(
967 query, allow_joins, reuse, summarize, for_save
968 )
969 return c
970
971 def as_sql(
972 self,
973 compiler: SQLCompiler,
974 connection: DatabaseConnection,
975 function: str | None = None,
976 template: str | None = None,
977 arg_joiner: str | None = None,
978 **extra_context: Any,
979 ) -> tuple[str, list[Any]]:
980 sql_parts = []
981 params = []
982 for arg in self.source_expressions:
983 try:
984 arg_sql, arg_params = compiler.compile(arg)
985 except EmptyResultSet:
986 empty_result_set_value = getattr(
987 arg, "empty_result_set_value", NotImplemented
988 )
989 if empty_result_set_value is NotImplemented:
990 raise
991 arg_sql, arg_params = compiler.compile(Value(empty_result_set_value))
992 except FullResultSet:
993 arg_sql, arg_params = compiler.compile(Value(True))
994 sql_parts.append(arg_sql)
995 params.extend(arg_params)
996 data = {**self.extra, **extra_context}
997 # Use the first supplied value in this order: the parameter to this
998 # method, a value supplied in __init__()'s **extra (the value in
999 # `data`), or the value defined on the class.
1000 if function is not None:
1001 data["function"] = function
1002 else:
1003 data.setdefault("function", self.function)
1004 template = template or data.get("template", self.template)
1005 arg_joiner = arg_joiner or data.get("arg_joiner", self.arg_joiner)
1006 data["expressions"] = data["field"] = arg_joiner.join(sql_parts)
1007 return template % data, params
1008
1009 def copy(self) -> Self:
1010 clone = super().copy()
1011 clone.source_expressions = self.source_expressions[:]
1012 clone.extra = self.extra.copy()
1013 return clone
1014
1015
1016@deconstructible(path="plain.postgres.expressions.Value")
1017class Value(Expression):
1018 """Represent a wrapped value as a node within an expression."""
1019
1020 # Provide a default value for `for_save` in order to allow unresolved
1021 # instances to be compiled until a decision is taken in #25425.
1022 for_save = False
1023
1024 def __init__(self, value: Any, output_field: Field | None = None):
1025 """
1026 Arguments:
1027 * value: the value this expression represents. The value will be
1028 added into the sql parameter list and properly quoted.
1029
1030 * output_field: an instance of the model field type that this
1031 expression will return, such as IntegerField() or TextField().
1032 """
1033 super().__init__(output_field=output_field)
1034 self.value = value
1035
1036 def __repr__(self) -> str:
1037 return f"{self.__class__.__name__}({self.value!r})"
1038
1039 def as_sql(
1040 self, compiler: SQLCompiler, connection: DatabaseConnection
1041 ) -> tuple[str, list[Any]]:
1042 val = self.value
1043 output_field = self._output_field_or_none
1044 if output_field is not None:
1045 if self.for_save:
1046 val = output_field.get_db_prep_save(val, connection=connection)
1047 else:
1048 val = output_field.get_db_prep_value(val, connection=connection)
1049 if hasattr(output_field, "get_placeholder"):
1050 return output_field.get_placeholder(val, compiler, connection), [val] # ty: ignore[call-non-callable]
1051 if val is None:
1052 return "NULL", []
1053 return "%s", [val]
1054
1055 def resolve_expression(
1056 self,
1057 query: Any = None,
1058 allow_joins: bool = True,
1059 reuse: Any = None,
1060 summarize: bool = False,
1061 for_save: bool = False,
1062 ) -> Value:
1063 c = super().resolve_expression(query, allow_joins, reuse, summarize, for_save)
1064 c.for_save = for_save
1065 return c
1066
1067 def get_group_by_cols(self) -> list[Any]:
1068 return []
1069
1070 def _resolve_output_field(self) -> Field | None:
1071 if isinstance(self.value, str):
1072 return fields.TextField()
1073 if isinstance(self.value, bool):
1074 return fields.BooleanField()
1075 if isinstance(self.value, int):
1076 return fields.IntegerField()
1077 if isinstance(self.value, float):
1078 return fields.FloatField()
1079 if isinstance(self.value, datetime.datetime):
1080 return fields.DateTimeField()
1081 if isinstance(self.value, datetime.date):
1082 return fields.DateField()
1083 if isinstance(self.value, datetime.time):
1084 return fields.TimeField()
1085 if isinstance(self.value, datetime.timedelta):
1086 return fields.DurationField()
1087 if isinstance(self.value, Decimal):
1088 return fields.DecimalField()
1089 if isinstance(self.value, bytes):
1090 return fields.BinaryField()
1091 if isinstance(self.value, UUID):
1092 return fields.UUIDField()
1093
1094 @property
1095 def empty_result_set_value(self) -> Any:
1096 return self.value
1097
1098
1099class RawSQL(Expression):
1100 def __init__(
1101 self, sql: str, params: Sequence[Any], output_field: Field | None = None
1102 ):
1103 if output_field is None:
1104 output_field = fields.Field()
1105 self.sql, self.params = sql, params
1106 super().__init__(output_field=output_field)
1107
1108 def __repr__(self) -> str:
1109 return f"{self.__class__.__name__}({self.sql}, {self.params})"
1110
1111 def as_sql(
1112 self, compiler: SQLCompiler, connection: DatabaseConnection
1113 ) -> tuple[str, Sequence[Any]]:
1114 return f"({self.sql})", self.params
1115
1116 def get_group_by_cols(self) -> list[BaseExpression]:
1117 return [self]
1118
1119
1120class Star(Expression):
1121 def __repr__(self) -> str:
1122 return "'*'"
1123
1124 def as_sql(
1125 self, compiler: SQLCompiler, connection: DatabaseConnection
1126 ) -> tuple[str, list[Any]]:
1127 return "*", []
1128
1129
1130class Col(Expression):
1131 contains_column_references = True
1132 possibly_multivalued = False
1133
1134 def __init__(
1135 self, alias: str | None, target: Any, output_field: Field | None = None
1136 ):
1137 if output_field is None:
1138 output_field = target
1139 super().__init__(output_field=output_field)
1140 self.alias, self.target = alias, target
1141
1142 def __repr__(self) -> str:
1143 alias, target = self.alias, self.target
1144 identifiers = (alias, str(target)) if alias else (str(target),)
1145 return "{}({})".format(self.__class__.__name__, ", ".join(identifiers))
1146
1147 def as_sql(
1148 self, compiler: SQLCompiler, connection: DatabaseConnection
1149 ) -> tuple[str, list[Any]]:
1150 alias, column = self.alias, self.target.column
1151 identifiers = (alias, column) if alias else (column,)
1152 sql = ".".join(map(compiler.quote_name_unless_alias, identifiers))
1153 return sql, []
1154
1155 def relabeled_clone(self, change_map: dict[str, str]) -> Self:
1156 if self.alias is None:
1157 return self
1158 return self.__class__(
1159 change_map.get(self.alias, self.alias), self.target, self.output_field
1160 )
1161
1162 def get_group_by_cols(self) -> list[BaseExpression]:
1163 return [self]
1164
1165 def get_db_converters(
1166 self, connection: DatabaseConnection
1167 ) -> list[Callable[..., Any]]:
1168 if self.target == self.output_field:
1169 return self.output_field.get_db_converters(connection)
1170 return self.output_field.get_db_converters(
1171 connection
1172 ) + self.target.get_db_converters(connection)
1173
1174
1175class Ref(Expression):
1176 """
1177 Reference to column alias of the query. For example, Ref('sum_cost') in
1178 qs.annotate(sum_cost=Sum('cost')) query.
1179 """
1180
1181 def __init__(self, refs: str, source: Any):
1182 super().__init__()
1183 self.refs, self.source = refs, source
1184
1185 def __repr__(self) -> str:
1186 return f"{self.__class__.__name__}({self.refs}, {self.source})"
1187
1188 def get_source_expressions(self) -> list[Any]:
1189 return [self.source]
1190
1191 def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1192 (self.source,) = exprs
1193
1194 def resolve_expression(
1195 self,
1196 query: Any = None,
1197 allow_joins: bool = True,
1198 reuse: Any = None,
1199 summarize: bool = False,
1200 for_save: bool = False,
1201 ) -> Ref:
1202 # The sub-expression `source` has already been resolved, as this is
1203 # just a reference to the name of `source`.
1204 return self
1205
1206 def get_refs(self) -> set[str]:
1207 return {self.refs}
1208
1209 def relabeled_clone(self, change_map: dict[str, str]) -> Self:
1210 return self
1211
1212 def as_sql(
1213 self, compiler: SQLCompiler, connection: DatabaseConnection
1214 ) -> tuple[str, list[Any]]:
1215 return quote_name(self.refs), []
1216
1217 def get_group_by_cols(self) -> list[BaseExpression]:
1218 return [self]
1219
1220
1221class ExpressionList(Func):
1222 """
1223 An expression containing multiple expressions. Can be used to provide a
1224 list of expressions as an argument to another expression, like a partition
1225 clause.
1226 """
1227
1228 template = "%(expressions)s"
1229
1230 def __init__(self, *expressions: Any, **extra: Any):
1231 if not expressions:
1232 raise ValueError(
1233 f"{self.__class__.__name__} requires at least one expression."
1234 )
1235 super().__init__(*expressions, **extra)
1236
1237 def __str__(self) -> str:
1238 return self.arg_joiner.join(str(arg) for arg in self.source_expressions)
1239
1240
1241class OrderByList(Func):
1242 template = "ORDER BY %(expressions)s"
1243
1244 def __init__(self, *expressions: Any, **extra: Any):
1245 expressions_tuple = tuple(
1246 (
1247 OrderBy(F(expr[1:]), descending=True)
1248 if isinstance(expr, str) and expr[0] == "-"
1249 else expr
1250 )
1251 for expr in expressions
1252 )
1253 super().__init__(*expressions_tuple, **extra)
1254
1255 def as_sql(self, *args: Any, **kwargs: Any) -> tuple[str, list[Any]]:
1256 if not self.source_expressions:
1257 return "", []
1258 sql, params = super().as_sql(*args, **kwargs)
1259 return sql, list(params)
1260
1261 def get_group_by_cols(self) -> list[Any]:
1262 group_by_cols = []
1263 for order_by in self.get_source_expressions():
1264 group_by_cols.extend(order_by.get_group_by_cols())
1265 return group_by_cols
1266
1267
1268@deconstructible(path="plain.postgres.expressions.ExpressionWrapper")
1269class ExpressionWrapper(Expression):
1270 """
1271 An expression that can wrap another expression so that it can provide
1272 extra context to the inner expression, such as the output_field.
1273 """
1274
1275 def __init__(self, expression: Any, output_field: Field):
1276 super().__init__(output_field=output_field)
1277 self.expression = expression
1278
1279 def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1280 self.expression = exprs[0]
1281
1282 def get_source_expressions(self) -> list[Any]:
1283 return [self.expression]
1284
1285 def get_group_by_cols(self) -> list[Any]:
1286 if isinstance(self.expression, Expression):
1287 expression = self.expression.copy()
1288 expression.output_field = self.output_field
1289 return expression.get_group_by_cols()
1290 # For non-expressions e.g. an SQL WHERE clause, the entire
1291 # `expression` must be included in the GROUP BY clause.
1292 return super().get_group_by_cols()
1293
1294 def as_sql(
1295 self, compiler: SQLCompiler, connection: DatabaseConnection
1296 ) -> tuple[str, Sequence[Any]]:
1297 return compiler.compile(self.expression)
1298
1299 def __repr__(self) -> str:
1300 return f"{self.__class__.__name__}({self.expression})"
1301
1302
1303class NegatedExpression(ExpressionWrapper):
1304 """The logical negation of a conditional expression."""
1305
1306 def __init__(self, expression: Any):
1307 super().__init__(expression, output_field=fields.BooleanField())
1308
1309 def __invert__(self) -> Any:
1310 return self.expression.copy()
1311
1312 def as_sql(
1313 self, compiler: SQLCompiler, connection: DatabaseConnection
1314 ) -> tuple[str, Sequence[Any]]:
1315 try:
1316 sql, params = super().as_sql(compiler, connection)
1317 except EmptyResultSet:
1318 return compiler.compile(Value(True))
1319 return f"NOT {sql}", params
1320
1321 def resolve_expression(
1322 self,
1323 query: Any = None,
1324 allow_joins: bool = True,
1325 reuse: Any = None,
1326 summarize: bool = False,
1327 for_save: bool = False,
1328 ) -> NegatedExpression:
1329 resolved = super().resolve_expression(
1330 query, allow_joins, reuse, summarize, for_save
1331 )
1332 if not getattr(resolved.expression, "conditional", False):
1333 raise TypeError("Cannot negate non-conditional expressions.")
1334 return resolved
1335
1336 def select_format(
1337 self, compiler: SQLCompiler, sql: str, params: Sequence[Any]
1338 ) -> tuple[str, Sequence[Any]]:
1339 # Boolean expressions work directly in SELECT
1340 return sql, params
1341
1342
1343@deconstructible(path="plain.postgres.expressions.When")
1344class When(Expression):
1345 template = "WHEN %(condition)s THEN %(result)s"
1346 # This isn't a complete conditional expression, must be used in Case().
1347 conditional = False
1348 condition: SQLCompilable
1349
1350 def __init__(
1351 self, condition: Q | Expression | None = None, then: Any = None, **lookups: Any
1352 ):
1353 lookups_dict: dict[str, Any] | None = lookups or None
1354 if lookups_dict:
1355 if condition is None:
1356 condition, lookups_dict = Q(**lookups_dict), None
1357 elif getattr(condition, "conditional", False):
1358 condition, lookups_dict = Q(condition, **lookups_dict), None
1359 if (
1360 condition is None
1361 or not getattr(condition, "conditional", False)
1362 or lookups_dict
1363 ):
1364 raise TypeError(
1365 "When() supports a Q object, a boolean expression, or lookups "
1366 "as a condition."
1367 )
1368 if isinstance(condition, Q) and not condition:
1369 raise ValueError("An empty Q() can't be used as a When() condition.")
1370 super().__init__(output_field=None)
1371 self.condition = condition # ty: ignore[invalid-assignment]
1372 self.result = self._parse_expressions(then)[0]
1373
1374 def __str__(self) -> str:
1375 return f"WHEN {self.condition!r} THEN {self.result!r}"
1376
1377 def __repr__(self) -> str:
1378 return f"<{self.__class__.__name__}: {self}>"
1379
1380 def get_source_expressions(self) -> list[Any]:
1381 return [self.condition, self.result]
1382
1383 def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1384 self.condition, self.result = exprs
1385
1386 def get_source_fields(self) -> list[Field | None]:
1387 # We're only interested in the fields of the result expressions.
1388 return [self.result._output_field_or_none]
1389
1390 def resolve_expression(
1391 self,
1392 query: Any = None,
1393 allow_joins: bool = True,
1394 reuse: Any = None,
1395 summarize: bool = False,
1396 for_save: bool = False,
1397 ) -> When:
1398 c = self.copy()
1399 c.is_summary = summarize
1400 if isinstance(c.condition, ResolvableExpression):
1401 c.condition = c.condition.resolve_expression(
1402 query, allow_joins, reuse, summarize, False
1403 )
1404 c.result = c.result.resolve_expression(
1405 query, allow_joins, reuse, summarize, for_save
1406 )
1407 return c
1408
1409 def as_sql(
1410 self,
1411 compiler: SQLCompiler,
1412 connection: DatabaseConnection,
1413 template: str | None = None,
1414 **extra_context: Any,
1415 ) -> tuple[str, tuple[Any, ...]]:
1416 template_params = extra_context
1417 sql_params = []
1418 # After resolve_expression, condition is WhereNode | resolved Expression (both SQLCompilable)
1419 condition_sql, condition_params = compiler.compile(self.condition)
1420 template_params["condition"] = condition_sql
1421 result_sql, result_params = compiler.compile(self.result)
1422 template_params["result"] = result_sql
1423 template = template or self.template
1424 return template % template_params, (
1425 *sql_params,
1426 *condition_params,
1427 *result_params,
1428 )
1429
1430 def get_group_by_cols(self) -> list[Any]:
1431 # This is not a complete expression and cannot be used in GROUP BY.
1432 cols = []
1433 for source in self.get_source_expressions():
1434 cols.extend(source.get_group_by_cols())
1435 return cols
1436
1437
1438@deconstructible(path="plain.postgres.expressions.Case")
1439class Case(Expression):
1440 """
1441 An SQL searched CASE expression:
1442
1443 CASE
1444 WHEN n > 0
1445 THEN 'positive'
1446 WHEN n < 0
1447 THEN 'negative'
1448 ELSE 'zero'
1449 END
1450 """
1451
1452 template = "CASE %(cases)s ELSE %(default)s END"
1453 case_joiner = " "
1454
1455 def __init__(
1456 self,
1457 *cases: When,
1458 default: Any = None,
1459 output_field: Field | None = None,
1460 **extra: Any,
1461 ):
1462 if not all(isinstance(case, When) for case in cases):
1463 raise TypeError("Positional arguments must all be When objects.")
1464 super().__init__(output_field)
1465 self.cases = list(cases)
1466 self.default = self._parse_expressions(default)[0]
1467 self.extra = extra
1468
1469 def __str__(self) -> str:
1470 return "CASE {}, ELSE {!r}".format(
1471 ", ".join(str(c) for c in self.cases),
1472 self.default,
1473 )
1474
1475 def __repr__(self) -> str:
1476 return f"<{self.__class__.__name__}: {self}>"
1477
1478 def get_source_expressions(self) -> list[Any]:
1479 return self.cases + [self.default]
1480
1481 def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1482 *self.cases, self.default = exprs
1483
1484 def resolve_expression(
1485 self,
1486 query: Any = None,
1487 allow_joins: bool = True,
1488 reuse: Any = None,
1489 summarize: bool = False,
1490 for_save: bool = False,
1491 ) -> Case:
1492 c = self.copy()
1493 c.is_summary = summarize
1494 for pos, case in enumerate(c.cases):
1495 c.cases[pos] = case.resolve_expression(
1496 query, allow_joins, reuse, summarize, for_save
1497 )
1498 c.default = c.default.resolve_expression(
1499 query, allow_joins, reuse, summarize, for_save
1500 )
1501 return c
1502
1503 def copy(self) -> Self:
1504 c = super().copy()
1505 c.cases = c.cases[:]
1506 return c
1507
1508 def as_sql(
1509 self,
1510 compiler: SQLCompiler,
1511 connection: DatabaseConnection,
1512 template: str | None = None,
1513 case_joiner: str | None = None,
1514 **extra_context: Any,
1515 ) -> tuple[str, list[Any]]:
1516 if not self.cases:
1517 sql, params = compiler.compile(self.default)
1518 return sql, list(params)
1519 template_params = {**self.extra, **extra_context}
1520 case_parts = []
1521 sql_params = []
1522 default_sql, default_params = compiler.compile(self.default)
1523 for case in self.cases:
1524 try:
1525 case_sql, case_params = compiler.compile(case)
1526 except EmptyResultSet:
1527 continue
1528 except FullResultSet:
1529 default_sql, default_params = compiler.compile(case.result)
1530 break
1531 case_parts.append(case_sql)
1532 sql_params.extend(case_params)
1533 if not case_parts:
1534 return default_sql, list(default_params)
1535 case_joiner = case_joiner or self.case_joiner
1536 template_params["cases"] = case_joiner.join(case_parts)
1537 template_params["default"] = default_sql
1538 sql_params.extend(default_params)
1539 template = template or template_params.get("template", self.template)
1540 sql = template % template_params
1541 if self._output_field_or_none is not None:
1542 sql = connection.unification_cast_sql(self.output_field) % sql
1543 return sql, sql_params
1544
1545 def get_group_by_cols(self) -> list[Any]:
1546 if not self.cases:
1547 return self.default.get_group_by_cols()
1548 return super().get_group_by_cols()
1549
1550
1551class Subquery(BaseExpression, Combinable):
1552 """
1553 An explicit subquery. It may contain OuterRef() references to the outer
1554 query which will be resolved when it is applied to that query.
1555 """
1556
1557 template = "(%(subquery)s)"
1558 contains_aggregate = False
1559 empty_result_set_value = None
1560
1561 def __init__(
1562 self,
1563 query: QuerySet[Any] | Query,
1564 output_field: Field | None = None,
1565 **extra: Any,
1566 ):
1567 # Import here to avoid circular import
1568 from plain.postgres.sql.query import Query
1569
1570 # Allow the usage of both QuerySet and sql.Query objects.
1571 if isinstance(query, Query):
1572 # It's already a Query object, use it directly
1573 sql_query = query
1574 else:
1575 # It's a QuerySet, extract the sql.Query
1576 sql_query = query.sql_query
1577 self.query = sql_query.clone()
1578 self.query.subquery = True
1579 self.extra = extra
1580 super().__init__(output_field)
1581
1582 def get_source_expressions(self) -> list[Any]:
1583 return [self.query]
1584
1585 def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1586 self.query = exprs[0]
1587
1588 def _resolve_output_field(self) -> Field | None:
1589 return self.query.output_field
1590
1591 def copy(self) -> Self:
1592 clone = super().copy()
1593 clone.query = clone.query.clone()
1594 return clone
1595
1596 @property
1597 def external_aliases(self) -> dict[str, bool]:
1598 return self.query.external_aliases
1599
1600 def get_external_cols(self) -> list[Any]:
1601 return self.query.get_external_cols()
1602
1603 def as_sql(
1604 self,
1605 compiler: SQLCompiler,
1606 connection: DatabaseConnection,
1607 template: str | None = None,
1608 **extra_context: Any,
1609 ) -> tuple[str, tuple[Any, ...]]:
1610 template_params = {**self.extra, **extra_context}
1611 subquery_sql, sql_params = self.query.as_sql(compiler, connection)
1612 template_params["subquery"] = subquery_sql[1:-1]
1613
1614 template = template or template_params.get("template", self.template)
1615 sql = template % template_params
1616 return sql, sql_params
1617
1618 def get_group_by_cols(self) -> list[Any]:
1619 return self.query.get_group_by_cols(wrapper=self)
1620
1621
1622class Exists(Subquery):
1623 template = "EXISTS(%(subquery)s)"
1624 output_field = fields.BooleanField()
1625 empty_result_set_value = False
1626
1627 def __init__(self, query: QuerySet[Any] | Query, **kwargs: Any):
1628 super().__init__(query, **kwargs)
1629 self.query = self.query.exists()
1630
1631 def select_format(
1632 self, compiler: SQLCompiler, sql: str, params: Sequence[Any]
1633 ) -> tuple[str, Sequence[Any]]:
1634 # Boolean expressions work directly in SELECT
1635 return sql, params
1636
1637
1638@deconstructible(path="plain.postgres.expressions.OrderBy")
1639class OrderBy(Expression):
1640 template = "%(expression)s %(ordering)s"
1641 conditional = False
1642
1643 def __init__(
1644 self,
1645 expression: Any,
1646 descending: bool = False,
1647 nulls_first: bool | None = None,
1648 nulls_last: bool | None = None,
1649 ):
1650 if nulls_first and nulls_last:
1651 raise ValueError("nulls_first and nulls_last are mutually exclusive")
1652 if nulls_first is False or nulls_last is False:
1653 raise ValueError("nulls_first and nulls_last values must be True or None.")
1654 self.nulls_first = nulls_first
1655 self.nulls_last = nulls_last
1656 self.descending = descending
1657 if not isinstance(expression, ResolvableExpression):
1658 raise ValueError("expression must be an expression type")
1659 self.expression = expression
1660
1661 def __repr__(self) -> str:
1662 return f"{self.__class__.__name__}({self.expression}, descending={self.descending})"
1663
1664 def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1665 self.expression = exprs[0]
1666
1667 def get_source_expressions(self) -> list[Any]:
1668 return [self.expression]
1669
1670 def as_sql(
1671 self,
1672 compiler: SQLCompiler,
1673 connection: DatabaseConnection,
1674 template: str | None = None,
1675 **extra_context: Any,
1676 ) -> tuple[str, tuple[Any, ...]]:
1677 template = template or self.template
1678 # Handle NULLS FIRST/LAST modifiers
1679 if self.nulls_last:
1680 template = f"{template} NULLS LAST"
1681 elif self.nulls_first:
1682 template = f"{template} NULLS FIRST"
1683 expression_sql, params = compiler.compile(self.expression)
1684 placeholders = {
1685 "expression": expression_sql,
1686 "ordering": "DESC" if self.descending else "ASC",
1687 **extra_context,
1688 }
1689 params *= template.count("%(expression)s")
1690 return (template % placeholders).rstrip(), params
1691
1692 def get_group_by_cols(self) -> list[Any]:
1693 cols = []
1694 for source in self.get_source_expressions():
1695 cols.extend(source.get_group_by_cols())
1696 return cols
1697
1698 def reverse_ordering(self) -> OrderBy:
1699 self.descending = not self.descending
1700 if self.nulls_first:
1701 self.nulls_last = True
1702 self.nulls_first = None
1703 elif self.nulls_last:
1704 self.nulls_first = True
1705 self.nulls_last = None
1706 return self
1707
1708 def asc(self) -> None: # ty: ignore[invalid-method-override]
1709 self.descending = False
1710
1711 def desc(self) -> None: # ty: ignore[invalid-method-override]
1712 self.descending = True
1713
1714
1715class Window(Expression):
1716 template = "%(expression)s OVER (%(window)s)"
1717 # Although the main expression may either be an aggregate or an
1718 # expression with an aggregate function, the GROUP BY that will
1719 # be introduced in the query as a result is not desired.
1720 contains_aggregate = False
1721 contains_over_clause = True
1722 partition_by: ExpressionList | None
1723 order_by: OrderByList | None
1724
1725 def __init__(
1726 self,
1727 expression: Any,
1728 partition_by: Any = None,
1729 order_by: Any = None,
1730 frame: Any = None,
1731 output_field: Field | None = None,
1732 ):
1733 self.partition_by = partition_by
1734 self.order_by = order_by
1735 self.frame = frame
1736
1737 if not getattr(expression, "window_compatible", False):
1738 raise ValueError(
1739 f"Expression '{expression.__class__.__name__}' isn't compatible with OVER clauses."
1740 )
1741
1742 if self.partition_by is not None:
1743 partition_by_values = (
1744 self.partition_by
1745 if isinstance(self.partition_by, tuple | list)
1746 else (self.partition_by,)
1747 )
1748 self.partition_by = ExpressionList(*partition_by_values)
1749
1750 if self.order_by is not None:
1751 if isinstance(self.order_by, list | tuple):
1752 self.order_by = OrderByList(*self.order_by)
1753 elif isinstance(self.order_by, BaseExpression | str):
1754 self.order_by = OrderByList(self.order_by)
1755 else:
1756 raise ValueError(
1757 "Window.order_by must be either a string reference to a "
1758 "field, an expression, or a list or tuple of them."
1759 )
1760 super().__init__(output_field=output_field)
1761 self.source_expression = self._parse_expressions(expression)[0]
1762
1763 def _resolve_output_field(self) -> Field | None:
1764 return self.source_expression.output_field
1765
1766 def get_source_expressions(self) -> list[Any]:
1767 return [self.source_expression, self.partition_by, self.order_by, self.frame]
1768
1769 def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1770 self.source_expression, self.partition_by, self.order_by, self.frame = exprs
1771
1772 def as_sql(
1773 self,
1774 compiler: SQLCompiler,
1775 connection: DatabaseConnection,
1776 template: str | None = None,
1777 ) -> tuple[str, tuple[Any, ...]]:
1778 expr_sql, params = compiler.compile(self.source_expression)
1779 window_sql, window_params = [], ()
1780
1781 if self.partition_by is not None:
1782 sql_expr, sql_params = self.partition_by.as_sql(
1783 compiler=compiler,
1784 connection=connection,
1785 template="PARTITION BY %(expressions)s",
1786 )
1787 window_sql.append(sql_expr)
1788 window_params += tuple(sql_params)
1789
1790 if self.order_by is not None:
1791 order_sql, order_params = compiler.compile(self.order_by)
1792 window_sql.append(order_sql)
1793 window_params += tuple(order_params)
1794
1795 if self.frame:
1796 frame_sql, frame_params = compiler.compile(self.frame)
1797 window_sql.append(frame_sql)
1798 window_params += tuple(frame_params)
1799
1800 template = template or self.template
1801
1802 return (
1803 template % {"expression": expr_sql, "window": " ".join(window_sql).strip()},
1804 (*params, *window_params),
1805 )
1806
1807 def __str__(self) -> str:
1808 return "{} OVER ({}{}{})".format(
1809 str(self.source_expression),
1810 "PARTITION BY " + str(self.partition_by) if self.partition_by else "",
1811 str(self.order_by or ""),
1812 str(self.frame or ""),
1813 )
1814
1815 def __repr__(self) -> str:
1816 return f"<{self.__class__.__name__}: {self}>"
1817
1818 def get_group_by_cols(self) -> list[Any]:
1819 group_by_cols = []
1820 if self.partition_by:
1821 group_by_cols.extend(self.partition_by.get_group_by_cols())
1822 if self.order_by is not None:
1823 group_by_cols.extend(self.order_by.get_group_by_cols())
1824 return group_by_cols
1825
1826
1827class WindowFrame(Expression):
1828 """
1829 Model the frame clause in window expressions. There are two types of frame
1830 clauses which are subclasses, however, all processing and validation (by no
1831 means intended to be complete) is done here. Thus, providing an end for a
1832 frame is optional (the default is UNBOUNDED FOLLOWING, which is the last
1833 row in the frame).
1834 """
1835
1836 template = "%(frame_type)s BETWEEN %(start)s AND %(end)s"
1837 frame_type: str
1838
1839 def __init__(self, start: int | None = None, end: int | None = None):
1840 self.start = Value(start)
1841 self.end = Value(end)
1842
1843 def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1844 self.start, self.end = exprs
1845
1846 def get_source_expressions(self) -> list[Any]:
1847 return [self.start, self.end]
1848
1849 def as_sql(
1850 self, compiler: SQLCompiler, connection: DatabaseConnection
1851 ) -> tuple[str, list[Any]]:
1852 start, end = self.window_frame_start_end(
1853 connection, self.start.value, self.end.value
1854 )
1855 return (
1856 self.template
1857 % {
1858 "frame_type": self.frame_type,
1859 "start": start,
1860 "end": end,
1861 },
1862 [],
1863 )
1864
1865 def __repr__(self) -> str:
1866 return f"<{self.__class__.__name__}: {self}>"
1867
1868 def get_group_by_cols(self) -> list[Any]:
1869 return []
1870
1871 def __str__(self) -> str:
1872 if self.start.value is not None and self.start.value < 0:
1873 start = f"{abs(self.start.value)} {PRECEDING}"
1874 elif self.start.value is not None and self.start.value == 0:
1875 start = CURRENT_ROW
1876 else:
1877 start = UNBOUNDED_PRECEDING
1878
1879 if self.end.value is not None and self.end.value > 0:
1880 end = f"{self.end.value} {FOLLOWING}"
1881 elif self.end.value is not None and self.end.value == 0:
1882 end = CURRENT_ROW
1883 else:
1884 end = UNBOUNDED_FOLLOWING
1885 return self.template % {
1886 "frame_type": self.frame_type,
1887 "start": start,
1888 "end": end,
1889 }
1890
1891 def window_frame_start_end(
1892 self, connection: DatabaseConnection, start: int | None, end: int | None
1893 ) -> tuple[str, str]:
1894 """Return the window frame start and end for the given connection."""
1895 raise NotImplementedError("Subclasses must implement window_frame_start_end()")
1896
1897
1898class RowRange(WindowFrame):
1899 frame_type = "ROWS"
1900
1901 def window_frame_start_end(
1902 self, connection: DatabaseConnection, start: int | None, end: int | None
1903 ) -> tuple[str, str]:
1904 return window_frame_rows_start_end(start, end)
1905
1906
1907class ValueRange(WindowFrame):
1908 frame_type = "RANGE"
1909
1910 def window_frame_start_end(
1911 self, connection: DatabaseConnection, start: int | None, end: int | None
1912 ) -> tuple[str, str]:
1913 return window_frame_range_start_end(start, end)