v0.148.0
  1from __future__ import annotations
  2
  3from collections.abc import Callable, Generator
  4from contextlib import ContextDecorator, contextmanager
  5from typing import Any
  6
  7import psycopg
  8
  9from plain.postgres.db import get_connection
 10
 11
 12class TransactionManagementError(psycopg.ProgrammingError):
 13    """Transaction management is used improperly."""
 14
 15    pass
 16
 17
 18@contextmanager
 19def mark_for_rollback_on_error() -> Generator[None]:
 20    """
 21    Internal low-level utility to mark a transaction as "needs rollback" when
 22    an exception is raised while not enforcing the enclosed block to be in a
 23    transaction. This is needed by Model.save() and friends to avoid starting a
 24    transaction when in autocommit mode and a single query is executed.
 25
 26    It's equivalent to:
 27
 28        if get_connection().get_autocommit():
 29            yield
 30        else:
 31            with transaction.atomic(savepoint=False):
 32                yield
 33
 34    but it uses low-level utilities to avoid performance overhead.
 35    """
 36    try:
 37        yield
 38    except Exception as exc:
 39        conn = get_connection()
 40        if conn.in_atomic_block:
 41            conn.needs_rollback = True
 42            conn.rollback_exc = exc
 43        raise
 44
 45
 46def on_commit(func: Callable[[], Any], robust: bool = False) -> None:
 47    """
 48    Register `func` to be called when the current transaction is committed.
 49    If the current transaction is rolled back, `func` will not be called.
 50    """
 51    get_connection().on_commit(func, robust)
 52
 53
 54#################################
 55# Decorators / context managers #
 56#################################
 57
 58
 59class Atomic(ContextDecorator):
 60    """
 61    Guarantee the atomic execution of a given block.
 62
 63    An instance can be used either as a decorator or as a context manager.
 64
 65    When it's used as a decorator, __call__ wraps the execution of the
 66    decorated function in the instance itself, used as a context manager.
 67
 68    When it's used as a context manager, __enter__ creates a transaction or a
 69    savepoint, depending on whether a transaction is already in progress, and
 70    __exit__ commits the transaction or releases the savepoint on normal exit,
 71    and rolls back the transaction or to the savepoint on exceptions.
 72
 73    It's possible to disable the creation of savepoints if the goal is to
 74    ensure that some code runs within a transaction without creating overhead.
 75
 76    A stack of savepoints identifiers is maintained as an attribute of the
 77    connection. None denotes the absence of a savepoint.
 78
 79    This allows reentrancy even if the same AtomicWrapper is reused. For
 80    example, it's possible to define `oa = atomic('other')` and use `@oa` or
 81    `with oa:` multiple times.
 82
 83    Since database connections are stored per-context (ContextVar), this is thread-safe.
 84
 85    An atomic block can be tagged as durable. In this case, raise a
 86    RuntimeError if it's nested within another atomic block. This guarantees
 87    that database changes in a durable block are committed to the database when
 88    the block exists without error.
 89
 90    This is a private API.
 91    """
 92
 93    def __init__(self, savepoint: bool, durable: bool) -> None:
 94        self.savepoint = savepoint
 95        self.durable = durable
 96        self._from_testcase = False
 97
 98    def __enter__(self) -> None:
 99        conn = get_connection()
100        if (
101            self.durable
102            and conn.atomic_blocks
103            and not conn.atomic_blocks[-1]._from_testcase
104        ):
105            raise RuntimeError(
106                "A durable atomic block cannot be nested within another atomic block."
107            )
108        if not conn.in_atomic_block:
109            # Reset state when entering an outermost atomic block.
110            conn.needs_rollback = False
111        if conn.in_atomic_block:
112            # We're already in a transaction; create a savepoint, unless we
113            # were told not to or we're already waiting for a rollback. The
114            # second condition avoids creating useless savepoints and prevents
115            # overwriting needs_rollback until the rollback is performed.
116            if self.savepoint and not conn.needs_rollback:
117                sid = conn.savepoint()
118                conn.savepoint_ids.append(sid)
119            else:
120                conn.savepoint_ids.append(None)
121        else:
122            conn.set_autocommit(False)
123            conn.in_atomic_block = True
124
125        if conn.in_atomic_block:
126            conn.atomic_blocks.append(self)
127
128    def __exit__(
129        self,
130        exc_type: type[BaseException] | None,
131        exc_value: BaseException | None,
132        traceback: Any,
133    ) -> None:
134        conn = get_connection()
135        if conn.in_atomic_block:
136            conn.atomic_blocks.pop()
137
138        if conn.savepoint_ids:
139            sid = conn.savepoint_ids.pop()
140        else:
141            # Prematurely unset this flag to allow using commit or rollback.
142            conn.in_atomic_block = False
143
144        try:
145            if exc_type is None and not conn.needs_rollback:
146                if conn.in_atomic_block:
147                    # Release savepoint if there is one
148                    if sid is not None:
149                        try:
150                            conn.savepoint_commit(sid)
151                        except psycopg.DatabaseError:
152                            try:
153                                conn.savepoint_rollback(sid)
154                                # The savepoint won't be reused. Release it to
155                                # minimize overhead for the database server.
156                                conn.savepoint_commit(sid)
157                            except psycopg.Error:
158                                # If rolling back to a savepoint fails, mark for
159                                # rollback at a higher level and avoid shadowing
160                                # the original exception.
161                                conn.needs_rollback = True
162                            raise
163                else:
164                    # Commit transaction
165                    try:
166                        conn.commit()
167                    except psycopg.DatabaseError:
168                        try:
169                            conn.rollback()
170                        except psycopg.Error:
171                            # An error during rollback means that something
172                            # went wrong with the connection. Drop it.
173                            conn.close()
174                        raise
175            else:
176                # This flag will be set to True again if there isn't a savepoint
177                # allowing to perform the rollback at this level.
178                conn.needs_rollback = False
179                if conn.in_atomic_block:
180                    # Roll back to savepoint if there is one, mark for rollback
181                    # otherwise.
182                    if sid is None:
183                        conn.needs_rollback = True
184                    else:
185                        try:
186                            conn.savepoint_rollback(sid)
187                            # The savepoint won't be reused. Release it to
188                            # minimize overhead for the database server.
189                            conn.savepoint_commit(sid)
190                        except psycopg.Error:
191                            # If rolling back to a savepoint fails, mark for
192                            # rollback at a higher level and avoid shadowing
193                            # the original exception.
194                            conn.needs_rollback = True
195                else:
196                    # Roll back transaction
197                    try:
198                        conn.rollback()
199                    except psycopg.Error:
200                        # An error during rollback means that something
201                        # went wrong with the connection. Drop it.
202                        conn.close()
203
204        finally:
205            # Outermost block exit when autocommit was enabled. Skip when
206            # the connection was dropped during rollback/commit failure —
207            # ensure_connection() would otherwise acquire a fresh pool
208            # connection just to flip autocommit while the original error
209            # is still propagating.
210            if not conn.in_atomic_block and conn.connection is not None:
211                conn.set_autocommit(True)
212
213
214def atomic[F: Callable[..., Any]](
215    func: F | None = None, *, savepoint: bool = True, durable: bool = False
216) -> F | Atomic:
217    """Create an atomic transaction context or decorator."""
218    if callable(func):
219        return Atomic(savepoint, durable)(func)
220    return Atomic(savepoint, durable)