v0.145.1
   1from __future__ import annotations
   2
   3import copy
   4import datetime
   5import functools
   6import inspect
   7from collections import defaultdict
   8from decimal import Decimal
   9from functools import cached_property
  10from types import NoneType
  11from typing import TYPE_CHECKING, Any, Protocol, Self, runtime_checkable
  12from uuid import UUID
  13
  14import psycopg
  15
  16from plain.postgres import fields
  17from plain.postgres.constants import LOOKUP_SEP
  18from plain.postgres.dialect import (
  19    CURRENT_ROW,
  20    FOLLOWING,
  21    PRECEDING,
  22    UNBOUNDED_FOLLOWING,
  23    UNBOUNDED_PRECEDING,
  24    combine_expression,
  25    quote_name,
  26    subtract_temporals,
  27    window_frame_range_start_end,
  28    window_frame_rows_start_end,
  29)
  30from plain.postgres.exceptions import EmptyResultSet, FieldError, FullResultSet
  31from plain.postgres.query_utils import Q
  32from plain.utils.deconstruct import deconstructible
  33from plain.utils.hashable import make_hashable
  34
  35if TYPE_CHECKING:
  36    from collections.abc import Callable, Iterable, Sequence
  37
  38    from plain.postgres.connection import DatabaseConnection
  39    from plain.postgres.fields import Field
  40    from plain.postgres.lookups import Lookup, Transform
  41    from plain.postgres.query import QuerySet
  42    from plain.postgres.sql.compiler import SQLCompilable, SQLCompiler
  43    from plain.postgres.sql.query import Query
  44
  45__all__ = [
  46    # Core expression classes
  47    "F",
  48    "Value",
  49    "Case",
  50    "When",
  51    "Subquery",
  52    "Exists",
  53    "OuterRef",
  54    "Window",
  55    "ExpressionWrapper",
  56    "RawSQL",
  57    "OrderBy",
  58    # Base classes (for extension)
  59    "Func",
  60    "Expression",
  61    "Combinable",
  62    # Window frame specs
  63    "RowRange",
  64    "ValueRange",
  65]
  66
  67
  68@runtime_checkable
  69class ResolvableExpression(Protocol):
  70    """Protocol for expressions that can be resolved in query context."""
  71
  72    def resolve_expression(
  73        self,
  74        query: Any = None,
  75        allow_joins: bool = True,
  76        reuse: Any = None,
  77        summarize: bool = False,
  78        for_save: bool = False,
  79    ) -> Any: ...
  80
  81
  82@runtime_checkable
  83class ReplaceableExpression(Protocol):
  84    """Protocol for expressions that support expression replacement."""
  85
  86    def replace_expressions(self, replacements: dict[Any, Any]) -> Self: ...
  87
  88
  89class Combinable:
  90    """
  91    Provide the ability to combine one or two objects with
  92    some connector. For example F('foo') + F('bar').
  93    """
  94
  95    # Arithmetic connectors
  96    ADD = "+"
  97    SUB = "-"
  98    MUL = "*"
  99    DIV = "/"
 100    POW = "^"
 101    # The following is a quoted % operator - it is quoted because it can be
 102    # used in strings that also have parameter substitution.
 103    MOD = "%%"
 104
 105    # Bitwise operators - note that these are generated by .bitand()
 106    # and .bitor(), the '&' and '|' are reserved for boolean operator
 107    # usage.
 108    BITAND = "&"
 109    BITOR = "|"
 110    BITLEFTSHIFT = "<<"
 111    BITRIGHTSHIFT = ">>"
 112    BITXOR = "#"
 113
 114    def _combine(
 115        self, other: Any, connector: str, reversed: bool
 116    ) -> CombinedExpression:
 117        if not isinstance(other, ResolvableExpression):
 118            # everything must be resolvable to an expression
 119            other = Value(other)
 120
 121        if reversed:
 122            return CombinedExpression(other, connector, self)
 123        return CombinedExpression(self, connector, other)
 124
 125    #############
 126    # OPERATORS #
 127    #############
 128
 129    def __neg__(self) -> CombinedExpression:
 130        return self._combine(-1, self.MUL, False)
 131
 132    def __add__(self, other: Any) -> CombinedExpression:
 133        return self._combine(other, self.ADD, False)
 134
 135    def __sub__(self, other: Any) -> CombinedExpression:
 136        return self._combine(other, self.SUB, False)
 137
 138    def __mul__(self, other: Any) -> CombinedExpression:
 139        return self._combine(other, self.MUL, False)
 140
 141    def __truediv__(self, other: Any) -> CombinedExpression:
 142        return self._combine(other, self.DIV, False)
 143
 144    def __mod__(self, other: Any) -> CombinedExpression:
 145        return self._combine(other, self.MOD, False)
 146
 147    def __pow__(self, other: Any) -> CombinedExpression:
 148        return self._combine(other, self.POW, False)
 149
 150    def __and__(self, other: Any) -> Q:
 151        if getattr(self, "conditional", False) and getattr(other, "conditional", False):
 152            return Q(self) & Q(other)
 153        raise NotImplementedError(
 154            "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
 155        )
 156
 157    def bitand(self, other: Any) -> CombinedExpression:
 158        return self._combine(other, self.BITAND, False)
 159
 160    def bitleftshift(self, other: Any) -> CombinedExpression:
 161        return self._combine(other, self.BITLEFTSHIFT, False)
 162
 163    def bitrightshift(self, other: Any) -> CombinedExpression:
 164        return self._combine(other, self.BITRIGHTSHIFT, False)
 165
 166    def __xor__(self, other: Any) -> Q:
 167        if getattr(self, "conditional", False) and getattr(other, "conditional", False):
 168            return Q(self) ^ Q(other)
 169        raise NotImplementedError(
 170            "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
 171        )
 172
 173    def bitxor(self, other: Any) -> CombinedExpression:
 174        return self._combine(other, self.BITXOR, False)
 175
 176    def __or__(self, other: Any) -> Q:
 177        if getattr(self, "conditional", False) and getattr(other, "conditional", False):
 178            return Q(self) | Q(other)
 179        raise NotImplementedError(
 180            "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
 181        )
 182
 183    def bitor(self, other: Any) -> CombinedExpression:
 184        return self._combine(other, self.BITOR, False)
 185
 186    def __radd__(self, other: Any) -> CombinedExpression:
 187        return self._combine(other, self.ADD, True)
 188
 189    def __rsub__(self, other: Any) -> CombinedExpression:
 190        return self._combine(other, self.SUB, True)
 191
 192    def __rmul__(self, other: Any) -> CombinedExpression:
 193        return self._combine(other, self.MUL, True)
 194
 195    def __rtruediv__(self, other: Any) -> CombinedExpression:
 196        return self._combine(other, self.DIV, True)
 197
 198    def __rmod__(self, other: Any) -> CombinedExpression:
 199        return self._combine(other, self.MOD, True)
 200
 201    def __rpow__(self, other: Any) -> CombinedExpression:
 202        return self._combine(other, self.POW, True)
 203
 204    def __rand__(self, other: Any) -> None:
 205        raise NotImplementedError(
 206            "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
 207        )
 208
 209    def __ror__(self, other: Any) -> None:
 210        raise NotImplementedError(
 211            "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
 212        )
 213
 214    def __rxor__(self, other: Any) -> None:
 215        raise NotImplementedError(
 216            "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
 217        )
 218
 219    def __invert__(self) -> NegatedExpression:
 220        return NegatedExpression(self)
 221
 222
 223class BaseExpression:
 224    """Base class for all query expressions."""
 225
 226    empty_result_set_value = NotImplemented
 227    # aggregate specific fields
 228    is_summary = False
 229    _output_field_resolved_to_none = False
 230    # Can the expression be used in a WHERE clause?
 231    filterable = True
 232    # Can the expression can be used as a source expression in Window?
 233    window_compatible = False
 234
 235    def __init__(self, output_field: Field | None = None):
 236        if output_field is not None:
 237            self.output_field = output_field
 238
 239    def __getstate__(self) -> dict[str, Any]:
 240        state = self.__dict__.copy()
 241        state.pop("convert_value", None)
 242        return state
 243
 244    def get_db_converters(
 245        self, connection: DatabaseConnection
 246    ) -> list[Callable[..., Any]]:
 247        converters = []
 248        if self.convert_value is not self._convert_value_noop:
 249            converters.append(self.convert_value)
 250        converters.extend(self.output_field.get_db_converters(connection))
 251        return converters
 252
 253    def get_source_expressions(self) -> list[Any]:
 254        return []
 255
 256    def set_source_expressions(self, exprs: Sequence[Any]) -> None:
 257        assert not exprs
 258
 259    def _parse_expressions(self, *expressions: Any) -> list[Any]:
 260        return [
 261            arg
 262            if isinstance(arg, ResolvableExpression)
 263            else (F(arg) if isinstance(arg, str) else Value(arg))
 264            for arg in expressions
 265        ]
 266
 267    def as_sql(
 268        self, compiler: SQLCompiler, connection: DatabaseConnection
 269    ) -> tuple[str, Sequence[Any]]:
 270        """
 271        Return a (sql, params) tuple to be included in the current query.
 272
 273        Arguments:
 274         * compiler: the query compiler responsible for generating the query.
 275           Must have a compile method, returning a (sql, [params]) tuple.
 276           Calling compiler(value) will return a quoted `value`.
 277
 278         * connection: the database connection used for the current query.
 279
 280        Return: (sql, params)
 281          Where `sql` is a string containing ordered sql parameters to be
 282          replaced with the elements of the list `params`.
 283        """
 284        raise NotImplementedError("Subclasses must implement as_sql()")
 285
 286    @cached_property
 287    def contains_aggregate(self) -> bool:
 288        return any(
 289            expr and expr.contains_aggregate for expr in self.get_source_expressions()
 290        )
 291
 292    @cached_property
 293    def contains_over_clause(self) -> bool:
 294        return any(
 295            expr and expr.contains_over_clause for expr in self.get_source_expressions()
 296        )
 297
 298    @cached_property
 299    def contains_column_references(self) -> bool:
 300        return any(
 301            expr and expr.contains_column_references
 302            for expr in self.get_source_expressions()
 303        )
 304
 305    def resolve_expression(
 306        self,
 307        query: Any = None,
 308        allow_joins: bool = True,
 309        reuse: Any = None,
 310        summarize: bool = False,
 311        for_save: bool = False,
 312    ) -> Self:
 313        """
 314        Provide the chance to do any preprocessing or validation before being
 315        added to the query.
 316
 317        Arguments:
 318         * query: the backend query implementation
 319         * allow_joins: boolean allowing or denying use of joins
 320           in this query
 321         * reuse: a set of reusable joins for multijoins
 322         * summarize: a terminal aggregate clause
 323         * for_save: whether this expression about to be used in a save or update
 324
 325        Return: an Expression to be added to the query.
 326        """
 327        c = self.copy()
 328        c.is_summary = summarize
 329        c.set_source_expressions(
 330            [
 331                expr.resolve_expression(query, allow_joins, reuse, summarize)
 332                if expr
 333                else None
 334                for expr in c.get_source_expressions()
 335            ]
 336        )
 337        return c
 338
 339    @property
 340    def conditional(self) -> bool:
 341        output_field = getattr(self, "output_field", None)
 342        return isinstance(output_field, fields.BooleanField)
 343
 344    @property
 345    def field(self) -> Field:
 346        return self.output_field
 347
 348    @cached_property
 349    def output_field(self) -> Field:
 350        """Return the output type of this expressions."""
 351        output_field = self._resolve_output_field()
 352        if output_field is None:
 353            self._output_field_resolved_to_none = True
 354            raise FieldError("Cannot resolve expression type, unknown output_field")
 355        return output_field
 356
 357    @cached_property
 358    def _output_field_or_none(self) -> Field | None:
 359        """
 360        Return the output field of this expression, or None if
 361        _resolve_output_field() didn't return an output type.
 362        """
 363        try:
 364            return self.output_field
 365        except FieldError:
 366            if not self._output_field_resolved_to_none:
 367                raise
 368            return None
 369
 370    def _resolve_output_field(self) -> Field | None:
 371        """
 372        Attempt to infer the output type of the expression.
 373
 374        As a guess, if the output fields of all source fields match then simply
 375        infer the same type here.
 376
 377        If a source's output field resolves to None, exclude it from this check.
 378        If all sources are None, then an error is raised higher up the stack in
 379        the output_field property.
 380        """
 381        # This guess is mostly a bad idea, but there is quite a lot of code
 382        # (especially 3rd party Func subclasses) that depend on it, we'd need a
 383        # deprecation path to fix it.
 384        sources_iter = (
 385            source for source in self.get_source_fields() if source is not None
 386        )
 387        for output_field in sources_iter:
 388            for source in sources_iter:
 389                if not isinstance(output_field, source.__class__):
 390                    raise FieldError(
 391                        f"Expression contains mixed types: {output_field.__class__.__name__}, {source.__class__.__name__}. You must "
 392                        "set output_field."
 393                    )
 394            return output_field
 395        return None
 396
 397    @staticmethod
 398    def _convert_value_noop(
 399        value: Any, expression: Any, connection: DatabaseConnection
 400    ) -> Any:
 401        return value
 402
 403    @cached_property
 404    def convert_value(self) -> Callable[[Any, Any, Any], Any]:
 405        """
 406        Expressions provide their own converters because users have the option
 407        of manually specifying the output_field which may be a different type
 408        from the one the database returns.
 409        """
 410        field = self.output_field
 411        if isinstance(field, fields.FloatField):
 412            return (
 413                lambda value, expression, connection: None
 414                if value is None
 415                else float(value)
 416            )
 417        elif isinstance(field, fields.IntegerField | fields.PrimaryKeyField):
 418            return (
 419                lambda value, expression, connection: None
 420                if value is None
 421                else int(value)
 422            )
 423        elif isinstance(field, fields.DecimalField):
 424            return (
 425                lambda value, expression, connection: None
 426                if value is None
 427                else Decimal(value)
 428            )
 429        return self._convert_value_noop
 430
 431    def get_lookup(self, lookup: str) -> type[Lookup] | None:
 432        return self.output_field.get_lookup(lookup)
 433
 434    def get_transform(self, name: str) -> type[Transform] | None:
 435        return self.output_field.get_transform(name)  # ty: ignore[invalid-return-type]
 436
 437    def relabeled_clone(self, change_map: dict[str, str]) -> Self:
 438        clone = self.copy()
 439        clone.set_source_expressions(
 440            [
 441                e.relabeled_clone(change_map) if e is not None else None
 442                for e in self.get_source_expressions()
 443            ]
 444        )
 445        return clone
 446
 447    def replace_expressions(self, replacements: dict[BaseExpression, Any]) -> Self:
 448        if replacement := replacements.get(self):
 449            return replacement
 450        clone = self.copy()
 451        source_expressions = clone.get_source_expressions()
 452        clone.set_source_expressions(
 453            [
 454                expr.replace_expressions(replacements) if expr else None
 455                for expr in source_expressions
 456            ]
 457        )
 458        return clone
 459
 460    def get_refs(self) -> set[str]:
 461        refs = set()
 462        for expr in self.get_source_expressions():
 463            refs |= expr.get_refs()
 464        return refs
 465
 466    def copy(self) -> Self:
 467        return copy.copy(self)
 468
 469    def prefix_references(self, prefix: str) -> Self:
 470        clone = self.copy()
 471        clone.set_source_expressions(
 472            [
 473                F(f"{prefix}{expr.name}")
 474                if isinstance(expr, F)
 475                else expr.prefix_references(prefix)
 476                for expr in self.get_source_expressions()
 477            ]
 478        )
 479        return clone
 480
 481    def get_group_by_cols(self) -> list[BaseExpression]:
 482        if not self.contains_aggregate:
 483            return [self]
 484        cols: list[BaseExpression] = []
 485        for source in self.get_source_expressions():
 486            cols.extend(source.get_group_by_cols())
 487        return cols
 488
 489    def get_source_fields(self) -> list[Field | None]:
 490        """Return the underlying field types used by this aggregate."""
 491        return [e._output_field_or_none for e in self.get_source_expressions()]
 492
 493    def asc(self, **kwargs: Any) -> OrderBy:
 494        return OrderBy(self, **kwargs)
 495
 496    def desc(self, **kwargs: Any) -> OrderBy:
 497        return OrderBy(self, descending=True, **kwargs)
 498
 499    def reverse_ordering(self) -> Self:
 500        return self
 501
 502    def flatten(self) -> Iterable[Any]:
 503        """
 504        Recursively yield this expression and all subexpressions, in
 505        depth-first order.
 506        """
 507        yield self
 508        for expr in self.get_source_expressions():
 509            if expr:
 510                if hasattr(expr, "flatten"):
 511                    yield from expr.flatten()
 512                else:
 513                    yield expr
 514
 515    def select_format(
 516        self, compiler: SQLCompiler, sql: str, params: Sequence[Any]
 517    ) -> tuple[str, Sequence[Any]]:
 518        """Custom format for select clauses."""
 519        if output_field := getattr(self, "output_field", None):
 520            if select_format := getattr(output_field, "select_format", None):
 521                return select_format(compiler, sql, params)
 522        return sql, params
 523
 524
 525@deconstructible
 526class Expression(BaseExpression, Combinable):
 527    """An expression that can be combined with other expressions."""
 528
 529    # Set by @deconstructible decorator in __new__
 530    _constructor_args: tuple[tuple[Any, ...], dict[str, Any]]
 531
 532    @cached_property
 533    def identity(self) -> tuple[Any, ...]:
 534        constructor_signature = inspect.signature(self.__init__)
 535        args, kwargs = self._constructor_args
 536        signature = constructor_signature.bind_partial(*args, **kwargs)
 537        signature.apply_defaults()
 538        arguments = signature.arguments.items()
 539        identity: list[Any] = [self.__class__]
 540        for arg, value in arguments:
 541            if isinstance(value, fields.Field):
 542                if value.name and value.model:
 543                    value = (value.model.model_options.label, value.name)
 544                else:
 545                    value = type(value)
 546            else:
 547                value = make_hashable(value)
 548            identity.append((arg, value))
 549        return tuple(identity)
 550
 551    def __eq__(self, other: object) -> bool:
 552        if not isinstance(other, Expression):
 553            return NotImplemented
 554        return other.identity == self.identity
 555
 556    def __hash__(self) -> int:
 557        return hash(self.identity)
 558
 559
 560# Type inference for CombinedExpression.output_field.
 561# Missing items will result in FieldError, by design.
 562#
 563# The current approach for NULL is based on lowest common denominator behavior
 564# i.e. if one of the supported databases is raising an error (rather than
 565# return NULL) for `val <op> NULL`, then Plain raises FieldError.
 566
 567_connector_combinations = [
 568    # Numeric operations - operands of same type.
 569    {
 570        connector: [
 571            (fields.IntegerField, fields.IntegerField, fields.IntegerField),
 572            (fields.FloatField, fields.FloatField, fields.FloatField),
 573            (fields.DecimalField, fields.DecimalField, fields.DecimalField),
 574        ]
 575        for connector in (
 576            Combinable.ADD,
 577            Combinable.SUB,
 578            Combinable.MUL,
 579            Combinable.DIV,
 580            Combinable.MOD,
 581            Combinable.POW,
 582        )
 583    },
 584    # Numeric operations - operands of different type.
 585    {
 586        connector: [
 587            (fields.IntegerField, fields.DecimalField, fields.DecimalField),
 588            (fields.DecimalField, fields.IntegerField, fields.DecimalField),
 589            (fields.IntegerField, fields.FloatField, fields.FloatField),
 590            (fields.FloatField, fields.IntegerField, fields.FloatField),
 591        ]
 592        for connector in (
 593            Combinable.ADD,
 594            Combinable.SUB,
 595            Combinable.MUL,
 596            Combinable.DIV,
 597            Combinable.MOD,
 598        )
 599    },
 600    # Bitwise operators.
 601    {
 602        connector: [
 603            (fields.IntegerField, fields.IntegerField, fields.IntegerField),
 604        ]
 605        for connector in (
 606            Combinable.BITAND,
 607            Combinable.BITOR,
 608            Combinable.BITLEFTSHIFT,
 609            Combinable.BITRIGHTSHIFT,
 610            Combinable.BITXOR,
 611        )
 612    },
 613    # Numeric with NULL.
 614    {
 615        connector: [
 616            (field_type, NoneType, field_type),
 617            (NoneType, field_type, field_type),
 618        ]
 619        for connector in (
 620            Combinable.ADD,
 621            Combinable.SUB,
 622            Combinable.MUL,
 623            Combinable.DIV,
 624            Combinable.MOD,
 625            Combinable.POW,
 626        )
 627        for field_type in (fields.IntegerField, fields.DecimalField, fields.FloatField)
 628    },
 629    # Date/DateTimeField/DurationField/TimeField.
 630    {
 631        Combinable.ADD: [
 632            # Date/DateTimeField.
 633            (fields.DateField, fields.DurationField, fields.DateTimeField),
 634            (fields.DateTimeField, fields.DurationField, fields.DateTimeField),
 635            (fields.DurationField, fields.DateField, fields.DateTimeField),
 636            (fields.DurationField, fields.DateTimeField, fields.DateTimeField),
 637            # DurationField.
 638            (fields.DurationField, fields.DurationField, fields.DurationField),
 639            # TimeField.
 640            (fields.TimeField, fields.DurationField, fields.TimeField),
 641            (fields.DurationField, fields.TimeField, fields.TimeField),
 642        ],
 643    },
 644    {
 645        Combinable.SUB: [
 646            # Date/DateTimeField.
 647            (fields.DateField, fields.DurationField, fields.DateTimeField),
 648            (fields.DateTimeField, fields.DurationField, fields.DateTimeField),
 649            (fields.DateField, fields.DateField, fields.DurationField),
 650            (fields.DateField, fields.DateTimeField, fields.DurationField),
 651            (fields.DateTimeField, fields.DateField, fields.DurationField),
 652            (fields.DateTimeField, fields.DateTimeField, fields.DurationField),
 653            # DurationField.
 654            (fields.DurationField, fields.DurationField, fields.DurationField),
 655            # TimeField.
 656            (fields.TimeField, fields.DurationField, fields.TimeField),
 657            (fields.TimeField, fields.TimeField, fields.DurationField),
 658        ],
 659    },
 660]
 661
 662_connector_combinators = defaultdict(list)
 663
 664
 665def register_combinable_fields(
 666    lhs: type[Field] | type[None],
 667    connector: str,
 668    rhs: type[Field] | type[None],
 669    result: type[Field],
 670) -> None:
 671    """
 672    Register combinable types:
 673        lhs <connector> rhs -> result
 674    e.g.
 675        register_combinable_fields(
 676            IntegerField, Combinable.ADD, FloatField, FloatField
 677        )
 678    """
 679    _connector_combinators[connector].append((lhs, rhs, result))
 680
 681
 682for d in _connector_combinations:
 683    for connector, field_types in d.items():
 684        for lhs, rhs, result in field_types:
 685            register_combinable_fields(lhs, connector, rhs, result)
 686
 687
 688@functools.lru_cache(maxsize=128)
 689def _resolve_combined_type(
 690    connector: str, lhs_type: type[Field], rhs_type: type[Field]
 691) -> type[Field] | None:
 692    combinators = _connector_combinators.get(connector, ())
 693    for combinator_lhs_type, combinator_rhs_type, combined_type in combinators:
 694        if issubclass(lhs_type, combinator_lhs_type) and issubclass(
 695            rhs_type, combinator_rhs_type
 696        ):
 697            return combined_type
 698    return None
 699
 700
 701class CombinedExpression(Expression):
 702    def __init__(
 703        self, lhs: Any, connector: str, rhs: Any, output_field: Field | None = None
 704    ):
 705        super().__init__(output_field=output_field)
 706        self.connector = connector
 707        self.lhs = lhs
 708        self.rhs = rhs
 709
 710    def __repr__(self) -> str:
 711        return f"<{self.__class__.__name__}: {self}>"
 712
 713    def __str__(self) -> str:
 714        return f"{self.lhs} {self.connector} {self.rhs}"
 715
 716    def get_source_expressions(self) -> list[Any]:
 717        return [self.lhs, self.rhs]
 718
 719    def set_source_expressions(self, exprs: Sequence[Any]) -> None:
 720        self.lhs, self.rhs = exprs
 721
 722    def _resolve_output_field(self) -> Field | None:
 723        # We avoid using super() here for reasons given in
 724        # Expression._resolve_output_field()
 725        combined_type = _resolve_combined_type(
 726            self.connector,
 727            type(self.lhs._output_field_or_none),
 728            type(self.rhs._output_field_or_none),
 729        )
 730        if combined_type is None:
 731            raise FieldError(
 732                f"Cannot infer type of {self.connector!r} expression involving these "
 733                f"types: {self.lhs.output_field.__class__.__name__}, "
 734                f"{self.rhs.output_field.__class__.__name__}. You must set "
 735                f"output_field."
 736            )
 737        return combined_type()
 738
 739    def as_sql(
 740        self, compiler: SQLCompiler, connection: DatabaseConnection
 741    ) -> tuple[str, list[Any]]:
 742        expressions = []
 743        expression_params = []
 744        sql, params = compiler.compile(self.lhs)
 745        expressions.append(sql)
 746        expression_params.extend(params)
 747        sql, params = compiler.compile(self.rhs)
 748        expressions.append(sql)
 749        expression_params.extend(params)
 750        # order of precedence
 751        expression_wrapper = "(%s)"
 752        sql = combine_expression(self.connector, expressions)
 753        return expression_wrapper % sql, expression_params
 754
 755    def resolve_expression(
 756        self,
 757        query: Any = None,
 758        allow_joins: bool = True,
 759        reuse: Any = None,
 760        summarize: bool = False,
 761        for_save: bool = False,
 762    ) -> CombinedExpression | TemporalSubtraction:
 763        lhs = self.lhs.resolve_expression(
 764            query, allow_joins, reuse, summarize, for_save
 765        )
 766        rhs = self.rhs.resolve_expression(
 767            query, allow_joins, reuse, summarize, for_save
 768        )
 769        if not isinstance(self, TemporalSubtraction):
 770            try:
 771                lhs_field = lhs.output_field
 772            except (AttributeError, FieldError):
 773                lhs_field = None
 774            try:
 775                rhs_field = rhs.output_field
 776            except (AttributeError, FieldError):
 777                rhs_field = None
 778            is_temporal = isinstance(
 779                lhs_field, fields.DateField | fields.DateTimeField | fields.TimeField
 780            )
 781            same_type = (
 782                lhs_field is not None
 783                and rhs_field is not None
 784                and type(lhs_field) is type(rhs_field)
 785            )
 786            if self.connector == self.SUB and is_temporal and same_type:
 787                return TemporalSubtraction(self.lhs, self.rhs).resolve_expression(
 788                    query,
 789                    allow_joins,
 790                    reuse,
 791                    summarize,
 792                    for_save,
 793                )
 794        c = self.copy()
 795        c.is_summary = summarize
 796        c.lhs = lhs
 797        c.rhs = rhs
 798        return c
 799
 800
 801class TemporalSubtraction(CombinedExpression):
 802    output_field = fields.DurationField()
 803
 804    def __init__(self, lhs: Any, rhs: Any):
 805        super().__init__(lhs, self.SUB, rhs)
 806
 807    def as_sql(
 808        self, compiler: SQLCompiler, connection: DatabaseConnection
 809    ) -> tuple[str, list[Any]]:
 810        lhs = compiler.compile(self.lhs)
 811        rhs = compiler.compile(self.rhs)
 812        sql, params = subtract_temporals(self.lhs.output_field, lhs, rhs)
 813        return sql, list(params)
 814
 815
 816@deconstructible(path="plain.postgres.F")
 817class F(Combinable):
 818    """An object capable of resolving references to existing query objects."""
 819
 820    def __init__(self, name: str):
 821        """
 822        Arguments:
 823         * name: the name of the field this expression references
 824        """
 825        self.name = name
 826
 827    def __repr__(self) -> str:
 828        return f"{self.__class__.__name__}({self.name})"
 829
 830    def resolve_expression(
 831        self,
 832        query: Any = None,
 833        allow_joins: bool = True,
 834        reuse: Any = None,
 835        summarize: bool = False,
 836        for_save: bool = False,
 837    ) -> Any:
 838        return query.resolve_ref(self.name, allow_joins, reuse, summarize)
 839
 840    def replace_expressions(self, replacements: dict[Any, Any]) -> F:
 841        return replacements.get(self, self)
 842
 843    def asc(self, **kwargs: Any) -> OrderBy:
 844        return OrderBy(self, **kwargs)
 845
 846    def desc(self, **kwargs: Any) -> OrderBy:
 847        return OrderBy(self, descending=True, **kwargs)
 848
 849    def __eq__(self, other: object) -> bool:
 850        if not isinstance(other, F):
 851            return NotImplemented
 852        return self.__class__ == other.__class__ and self.name == other.name
 853
 854    def __hash__(self) -> int:
 855        return hash(self.name)
 856
 857    def copy(self) -> Self:
 858        return copy.copy(self)
 859
 860
 861class ResolvedOuterRef(F):
 862    """
 863    An object that contains a reference to an outer query.
 864
 865    In this case, the reference to the outer query has been resolved because
 866    the inner query has been used as a subquery.
 867    """
 868
 869    contains_aggregate = False
 870    contains_over_clause = False
 871
 872    def as_sql(self, *args: Any, **kwargs: Any) -> None:
 873        raise ValueError(
 874            "This queryset contains a reference to an outer query and may "
 875            "only be used in a subquery."
 876        )
 877
 878    def resolve_expression(self, *args: Any, **kwargs: Any) -> Any:
 879        col = super().resolve_expression(*args, **kwargs)
 880        if col.contains_over_clause:
 881            raise psycopg.NotSupportedError(
 882                f"Referencing outer query window expression is not supported: "
 883                f"{self.name}."
 884            )
 885        # FIXME: Rename possibly_multivalued to multivalued and fix detection
 886        # for non-multivalued JOINs (e.g. foreign key fields). This should take
 887        # into account only many-to-many and one-to-many relationships.
 888        col.possibly_multivalued = LOOKUP_SEP in self.name
 889        return col
 890
 891    def relabeled_clone(self, relabels: dict[str, str]) -> ResolvedOuterRef:
 892        return self
 893
 894    def get_group_by_cols(self) -> list[Any]:
 895        return []
 896
 897
 898class OuterRef(F):
 899    contains_aggregate = False
 900
 901    def resolve_expression(self, *args: Any, **kwargs: Any) -> ResolvedOuterRef | F:
 902        if isinstance(self.name, self.__class__):
 903            return self.name
 904        return ResolvedOuterRef(self.name)
 905
 906    def relabeled_clone(self, relabels: dict[str, str]) -> OuterRef:
 907        return self
 908
 909
 910@deconstructible(path="plain.postgres.expressions.Func")
 911class Func(Expression):
 912    """An SQL function call."""
 913
 914    function = None
 915    template = "%(function)s(%(expressions)s)"
 916    arg_joiner = ", "
 917    arity = None  # The number of arguments the function accepts.
 918
 919    def __init__(
 920        self, *expressions: Any, output_field: Field | None = None, **extra: Any
 921    ):
 922        if self.arity is not None and len(expressions) != self.arity:
 923            raise TypeError(
 924                "'{}' takes exactly {} {} ({} given)".format(
 925                    self.__class__.__name__,
 926                    self.arity,
 927                    "argument" if self.arity == 1 else "arguments",
 928                    len(expressions),
 929                )
 930            )
 931        super().__init__(output_field=output_field)
 932        self.source_expressions: list[Any] = self._parse_expressions(*expressions)
 933        self.extra = extra
 934
 935    def __repr__(self) -> str:
 936        args = self.arg_joiner.join(str(arg) for arg in self.source_expressions)
 937        extra = {**self.extra, **self._get_repr_options()}
 938        if extra:
 939            extra = ", ".join(
 940                str(key) + "=" + str(val) for key, val in sorted(extra.items())
 941            )
 942            return f"{self.__class__.__name__}({args}, {extra})"
 943        return f"{self.__class__.__name__}({args})"
 944
 945    def _get_repr_options(self) -> dict[str, Any]:
 946        """Return a dict of extra __init__() options to include in the repr."""
 947        return {}
 948
 949    def get_source_expressions(self) -> list[Any]:
 950        return self.source_expressions
 951
 952    def set_source_expressions(self, exprs: Sequence[Any]) -> None:
 953        self.source_expressions = list(exprs)
 954
 955    def resolve_expression(
 956        self,
 957        query: Any = None,
 958        allow_joins: bool = True,
 959        reuse: Any = None,
 960        summarize: bool = False,
 961        for_save: bool = False,
 962    ) -> Self:
 963        c = self.copy()
 964        c.is_summary = summarize
 965        for pos, arg in enumerate(c.source_expressions):
 966            c.source_expressions[pos] = arg.resolve_expression(
 967                query, allow_joins, reuse, summarize, for_save
 968            )
 969        return c
 970
 971    def as_sql(
 972        self,
 973        compiler: SQLCompiler,
 974        connection: DatabaseConnection,
 975        function: str | None = None,
 976        template: str | None = None,
 977        arg_joiner: str | None = None,
 978        **extra_context: Any,
 979    ) -> tuple[str, list[Any]]:
 980        sql_parts = []
 981        params = []
 982        for arg in self.source_expressions:
 983            try:
 984                arg_sql, arg_params = compiler.compile(arg)
 985            except EmptyResultSet:
 986                empty_result_set_value = getattr(
 987                    arg, "empty_result_set_value", NotImplemented
 988                )
 989                if empty_result_set_value is NotImplemented:
 990                    raise
 991                arg_sql, arg_params = compiler.compile(Value(empty_result_set_value))
 992            except FullResultSet:
 993                arg_sql, arg_params = compiler.compile(Value(True))
 994            sql_parts.append(arg_sql)
 995            params.extend(arg_params)
 996        data = {**self.extra, **extra_context}
 997        # Use the first supplied value in this order: the parameter to this
 998        # method, a value supplied in __init__()'s **extra (the value in
 999        # `data`), or the value defined on the class.
