v0.145.1
  1"""
  2PostgreSQL-specific SQL generation functions.
  3
  4All functions in this module are stateless — they don't depend on connection state.
  5Higher-level SQL builders that need connections live in ddl.py.
  6"""
  7
  8from __future__ import annotations
  9
 10import datetime
 11import ipaddress
 12import json
 13from collections.abc import Callable, Iterable
 14from functools import lru_cache, partial
 15from typing import TYPE_CHECKING, Any
 16
 17import psycopg
 18from psycopg.types.json import Jsonb
 19
 20from plain.postgres.constants import OnConflict
 21from plain.postgres.utils import split_tzname_delta
 22from plain.utils import timezone
 23from plain.utils.regex_helper import _lazy_re_compile
 24
 25if TYPE_CHECKING:
 26    from plain.postgres.fields import Field
 27
 28# Start and end points for window expressions.
 29PRECEDING: str = "PRECEDING"
 30FOLLOWING: str = "FOLLOWING"
 31UNBOUNDED_PRECEDING: str = "UNBOUNDED " + PRECEDING
 32UNBOUNDED_FOLLOWING: str = "UNBOUNDED " + FOLLOWING
 33CURRENT_ROW: str = "CURRENT ROW"
 34
 35# Prefix for EXPLAIN queries.
 36EXPLAIN_PREFIX: str = "EXPLAIN"
 37EXPLAIN_OPTIONS = frozenset(
 38    [
 39        "ANALYZE",
 40        "BUFFERS",
 41        "COSTS",
 42        "SETTINGS",
 43        "SUMMARY",
 44        "TIMING",
 45        "VERBOSE",
 46        "WAL",
 47    ]
 48)
 49SUPPORTED_EXPLAIN_FORMATS: set[str] = {"JSON", "TEXT", "XML", "YAML"}
 50
 51# Maximum length of an identifier (63 by default in PostgreSQL).
 52MAX_NAME_LENGTH: int = 63
 53
 54# Value to use during INSERT to specify that a field should use its default value.
 55PK_DEFAULT_VALUE: str = "DEFAULT"
 56
 57# SQL clause to make a constraint "initially deferred" during CREATE TABLE.
 58DEFERRABLE_SQL: str = " DEFERRABLE INITIALLY DEFERRED"
 59
 60# EXTRACT format validation pattern.
 61_EXTRACT_FORMAT_RE = _lazy_re_compile(r"[A-Z_]+")
 62
 63
 64# SQL operators for lookups.
 65OPERATORS: dict[str, str] = {
 66    "exact": "= %s",
 67    "iexact": "= UPPER(%s)",
 68    "contains": "LIKE %s",
 69    "icontains": "LIKE UPPER(%s)",
 70    "regex": "~ %s",
 71    "iregex": "~* %s",
 72    "gt": "> %s",
 73    "gte": ">= %s",
 74    "lt": "< %s",
 75    "lte": "<= %s",
 76    "startswith": "LIKE %s",
 77    "endswith": "LIKE %s",
 78    "istartswith": "LIKE UPPER(%s)",
 79    "iendswith": "LIKE UPPER(%s)",
 80}
 81
 82# SQL pattern for escaping special characters in LIKE clauses.
 83# Used when the right-hand side isn't a raw string (e.g., an expression).
 84PATTERN_ESC = (
 85    r"REPLACE(REPLACE(REPLACE({}, E'\\', E'\\\\'), E'%%', E'\\%%'), E'_', E'\\_')"
 86)
 87
 88# Pattern operators for non-literal LIKE lookups.
 89PATTERN_OPS: dict[str, str] = {
 90    "contains": "LIKE '%%' || {} || '%%'",
 91    "icontains": "LIKE '%%' || UPPER({}) || '%%'",
 92    "startswith": "LIKE {} || '%%'",
 93    "istartswith": "LIKE UPPER({}) || '%%'",
 94    "endswith": "LIKE '%%' || {}",
 95    "iendswith": "LIKE '%%' || UPPER({})",
 96}
 97
 98
 99@lru_cache
