1from __future__ import annotations
2
3from collections.abc import Callable, Generator
4from contextlib import ContextDecorator, contextmanager
5from typing import Any
6
7import psycopg
8
9from plain.postgres.db import get_connection
10
11
12class TransactionManagementError(psycopg.ProgrammingError):
13 """Transaction management is used improperly."""
14
15 pass
16
17
18@contextmanager
19def mark_for_rollback_on_error() -> Generator[None]:
20 """
21 Internal low-level utility to mark a transaction as "needs rollback" when
22 an exception is raised while not enforcing the enclosed block to be in a
23 transaction. This is needed by Model.save() and friends to avoid starting a
24 transaction when in autocommit mode and a single query is executed.
25
26 It's equivalent to:
27
28 if get_connection().get_autocommit():
29 yield
30 else:
31 with transaction.atomic(savepoint=False):
32 yield
33
34 but it uses low-level utilities to avoid performance overhead.
35 """
36 try:
37 yield
38 except Exception as exc:
39 conn = get_connection()
40 if conn.in_atomic_block:
41 conn.needs_rollback = True
42 conn.rollback_exc = exc
43 raise
44
45
46def on_commit(func: Callable[[], Any], robust: bool = False) -> None:
47 """
48 Register `func` to be called when the current transaction is committed.
49 If the current transaction is rolled back, `func` will not be called.
50 """
51 get_connection().on_commit(func, robust)
52
53
54#################################
55# Decorators / context managers #
56#################################
57
58
59class Atomic(ContextDecorator):
60 """
61 Guarantee the atomic execution of a given block.
62
63 An instance can be used either as a decorator or as a context manager.
64
65 When it's used as a decorator, __call__ wraps the execution of the
66 decorated function in the instance itself, used as a context manager.
67
68 When it's used as a context manager, __enter__ creates a transaction or a
69 savepoint, depending on whether a transaction is already in progress, and
70 __exit__ commits the transaction or releases the savepoint on normal exit,
71 and rolls back the transaction or to the savepoint on exceptions.
72
73 It's possible to disable the creation of savepoints if the goal is to
74 ensure that some code runs within a transaction without creating overhead.
75
76 A stack of savepoints identifiers is maintained as an attribute of the
77 connection. None denotes the absence of a savepoint.
78
79 This allows reentrancy even if the same AtomicWrapper is reused. For
80 example, it's possible to define `oa = atomic('other')` and use `@oa` or
81 `with oa:` multiple times.
82
83 Since database connections are stored per-context (ContextVar), this is thread-safe.
84
85 An atomic block can be tagged as durable. In this case, raise a
86 RuntimeError if it's nested within another atomic block. This guarantees
87 that database changes in a durable block are committed to the database when
88 the block exists without error.
89
90 This is a private API.
91 """
92
93 def __init__(self, savepoint: bool, durable: bool) -> None:
94 self.savepoint = savepoint
95 self.durable = durable
96 self._from_testcase = False
97
98 def __enter__(self) -> None:
99 conn = get_connection()
100 if (
101 self.durable
102 and conn.atomic_blocks
103 and not conn.atomic_blocks[-1]._from_testcase
104 ):
105 raise RuntimeError(
106 "A durable atomic block cannot be nested within another atomic block."
107 )
108 if not conn.in_atomic_block:
109 # Reset state when entering an outermost atomic block.
110 conn.needs_rollback = False
111 if conn.in_atomic_block:
112 # We're already in a transaction; create a savepoint, unless we
113 # were told not to or we're already waiting for a rollback. The
114 # second condition avoids creating useless savepoints and prevents
115 # overwriting needs_rollback until the rollback is performed.
116 if self.savepoint and not conn.needs_rollback:
117 sid = conn.savepoint()
118 conn.savepoint_ids.append(sid)
119 else:
120 conn.savepoint_ids.append(None)
121 else:
122 conn.set_autocommit(False)
123 conn.in_atomic_block = True
124
125 if conn.in_atomic_block:
126 conn.atomic_blocks.append(self)
127
128 def __exit__(
129 self,
130 exc_type: type[BaseException] | None,
131 exc_value: BaseException | None,
132 traceback: Any,
133 ) -> None:
134 conn = get_connection()
135 if conn.in_atomic_block:
136 conn.atomic_blocks.pop()
137
138 if conn.savepoint_ids:
139 sid = conn.savepoint_ids.pop()
140 else:
141 # Prematurely unset this flag to allow using commit or rollback.
142 conn.in_atomic_block = False
143
144 try:
145 if exc_type is None and not conn.needs_rollback:
146 if conn.in_atomic_block:
147 # Release savepoint if there is one
148 if sid is not None:
149 try:
150 conn.savepoint_commit(sid)
151 except psycopg.DatabaseError:
152 try:
153 conn.savepoint_rollback(sid)
154 # The savepoint won't be reused. Release it to
155 # minimize overhead for the database server.
156 conn.savepoint_commit(sid)
157 except psycopg.Error:
158 # If rolling back to a savepoint fails, mark for
159 # rollback at a higher level and avoid shadowing
160 # the original exception.
161 conn.needs_rollback = True
162 raise
163 else:
164 # Commit transaction
165 try:
166 conn.commit()
167 except psycopg.DatabaseError:
168 try:
169 conn.rollback()
170 except psycopg.Error:
171 # An error during rollback means that something
172 # went wrong with the connection. Drop it.
173 conn.close()
174 raise
175 else:
176 # This flag will be set to True again if there isn't a savepoint
177 # allowing to perform the rollback at this level.
178 conn.needs_rollback = False
179 if conn.in_atomic_block:
180 # Roll back to savepoint if there is one, mark for rollback
181 # otherwise.
182 if sid is None:
183 conn.needs_rollback = True
184 else:
185 try:
186 conn.savepoint_rollback(sid)
187 # The savepoint won't be reused. Release it to
188 # minimize overhead for the database server.
189 conn.savepoint_commit(sid)
190 except psycopg.Error:
191 # If rolling back to a savepoint fails, mark for
192 # rollback at a higher level and avoid shadowing
193 # the original exception.
194 conn.needs_rollback = True
195 else:
196 # Roll back transaction
197 try:
198 conn.rollback()
199 except psycopg.Error:
200 # An error during rollback means that something
201 # went wrong with the connection. Drop it.
202 conn.close()
203
204 finally:
205 # Outermost block exit when autocommit was enabled. Skip when
206 # the connection was dropped during rollback/commit failure —
207 # ensure_connection() would otherwise acquire a fresh pool
208 # connection just to flip autocommit while the original error
209 # is still propagating.
210 if not conn.in_atomic_block and conn.connection is not None:
211 conn.set_autocommit(True)
212
213
214def atomic[F: Callable[..., Any]](
215 func: F | None = None, *, savepoint: bool = True, durable: bool = False
216) -> F | Atomic:
217 """Create an atomic transaction context or decorator."""
218 if callable(func):
219 return Atomic(savepoint, durable)(func)
220 return Atomic(savepoint, durable)