v0.145.1
  1from __future__ import annotations
  2
  3import importlib.metadata
  4import re
  5import time
  6import traceback
  7import weakref
  8from collections.abc import Callable, Generator
  9from contextlib import contextmanager
 10from typing import TYPE_CHECKING, Any
 11
 12from opentelemetry import context as otel_context
 13from opentelemetry import metrics, trace
 14from opentelemetry.metrics import CallbackOptions, Observation
 15from opentelemetry.semconv.metrics.db_metrics import DB_CLIENT_OPERATION_DURATION
 16
 17if TYPE_CHECKING:
 18    from opentelemetry.trace import Span
 19    from psycopg import Connection as PsycopgConnection
 20
 21    from plain.postgres.connection import DatabaseConnection
 22    from plain.postgres.sources import PoolSource
 23
 24from opentelemetry.semconv._incubating.attributes.db_attributes import (
 25    DB_CLIENT_CONNECTION_POOL_NAME,
 26    DB_CLIENT_CONNECTION_STATE,
 27    DB_QUERY_PARAMETER_TEMPLATE,
 28    DbClientConnectionStateValues,
 29)
 30from opentelemetry.semconv._incubating.metrics.db_metrics import (
 31    DB_CLIENT_CONNECTION_COUNT,
 32    DB_CLIENT_CONNECTION_IDLE_MAX,
 33    DB_CLIENT_CONNECTION_IDLE_MIN,
 34    DB_CLIENT_CONNECTION_MAX,
 35    DB_CLIENT_CONNECTION_PENDING_REQUESTS,
 36    DB_CLIENT_CONNECTION_TIMEOUTS,
 37    DB_CLIENT_CONNECTION_USE_TIME,
 38    DB_CLIENT_CONNECTION_WAIT_TIME,
 39    DB_CLIENT_RESPONSE_RETURNED_ROWS,
 40)
 41from opentelemetry.semconv.attributes.code_attributes import (
 42    CODE_COLUMN_NUMBER,
 43    CODE_FILE_PATH,
 44    CODE_FUNCTION_NAME,
 45    CODE_LINE_NUMBER,
 46    CODE_STACKTRACE,
 47)
 48from opentelemetry.semconv.attributes.db_attributes import (
 49    DB_COLLECTION_NAME,
 50    DB_NAMESPACE,
 51    DB_OPERATION_NAME,
 52    DB_QUERY_SUMMARY,
 53    DB_QUERY_TEXT,
 54    DB_SYSTEM_NAME,
 55    DbSystemNameValues,
 56)
 57from opentelemetry.semconv.attributes.error_attributes import ERROR_TYPE
 58from opentelemetry.semconv.attributes.network_attributes import (
 59    NETWORK_PEER_ADDRESS,
 60    NETWORK_PEER_PORT,
 61)
 62from opentelemetry.semconv.attributes.server_attributes import (
 63    SERVER_ADDRESS,
 64    SERVER_PORT,
 65)
 66from opentelemetry.trace import SpanKind
 67
 68from plain.runtime import settings
 69from plain.utils.otel import format_exception_type
 70
 71# Use a stable string key so OpenTelemetry context APIs receive the expected type.
 72_SUPPRESS_KEY = "plain.postgres.suppress_db_tracing"
 73
 74try:
 75    _package_version = importlib.metadata.version("plain.postgres")
 76except importlib.metadata.PackageNotFoundError:
 77    _package_version = "dev"
 78
 79tracer = trace.get_tracer("plain.postgres", _package_version)
 80
 81meter = metrics.get_meter("plain.postgres", version=_package_version)
 82query_duration_histogram = meter.create_histogram(
 83    name=DB_CLIENT_OPERATION_DURATION,
 84    unit="s",
 85    description="Duration of database client operations.",
 86)
 87returned_rows_histogram = meter.create_histogram(
 88    name=DB_CLIENT_RESPONSE_RETURNED_ROWS,
 89    unit="{row}",
 90    description="Number of rows returned by the operation.",
 91)
 92connection_wait_time_histogram = meter.create_histogram(
 93    name=DB_CLIENT_CONNECTION_WAIT_TIME,
 94    unit="s",
 95    description="The time it took to obtain an open connection from the pool.",
 96)
 97connection_use_time_histogram = meter.create_histogram(
 98    name=DB_CLIENT_CONNECTION_USE_TIME,
 99    unit="s",
100    description="The time between borrowing a connection and returning it to the pool.",
101)
102connection_timeouts_counter = meter.create_counter(
103    name=DB_CLIENT_CONNECTION_TIMEOUTS,
104    unit="{timeout}",
105    description="The number of connection timeouts that have occurred trying to obtain a connection from the pool.",
106)
107
108# WeakKeyDictionary prevents leaks if a conn is GC'd without explicit release().
109_use_start: weakref.WeakKeyDictionary[PsycopgConnection[Any], float] = (
110    weakref.WeakKeyDictionary()
111)
112
113DB_SYSTEM = DbSystemNameValues.POSTGRESQL.value
114
115
116def record_connection_acquire(
117    pool_name: str,
118    conn: PsycopgConnection[Any],
119    wait_seconds: float,
120    checkout_time: float,
121) -> None:
122    connection_wait_time_histogram.record(
123        wait_seconds, {DB_CLIENT_CONNECTION_POOL_NAME: pool_name}
124    )
125    _use_start[conn] = checkout_time
126
127
128def record_connection_release(
129    pool_name: str, conn: PsycopgConnection[Any], return_time: float
130) -> None:
131    start = _use_start.pop(conn, None)
132    if start is None:
133        return
134    connection_use_time_histogram.record(
135        return_time - start, {DB_CLIENT_CONNECTION_POOL_NAME: pool_name}
136    )
137
138
139def record_connection_timeout(pool_name: str) -> None:
140    connection_timeouts_counter.add(1, {DB_CLIENT_CONNECTION_POOL_NAME: pool_name})
141
142
143def register_pool_observables(pool_source: PoolSource) -> None:
144    """Register observable gauges that read `pool.get_stats()` at collection time.
145
146    Safe to call multiple times — the OTel SDK keeps one instrument per name.
147    """
148    pool_attrs = {DB_CLIENT_CONNECTION_POOL_NAME: pool_source.name}
149    idle_attrs = {
150        **pool_attrs,
151        DB_CLIENT_CONNECTION_STATE: DbClientConnectionStateValues.IDLE.value,
152    }
153    used_attrs = {
154        **pool_attrs,
155        DB_CLIENT_CONNECTION_STATE: DbClientConnectionStateValues.USED.value,
156    }
157
158    def _count(_options: CallbackOptions) -> list[Observation]:
159        stats = pool_source.get_stats()
160        if stats is None:
161            return []
162        size = stats.get("pool_size", 0)
163        available = stats.get("pool_available", 0)
164        used = max(size - available, 0)
165        return [
166            Observation(used, used_attrs),
167            Observation(available, idle_attrs),
168        ]
169
170    def _single(stats_key: str) -> Callable[[CallbackOptions], list[Observation]]:
171        def callback(_options: CallbackOptions) -> list[Observation]:
172            stats = pool_source.get_stats()
173            if stats is None:
174                return []
175            return [Observation(stats.get(stats_key, 0), pool_attrs)]
176
177        return callback
178
179    meter.create_observable_up_down_counter(
180        name=DB_CLIENT_CONNECTION_COUNT,
181        unit="{connection}",
182        description="The number of connections that are currently in state described by the state attribute.",
183        callbacks=[_count],
184    )
185    for name, unit, description, stats_key in (
186        (
187            DB_CLIENT_CONNECTION_MAX,
188            "{connection}",
189            "The maximum number of open connections allowed.",
190            "pool_max",
191        ),
192        (
193            DB_CLIENT_CONNECTION_IDLE_MIN,
194            "{connection}",
195            "The minimum number of idle open connections allowed.",
196            "pool_min",
197        ),
198        (
199            DB_CLIENT_CONNECTION_IDLE_MAX,
200            "{connection}",
201            "The maximum number of idle open connections allowed.",
202            "pool_max",
203        ),
204        (
205            DB_CLIENT_CONNECTION_PENDING_REQUESTS,
206            "{request}",
207            "The number of current pending requests for an open connection.",
208            "requests_waiting",
209        ),
210    ):
211        meter.create_observable_up_down_counter(
212            name=name,
213            unit=unit,
214            description=description,
215            callbacks=[_single(stats_key)],
216        )
217
218
219def extract_operation_and_target(sql: str) -> tuple[str, str | None, str | None]:
220    """Extract operation, table name, and collection from SQL.
221
222    Returns: (operation, summary, collection_name)
223    """
224    sql_upper = sql.upper().strip()
225
226    # Strip leading parentheses (e.g. UNION queries: "(SELECT ... UNION ...)")
227    operation = sql_upper.lstrip("(").split()[0] if sql_upper else "UNKNOWN"
228
229    # Pattern to match quoted and unquoted identifiers
230    # Matches: "quoted" (PostgreSQL), unquoted.name
231    identifier_pattern = r'("([^"]+)"|([\w.]+))'
232
233    # Map operations to the SQL keyword that precedes the table name.
234    keyword_by_operation = {
235        "SELECT": "FROM",
236        "DELETE": "FROM",
237        "INSERT": "INTO",
238        "UPDATE": "UPDATE",
239    }
240
241    # Extract table/collection name based on operation
242    collection_name = None
243    summary = operation
244
245    keyword = keyword_by_operation.get(operation)
246    if keyword:
247        match = re.search(rf"{keyword}\s+{identifier_pattern}", sql, re.IGNORECASE)
248        if match:
249            collection_name = _clean_identifier(match.group(1))
250            summary = f"{operation} {collection_name}"
251
252    # Detect UNION queries
253    if " UNION " in sql_upper and summary:
254        summary = f"{summary} UNION"
255
256    return operation, summary, collection_name
257
258
259def _clean_identifier(identifier: str) -> str:
260    """Remove quotes from SQL identifiers."""
261    if identifier.startswith('"') and identifier.endswith('"'):
262        return identifier[1:-1]
263    return identifier
264
265
266@contextmanager
267def db_span(
268    db: DatabaseConnection,
269    sql: Any,
270    *,
271    many: bool = False,
272    params: Any = None,
273    row_count_provider: Callable[[], int] | None = None,
274) -> Generator[Span | None]:
275    """Open an OpenTelemetry CLIENT span for a database query.
276
277    All common attributes (`db.*`, `network.*`, `server.*`, etc.) are set
278    automatically. Follows OpenTelemetry semantic conventions for database
279    instrumentation.
280
281    If `row_count_provider` is given, `db.client.response.returned_rows` is
282    recorded for SELECT operations using its return value (callable so the
283    final count is read after streaming consumers finish iterating).
284    """
285
286    # Fast-exit if instrumentation suppression flag set in context.
287    if otel_context.get_value(_SUPPRESS_KEY):
288        yield None
289        return
290
291    sql = str(sql)  # Ensure SQL is a string for span attributes.
292
293    # Extract operation and target information
294    operation, summary, collection_name = extract_operation_and_target(sql)
295
296    if many:
297        summary = f"{summary} many"
298
299    # Span name follows semantic conventions: {target} or {db.operation.name} {target}
300    if summary:
301        span_name = summary[:255]
302    else:
303        span_name = operation
304
305    # Single settings_dict read — the property delegates to source.config.
306    cfg = db.settings_dict
307
308    # Build attribute set following semantic conventions
309    attrs: dict[str, Any] = {
310        DB_SYSTEM_NAME: DB_SYSTEM,
311        DB_NAMESPACE: cfg.get("DATABASE"),
312        DB_QUERY_TEXT: sql,  # Already parameterized from Django/Plain
313        DB_QUERY_SUMMARY: summary,
314        DB_OPERATION_NAME: operation,
315    }
316
317    attrs.update(_get_code_attributes())
318
319    # Add collection name if detected
320    if collection_name:
321        attrs[DB_COLLECTION_NAME] = collection_name
322
323    # Server/network endpoint. `server.*` is the primary pair per current
324    # semconv; `network.peer.*` is recommended supplementary.
325    if host := cfg.get("HOST"):
326        attrs[SERVER_ADDRESS] = host
327        attrs[NETWORK_PEER_ADDRESS] = host
328
329    if port := cfg.get("PORT"):
330        try:
331            port_int = int(port)
332        except (TypeError, ValueError):
333            pass
334        else:
335            attrs[SERVER_PORT] = port_int
336            attrs[NETWORK_PEER_PORT] = port_int
337
338    # Add query parameters as attributes when DEBUG is True
339    if settings.DEBUG and params is not None:
340        # Convert params to appropriate format based on type
341        if isinstance(params, dict):
342            # Dictionary params (e.g., for named placeholders)
343            for key, value in params.items():
344                attrs[f"{DB_QUERY_PARAMETER_TEMPLATE}.{key}"] = str(value)
345        elif isinstance(params, list | tuple):
346            # Sequential params (e.g., for %s or ? placeholders)
347            for i, value in enumerate(params):
348                attrs[f"{DB_QUERY_PARAMETER_TEMPLATE}.{i + 1}"] = str(value)
349        else:
350            # Single param (rare but possible)
351            attrs[f"{DB_QUERY_PARAMETER_TEMPLATE}.1"] = str(params)
352
353    with tracer.start_as_current_span(
354        span_name, kind=SpanKind.CLIENT, attributes=attrs
355    ) as span:
356        start = time.perf_counter()
357        try:
358            yield span
359        except Exception as exc:
360            # record_exception + set_status(ERROR) handled by
361            # start_as_current_span when the exception propagates out.
362            span.set_attribute(ERROR_TYPE, format_exception_type(exc))
363            raise
364        duration_s = time.perf_counter() - start
365
366        metric_attrs: dict[str, str] = {
367            DB_SYSTEM_NAME: DB_SYSTEM,
368            DB_OPERATION_NAME: operation,
369        }
370        if collection_name:
371            metric_attrs[DB_COLLECTION_NAME] = collection_name
372        query_duration_histogram.record(duration_s, metric_attrs)
373
374        # Scope returned_rows to SELECT; rowcount for INSERT/UPDATE/DELETE
375        # is rows-affected, which is a different semantic.
376        if row_count_provider is not None and operation == "SELECT":
377            count = row_count_provider()
378            if count >= 0:
379                returned_rows_histogram.record(count, metric_attrs)
380
381
382@contextmanager
383def suppress_db_tracing() -> Generator[None]:
384    token = otel_context.attach(otel_context.set_value(_SUPPRESS_KEY, True))
385    try:
386        yield
387    finally:
388        otel_context.detach(token)
389
390
391def _is_internal_frame(frame: traceback.FrameSummary) -> bool:
392    """Return True if the frame is internal to plain.postgres or contextlib."""
393    filepath = frame.filename
394    if not filepath:
395        return True
396    if "/plain/postgres/" in filepath:
397        return True
398    if filepath.endswith("contextlib.py"):
399        return True
400    return False
401
402
403def _get_code_attributes() -> dict[str, Any]:
404    """Extract code context attributes for the current database query.
405
406    Returns a dict of OpenTelemetry code attributes.
407    """
408    stack = traceback.extract_stack()
409
410    # Find the first user code frame (outermost non-internal frame from the top of the call stack)
411    for frame in reversed(stack):
412        if _is_internal_frame(frame):
413            continue
414
415        attrs: dict[str, Any] = {
416            CODE_FILE_PATH: frame.filename,
417        }
418        if frame.lineno:
419            attrs[CODE_LINE_NUMBER] = frame.lineno
420        if frame.name:
421            attrs[CODE_FUNCTION_NAME] = frame.name
422        if frame.colno:
423            attrs[CODE_COLUMN_NUMBER] = frame.colno
424
425        # Add full stack trace only in DEBUG mode (expensive)
426        if settings.DEBUG:
427            filtered_stack = [f for f in stack if not _is_internal_frame(f)]
428            attrs[CODE_STACKTRACE] = "".join(traceback.format_list(filtered_stack))
429
430        return attrs
431
432    return {}