100def get_json_dumps(
101    encoder: type[json.JSONEncoder] | None,
102) -> Callable[..., str]:
103    if encoder is None:
104        return json.dumps
105    return partial(json.dumps, cls=encoder)
106
107
108def quote_name(name: str) -> str:
109    """
110    Return a quoted version of the given table, index, or column name.
111    Does not quote the given name if it's already been quoted.
112    """
113    if name.startswith('"') and name.endswith('"'):
114        return name  # Quoting once is enough.
115    return f'"{name}"'
116
117
118# Postgres interval literals we accept in timeout settings. Requires an
119# explicit unit so a `"1"` typo can't silently become 1ms (Postgres's
120# implicit unit for lock_timeout/statement_timeout). Fractional values
121# and mixed case are allowed to match what Postgres itself accepts.
122# Rejecting single quotes prevents malformed settings from escaping the
123# SQL literal.
124_TIMEOUT_VALUE_RE = _lazy_re_compile(r"(?i)^\d+(\.\d+)?\s*(us|ms|s|min|h|d)$")
125
126
127def _validate_timeout_value(name: str, value: str) -> None:
128    if not _TIMEOUT_VALUE_RE.match(value):
129        raise ValueError(
130            f"Invalid Postgres interval for {name}: {value!r}. "
131            "Expected e.g. '3s', '500ms', '1.5min' — a number with an "
132            "explicit unit (us, ms, s, min, h, d)."
133        )
134
135
136def build_timeout_set_clauses(
137    *,
138    lock_timeout: str,
139    statement_timeout: str | None,
140    local: bool = True,
141) -> str:
142    """Return a `SET [LOCAL] lock_timeout = '...'; [SET [LOCAL] statement_timeout = '...'; ]`
143    prelude to prepend to a DDL statement.
144
145    `local=True` emits `SET LOCAL` (transaction-scoped, auto-restores on commit).
146    `local=False` emits session-level `SET` (used in autocommit mode; requires
147    a matching RESET).
148
149    `statement_timeout=None` omits that SET — used for SHARE UPDATE EXCLUSIVE
150    operations (CONCURRENTLY, VALIDATE CONSTRAINT) that are non-blocking and
151    should run to completion on any table size.
152    """
153    _validate_timeout_value("lock_timeout", lock_timeout)
154    if statement_timeout is not None:
155        _validate_timeout_value("statement_timeout", statement_timeout)
156    scope = "LOCAL " if local else ""
157    parts = [f"SET {scope}lock_timeout = '{lock_timeout}'"]
158    if statement_timeout is not None:
159        parts.append(f"SET {scope}statement_timeout = '{statement_timeout}'")
160    return "; ".join(parts) + "; "
161
162
163def date_extract_sql(
164    lookup_type: str, sql: str, params: list[Any] | tuple[Any, ...]
165) -> tuple[str, list[Any] | tuple[Any, ...]]:
166    """
167    Given a lookup_type of 'year', 'month', or 'day', return the SQL that
168    extracts a value from the given date field field_name.
169    """
170    # https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-EXTRACT
171    if lookup_type == "week_day":
172        # PostgreSQL DOW returns 0=Sunday, 6=Saturday; we return 1=Sunday, 7=Saturday.
173        return f"EXTRACT(DOW FROM {sql}) + 1", params
174    elif lookup_type == "iso_week_day":
175        return f"EXTRACT(ISODOW FROM {sql})", params
176    elif lookup_type == "iso_year":
177        return f"EXTRACT(ISOYEAR FROM {sql})", params
178
179    lookup_type = lookup_type.upper()
180    if not _EXTRACT_FORMAT_RE.fullmatch(lookup_type):
181        raise ValueError(f"Invalid lookup type: {lookup_type!r}")
182    return f"EXTRACT({lookup_type} FROM {sql})", params
183
184
185def _prepare_tzname_delta(tzname: str) -> str:
186    tzname, sign, offset = split_tzname_delta(tzname)
187    if offset:
188        sign = "-" if sign == "+" else "+"
189        return f"{tzname}{sign}{offset}"
190    return tzname
191
192
193def _convert_sql_to_tz(
194    sql: str, params: list[Any] | tuple[Any, ...], tzname: str | None
195) -> tuple[str, list[Any] | tuple[Any, ...]]:
196    if tzname:
197        tzname_param = _prepare_tzname_delta(tzname)
198        return f"{sql} AT TIME ZONE %s", (*params, tzname_param)
199    return sql, params
200
201
202def date_trunc_sql(
203    lookup_type: str,
204    sql: str,
205    params: list[Any] | tuple[Any, ...],
206    tzname: str | None = None,
207) -> tuple[str, tuple[Any, ...]]:
208    """
209    Given a lookup_type of 'year', 'month', or 'day', return the SQL that
210    truncates the given date or datetime field field_name to a date object
211    with only the given specificity.
212
213    If `tzname` is provided, the given value is truncated in a specific timezone.
214    """
215    sql, params = _convert_sql_to_tz(sql, params, tzname)
216    # https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-TRUNC
217    return f"DATE_TRUNC(%s, {sql})", (lookup_type, *params)
218
219
220def datetime_cast_date_sql(
221    sql: str, params: list[Any] | tuple[Any, ...], tzname: str | None
222) -> tuple[str, list[Any] | tuple[Any, ...]]:
223    """Return the SQL to cast a datetime value to date value."""
224    sql, params = _convert_sql_to_tz(sql, params, tzname)
225    return f"({sql})::date", params
226
227
228def datetime_cast_time_sql(
229    sql: str, params: list[Any] | tuple[Any, ...], tzname: str | None
230) -> tuple[str, list[Any] | tuple[Any, ...]]:
231    """Return the SQL to cast a datetime value to time value."""
232    sql, params = _convert_sql_to_tz(sql, params, tzname)
233    return f"({sql})::time", params
234
235
236def datetime_extract_sql(
237    lookup_type: str,
238    sql: str,
239    params: list[Any] | tuple[Any, ...],
240    tzname: str | None,
241) -> tuple[str, list[Any] | tuple[Any, ...]]:
242    """
243    Given a lookup_type of 'year', 'month', 'day', 'hour', 'minute', or
244    'second', return the SQL that extracts a value from the given
245    datetime field field_name.
246    """
247    sql, params = _convert_sql_to_tz(sql, params, tzname)
248    if lookup_type == "second":
249        # Truncate fractional seconds.
250        return f"EXTRACT(SECOND FROM DATE_TRUNC(%s, {sql}))", ("second", *params)
251    return date_extract_sql(lookup_type, sql, params)
252
253
254def datetime_trunc_sql(
255    lookup_type: str,
256    sql: str,
257    params: list[Any] | tuple[Any, ...],
258    tzname: str | None,
259) -> tuple[str, tuple[Any, ...]]:
260    """
261    Given a lookup_type of 'year', 'month', 'day', 'hour', 'minute', or
262    'second', return the SQL that truncates the given datetime field
263    field_name to a datetime object with only the given specificity.
264    """
265    sql, params = _convert_sql_to_tz(sql, params, tzname)
266    # https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-TRUNC
267    return f"DATE_TRUNC(%s, {sql})", (lookup_type, *params)
268
269
270def time_extract_sql(
271    lookup_type: str, sql: str, params: list[Any] | tuple[Any, ...]
272) -> tuple[str, list[Any] | tuple[Any, ...]]:
273    """
274    Given a lookup_type of 'hour', 'minute', or 'second', return the SQL
275    that extracts a value from the given time field field_name.
276    """
277    if lookup_type == "second":
278        # Truncate fractional seconds.
279        return f"EXTRACT(SECOND FROM DATE_TRUNC(%s, {sql}))", ("second", *params)
280    return date_extract_sql(lookup_type, sql, params)
281
282
283def time_trunc_sql(
284    lookup_type: str,
285    sql: str,
286    params: list[Any] | tuple[Any, ...],
287    tzname: str | None = None,
288) -> tuple[str, tuple[Any, ...]]:
289    """
290    Given a lookup_type of 'hour', 'minute' or 'second', return the SQL
291    that truncates the given time or datetime field field_name to a time
292    object with only the given specificity.
293
294    If `tzname` is provided, the given value is truncated in a specific timezone.
295    """
296    sql, params = _convert_sql_to_tz(sql, params, tzname)
297    return f"DATE_TRUNC(%s, {sql})::time", (lookup_type, *params)
298
299
300def distinct_sql(
301    fields: list[str], params: list[Any] | tuple[Any, ...]
302) -> tuple[list[str], list[Any]]:
303    """
304    Return an SQL DISTINCT clause which removes duplicate rows from the
305    result set. If any fields are given, only check the given fields for
306    duplicates.
307    """
308    if fields:
309        params = [param for param_list in params for param in param_list]
310        return (["DISTINCT ON ({})".format(", ".join(fields))], params)
311    else:
312        return ["DISTINCT"], []
313
314
315def for_update_sql(
316    nowait: bool = False,
317    skip_locked: bool = False,
318    of: tuple[str, ...] = (),
319    no_key: bool = False,
320) -> str:
321    """Return the FOR UPDATE SQL clause to lock rows for an update operation."""
322    return "FOR{} UPDATE{}{}{}".format(
323        " NO KEY" if no_key else "",
324        " OF {}".format(", ".join(of)) if of else "",
325        " NOWAIT" if nowait else "",
326        " SKIP LOCKED" if skip_locked else "",
327    )
328
329
330def limit_offset_sql(low_mark: int | None, high_mark: int | None) -> str:
331    """Return LIMIT/OFFSET SQL clause."""
332    offset = low_mark or 0
333    if high_mark is not None:
334        limit = high_mark - offset
335    else:
336        limit = None
337    return " ".join(
338        sql
339        for sql in (
340            ("LIMIT %d" % limit) if limit else None,  # noqa: UP031
341            ("OFFSET %d" % offset) if offset else None,  # noqa: UP031
342        )
343        if sql
344    )
345
346
347def lookup_cast(lookup_type: str, field: Field | None = None) -> str:
348    """
349    Return the string to use in a query when performing lookups
350    ("contains", "like", etc.). It should contain a '%s' placeholder for
351    the column being searched against.
352    """
353    from plain.postgres.fields import (
354        GenericIPAddressField,
355    )
356
357    lookup = "%s"
358
359    # Cast text lookups to text to allow things like filter(x__contains=4)
360    if lookup_type in (
361        "iexact",
362        "contains",
363        "icontains",
364        "startswith",
365        "istartswith",
366        "endswith",
367        "iendswith",
368        "regex",
369        "iregex",
370    ):
371        if isinstance(field, GenericIPAddressField):
372            lookup = "HOST(%s)"
373        else:
374            lookup = "%s::text"
375
376    # Use UPPER(x) for case-insensitive lookups; it's faster.
377    if lookup_type in ("iexact", "icontains", "istartswith", "iendswith"):
378        lookup = f"UPPER({lookup})"
379
380    return lookup
381
382
383def return_insert_columns(fields: list[Field]) -> tuple[str, tuple[Any, ...]]:
384    """Return the RETURNING clause SQL and params to append to an INSERT query."""
385    if not fields:
386        return "", ()
387    columns = [
388        f"{quote_name(field.model.model_options.db_table)}.{quote_name(field.column)}"
389        for field in fields
390    ]
391    return "RETURNING {}".format(", ".join(columns)), ()
392
393
394def bulk_insert_sql(fields: list[Field], placeholder_rows: list[list[str]]) -> str:
395    """Return the SQL for bulk inserting rows."""
396    placeholder_rows_sql = (", ".join(row) for row in placeholder_rows)
397    values_sql = ", ".join(f"({sql})" for sql in placeholder_rows_sql)
398    return "VALUES " + values_sql
399
400
401def regex_lookup(lookup_type: str) -> str:
402    """
403    Return the string to use in a query when performing regular expression
404    lookups (using "regex" or "iregex").
405    """
406    # PostgreSQL uses ~ for regex and ~* for case-insensitive regex
407    if lookup_type == "regex":
408        return "%s ~ %s"
409    return "%s ~* %s"
410
411
412def prep_for_like_query(x: str) -> str:
413    """Prepare a value for use in a LIKE query."""
414    return str(x).replace("\\", "\\\\").replace("%", r"\%").replace("_", r"\_")
415
416
417def adapt_ipaddressfield_value(
418    value: str | None,
419) -> ipaddress.IPv4Address | ipaddress.IPv6Address | None:
420    """
421    Transform a string representation of an IP address into the expected
422    type for the backend driver.
423    """
424    if value:
425        return ipaddress.ip_address(value)
426    return None
427
428
429def adapt_json_value(value: Any, encoder: type[json.JSONEncoder] | None) -> Jsonb:
430    return Jsonb(value, dumps=get_json_dumps(encoder))
431
432
433def year_lookup_bounds_for_date_field(
434    value: int, iso_year: bool = False
435) -> list[datetime.date]:
436    """
437    Return a two-elements list with the lower and upper bound to be used
438    with a BETWEEN operator to query a DateField value using a year lookup.
439
440    `value` is an int, containing the looked-up year.
441    If `iso_year` is True, return bounds for ISO-8601 week-numbering years.
442    """
443    if iso_year:
444        first = datetime.date.fromisocalendar(value, 1, 1)
445        second = datetime.date.fromisocalendar(value + 1, 1, 1) - datetime.timedelta(
446            days=1
447        )
448    else:
449        first = datetime.date(value, 1, 1)
450        second = datetime.date(value, 12, 31)
451    return [first, second]
452
453
454def year_lookup_bounds_for_datetime_field(
455    value: int, iso_year: bool = False
456) -> list[datetime.datetime]:
457    """
458    Return a two-elements list with the lower and upper bound to be used
459    with a BETWEEN operator to query a DateTimeField value using a year lookup.
460
461    `value` is an int, containing the looked-up year.
462    If `iso_year` is True, return bounds for ISO-8601 week-numbering years.
463    """
464    if iso_year:
465        first = datetime.datetime.fromisocalendar(value, 1, 1)
466        second = datetime.datetime.fromisocalendar(
467            value + 1, 1, 1
468        ) - datetime.timedelta(microseconds=1)
469    else:
470        first = datetime.datetime(value, 1, 1)
471        second = datetime.datetime(value, 12, 31, 23, 59, 59, 999999)
472
473    # Make sure that datetimes are aware in the current timezone
474    tz = timezone.get_current_timezone()
475    first = timezone.make_aware(first, tz)
476    second = timezone.make_aware(second, tz)
477    return [first, second]
478
479
480def combine_expression(connector: str, sub_expressions: list[str]) -> str:
481    """
482    Combine a list of subexpressions into a single expression, using
483    the provided connecting operator.
484    """
485    conn = f" {connector} "
486    return conn.join(sub_expressions)
487
488
489def subtract_temporals(
490    field: Field,
491    lhs: tuple[str, list[Any] | tuple[Any, ...]],
492    rhs: tuple[str, list[Any] | tuple[Any, ...]],
493) -> tuple[str, tuple[Any, ...]]:
494    from plain.postgres.fields import DateField, DateTimeField
495
496    lhs_sql, lhs_params = lhs
497    rhs_sql, rhs_params = rhs
498    params = (*lhs_params, *rhs_params)
499    # DateField (but not DateTimeField) needs interval conversion
500    if isinstance(field, DateField) and not isinstance(field, DateTimeField):
501        return f"(interval '1 day' * ({lhs_sql} - {rhs_sql}))", params
502    # Use native temporal subtraction
503    return f"({lhs_sql} - {rhs_sql})", params
504
505
506def window_frame_start(start: int | None) -> str:
507    if isinstance(start, int):
508        if start < 0:
509            return "%d %s" % (abs(start), PRECEDING)  # noqa: UP031
510        elif start == 0:
511            return CURRENT_ROW
512    elif start is None:
513        return UNBOUNDED_PRECEDING
514    raise ValueError(
515        f"start argument must be a negative integer, zero, or None, but got '{start}'."
516    )
517
518
519def window_frame_end(end: int | None) -> str:
520    if isinstance(end, int):
521        if end == 0:
522            return CURRENT_ROW
523        elif end > 0:
524            return "%d %s" % (end, FOLLOWING)  # noqa: UP031
525    elif end is None:
526        return UNBOUNDED_FOLLOWING
527    raise ValueError(
528        f"end argument must be a positive integer, zero, or None, but got '{end}'."
529    )
530
531
532def window_frame_rows_start_end(
533    start: int | None = None, end: int | None = None
534) -> tuple[str, str]:
535    """Return SQL for start and end points in an OVER clause window frame."""
536    return window_frame_start(start), window_frame_end(end)
537
538
539def window_frame_range_start_end(
540    start: int | None = None, end: int | None = None
541) -> tuple[str, str]:
542    start_, end_ = window_frame_rows_start_end(start, end)
543    # PostgreSQL only supports UNBOUNDED with PRECEDING/FOLLOWING
544    if (start and start < 0) or (end and end > 0):
545        raise psycopg.NotSupportedError(
546            "PostgreSQL only supports UNBOUNDED together with PRECEDING and FOLLOWING."
547        )
548    return start_, end_
549
550
551def explain_query_prefix(format: str | None = None, **options: Any) -> str:
552    extra = {}
553    # Normalize options.
554    if options:
555        options = {
556            name.upper(): "true" if value else "false"
557            for name, value in options.items()
558        }
559        for valid_option in EXPLAIN_OPTIONS:
560            value = options.pop(valid_option, None)
561            if value is not None:
562                extra[valid_option] = value
563    if format:
564        normalized_format = format.upper()
565        if normalized_format not in SUPPORTED_EXPLAIN_FORMATS:
566            msg = "{} is not a recognized format. Allowed formats: {}".format(
567                normalized_format, ", ".join(sorted(SUPPORTED_EXPLAIN_FORMATS))
568            )
569            raise ValueError(msg)
570        extra["FORMAT"] = format
571    if options:
572        raise ValueError(
573            "Unknown options: {}".format(", ".join(sorted(options.keys())))
574        )
575    prefix = EXPLAIN_PREFIX
576    if extra:
577        prefix += " ({})".format(", ".join("{} {}".format(*i) for i in extra.items()))
578    return prefix
579
580
581def on_conflict_suffix_sql(
582    fields: list[Field],
583    on_conflict: OnConflict | None,
584    update_fields: Iterable[str],
585    unique_fields: Iterable[str],
586) -> str:
587    if on_conflict == OnConflict.IGNORE:
588        return "ON CONFLICT DO NOTHING"
589    if on_conflict == OnConflict.UPDATE:
590        return "ON CONFLICT({}) DO UPDATE SET {}".format(
591            ", ".join(map(quote_name, unique_fields)),
592            ", ".join(
593                [
594                    f"{field} = EXCLUDED.{field}"
595                    for field in map(quote_name, update_fields)
596                ]
597            ),
598        )
599    return ""