v0.145.0
  1from __future__ import annotations
  2
  3import functools
  4import time
  5from collections.abc import Generator, Iterator, Mapping, Sequence
  6from contextlib import contextmanager
  7from hashlib import md5
  8from types import TracebackType
  9from typing import TYPE_CHECKING, Any, Self
 10
 11import psycopg
 12
 13from plain.logs import get_framework_logger
 14from plain.postgres.otel import db_span
 15from plain.utils.dateparse import parse_time
 16
 17if TYPE_CHECKING:
 18    from plain.postgres.connection import DatabaseConnection
 19
 20logger = get_framework_logger()
 21
 22
 23def make_model_tuple(model: Any) -> tuple[str, str]:
 24    """
 25    Take a model or a string of the form "package_label.ModelName" and return a
 26    corresponding ("package_label", "modelname") tuple. If a tuple is passed in,
 27    assume it's a valid model tuple already and return it unchanged.
 28    """
 29    try:
 30        if isinstance(model, tuple):
 31            model_tuple = model
 32        elif isinstance(model, str):
 33            package_label, model_name = model.split(".")
 34            model_tuple = package_label, model_name.lower()
 35        else:
 36            model_tuple = (
 37                model.model_options.package_label,
 38                model.model_options.model_name,
 39            )
 40        assert len(model_tuple) == 2
 41        return model_tuple
 42    except (ValueError, AssertionError):
 43        raise ValueError(
 44            f"Invalid model reference '{model}'. String model references "
 45            "must be of the form 'package_label.ModelName'."
 46        )
 47
 48
 49def resolve_callables(
 50    mapping: dict[str, Any],
 51) -> Generator[tuple[str, Any]]:
 52    """
 53    Generate key/value pairs for the given mapping where the values are
 54    evaluated if they're callable.
 55    """
 56    for k, v in mapping.items():
 57        yield k, v() if callable(v) else v
 58
 59
 60class CursorWrapper:
 61    def __init__(self, cursor: Any, db: DatabaseConnection) -> None:
 62        self.cursor = cursor
 63        self.db = db
 64
 65    def __getattr__(self, attr: str) -> Any:
 66        return getattr(self.cursor, attr)
 67
 68    def __iter__(self) -> Iterator[tuple[Any, ...]]:
 69        yield from self.cursor
 70
 71    def fetchone(self) -> tuple[Any, ...] | None:
 72        return self.cursor.fetchone()
 73
 74    def fetchmany(self, size: int | None = None) -> list[tuple[Any, ...]]:
 75        if size is None:
 76            return self.cursor.fetchmany()
 77        return self.cursor.fetchmany(size)
 78
 79    def fetchall(self) -> list[tuple[Any, ...]]:
 80        return self.cursor.fetchall()
 81
 82    def __enter__(self) -> Self:
 83        return self
 84
 85    def __exit__(
 86        self,
 87        type: type[BaseException] | None,
 88        value: BaseException | None,
 89        traceback: TracebackType | None,
 90    ) -> None:
 91        # Close instead of passing through to avoid backend-specific behavior
 92        # (#17671). Catch errors liberally because errors in cleanup code
 93        # aren't useful.
 94        try:
 95            self.close()
 96        except psycopg.Error:
 97            pass
 98
 99    def stream(
100        self, sql: str, params: Sequence[Any] | None = None
101    ) -> Generator[tuple[Any, ...]]:
102        self.db.validate_no_broken_transaction()
103        # psycopg's server-side cursor leaves rowcount at -1, so count rows as
104        # they're yielded and feed db_span via the closure.
105        count = 0
106        with db_span(self.db, sql, params=params, row_count_provider=lambda: count):
107            try:
108                iterator = (
109                    self.cursor.stream(sql)
110                    if params is None
111                    else self.cursor.stream(sql, params)
112                )
113                for row in iterator:
114                    count += 1
115                    yield row
116            finally:
117                try:
118                    self.close()
119                except psycopg.Error:
120                    pass
121
122    # execute() and executemany() cannot be implemented in __getattr__ because
123    # the code must run when the method is invoked, not just when it is accessed.
124
125    def execute(
126        self, sql: str, params: Sequence[Any] | Mapping[str, Any] | None = None
127    ) -> Self:
128        return self._execute_with_wrappers(
129            sql, params, many=False, executor=self._execute
130        )
131
132    def executemany(self, sql: str, param_list: Sequence[Sequence[Any]]) -> Self:
133        return self._execute_with_wrappers(
134            sql, param_list, many=True, executor=self._executemany
135        )
136
137    def _execute_with_wrappers(
138        self, sql: str, params: Any, many: bool, executor: Any
139    ) -> Self:
140        context: dict[str, Any] = {"connection": self.db, "cursor": self}
141        for wrapper in reversed(self.db.execute_wrappers):
142            executor = functools.partial(wrapper, executor)
143        executor(sql, params, many, context)
144        return self
145
146    def _execute(self, sql: str, params: Any, *ignored_wrapper_args: Any) -> None:
147        with db_span(
148            self.db, sql, params=params, row_count_provider=lambda: self.cursor.rowcount
149        ):
150            self.db.validate_no_broken_transaction()
151            if params is None:
152                self.cursor.execute(sql)
153            else:
154                self.cursor.execute(sql, params)
155
156    def _executemany(
157        self, sql: str, param_list: Any, *ignored_wrapper_args: Any
158    ) -> None:
159        with db_span(
160            self.db,
161            sql,
162            many=True,
163            params=param_list,
164            row_count_provider=lambda: self.cursor.rowcount,
165        ):
166            self.db.validate_no_broken_transaction()
167            self.cursor.executemany(sql, param_list)
168
169
170class CursorDebugWrapper(CursorWrapper):
171    def stream(
172        self, sql: str, params: Sequence[Any] | None = None
173    ) -> Generator[tuple[Any, ...]]:
174        with self.debug_sql(sql, params, use_last_executed_query=True):
175            yield from super().stream(sql, params)
176
177    def execute(
178        self, sql: str, params: Sequence[Any] | Mapping[str, Any] | None = None
179    ) -> Self:
180        with self.debug_sql(sql, params, use_last_executed_query=True):
181            super().execute(sql, params)
182        return self
183
184    def executemany(self, sql: str, param_list: Sequence[Sequence[Any]]) -> Self:
185        with self.debug_sql(sql, param_list, many=True):
186            super().executemany(sql, param_list)
187        return self
188
189    @contextmanager
190    def debug_sql(
191        self,
192        sql: str | None = None,
193        params: Any = None,
194        use_last_executed_query: bool = False,
195        many: bool = False,
196    ) -> Generator[None]:
197        start = time.monotonic()
198        try:
199            yield
200        finally:
201            stop = time.monotonic()
202            duration = stop - start
203            if use_last_executed_query:
204                sql = self.db.last_executed_query(self.cursor, sql, params)  # ty: ignore[invalid-argument-type]
205            try:
206                times = len(params) if many else ""
207            except TypeError:
208                # params could be an iterator.
209                times = "?"
210            self.db.queries_log.append(
211                {
212                    "sql": f"{times} times: {sql}" if many else sql,
213                    "time": f"{duration:.3f}",
214                }
215            )
216            logger.debug(
217                "Query executed",
218                extra={
219                    "duration": round(duration, 3),
220                    "sql": sql,
221                    "params": params,
222                },
223            )
224
225
226@contextmanager
227def debug_transaction(connection: DatabaseConnection, sql: str) -> Generator[None]:
228    start = time.monotonic()
229    try:
230        yield
231    finally:
232        if connection.queries_logged:
233            stop = time.monotonic()
234            duration = stop - start
235            connection.queries_log.append(
236                {
237                    "sql": f"{sql}",
238                    "time": f"{duration:.3f}",
239                }
240            )
241            logger.debug(
242                "Transaction command",
243                extra={
244                    "duration": round(duration, 3),
245                    "sql": sql,
246                },
247            )
248
249
250def split_tzname_delta(tzname: str) -> tuple[str, str | None, str | None]:
251    """
252    Split a time zone name into a 3-tuple of (name, sign, offset).
253    """
254    for sign in ["+", "-"]:
255        if sign in tzname:
256            name, offset = tzname.rsplit(sign, 1)
257            if offset and parse_time(offset):
258                return name, sign, offset
259    return tzname, None, None
260
261
262###############################################
263# Converters from Python to database (string) #
264###############################################
265
266
267def split_identifier(identifier: str) -> tuple[str, str]:
268    """
269    Split an SQL identifier into a two element tuple of (namespace, name).
270
271    The identifier could be a table, column, or sequence name might be prefixed
272    by a namespace.
273    """
274    try:
275        namespace, name = identifier.split('"."')
276    except ValueError:
277        namespace, name = "", identifier
278    return namespace.strip('"'), name.strip('"')
279
280
281def truncate_name(identifier: str, length: int | None = None, hash_len: int = 4) -> str:
282    """
283    Shorten an SQL identifier to a repeatable mangled version with the given
284    length.
285
286    If a quote stripped name contains a namespace, e.g. USERNAME"."TABLE,
287    truncate the table portion only.
288    """
289    namespace, name = split_identifier(identifier)
290
291    if length is None or len(name) <= length:
292        return identifier
293
294    digest = names_digest(name, length=hash_len)
295    return "{}{}{}".format(
296        f'{namespace}"."' if namespace else "",
297        name[: length - hash_len],
298        digest,
299    )
300
301
302def names_digest(*args: str, length: int) -> str:
303    """
304    Generate a 32-bit digest of a set of arguments that can be used to shorten
305    identifying names.
306    """
307    h = md5(usedforsecurity=False)
308    for arg in args:
309        h.update(arg.encode())
310    return h.hexdigest()[:length]
311
312
313def generate_identifier_name(
314    table_name: str, column_names: list[str], suffix: str = ""
315) -> str:
316    """Generate a deterministic name for an index or constraint.
317
318    The name is composed of the table name, column names, a hash digest,
319    and an optional suffix. Long names are truncated proportionally.
320    """
321    from .dialect import MAX_NAME_LENGTH
322
323    _, table_name = split_identifier(table_name)
324    hash_suffix_part = f"{names_digest(table_name, *column_names, length=8)}{suffix}"
325    max_length = MAX_NAME_LENGTH
326    name = f"{table_name}_{'_'.join(column_names)}_{hash_suffix_part}"
327    if len(name) <= max_length:
328        return name
329    if len(hash_suffix_part) > max_length / 3:
330        hash_suffix_part = hash_suffix_part[: max_length // 3]
331    other_length = (max_length - len(hash_suffix_part)) // 2 - 1
332    name = f"{table_name[:other_length]}_{'_'.join(column_names)[:other_length]}_{hash_suffix_part}"
333    if name[0] == "_" or name[0].isdigit():
334        name = f"D{name[:-1]}"
335    return name
336
337
338def strip_quotes(table_name: str) -> str:
339    """
340    Strip quotes off of quoted table names to make them safe for use in index
341    names, sequence names, etc.
342    """
343    has_quotes = table_name.startswith('"') and table_name.endswith('"')
344    return table_name[1:-1] if has_quotes else table_name