1000        if function is not None:
1001            data["function"] = function
1002        else:
1003            data.setdefault("function", self.function)
1004        template = template or data.get("template", self.template)
1005        arg_joiner = arg_joiner or data.get("arg_joiner", self.arg_joiner)
1006        data["expressions"] = data["field"] = arg_joiner.join(sql_parts)
1007        return template % data, params
1008
1009    def copy(self) -> Self:
1010        clone = super().copy()
1011        clone.source_expressions = self.source_expressions[:]
1012        clone.extra = self.extra.copy()
1013        return clone
1014
1015
1016@deconstructible(path="plain.postgres.expressions.Value")
1017class Value(Expression):
1018    """Represent a wrapped value as a node within an expression."""
1019
1020    # Provide a default value for `for_save` in order to allow unresolved
1021    # instances to be compiled until a decision is taken in #25425.
1022    for_save = False
1023
1024    def __init__(self, value: Any, output_field: Field | None = None):
1025        """
1026        Arguments:
1027         * value: the value this expression represents. The value will be
1028           added into the sql parameter list and properly quoted.
1029
1030         * output_field: an instance of the model field type that this
1031           expression will return, such as IntegerField() or TextField().
1032        """
1033        super().__init__(output_field=output_field)
1034        self.value = value
1035
1036    def __repr__(self) -> str:
1037        return f"{self.__class__.__name__}({self.value!r})"
1038
1039    def as_sql(
1040        self, compiler: SQLCompiler, connection: DatabaseConnection
1041    ) -> tuple[str, list[Any]]:
1042        val = self.value
1043        output_field = self._output_field_or_none
1044        if output_field is not None:
1045            if self.for_save:
1046                val = output_field.get_db_prep_save(val, connection=connection)
1047            else:
1048                val = output_field.get_db_prep_value(val, connection=connection)
1049            if hasattr(output_field, "get_placeholder"):
1050                return output_field.get_placeholder(val, compiler, connection), [val]  # ty: ignore[call-non-callable]
1051        if val is None:
1052            return "NULL", []
1053        return "%s", [val]
1054
1055    def resolve_expression(
1056        self,
1057        query: Any = None,
1058        allow_joins: bool = True,
1059        reuse: Any = None,
1060        summarize: bool = False,
1061        for_save: bool = False,
1062    ) -> Value:
1063        c = super().resolve_expression(query, allow_joins, reuse, summarize, for_save)
1064        c.for_save = for_save
1065        return c
1066
1067    def get_group_by_cols(self) -> list[Any]:
1068        return []
1069
1070    def _resolve_output_field(self) -> Field | None:
1071        if isinstance(self.value, str):
1072            return fields.TextField()
1073        if isinstance(self.value, bool):
1074            return fields.BooleanField()
1075        if isinstance(self.value, int):
1076            return fields.IntegerField()
1077        if isinstance(self.value, float):
1078            return fields.FloatField()
1079        if isinstance(self.value, datetime.datetime):
1080            return fields.DateTimeField()
1081        if isinstance(self.value, datetime.date):
1082            return fields.DateField()
1083        if isinstance(self.value, datetime.time):
1084            return fields.TimeField()
1085        if isinstance(self.value, datetime.timedelta):
1086            return fields.DurationField()
1087        if isinstance(self.value, Decimal):
1088            return fields.DecimalField()
1089        if isinstance(self.value, bytes):
1090            return fields.BinaryField()
1091        if isinstance(self.value, UUID):
1092            return fields.UUIDField()
1093
1094    @property
1095    def empty_result_set_value(self) -> Any:
1096        return self.value
1097
1098
1099class RawSQL(Expression):
1100    def __init__(
1101        self, sql: str, params: Sequence[Any], output_field: Field | None = None
1102    ):
1103        if output_field is None:
1104            output_field = fields.Field()
1105        self.sql, self.params = sql, params
1106        super().__init__(output_field=output_field)
1107
1108    def __repr__(self) -> str:
1109        return f"{self.__class__.__name__}({self.sql}, {self.params})"
1110
1111    def as_sql(
1112        self, compiler: SQLCompiler, connection: DatabaseConnection
1113    ) -> tuple[str, Sequence[Any]]:
1114        return f"({self.sql})", self.params
1115
1116    def get_group_by_cols(self) -> list[BaseExpression]:
1117        return [self]
1118
1119
1120class Star(Expression):
1121    def __repr__(self) -> str:
1122        return "'*'"
1123
1124    def as_sql(
1125        self, compiler: SQLCompiler, connection: DatabaseConnection
1126    ) -> tuple[str, list[Any]]:
1127        return "*", []
1128
1129
1130class Col(Expression):
1131    contains_column_references = True
1132    possibly_multivalued = False
1133
1134    def __init__(
1135        self, alias: str | None, target: Any, output_field: Field | None = None
1136    ):
1137        if output_field is None:
1138            output_field = target
1139        super().__init__(output_field=output_field)
1140        self.alias, self.target = alias, target
1141
1142    def __repr__(self) -> str:
1143        alias, target = self.alias, self.target
1144        identifiers = (alias, str(target)) if alias else (str(target),)
1145        return "{}({})".format(self.__class__.__name__, ", ".join(identifiers))
1146
1147    def as_sql(
1148        self, compiler: SQLCompiler, connection: DatabaseConnection
1149    ) -> tuple[str, list[Any]]:
1150        alias, column = self.alias, self.target.column
1151        identifiers = (alias, column) if alias else (column,)
1152        sql = ".".join(map(compiler.quote_name_unless_alias, identifiers))
1153        return sql, []
1154
1155    def relabeled_clone(self, change_map: dict[str, str]) -> Self:
1156        if self.alias is None:
1157            return self
1158        return self.__class__(
1159            change_map.get(self.alias, self.alias), self.target, self.output_field
1160        )
1161
1162    def get_group_by_cols(self) -> list[BaseExpression]:
1163        return [self]
1164
1165    def get_db_converters(
1166        self, connection: DatabaseConnection
1167    ) -> list[Callable[..., Any]]:
1168        if self.target == self.output_field:
1169            return self.output_field.get_db_converters(connection)
1170        return self.output_field.get_db_converters(
1171            connection
1172        ) + self.target.get_db_converters(connection)
1173
1174
1175class Ref(Expression):
1176    """
1177    Reference to column alias of the query. For example, Ref('sum_cost') in
1178    qs.annotate(sum_cost=Sum('cost')) query.
1179    """
1180
1181    def __init__(self, refs: str, source: Any):
1182        super().__init__()
1183        self.refs, self.source = refs, source
1184
1185    def __repr__(self) -> str:
1186        return f"{self.__class__.__name__}({self.refs}, {self.source})"
1187
1188    def get_source_expressions(self) -> list[Any]:
1189        return [self.source]
1190
1191    def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1192        (self.source,) = exprs
1193
1194    def resolve_expression(
1195        self,
1196        query: Any = None,
1197        allow_joins: bool = True,
1198        reuse: Any = None,
1199        summarize: bool = False,
1200        for_save: bool = False,
1201    ) -> Ref:
1202        # The sub-expression `source` has already been resolved, as this is
1203        # just a reference to the name of `source`.
1204        return self
1205
1206    def get_refs(self) -> set[str]:
1207        return {self.refs}
1208
1209    def relabeled_clone(self, change_map: dict[str, str]) -> Self:
1210        return self
1211
1212    def as_sql(
1213        self, compiler: SQLCompiler, connection: DatabaseConnection
1214    ) -> tuple[str, list[Any]]:
1215        return quote_name(self.refs), []
1216
1217    def get_group_by_cols(self) -> list[BaseExpression]:
1218        return [self]
1219
1220
1221class ExpressionList(Func):
1222    """
1223    An expression containing multiple expressions. Can be used to provide a
1224    list of expressions as an argument to another expression, like a partition
1225    clause.
1226    """
1227
1228    template = "%(expressions)s"
1229
1230    def __init__(self, *expressions: Any, **extra: Any):
1231        if not expressions:
1232            raise ValueError(
1233                f"{self.__class__.__name__} requires at least one expression."
1234            )
1235        super().__init__(*expressions, **extra)
1236
1237    def __str__(self) -> str:
1238        return self.arg_joiner.join(str(arg) for arg in self.source_expressions)
1239
1240
1241class OrderByList(Func):
1242    template = "ORDER BY %(expressions)s"
1243
1244    def __init__(self, *expressions: Any, **extra: Any):
1245        expressions_tuple = tuple(
1246            (
1247                OrderBy(F(expr[1:]), descending=True)
1248                if isinstance(expr, str) and expr[0] == "-"
1249                else expr
1250            )
1251            for expr in expressions
1252        )
1253        super().__init__(*expressions_tuple, **extra)
1254
1255    def as_sql(self, *args: Any, **kwargs: Any) -> tuple[str, list[Any]]:
1256        if not self.source_expressions:
1257            return "", []
1258        sql, params = super().as_sql(*args, **kwargs)
1259        return sql, list(params)
1260
1261    def get_group_by_cols(self) -> list[Any]:
1262        group_by_cols = []
1263        for order_by in self.get_source_expressions():
1264            group_by_cols.extend(order_by.get_group_by_cols())
1265        return group_by_cols
1266
1267
1268@deconstructible(path="plain.postgres.expressions.ExpressionWrapper")
1269class ExpressionWrapper(Expression):
1270    """
1271    An expression that can wrap another expression so that it can provide
1272    extra context to the inner expression, such as the output_field.
1273    """
1274
1275    def __init__(self, expression: Any, output_field: Field):
1276        super().__init__(output_field=output_field)
1277        self.expression = expression
1278
1279    def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1280        self.expression = exprs[0]
1281
1282    def get_source_expressions(self) -> list[Any]:
1283        return [self.expression]
1284
1285    def get_group_by_cols(self) -> list[Any]:
1286        if isinstance(self.expression, Expression):
1287            expression = self.expression.copy()
1288            expression.output_field = self.output_field
1289            return expression.get_group_by_cols()
1290        # For non-expressions e.g. an SQL WHERE clause, the entire
1291        # `expression` must be included in the GROUP BY clause.
1292        return super().get_group_by_cols()
1293
1294    def as_sql(
1295        self, compiler: SQLCompiler, connection: DatabaseConnection
1296    ) -> tuple[str, Sequence[Any]]:
1297        return compiler.compile(self.expression)
1298
1299    def __repr__(self) -> str:
1300        return f"{self.__class__.__name__}({self.expression})"
1301
1302
1303class NegatedExpression(ExpressionWrapper):
1304    """The logical negation of a conditional expression."""
1305
1306    def __init__(self, expression: Any):
1307        super().__init__(expression, output_field=fields.BooleanField())
1308
1309    def __invert__(self) -> Any:
1310        return self.expression.copy()
1311
1312    def as_sql(
1313        self, compiler: SQLCompiler, connection: DatabaseConnection
1314    ) -> tuple[str, Sequence[Any]]:
1315        try:
1316            sql, params = super().as_sql(compiler, connection)
1317        except EmptyResultSet:
1318            return compiler.compile(Value(True))
1319        return f"NOT {sql}", params
1320
1321    def resolve_expression(
1322        self,
1323        query: Any = None,
1324        allow_joins: bool = True,
1325        reuse: Any = None,
1326        summarize: bool = False,
1327        for_save: bool = False,
1328    ) -> NegatedExpression:
1329        resolved = super().resolve_expression(
1330            query, allow_joins, reuse, summarize, for_save
1331        )
1332        if not getattr(resolved.expression, "conditional", False):
1333            raise TypeError("Cannot negate non-conditional expressions.")
1334        return resolved
1335
1336    def select_format(
1337        self, compiler: SQLCompiler, sql: str, params: Sequence[Any]
1338    ) -> tuple[str, Sequence[Any]]:
1339        # Boolean expressions work directly in SELECT
1340        return sql, params
1341
1342
1343@deconstructible(path="plain.postgres.expressions.When")
1344class When(Expression):
1345    template = "WHEN %(condition)s THEN %(result)s"
1346    # This isn't a complete conditional expression, must be used in Case().
1347    conditional = False
1348    condition: SQLCompilable
1349
1350    def __init__(
1351        self, condition: Q | Expression | None = None, then: Any = None, **lookups: Any
1352    ):
1353        lookups_dict: dict[str, Any] | None = lookups or None
1354        if lookups_dict:
1355            if condition is None:
1356                condition, lookups_dict = Q(**lookups_dict), None
1357            elif getattr(condition, "conditional", False):
1358                condition, lookups_dict = Q(condition, **lookups_dict), None
1359        if (
1360            condition is None
1361            or not getattr(condition, "conditional", False)
1362            or lookups_dict
1363        ):
1364            raise TypeError(
1365                "When() supports a Q object, a boolean expression, or lookups "
1366                "as a condition."
1367            )
1368        if isinstance(condition, Q) and not condition:
1369            raise ValueError("An empty Q() can't be used as a When() condition.")
1370        super().__init__(output_field=None)
1371        self.condition = condition  # ty: ignore[invalid-assignment]
1372        self.result = self._parse_expressions(then)[0]
1373
1374    def __str__(self) -> str:
1375        return f"WHEN {self.condition!r} THEN {self.result!r}"
1376
1377    def __repr__(self) -> str:
1378        return f"<{self.__class__.__name__}: {self}>"
1379
1380    def get_source_expressions(self) -> list[Any]:
1381        return [self.condition, self.result]
1382
1383    def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1384        self.condition, self.result = exprs
1385
1386    def get_source_fields(self) -> list[Field | None]:
1387        # We're only interested in the fields of the result expressions.
1388        return [self.result._output_field_or_none]
1389
1390    def resolve_expression(
1391        self,
1392        query: Any = None,
1393        allow_joins: bool = True,
1394        reuse: Any = None,
1395        summarize: bool = False,
1396        for_save: bool = False,
1397    ) -> When:
1398        c = self.copy()
1399        c.is_summary = summarize
1400        if isinstance(c.condition, ResolvableExpression):
1401            c.condition = c.condition.resolve_expression(
1402                query, allow_joins, reuse, summarize, False
1403            )
1404        c.result = c.result.resolve_expression(
1405            query, allow_joins, reuse, summarize, for_save
1406        )
1407        return c
1408
1409    def as_sql(
1410        self,
1411        compiler: SQLCompiler,
1412        connection: DatabaseConnection,
1413        template: str | None = None,
1414        **extra_context: Any,
1415    ) -> tuple[str, tuple[Any, ...]]:
1416        template_params = extra_context
1417        sql_params = []
1418        # After resolve_expression, condition is WhereNode | resolved Expression (both SQLCompilable)
1419        condition_sql, condition_params = compiler.compile(self.condition)
1420        template_params["condition"] = condition_sql
1421        result_sql, result_params = compiler.compile(self.result)
1422        template_params["result"] = result_sql
1423        template = template or self.template
1424        return template % template_params, (
1425            *sql_params,
1426            *condition_params,
1427            *result_params,
1428        )
1429
1430    def get_group_by_cols(self) -> list[Any]:
1431        # This is not a complete expression and cannot be used in GROUP BY.
1432        cols = []
1433        for source in self.get_source_expressions():
1434            cols.extend(source.get_group_by_cols())
1435        return cols
1436
1437
1438@deconstructible(path="plain.postgres.expressions.Case")
1439class Case(Expression):
1440    """
1441    An SQL searched CASE expression:
1442
1443        CASE
1444            WHEN n > 0
1445                THEN 'positive'
1446            WHEN n < 0
1447                THEN 'negative'
1448            ELSE 'zero'
1449        END
1450    """
1451
1452    template = "CASE %(cases)s ELSE %(default)s END"
1453    case_joiner = " "
1454
1455    def __init__(
1456        self,
1457        *cases: When,
1458        default: Any = None,
1459        output_field: Field | None = None,
1460        **extra: Any,
1461    ):
1462        if not all(isinstance(case, When) for case in cases):
1463            raise TypeError("Positional arguments must all be When objects.")
1464        super().__init__(output_field)
1465        self.cases = list(cases)
1466        self.default = self._parse_expressions(default)[0]
1467        self.extra = extra
1468
1469    def __str__(self) -> str:
1470        return "CASE {}, ELSE {!r}".format(
1471            ", ".join(str(c) for c in self.cases),
1472            self.default,
1473        )
1474
1475    def __repr__(self) -> str:
1476        return f"<{self.__class__.__name__}: {self}>"
1477
1478    def get_source_expressions(self) -> list[Any]:
1479        return self.cases + [self.default]
1480
1481    def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1482        *self.cases, self.default = exprs
1483
1484    def resolve_expression(
1485        self,
1486        query: Any = None,
1487        allow_joins: bool = True,
1488        reuse: Any = None,
1489        summarize: bool = False,
1490        for_save: bool = False,
1491    ) -> Case:
1492        c = self.copy()
1493        c.is_summary = summarize
1494        for pos, case in enumerate(c.cases):
1495            c.cases[pos] = case.resolve_expression(
1496                query, allow_joins, reuse, summarize, for_save
1497            )
1498        c.default = c.default.resolve_expression(
1499            query, allow_joins, reuse, summarize, for_save
1500        )
1501        return c
1502
1503    def copy(self) -> Self:
1504        c = super().copy()
1505        c.cases = c.cases[:]
1506        return c
1507
1508    def as_sql(
1509        self,
1510        compiler: SQLCompiler,
1511        connection: DatabaseConnection,
1512        template: str | None = None,
1513        case_joiner: str | None = None,
1514        **extra_context: Any,
1515    ) -> tuple[str, list[Any]]:
1516        if not self.cases:
1517            sql, params = compiler.compile(self.default)
1518            return sql, list(params)
1519        template_params = {**self.extra, **extra_context}
1520        case_parts = []
1521        sql_params = []
1522        default_sql, default_params = compiler.compile(self.default)
1523        for case in self.cases:
1524            try:
1525                case_sql, case_params = compiler.compile(case)
1526            except EmptyResultSet:
1527                continue
1528            except FullResultSet:
1529                default_sql, default_params = compiler.compile(case.result)
1530                break
1531            case_parts.append(case_sql)
1532            sql_params.extend(case_params)
1533        if not case_parts:
1534            return default_sql, list(default_params)
1535        case_joiner = case_joiner or self.case_joiner
1536        template_params["cases"] = case_joiner.join(case_parts)
1537        template_params["default"] = default_sql
1538        sql_params.extend(default_params)
1539        template = template or template_params.get("template", self.template)
1540        sql = template % template_params
1541        if self._output_field_or_none is not None:
1542            sql = connection.unification_cast_sql(self.output_field) % sql
1543        return sql, sql_params
1544
1545    def get_group_by_cols(self) -> list[Any]:
1546        if not self.cases:
1547            return self.default.get_group_by_cols()
1548        return super().get_group_by_cols()
1549
1550
1551class Subquery(BaseExpression, Combinable):
1552    """
1553    An explicit subquery. It may contain OuterRef() references to the outer
1554    query which will be resolved when it is applied to that query.
1555    """
1556
1557    template = "(%(subquery)s)"
1558    contains_aggregate = False
1559    empty_result_set_value = None
1560
1561    def __init__(
1562        self,
1563        query: QuerySet[Any] | Query,
1564        output_field: Field | None = None,
1565        **extra: Any,
1566    ):
1567        # Import here to avoid circular import
1568        from plain.postgres.sql.query import Query
1569
1570        # Allow the usage of both QuerySet and sql.Query objects.
1571        if isinstance(query, Query):
1572            # It's already a Query object, use it directly
1573            sql_query = query
1574        else:
1575            # It's a QuerySet, extract the sql.Query
1576            sql_query = query.sql_query
1577        self.query = sql_query.clone()
1578        self.query.subquery = True
1579        self.extra = extra
1580        super().__init__(output_field)
1581
1582    def get_source_expressions(self) -> list[Any]:
1583        return [self.query]
1584
1585    def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1586        self.query = exprs[0]
1587
1588    def _resolve_output_field(self) -> Field | None:
1589        return self.query.output_field
1590
1591    def copy(self) -> Self:
1592        clone = super().copy()
1593        clone.query = clone.query.clone()
1594        return clone
1595
1596    @property
1597    def external_aliases(self) -> dict[str, bool]:
1598        return self.query.external_aliases
1599
1600    def get_external_cols(self) -> list[Any]:
1601        return self.query.get_external_cols()
1602
1603    def as_sql(
1604        self,
1605        compiler: SQLCompiler,
1606        connection: DatabaseConnection,
1607        template: str | None = None,
1608        **extra_context: Any,
1609    ) -> tuple[str, tuple[Any, ...]]:
1610        template_params = {**self.extra, **extra_context}
1611        subquery_sql, sql_params = self.query.as_sql(compiler, connection)
1612        template_params["subquery"] = subquery_sql[1:-1]
1613
1614        template = template or template_params.get("template", self.template)
1615        sql = template % template_params
1616        return sql, sql_params
1617
1618    def get_group_by_cols(self) -> list[Any]:
1619        return self.query.get_group_by_cols(wrapper=self)
1620
1621
1622class Exists(Subquery):
1623    template = "EXISTS(%(subquery)s)"
1624    output_field = fields.BooleanField()
1625    empty_result_set_value = False
1626
1627    def __init__(self, query: QuerySet[Any] | Query, **kwargs: Any):
1628        super().__init__(query, **kwargs)
1629        self.query = self.query.exists()
1630
1631    def select_format(
1632        self, compiler: SQLCompiler, sql: str, params: Sequence[Any]
1633    ) -> tuple[str, Sequence[Any]]:
1634        # Boolean expressions work directly in SELECT
1635        return sql, params
1636
1637
1638@deconstructible(path="plain.postgres.expressions.OrderBy")
1639class OrderBy(Expression):
1640    template = "%(expression)s %(ordering)s"
1641    conditional = False
1642
1643    def __init__(
1644        self,
1645        expression: Any,
1646        descending: bool = False,
1647        nulls_first: bool | None = None,
1648        nulls_last: bool | None = None,
1649    ):
1650        if nulls_first and nulls_last:
1651            raise ValueError("nulls_first and nulls_last are mutually exclusive")
1652        if nulls_first is False or nulls_last is False:
1653            raise ValueError("nulls_first and nulls_last values must be True or None.")
1654        self.nulls_first = nulls_first
1655        self.nulls_last = nulls_last
1656        self.descending = descending
1657        if not isinstance(expression, ResolvableExpression):
1658            raise ValueError("expression must be an expression type")
1659        self.expression = expression
1660
1661    def __repr__(self) -> str:
1662        return f"{self.__class__.__name__}({self.expression}, descending={self.descending})"
1663
1664    def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1665        self.expression = exprs[0]
1666
1667    def get_source_expressions(self) -> list[Any]:
1668        return [self.expression]
1669
1670    def as_sql(
1671        self,
1672        compiler: SQLCompiler,
1673        connection: DatabaseConnection,
1674        template: str | None = None,
1675        **extra_context: Any,
1676    ) -> tuple[str, tuple[Any, ...]]:
1677        template = template or self.template
1678        # Handle NULLS FIRST/LAST modifiers
1679        if self.nulls_last:
1680            template = f"{template} NULLS LAST"
1681        elif self.nulls_first:
1682            template = f"{template} NULLS FIRST"
1683        expression_sql, params = compiler.compile(self.expression)
1684        placeholders = {
1685            "expression": expression_sql,
1686            "ordering": "DESC" if self.descending else "ASC",
1687            **extra_context,
1688        }
1689        params *= template.count("%(expression)s")
1690        return (template % placeholders).rstrip(), params
1691
1692    def get_group_by_cols(self) -> list[Any]:
1693        cols = []
1694        for source in self.get_source_expressions():
1695            cols.extend(source.get_group_by_cols())
1696        return cols
1697
1698    def reverse_ordering(self) -> OrderBy:
1699        self.descending = not self.descending
1700        if self.nulls_first:
1701            self.nulls_last = True
1702            self.nulls_first = None
1703        elif self.nulls_last:
1704            self.nulls_first = True
1705            self.nulls_last = None
1706        return self
1707
1708    def asc(self) -> None:  # ty: ignore[invalid-method-override]
1709        self.descending = False
1710
1711    def desc(self) -> None:  # ty: ignore[invalid-method-override]
1712        self.descending = True
1713
1714
1715class Window(Expression):
1716    template = "%(expression)s OVER (%(window)s)"
1717    # Although the main expression may either be an aggregate or an
1718    # expression with an aggregate function, the GROUP BY that will
1719    # be introduced in the query as a result is not desired.
1720    contains_aggregate = False
1721    contains_over_clause = True
1722    partition_by: ExpressionList | None
1723    order_by: OrderByList | None
1724
1725    def __init__(
1726        self,
1727        expression: Any,
1728        partition_by: Any = None,
1729        order_by: Any = None,
1730        frame: Any = None,
1731        output_field: Field | None = None,
1732    ):
1733        self.partition_by = partition_by
1734        self.order_by = order_by
1735        self.frame = frame
1736
1737        if not getattr(expression, "window_compatible", False):
1738            raise ValueError(
1739                f"Expression '{expression.__class__.__name__}' isn't compatible with OVER clauses."
1740            )
1741
1742        if self.partition_by is not None:
1743            partition_by_values = (
1744                self.partition_by
1745                if isinstance(self.partition_by, tuple | list)
1746                else (self.partition_by,)
1747            )
1748            self.partition_by = ExpressionList(*partition_by_values)
1749
1750        if self.order_by is not None:
1751            if isinstance(self.order_by, list | tuple):
1752                self.order_by = OrderByList(*self.order_by)
1753            elif isinstance(self.order_by, BaseExpression | str):
1754                self.order_by = OrderByList(self.order_by)
1755            else:
1756                raise ValueError(
1757                    "Window.order_by must be either a string reference to a "
1758                    "field, an expression, or a list or tuple of them."
1759                )
1760        super().__init__(output_field=output_field)
1761        self.source_expression = self._parse_expressions(expression)[0]
1762
1763    def _resolve_output_field(self) -> Field | None:
1764        return self.source_expression.output_field
1765
1766    def get_source_expressions(self) -> list[Any]:
1767        return [self.source_expression, self.partition_by, self.order_by, self.frame]
1768
1769    def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1770        self.source_expression, self.partition_by, self.order_by, self.frame = exprs
1771
1772    def as_sql(
1773        self,
1774        compiler: SQLCompiler,
1775        connection: DatabaseConnection,
1776        template: str | None = None,
1777    ) -> tuple[str, tuple[Any, ...]]:
1778        expr_sql, params = compiler.compile(self.source_expression)
1779        window_sql, window_params = [], ()
1780
1781        if self.partition_by is not None:
1782            sql_expr, sql_params = self.partition_by.as_sql(
1783                compiler=compiler,
1784                connection=connection,
1785                template="PARTITION BY %(expressions)s",
1786            )
1787            window_sql.append(sql_expr)
1788            window_params += tuple(sql_params)
1789
1790        if self.order_by is not None:
1791            order_sql, order_params = compiler.compile(self.order_by)
1792            window_sql.append(order_sql)
1793            window_params += tuple(order_params)
1794
1795        if self.frame:
1796            frame_sql, frame_params = compiler.compile(self.frame)
1797            window_sql.append(frame_sql)
1798            window_params += tuple(frame_params)
1799
1800        template = template or self.template
1801
1802        return (
1803            template % {"expression": expr_sql, "window": " ".join(window_sql).strip()},
1804            (*params, *window_params),
1805        )
1806
1807    def __str__(self) -> str:
1808        return "{} OVER ({}{}{})".format(
1809            str(self.source_expression),
1810            "PARTITION BY " + str(self.partition_by) if self.partition_by else "",
1811            str(self.order_by or ""),
1812            str(self.frame or ""),
1813        )
1814
1815    def __repr__(self) -> str:
1816        return f"<{self.__class__.__name__}: {self}>"
1817
1818    def get_group_by_cols(self) -> list[Any]:
1819        group_by_cols = []
1820        if self.partition_by:
1821            group_by_cols.extend(self.partition_by.get_group_by_cols())
1822        if self.order_by is not None:
1823            group_by_cols.extend(self.order_by.get_group_by_cols())
1824        return group_by_cols
1825
1826
1827class WindowFrame(Expression):
1828    """
1829    Model the frame clause in window expressions. There are two types of frame
1830    clauses which are subclasses, however, all processing and validation (by no
1831    means intended to be complete) is done here. Thus, providing an end for a
1832    frame is optional (the default is UNBOUNDED FOLLOWING, which is the last
1833    row in the frame).
1834    """
1835
1836    template = "%(frame_type)s BETWEEN %(start)s AND %(end)s"
1837    frame_type: str
1838
1839    def __init__(self, start: int | None = None, end: int | None = None):
1840        self.start = Value(start)
1841        self.end = Value(end)
1842
1843    def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1844        self.start, self.end = exprs
1845
1846    def get_source_expressions(self) -> list[Any]:
1847        return [self.start, self.end]
1848
1849    def as_sql(
1850        self, compiler: SQLCompiler, connection: DatabaseConnection
1851    ) -> tuple[str, list[Any]]:
1852        start, end = self.window_frame_start_end(
1853            connection, self.start.value, self.end.value
1854        )
1855        return (
1856            self.template
1857            % {
1858                "frame_type": self.frame_type,
1859                "start": start,
1860                "end": end,
1861            },
1862            [],
1863        )
1864
1865    def __repr__(self) -> str:
1866        return f"<{self.__class__.__name__}: {self}>"
1867
1868    def get_group_by_cols(self) -> list[Any]:
1869        return []
1870
1871    def __str__(self) -> str:
1872        if self.start.value is not None and self.start.value < 0:
1873            start = f"{abs(self.start.value)} {PRECEDING}"
1874        elif self.start.value is not None and self.start.value == 0:
1875            start = CURRENT_ROW
1876        else:
1877            start = UNBOUNDED_PRECEDING
1878
1879        if self.end.value is not None and self.end.value > 0:
1880            end = f"{self.end.value} {FOLLOWING}"
1881        elif self.end.value is not None and self.end.value == 0:
1882            end = CURRENT_ROW
1883        else:
1884            end = UNBOUNDED_FOLLOWING
1885        return self.template % {
1886            "frame_type": self.frame_type,
1887            "start": start,
1888            "end": end,
1889        }
1890
1891    def window_frame_start_end(
1892        self, connection: DatabaseConnection, start: int | None, end: int | None
1893    ) -> tuple[str, str]:
1894        """Return the window frame start and end for the given connection."""
1895        raise NotImplementedError("Subclasses must implement window_frame_start_end()")
1896
1897
1898class RowRange(WindowFrame):
1899    frame_type = "ROWS"
1900
1901    def window_frame_start_end(
1902        self, connection: DatabaseConnection, start: int | None, end: int | None
1903    ) -> tuple[str, str]:
1904        return window_frame_rows_start_end(start, end)
1905
1906
1907class ValueRange(WindowFrame):
1908    frame_type = "RANGE"
1909
1910    def window_frame_start_end(
1911        self, connection: DatabaseConnection, start: int | None, end: int | None
1912    ) -> tuple[str, str]:
1913        return window_frame_range_start_end(start, end)