1from __future__ import annotations
2
3import functools
4import time
5from collections.abc import Generator, Iterator, Mapping, Sequence
6from contextlib import contextmanager
7from hashlib import md5
8from types import TracebackType
9from typing import TYPE_CHECKING, Any, Self
10
11import psycopg
12
13from plain.logs import get_framework_logger
14from plain.postgres.otel import db_span
15from plain.utils.dateparse import parse_time
16
17if TYPE_CHECKING:
18 from plain.postgres.connection import DatabaseConnection
19
20logger = get_framework_logger()
21
22
23def make_model_tuple(model: Any) -> tuple[str, str]:
24 """
25 Take a model or a string of the form "package_label.ModelName" and return a
26 corresponding ("package_label", "modelname") tuple. If a tuple is passed in,
27 assume it's a valid model tuple already and return it unchanged.
28 """
29 try:
30 if isinstance(model, tuple):
31 model_tuple = model
32 elif isinstance(model, str):
33 package_label, model_name = model.split(".")
34 model_tuple = package_label, model_name.lower()
35 else:
36 model_tuple = (
37 model.model_options.package_label,
38 model.model_options.model_name,
39 )
40 assert len(model_tuple) == 2
41 return model_tuple
42 except (ValueError, AssertionError):
43 raise ValueError(
44 f"Invalid model reference '{model}'. String model references "
45 "must be of the form 'package_label.ModelName'."
46 )
47
48
49def resolve_callables(
50 mapping: dict[str, Any],
51) -> Generator[tuple[str, Any]]:
52 """
53 Generate key/value pairs for the given mapping where the values are
54 evaluated if they're callable.
55 """
56 for k, v in mapping.items():
57 yield k, v() if callable(v) else v
58
59
60class CursorWrapper:
61 def __init__(self, cursor: Any, db: DatabaseConnection) -> None:
62 self.cursor = cursor
63 self.db = db
64
65 def __getattr__(self, attr: str) -> Any:
66 return getattr(self.cursor, attr)
67
68 def __iter__(self) -> Iterator[tuple[Any, ...]]:
69 yield from self.cursor
70
71 def fetchone(self) -> tuple[Any, ...] | None:
72 return self.cursor.fetchone()
73
74 def fetchmany(self, size: int | None = None) -> list[tuple[Any, ...]]:
75 if size is None:
76 return self.cursor.fetchmany()
77 return self.cursor.fetchmany(size)
78
79 def fetchall(self) -> list[tuple[Any, ...]]:
80 return self.cursor.fetchall()
81
82 def __enter__(self) -> Self:
83 return self
84
85 def __exit__(
86 self,
87 type: type[BaseException] | None,
88 value: BaseException | None,
89 traceback: TracebackType | None,
90 ) -> None:
91 # Close instead of passing through to avoid backend-specific behavior
92 # (#17671). Catch errors liberally because errors in cleanup code
93 # aren't useful.
94 try:
95 self.close()
96 except psycopg.Error:
97 pass
98
99 def stream(
100 self, sql: str, params: Sequence[Any] | None = None
101 ) -> Generator[tuple[Any, ...]]:
102 self.db.validate_no_broken_transaction()
103 # psycopg's server-side cursor leaves rowcount at -1, so count rows as
104 # they're yielded and feed db_span via the closure.
105 count = 0
106 with db_span(self.db, sql, params=params, row_count_provider=lambda: count):
107 try:
108 iterator = (
109 self.cursor.stream(sql)
110 if params is None
111 else self.cursor.stream(sql, params)
112 )
113 for row in iterator:
114 count += 1
115 yield row
116 finally:
117 try:
118 self.close()
119 except psycopg.Error:
120 pass
121
122 # execute() and executemany() cannot be implemented in __getattr__ because
123 # the code must run when the method is invoked, not just when it is accessed.
124
125 def execute(
126 self, sql: str, params: Sequence[Any] | Mapping[str, Any] | None = None
127 ) -> Self:
128 return self._execute_with_wrappers(
129 sql, params, many=False, executor=self._execute
130 )
131
132 def executemany(self, sql: str, param_list: Sequence[Sequence[Any]]) -> Self:
133 return self._execute_with_wrappers(
134 sql, param_list, many=True, executor=self._executemany
135 )
136
137 def _execute_with_wrappers(
138 self, sql: str, params: Any, many: bool, executor: Any
139 ) -> Self:
140 context: dict[str, Any] = {"connection": self.db, "cursor": self}
141 for wrapper in reversed(self.db.execute_wrappers):
142 executor = functools.partial(wrapper, executor)
143 executor(sql, params, many, context)
144 return self
145
146 def _execute(self, sql: str, params: Any, *ignored_wrapper_args: Any) -> None:
147 with db_span(
148 self.db, sql, params=params, row_count_provider=lambda: self.cursor.rowcount
149 ):
150 self.db.validate_no_broken_transaction()
151 if params is None:
152 self.cursor.execute(sql)
153 else:
154 self.cursor.execute(sql, params)
155
156 def _executemany(
157 self, sql: str, param_list: Any, *ignored_wrapper_args: Any
158 ) -> None:
159 with db_span(
160 self.db,
161 sql,
162 many=True,
163 params=param_list,
164 row_count_provider=lambda: self.cursor.rowcount,
165 ):
166 self.db.validate_no_broken_transaction()
167 self.cursor.executemany(sql, param_list)
168
169
170class CursorDebugWrapper(CursorWrapper):
171 def stream(
172 self, sql: str, params: Sequence[Any] | None = None
173 ) -> Generator[tuple[Any, ...]]:
174 with self.debug_sql(sql, params, use_last_executed_query=True):
175 yield from super().stream(sql, params)
176
177 def execute(
178 self, sql: str, params: Sequence[Any] | Mapping[str, Any] | None = None
179 ) -> Self:
180 with self.debug_sql(sql, params, use_last_executed_query=True):
181 super().execute(sql, params)
182 return self
183
184 def executemany(self, sql: str, param_list: Sequence[Sequence[Any]]) -> Self:
185 with self.debug_sql(sql, param_list, many=True):
186 super().executemany(sql, param_list)
187 return self
188
189 @contextmanager
190 def debug_sql(
191 self,
192 sql: str | None = None,
193 params: Any = None,
194 use_last_executed_query: bool = False,
195 many: bool = False,
196 ) -> Generator[None]:
197 start = time.monotonic()
198 try:
199 yield
200 finally:
201 stop = time.monotonic()
202 duration = stop - start
203 if use_last_executed_query:
204 sql = self.db.last_executed_query(self.cursor, sql, params) # ty: ignore[invalid-argument-type]
205 try:
206 times = len(params) if many else ""
207 except TypeError:
208 # params could be an iterator.
209 times = "?"
210 self.db.queries_log.append(
211 {
212 "sql": f"{times} times: {sql}" if many else sql,
213 "time": f"{duration:.3f}",
214 }
215 )
216 logger.debug(
217 "Query executed",
218 extra={
219 "duration": round(duration, 3),
220 "sql": sql,
221 "params": params,
222 },
223 )
224
225
226@contextmanager
227def debug_transaction(connection: DatabaseConnection, sql: str) -> Generator[None]:
228 start = time.monotonic()
229 try:
230 yield
231 finally:
232 if connection.queries_logged:
233 stop = time.monotonic()
234 duration = stop - start
235 connection.queries_log.append(
236 {
237 "sql": f"{sql}",
238 "time": f"{duration:.3f}",
239 }
240 )
241 logger.debug(
242 "Transaction command",
243 extra={
244 "duration": round(duration, 3),
245 "sql": sql,
246 },
247 )
248
249
250def split_tzname_delta(tzname: str) -> tuple[str, str | None, str | None]:
251 """
252 Split a time zone name into a 3-tuple of (name, sign, offset).
253 """
254 for sign in ["+", "-"]:
255 if sign in tzname:
256 name, offset = tzname.rsplit(sign, 1)
257 if offset and parse_time(offset):
258 return name, sign, offset
259 return tzname, None, None
260
261
262###############################################
263# Converters from Python to database (string) #
264###############################################
265
266
267def split_identifier(identifier: str) -> tuple[str, str]:
268 """
269 Split an SQL identifier into a two element tuple of (namespace, name).
270
271 The identifier could be a table, column, or sequence name might be prefixed
272 by a namespace.
273 """
274 try:
275 namespace, name = identifier.split('"."')
276 except ValueError:
277 namespace, name = "", identifier
278 return namespace.strip('"'), name.strip('"')
279
280
281def truncate_name(identifier: str, length: int | None = None, hash_len: int = 4) -> str:
282 """
283 Shorten an SQL identifier to a repeatable mangled version with the given
284 length.
285
286 If a quote stripped name contains a namespace, e.g. USERNAME"."TABLE,
287 truncate the table portion only.
288 """
289 namespace, name = split_identifier(identifier)
290
291 if length is None or len(name) <= length:
292 return identifier
293
294 digest = names_digest(name, length=hash_len)
295 return "{}{}{}".format(
296 f'{namespace}"."' if namespace else "",
297 name[: length - hash_len],
298 digest,
299 )
300
301
302def names_digest(*args: str, length: int) -> str:
303 """
304 Generate a 32-bit digest of a set of arguments that can be used to shorten
305 identifying names.
306 """
307 h = md5(usedforsecurity=False)
308 for arg in args:
309 h.update(arg.encode())
310 return h.hexdigest()[:length]
311
312
313def generate_identifier_name(
314 table_name: str, column_names: list[str], suffix: str = ""
315) -> str:
316 """Generate a deterministic name for an index or constraint.
317
318 The name is composed of the table name, column names, a hash digest,
319 and an optional suffix. Long names are truncated proportionally.
320 """
321 from .dialect import MAX_NAME_LENGTH
322
323 _, table_name = split_identifier(table_name)
324 hash_suffix_part = f"{names_digest(table_name, *column_names, length=8)}{suffix}"
325 max_length = MAX_NAME_LENGTH
326 name = f"{table_name}_{'_'.join(column_names)}_{hash_suffix_part}"
327 if len(name) <= max_length:
328 return name
329 if len(hash_suffix_part) > max_length / 3:
330 hash_suffix_part = hash_suffix_part[: max_length // 3]
331 other_length = (max_length - len(hash_suffix_part)) // 2 - 1
332 name = f"{table_name[:other_length]}_{'_'.join(column_names)[:other_length]}_{hash_suffix_part}"
333 if name[0] == "_" or name[0].isdigit():
334 name = f"D{name[:-1]}"
335 return name
336
337
338def strip_quotes(table_name: str) -> str:
339 """
340 Strip quotes off of quoted table names to make them safe for use in index
341 names, sequence names, etc.
342 """
343 has_quotes = table_name.startswith('"') and table_name.endswith('"')
344 return table_name[1:-1] if has_quotes else table_name