1from __future__ import annotations
2
3import importlib.metadata
4import re
5import time
6import traceback
7import weakref
8from collections.abc import Callable, Generator
9from contextlib import contextmanager
10from typing import TYPE_CHECKING, Any
11
12from opentelemetry import context as otel_context
13from opentelemetry import metrics, trace
14from opentelemetry.metrics import CallbackOptions, Observation
15from opentelemetry.semconv.metrics.db_metrics import DB_CLIENT_OPERATION_DURATION
16
17if TYPE_CHECKING:
18 from opentelemetry.trace import Span
19 from psycopg import Connection as PsycopgConnection
20
21 from plain.postgres.connection import DatabaseConnection
22 from plain.postgres.sources import PoolSource
23
24from opentelemetry.semconv._incubating.attributes.db_attributes import (
25 DB_CLIENT_CONNECTION_POOL_NAME,
26 DB_CLIENT_CONNECTION_STATE,
27 DB_QUERY_PARAMETER_TEMPLATE,
28 DbClientConnectionStateValues,
29)
30from opentelemetry.semconv._incubating.metrics.db_metrics import (
31 DB_CLIENT_CONNECTION_COUNT,
32 DB_CLIENT_CONNECTION_IDLE_MAX,
33 DB_CLIENT_CONNECTION_IDLE_MIN,
34 DB_CLIENT_CONNECTION_MAX,
35 DB_CLIENT_CONNECTION_PENDING_REQUESTS,
36 DB_CLIENT_CONNECTION_TIMEOUTS,
37 DB_CLIENT_CONNECTION_USE_TIME,
38 DB_CLIENT_CONNECTION_WAIT_TIME,
39 DB_CLIENT_RESPONSE_RETURNED_ROWS,
40)
41from opentelemetry.semconv.attributes.code_attributes import (
42 CODE_COLUMN_NUMBER,
43 CODE_FILE_PATH,
44 CODE_FUNCTION_NAME,
45 CODE_LINE_NUMBER,
46 CODE_STACKTRACE,
47)
48from opentelemetry.semconv.attributes.db_attributes import (
49 DB_COLLECTION_NAME,
50 DB_NAMESPACE,
51 DB_OPERATION_NAME,
52 DB_QUERY_SUMMARY,
53 DB_QUERY_TEXT,
54 DB_SYSTEM_NAME,
55 DbSystemNameValues,
56)
57from opentelemetry.semconv.attributes.error_attributes import ERROR_TYPE
58from opentelemetry.semconv.attributes.network_attributes import (
59 NETWORK_PEER_ADDRESS,
60 NETWORK_PEER_PORT,
61)
62from opentelemetry.semconv.attributes.server_attributes import (
63 SERVER_ADDRESS,
64 SERVER_PORT,
65)
66from opentelemetry.trace import SpanKind
67
68from plain.runtime import settings
69from plain.utils.otel import format_exception_type
70
71# Use a stable string key so OpenTelemetry context APIs receive the expected type.
72_SUPPRESS_KEY = "plain.postgres.suppress_db_tracing"
73
74try:
75 _package_version = importlib.metadata.version("plain.postgres")
76except importlib.metadata.PackageNotFoundError:
77 _package_version = "dev"
78
79tracer = trace.get_tracer("plain.postgres", _package_version)
80
81meter = metrics.get_meter("plain.postgres", version=_package_version)
82query_duration_histogram = meter.create_histogram(
83 name=DB_CLIENT_OPERATION_DURATION,
84 unit="s",
85 description="Duration of database client operations.",
86)
87returned_rows_histogram = meter.create_histogram(
88 name=DB_CLIENT_RESPONSE_RETURNED_ROWS,
89 unit="{row}",
90 description="Number of rows returned by the operation.",
91)
92connection_wait_time_histogram = meter.create_histogram(
93 name=DB_CLIENT_CONNECTION_WAIT_TIME,
94 unit="s",
95 description="The time it took to obtain an open connection from the pool.",
96)
97connection_use_time_histogram = meter.create_histogram(
98 name=DB_CLIENT_CONNECTION_USE_TIME,
99 unit="s",
100 description="The time between borrowing a connection and returning it to the pool.",
101)
102connection_timeouts_counter = meter.create_counter(
103 name=DB_CLIENT_CONNECTION_TIMEOUTS,
104 unit="{timeout}",
105 description="The number of connection timeouts that have occurred trying to obtain a connection from the pool.",
106)
107
108# WeakKeyDictionary prevents leaks if a conn is GC'd without explicit release().
109_use_start: weakref.WeakKeyDictionary[PsycopgConnection[Any], float] = (
110 weakref.WeakKeyDictionary()
111)
112
113DB_SYSTEM = DbSystemNameValues.POSTGRESQL.value
114
115
116def record_connection_acquire(
117 pool_name: str,
118 conn: PsycopgConnection[Any],
119 wait_seconds: float,
120 checkout_time: float,
121) -> None:
122 connection_wait_time_histogram.record(
123 wait_seconds, {DB_CLIENT_CONNECTION_POOL_NAME: pool_name}
124 )
125 _use_start[conn] = checkout_time
126
127
128def record_connection_release(
129 pool_name: str, conn: PsycopgConnection[Any], return_time: float
130) -> None:
131 start = _use_start.pop(conn, None)
132 if start is None:
133 return
134 connection_use_time_histogram.record(
135 return_time - start, {DB_CLIENT_CONNECTION_POOL_NAME: pool_name}
136 )
137
138
139def record_connection_timeout(pool_name: str) -> None:
140 connection_timeouts_counter.add(1, {DB_CLIENT_CONNECTION_POOL_NAME: pool_name})
141
142
143def register_pool_observables(pool_source: PoolSource) -> None:
144 """Register observable gauges that read `pool.get_stats()` at collection time.
145
146 Safe to call multiple times — the OTel SDK keeps one instrument per name.
147 """
148 pool_attrs = {DB_CLIENT_CONNECTION_POOL_NAME: pool_source.name}
149 idle_attrs = {
150 **pool_attrs,
151 DB_CLIENT_CONNECTION_STATE: DbClientConnectionStateValues.IDLE.value,
152 }
153 used_attrs = {
154 **pool_attrs,
155 DB_CLIENT_CONNECTION_STATE: DbClientConnectionStateValues.USED.value,
156 }
157
158 def _count(_options: CallbackOptions) -> list[Observation]:
159 stats = pool_source.get_stats()
160 if stats is None:
161 return []
162 size = stats.get("pool_size", 0)
163 available = stats.get("pool_available", 0)
164 used = max(size - available, 0)
165 return [
166 Observation(used, used_attrs),
167 Observation(available, idle_attrs),
168 ]
169
170 def _single(stats_key: str) -> Callable[[CallbackOptions], list[Observation]]:
171 def callback(_options: CallbackOptions) -> list[Observation]:
172 stats = pool_source.get_stats()
173 if stats is None:
174 return []
175 return [Observation(stats.get(stats_key, 0), pool_attrs)]
176
177 return callback
178
179 meter.create_observable_up_down_counter(
180 name=DB_CLIENT_CONNECTION_COUNT,
181 unit="{connection}",
182 description="The number of connections that are currently in state described by the state attribute.",
183 callbacks=[_count],
184 )
185 for name, unit, description, stats_key in (
186 (
187 DB_CLIENT_CONNECTION_MAX,
188 "{connection}",
189 "The maximum number of open connections allowed.",
190 "pool_max",
191 ),
192 (
193 DB_CLIENT_CONNECTION_IDLE_MIN,
194 "{connection}",
195 "The minimum number of idle open connections allowed.",
196 "pool_min",
197 ),
198 (
199 DB_CLIENT_CONNECTION_IDLE_MAX,
200 "{connection}",
201 "The maximum number of idle open connections allowed.",
202 "pool_max",
203 ),
204 (
205 DB_CLIENT_CONNECTION_PENDING_REQUESTS,
206 "{request}",
207 "The number of current pending requests for an open connection.",
208 "requests_waiting",
209 ),
210 ):
211 meter.create_observable_up_down_counter(
212 name=name,
213 unit=unit,
214 description=description,
215 callbacks=[_single(stats_key)],
216 )
217
218
219def extract_operation_and_target(sql: str) -> tuple[str, str | None, str | None]:
220 """Extract operation, table name, and collection from SQL.
221
222 Returns: (operation, summary, collection_name)
223 """
224 sql_upper = sql.upper().strip()
225
226 # Strip leading parentheses (e.g. UNION queries: "(SELECT ... UNION ...)")
227 operation = sql_upper.lstrip("(").split()[0] if sql_upper else "UNKNOWN"
228
229 # Pattern to match quoted and unquoted identifiers
230 # Matches: "quoted" (PostgreSQL), unquoted.name
231 identifier_pattern = r'("([^"]+)"|([\w.]+))'
232
233 # Map operations to the SQL keyword that precedes the table name.
234 keyword_by_operation = {
235 "SELECT": "FROM",
236 "DELETE": "FROM",
237 "INSERT": "INTO",
238 "UPDATE": "UPDATE",
239 }
240
241 # Extract table/collection name based on operation
242 collection_name = None
243 summary = operation
244
245 keyword = keyword_by_operation.get(operation)
246 if keyword:
247 match = re.search(rf"{keyword}\s+{identifier_pattern}", sql, re.IGNORECASE)
248 if match:
249 collection_name = _clean_identifier(match.group(1))
250 summary = f"{operation} {collection_name}"
251
252 # Detect UNION queries
253 if " UNION " in sql_upper and summary:
254 summary = f"{summary} UNION"
255
256 return operation, summary, collection_name
257
258
259def _clean_identifier(identifier: str) -> str:
260 """Remove quotes from SQL identifiers."""
261 if identifier.startswith('"') and identifier.endswith('"'):
262 return identifier[1:-1]
263 return identifier
264
265
266@contextmanager
267def db_span(
268 db: DatabaseConnection,
269 sql: Any,
270 *,
271 many: bool = False,
272 params: Any = None,
273 row_count_provider: Callable[[], int] | None = None,
274) -> Generator[Span | None]:
275 """Open an OpenTelemetry CLIENT span for a database query.
276
277 All common attributes (`db.*`, `network.*`, `server.*`, etc.) are set
278 automatically. Follows OpenTelemetry semantic conventions for database
279 instrumentation.
280
281 If `row_count_provider` is given, `db.client.response.returned_rows` is
282 recorded for SELECT operations using its return value (callable so the
283 final count is read after streaming consumers finish iterating).
284 """
285
286 # Fast-exit if instrumentation suppression flag set in context.
287 if otel_context.get_value(_SUPPRESS_KEY):
288 yield None
289 return
290
291 sql = str(sql) # Ensure SQL is a string for span attributes.
292
293 # Extract operation and target information
294 operation, summary, collection_name = extract_operation_and_target(sql)
295
296 if many:
297 summary = f"{summary} many"
298
299 # Span name follows semantic conventions: {target} or {db.operation.name} {target}
300 if summary:
301 span_name = summary[:255]
302 else:
303 span_name = operation
304
305 # Single settings_dict read — the property delegates to source.config.
306 cfg = db.settings_dict
307
308 # Build attribute set following semantic conventions
309 attrs: dict[str, Any] = {
310 DB_SYSTEM_NAME: DB_SYSTEM,
311 DB_NAMESPACE: cfg.get("DATABASE"),
312 DB_QUERY_TEXT: sql, # Already parameterized from Django/Plain
313 DB_QUERY_SUMMARY: summary,
314 DB_OPERATION_NAME: operation,
315 }
316
317 attrs.update(_get_code_attributes())
318
319 # Add collection name if detected
320 if collection_name:
321 attrs[DB_COLLECTION_NAME] = collection_name
322
323 # Server/network endpoint. `server.*` is the primary pair per current
324 # semconv; `network.peer.*` is recommended supplementary.
325 if host := cfg.get("HOST"):
326 attrs[SERVER_ADDRESS] = host
327 attrs[NETWORK_PEER_ADDRESS] = host
328
329 if port := cfg.get("PORT"):
330 try:
331 port_int = int(port)
332 except (TypeError, ValueError):
333 pass
334 else:
335 attrs[SERVER_PORT] = port_int
336 attrs[NETWORK_PEER_PORT] = port_int
337
338 # Add query parameters as attributes when DEBUG is True
339 if settings.DEBUG and params is not None:
340 # Convert params to appropriate format based on type
341 if isinstance(params, dict):
342 # Dictionary params (e.g., for named placeholders)
343 for key, value in params.items():
344 attrs[f"{DB_QUERY_PARAMETER_TEMPLATE}.{key}"] = str(value)
345 elif isinstance(params, list | tuple):
346 # Sequential params (e.g., for %s or ? placeholders)
347 for i, value in enumerate(params):
348 attrs[f"{DB_QUERY_PARAMETER_TEMPLATE}.{i + 1}"] = str(value)
349 else:
350 # Single param (rare but possible)
351 attrs[f"{DB_QUERY_PARAMETER_TEMPLATE}.1"] = str(params)
352
353 with tracer.start_as_current_span(
354 span_name, kind=SpanKind.CLIENT, attributes=attrs
355 ) as span:
356 start = time.perf_counter()
357 try:
358 yield span
359 except Exception as exc:
360 # record_exception + set_status(ERROR) handled by
361 # start_as_current_span when the exception propagates out.
362 span.set_attribute(ERROR_TYPE, format_exception_type(exc))
363 raise
364 duration_s = time.perf_counter() - start
365
366 metric_attrs: dict[str, str] = {
367 DB_SYSTEM_NAME: DB_SYSTEM,
368 DB_OPERATION_NAME: operation,
369 }
370 if collection_name:
371 metric_attrs[DB_COLLECTION_NAME] = collection_name
372 query_duration_histogram.record(duration_s, metric_attrs)
373
374 # Scope returned_rows to SELECT; rowcount for INSERT/UPDATE/DELETE
375 # is rows-affected, which is a different semantic.
376 if row_count_provider is not None and operation == "SELECT":
377 count = row_count_provider()
378 if count >= 0:
379 returned_rows_histogram.record(count, metric_attrs)
380
381
382@contextmanager
383def suppress_db_tracing() -> Generator[None]:
384 token = otel_context.attach(otel_context.set_value(_SUPPRESS_KEY, True))
385 try:
386 yield
387 finally:
388 otel_context.detach(token)
389
390
391def _is_internal_frame(frame: traceback.FrameSummary) -> bool:
392 """Return True if the frame is internal to plain.postgres or contextlib."""
393 filepath = frame.filename
394 if not filepath:
395 return True
396 if "/plain/postgres/" in filepath:
397 return True
398 if filepath.endswith("contextlib.py"):
399 return True
400 return False
401
402
403def _get_code_attributes() -> dict[str, Any]:
404 """Extract code context attributes for the current database query.
405
406 Returns a dict of OpenTelemetry code attributes.
407 """
408 stack = traceback.extract_stack()
409
410 # Find the first user code frame (outermost non-internal frame from the top of the call stack)
411 for frame in reversed(stack):
412 if _is_internal_frame(frame):
413 continue
414
415 attrs: dict[str, Any] = {
416 CODE_FILE_PATH: frame.filename,
417 }
418 if frame.lineno:
419 attrs[CODE_LINE_NUMBER] = frame.lineno
420 if frame.name:
421 attrs[CODE_FUNCTION_NAME] = frame.name
422 if frame.colno:
423 attrs[CODE_COLUMN_NUMBER] = frame.colno
424
425 # Add full stack trace only in DEBUG mode (expensive)
426 if settings.DEBUG:
427 filtered_stack = [f for f in stack if not _is_internal_frame(f)]
428 attrs[CODE_STACKTRACE] = "".join(traceback.format_list(filtered_stack))
429
430 return attrs
431
432 return {}