1"""
2PostgreSQL-specific SQL generation functions.
3
4All functions in this module are stateless — they don't depend on connection state.
5Higher-level SQL builders that need connections live in ddl.py.
6"""
7
8from __future__ import annotations
9
10import datetime
11import ipaddress
12import json
13from collections.abc import Callable, Iterable
14from functools import lru_cache, partial
15from typing import TYPE_CHECKING, Any
16
17import psycopg
18from psycopg.types.json import Jsonb
19
20from plain.postgres.constants import OnConflict
21from plain.postgres.utils import split_tzname_delta
22from plain.utils import timezone
23from plain.utils.regex_helper import _lazy_re_compile
24
25if TYPE_CHECKING:
26 from plain.postgres.fields import Field
27
28# Start and end points for window expressions.
29PRECEDING: str = "PRECEDING"
30FOLLOWING: str = "FOLLOWING"
31UNBOUNDED_PRECEDING: str = "UNBOUNDED " + PRECEDING
32UNBOUNDED_FOLLOWING: str = "UNBOUNDED " + FOLLOWING
33CURRENT_ROW: str = "CURRENT ROW"
34
35# Prefix for EXPLAIN queries.
36EXPLAIN_PREFIX: str = "EXPLAIN"
37EXPLAIN_OPTIONS = frozenset(
38 [
39 "ANALYZE",
40 "BUFFERS",
41 "COSTS",
42 "SETTINGS",
43 "SUMMARY",
44 "TIMING",
45 "VERBOSE",
46 "WAL",
47 ]
48)
49SUPPORTED_EXPLAIN_FORMATS: set[str] = {"JSON", "TEXT", "XML", "YAML"}
50
51# Maximum length of an identifier (63 by default in PostgreSQL).
52MAX_NAME_LENGTH: int = 63
53
54# Value to use during INSERT to specify that a field should use its default value.
55PK_DEFAULT_VALUE: str = "DEFAULT"
56
57# SQL clause to make a constraint "initially deferred" during CREATE TABLE.
58DEFERRABLE_SQL: str = " DEFERRABLE INITIALLY DEFERRED"
59
60# EXTRACT format validation pattern.
61_EXTRACT_FORMAT_RE = _lazy_re_compile(r"[A-Z_]+")
62
63
64# SQL operators for lookups.
65OPERATORS: dict[str, str] = {
66 "exact": "= %s",
67 "iexact": "= UPPER(%s)",
68 "contains": "LIKE %s",
69 "icontains": "LIKE UPPER(%s)",
70 "regex": "~ %s",
71 "iregex": "~* %s",
72 "gt": "> %s",
73 "gte": ">= %s",
74 "lt": "< %s",
75 "lte": "<= %s",
76 "startswith": "LIKE %s",
77 "endswith": "LIKE %s",
78 "istartswith": "LIKE UPPER(%s)",
79 "iendswith": "LIKE UPPER(%s)",
80}
81
82# SQL pattern for escaping special characters in LIKE clauses.
83# Used when the right-hand side isn't a raw string (e.g., an expression).
84PATTERN_ESC = (
85 r"REPLACE(REPLACE(REPLACE({}, E'\\', E'\\\\'), E'%%', E'\\%%'), E'_', E'\\_')"
86)
87
88# Pattern operators for non-literal LIKE lookups.
89PATTERN_OPS: dict[str, str] = {
90 "contains": "LIKE '%%' || {} || '%%'",
91 "icontains": "LIKE '%%' || UPPER({}) || '%%'",
92 "startswith": "LIKE {} || '%%'",
93 "istartswith": "LIKE UPPER({}) || '%%'",
94 "endswith": "LIKE '%%' || {}",
95 "iendswith": "LIKE '%%' || UPPER({})",
96}
97
98
99@lru_cache
100def get_json_dumps(
101 encoder: type[json.JSONEncoder] | None,
102) -> Callable[..., str]:
103 if encoder is None:
104 return json.dumps
105 return partial(json.dumps, cls=encoder)
106
107
108def quote_name(name: str) -> str:
109 """
110 Return a quoted version of the given table, index, or column name.
111 Does not quote the given name if it's already been quoted.
112 """
113 if name.startswith('"') and name.endswith('"'):
114 return name # Quoting once is enough.
115 return f'"{name}"'
116
117
118# Postgres interval literals we accept in timeout settings. Requires an
119# explicit unit so a `"1"` typo can't silently become 1ms (Postgres's
120# implicit unit for lock_timeout/statement_timeout). Fractional values
121# and mixed case are allowed to match what Postgres itself accepts.
122# Rejecting single quotes prevents malformed settings from escaping the
123# SQL literal.
124_TIMEOUT_VALUE_RE = _lazy_re_compile(r"(?i)^\d+(\.\d+)?\s*(us|ms|s|min|h|d)$")
125
126
127def _validate_timeout_value(name: str, value: str) -> None:
128 if not _TIMEOUT_VALUE_RE.match(value):
129 raise ValueError(
130 f"Invalid Postgres interval for {name}: {value!r}. "
131 "Expected e.g. '3s', '500ms', '1.5min' — a number with an "
132 "explicit unit (us, ms, s, min, h, d)."
133 )
134
135
136def build_timeout_set_clauses(
137 *,
138 lock_timeout: str,
139 statement_timeout: str | None,
140 local: bool = True,
141) -> str:
142 """Return a `SET [LOCAL] lock_timeout = '...'; [SET [LOCAL] statement_timeout = '...'; ]`
143 prelude to prepend to a DDL statement.
144
145 `local=True` emits `SET LOCAL` (transaction-scoped, auto-restores on commit).
146 `local=False` emits session-level `SET` (used in autocommit mode; requires
147 a matching RESET).
148
149 `statement_timeout=None` omits that SET — used for SHARE UPDATE EXCLUSIVE
150 operations (CONCURRENTLY, VALIDATE CONSTRAINT) that are non-blocking and
151 should run to completion on any table size.
152 """
153 _validate_timeout_value("lock_timeout", lock_timeout)
154 if statement_timeout is not None:
155 _validate_timeout_value("statement_timeout", statement_timeout)
156 scope = "LOCAL " if local else ""
157 parts = [f"SET {scope}lock_timeout = '{lock_timeout}'"]
158 if statement_timeout is not None:
159 parts.append(f"SET {scope}statement_timeout = '{statement_timeout}'")
160 return "; ".join(parts) + "; "
161
162
163def date_extract_sql(
164 lookup_type: str, sql: str, params: list[Any] | tuple[Any, ...]
165) -> tuple[str, list[Any] | tuple[Any, ...]]:
166 """
167 Given a lookup_type of 'year', 'month', or 'day', return the SQL that
168 extracts a value from the given date field field_name.
169 """
170 # https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-EXTRACT
171 if lookup_type == "week_day":
172 # PostgreSQL DOW returns 0=Sunday, 6=Saturday; we return 1=Sunday, 7=Saturday.
173 return f"EXTRACT(DOW FROM {sql}) + 1", params
174 elif lookup_type == "iso_week_day":
175 return f"EXTRACT(ISODOW FROM {sql})", params
176 elif lookup_type == "iso_year":
177 return f"EXTRACT(ISOYEAR FROM {sql})", params
178
179 lookup_type = lookup_type.upper()
180 if not _EXTRACT_FORMAT_RE.fullmatch(lookup_type):
181 raise ValueError(f"Invalid lookup type: {lookup_type!r}")
182 return f"EXTRACT({lookup_type} FROM {sql})", params
183
184
185def _prepare_tzname_delta(tzname: str) -> str:
186 tzname, sign, offset = split_tzname_delta(tzname)
187 if offset:
188 sign = "-" if sign == "+" else "+"
189 return f"{tzname}{sign}{offset}"
190 return tzname
191
192
193def _convert_sql_to_tz(
194 sql: str, params: list[Any] | tuple[Any, ...], tzname: str | None
195) -> tuple[str, list[Any] | tuple[Any, ...]]:
196 if tzname:
197 tzname_param = _prepare_tzname_delta(tzname)
198 return f"{sql} AT TIME ZONE %s", (*params, tzname_param)
199 return sql, params
200
201
202def date_trunc_sql(
203 lookup_type: str,
204 sql: str,
205 params: list[Any] | tuple[Any, ...],
206 tzname: str | None = None,
207) -> tuple[str, tuple[Any, ...]]:
208 """
209 Given a lookup_type of 'year', 'month', or 'day', return the SQL that
210 truncates the given date or datetime field field_name to a date object
211 with only the given specificity.
212
213 If `tzname` is provided, the given value is truncated in a specific timezone.
214 """
215 sql, params = _convert_sql_to_tz(sql, params, tzname)
216 # https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-TRUNC
217 return f"DATE_TRUNC(%s, {sql})", (lookup_type, *params)
218
219
220def datetime_cast_date_sql(
221 sql: str, params: list[Any] | tuple[Any, ...], tzname: str | None
222) -> tuple[str, list[Any] | tuple[Any, ...]]:
223 """Return the SQL to cast a datetime value to date value."""
224 sql, params = _convert_sql_to_tz(sql, params, tzname)
225 return f"({sql})::date", params
226
227
228def datetime_cast_time_sql(
229 sql: str, params: list[Any] | tuple[Any, ...], tzname: str | None
230) -> tuple[str, list[Any] | tuple[Any, ...]]:
231 """Return the SQL to cast a datetime value to time value."""
232 sql, params = _convert_sql_to_tz(sql, params, tzname)
233 return f"({sql})::time", params
234
235
236def datetime_extract_sql(
237 lookup_type: str,
238 sql: str,
239 params: list[Any] | tuple[Any, ...],
240 tzname: str | None,
241) -> tuple[str, list[Any] | tuple[Any, ...]]:
242 """
243 Given a lookup_type of 'year', 'month', 'day', 'hour', 'minute', or
244 'second', return the SQL that extracts a value from the given
245 datetime field field_name.
246 """
247 sql, params = _convert_sql_to_tz(sql, params, tzname)
248 if lookup_type == "second":
249 # Truncate fractional seconds.
250 return f"EXTRACT(SECOND FROM DATE_TRUNC(%s, {sql}))", ("second", *params)
251 return date_extract_sql(lookup_type, sql, params)
252
253
254def datetime_trunc_sql(
255 lookup_type: str,
256 sql: str,
257 params: list[Any] | tuple[Any, ...],
258 tzname: str | None,
259) -> tuple[str, tuple[Any, ...]]:
260 """
261 Given a lookup_type of 'year', 'month', 'day', 'hour', 'minute', or
262 'second', return the SQL that truncates the given datetime field
263 field_name to a datetime object with only the given specificity.
264 """
265 sql, params = _convert_sql_to_tz(sql, params, tzname)
266 # https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-TRUNC
267 return f"DATE_TRUNC(%s, {sql})", (lookup_type, *params)
268
269
270def time_extract_sql(
271 lookup_type: str, sql: str, params: list[Any] | tuple[Any, ...]
272) -> tuple[str, list[Any] | tuple[Any, ...]]:
273 """
274 Given a lookup_type of 'hour', 'minute', or 'second', return the SQL
275 that extracts a value from the given time field field_name.
276 """
277 if lookup_type == "second":
278 # Truncate fractional seconds.
279 return f"EXTRACT(SECOND FROM DATE_TRUNC(%s, {sql}))", ("second", *params)
280 return date_extract_sql(lookup_type, sql, params)
281
282
283def time_trunc_sql(
284 lookup_type: str,
285 sql: str,
286 params: list[Any] | tuple[Any, ...],
287 tzname: str | None = None,
288) -> tuple[str, tuple[Any, ...]]:
289 """
290 Given a lookup_type of 'hour', 'minute' or 'second', return the SQL
291 that truncates the given time or datetime field field_name to a time
292 object with only the given specificity.
293
294 If `tzname` is provided, the given value is truncated in a specific timezone.
295 """
296 sql, params = _convert_sql_to_tz(sql, params, tzname)
297 return f"DATE_TRUNC(%s, {sql})::time", (lookup_type, *params)
298
299
300def distinct_sql(
301 fields: list[str], params: list[Any] | tuple[Any, ...]
302) -> tuple[list[str], list[Any]]:
303 """
304 Return an SQL DISTINCT clause which removes duplicate rows from the
305 result set. If any fields are given, only check the given fields for
306 duplicates.
307 """
308 if fields:
309 params = [param for param_list in params for param in param_list]
310 return (["DISTINCT ON ({})".format(", ".join(fields))], params)
311 else:
312 return ["DISTINCT"], []
313
314
315def for_update_sql(
316 nowait: bool = False,
317 skip_locked: bool = False,
318 of: tuple[str, ...] = (),
319 no_key: bool = False,
320) -> str:
321 """Return the FOR UPDATE SQL clause to lock rows for an update operation."""
322 return "FOR{} UPDATE{}{}{}".format(
323 " NO KEY" if no_key else "",
324 " OF {}".format(", ".join(of)) if of else "",
325 " NOWAIT" if nowait else "",
326 " SKIP LOCKED" if skip_locked else "",
327 )
328
329
330def limit_offset_sql(low_mark: int | None, high_mark: int | None) -> str:
331 """Return LIMIT/OFFSET SQL clause."""
332 offset = low_mark or 0
333 if high_mark is not None:
334 limit = high_mark - offset
335 else:
336 limit = None
337 return " ".join(
338 sql
339 for sql in (
340 ("LIMIT %d" % limit) if limit else None, # noqa: UP031
341 ("OFFSET %d" % offset) if offset else None, # noqa: UP031
342 )
343 if sql
344 )
345
346
347def lookup_cast(lookup_type: str, field: Field | None = None) -> str:
348 """
349 Return the string to use in a query when performing lookups
350 ("contains", "like", etc.). It should contain a '%s' placeholder for
351 the column being searched against.
352 """
353 from plain.postgres.fields import (
354 GenericIPAddressField,
355 )
356
357 lookup = "%s"
358
359 # Cast text lookups to text to allow things like filter(x__contains=4)
360 if lookup_type in (
361 "iexact",
362 "contains",
363 "icontains",
364 "startswith",
365 "istartswith",
366 "endswith",
367 "iendswith",
368 "regex",
369 "iregex",
370 ):
371 if isinstance(field, GenericIPAddressField):
372 lookup = "HOST(%s)"
373 else:
374 lookup = "%s::text"
375
376 # Use UPPER(x) for case-insensitive lookups; it's faster.
377 if lookup_type in ("iexact", "icontains", "istartswith", "iendswith"):
378 lookup = f"UPPER({lookup})"
379
380 return lookup
381
382
383def return_insert_columns(fields: list[Field]) -> tuple[str, tuple[Any, ...]]:
384 """Return the RETURNING clause SQL and params to append to an INSERT query."""
385 if not fields:
386 return "", ()
387 columns = [
388 f"{quote_name(field.model.model_options.db_table)}.{quote_name(field.column)}"
389 for field in fields
390 ]
391 return "RETURNING {}".format(", ".join(columns)), ()
392
393
394def bulk_insert_sql(fields: list[Field], placeholder_rows: list[list[str]]) -> str:
395 """Return the SQL for bulk inserting rows."""
396 placeholder_rows_sql = (", ".join(row) for row in placeholder_rows)
397 values_sql = ", ".join(f"({sql})" for sql in placeholder_rows_sql)
398 return "VALUES " + values_sql
399
400
401def regex_lookup(lookup_type: str) -> str:
402 """
403 Return the string to use in a query when performing regular expression
404 lookups (using "regex" or "iregex").
405 """
406 # PostgreSQL uses ~ for regex and ~* for case-insensitive regex
407 if lookup_type == "regex":
408 return "%s ~ %s"
409 return "%s ~* %s"
410
411
412def prep_for_like_query(x: str) -> str:
413 """Prepare a value for use in a LIKE query."""
414 return str(x).replace("\\", "\\\\").replace("%", r"\%").replace("_", r"\_")
415
416
417def adapt_ipaddressfield_value(
418 value: str | None,
419) -> ipaddress.IPv4Address | ipaddress.IPv6Address | None:
420 """
421 Transform a string representation of an IP address into the expected
422 type for the backend driver.
423 """
424 if value:
425 return ipaddress.ip_address(value)
426 return None
427
428
429def adapt_json_value(value: Any, encoder: type[json.JSONEncoder] | None) -> Jsonb:
430 return Jsonb(value, dumps=get_json_dumps(encoder))
431
432
433def year_lookup_bounds_for_date_field(
434 value: int, iso_year: bool = False
435) -> list[datetime.date]:
436 """
437 Return a two-elements list with the lower and upper bound to be used
438 with a BETWEEN operator to query a DateField value using a year lookup.
439
440 `value` is an int, containing the looked-up year.
441 If `iso_year` is True, return bounds for ISO-8601 week-numbering years.
442 """
443 if iso_year:
444 first = datetime.date.fromisocalendar(value, 1, 1)
445 second = datetime.date.fromisocalendar(value + 1, 1, 1) - datetime.timedelta(
446 days=1
447 )
448 else:
449 first = datetime.date(value, 1, 1)
450 second = datetime.date(value, 12, 31)
451 return [first, second]
452
453
454def year_lookup_bounds_for_datetime_field(
455 value: int, iso_year: bool = False
456) -> list[datetime.datetime]:
457 """
458 Return a two-elements list with the lower and upper bound to be used
459 with a BETWEEN operator to query a DateTimeField value using a year lookup.
460
461 `value` is an int, containing the looked-up year.
462 If `iso_year` is True, return bounds for ISO-8601 week-numbering years.
463 """
464 if iso_year:
465 first = datetime.datetime.fromisocalendar(value, 1, 1)
466 second = datetime.datetime.fromisocalendar(
467 value + 1, 1, 1
468 ) - datetime.timedelta(microseconds=1)
469 else:
470 first = datetime.datetime(value, 1, 1)
471 second = datetime.datetime(value, 12, 31, 23, 59, 59, 999999)
472
473 # Make sure that datetimes are aware in the current timezone
474 tz = timezone.get_current_timezone()
475 first = timezone.make_aware(first, tz)
476 second = timezone.make_aware(second, tz)
477 return [first, second]
478
479
480def combine_expression(connector: str, sub_expressions: list[str]) -> str:
481 """
482 Combine a list of subexpressions into a single expression, using
483 the provided connecting operator.
484 """
485 conn = f" {connector} "
486 return conn.join(sub_expressions)
487
488
489def subtract_temporals(
490 field: Field,
491 lhs: tuple[str, list[Any] | tuple[Any, ...]],
492 rhs: tuple[str, list[Any] | tuple[Any, ...]],
493) -> tuple[str, tuple[Any, ...]]:
494 from plain.postgres.fields import DateField, DateTimeField
495
496 lhs_sql, lhs_params = lhs
497 rhs_sql, rhs_params = rhs
498 params = (*lhs_params, *rhs_params)
499 # DateField (but not DateTimeField) needs interval conversion
500 if isinstance(field, DateField) and not isinstance(field, DateTimeField):
501 return f"(interval '1 day' * ({lhs_sql} - {rhs_sql}))", params
502 # Use native temporal subtraction
503 return f"({lhs_sql} - {rhs_sql})", params
504
505
506def window_frame_start(start: int | None) -> str:
507 if isinstance(start, int):
508 if start < 0:
509 return "%d %s" % (abs(start), PRECEDING) # noqa: UP031
510 elif start == 0:
511 return CURRENT_ROW
512 elif start is None:
513 return UNBOUNDED_PRECEDING
514 raise ValueError(
515 f"start argument must be a negative integer, zero, or None, but got '{start}'."
516 )
517
518
519def window_frame_end(end: int | None) -> str:
520 if isinstance(end, int):
521 if end == 0:
522 return CURRENT_ROW
523 elif end > 0:
524 return "%d %s" % (end, FOLLOWING) # noqa: UP031
525 elif end is None:
526 return UNBOUNDED_FOLLOWING
527 raise ValueError(
528 f"end argument must be a positive integer, zero, or None, but got '{end}'."
529 )
530
531
532def window_frame_rows_start_end(
533 start: int | None = None, end: int | None = None
534) -> tuple[str, str]:
535 """Return SQL for start and end points in an OVER clause window frame."""
536 return window_frame_start(start), window_frame_end(end)
537
538
539def window_frame_range_start_end(
540 start: int | None = None, end: int | None = None
541) -> tuple[str, str]:
542 start_, end_ = window_frame_rows_start_end(start, end)
543 # PostgreSQL only supports UNBOUNDED with PRECEDING/FOLLOWING
544 if (start and start < 0) or (end and end > 0):
545 raise psycopg.NotSupportedError(
546 "PostgreSQL only supports UNBOUNDED together with PRECEDING and FOLLOWING."
547 )
548 return start_, end_
549
550
551def explain_query_prefix(format: str | None = None, **options: Any) -> str:
552 extra = {}
553 # Normalize options.
554 if options:
555 options = {
556 name.upper(): "true" if value else "false"
557 for name, value in options.items()
558 }
559 for valid_option in EXPLAIN_OPTIONS:
560 value = options.pop(valid_option, None)
561 if value is not None:
562 extra[valid_option] = value
563 if format:
564 normalized_format = format.upper()
565 if normalized_format not in SUPPORTED_EXPLAIN_FORMATS:
566 msg = "{} is not a recognized format. Allowed formats: {}".format(
567 normalized_format, ", ".join(sorted(SUPPORTED_EXPLAIN_FORMATS))
568 )
569 raise ValueError(msg)
570 extra["FORMAT"] = format
571 if options:
572 raise ValueError(
573 "Unknown options: {}".format(", ".join(sorted(options.keys())))
574 )
575 prefix = EXPLAIN_PREFIX
576 if extra:
577 prefix += " ({})".format(", ".join("{} {}".format(*i) for i in extra.items()))
578 return prefix
579
580
581def on_conflict_suffix_sql(
582 fields: list[Field],
583 on_conflict: OnConflict | None,
584 update_fields: Iterable[str],
585 unique_fields: Iterable[str],
586) -> str:
587 if on_conflict == OnConflict.IGNORE:
588 return "ON CONFLICT DO NOTHING"
589 if on_conflict == OnConflict.UPDATE:
590 return "ON CONFLICT({}) DO UPDATE SET {}".format(
591 ", ".join(map(quote_name, unique_fields)),
592 ", ".join(
593 [
594 f"{field} = EXCLUDED.{field}"
595 for field in map(quote_name, update_fields)
596 ]
597 ),
598 )
599 return ""