1from __future__ import annotations
2
3from collections.abc import Generator
4from typing import TYPE_CHECKING, Any
5
6if TYPE_CHECKING:
7 from typing import Self
8
9
10from plain.logs import get_framework_logger
11from plain.postgres.ddl import compile_database_default_sql
12from plain.postgres.dialect import build_timeout_set_clauses, quote_name
13from plain.postgres.fields import Field
14from plain.postgres.fields.base import ColumnField
15from plain.postgres.fields.related import RelatedField
16from plain.postgres.fields.reverse_related import ManyToManyRel
17from plain.postgres.transaction import atomic
18from plain.runtime import settings as plain_settings
19
20if TYPE_CHECKING:
21 from plain.postgres.base import Model
22 from plain.postgres.connection import DatabaseConnection
23 from plain.postgres.fields import Field
24
25logger = get_framework_logger()
26
27
28class DatabaseSchemaEditor:
29 """
30 Responsible for emitting schema-changing statements to PostgreSQL - model
31 creation/removal/alteration, field renaming, index management, and so on.
32 """
33
34 sql_create_table = "CREATE TABLE %(table)s (%(definition)s)"
35 sql_rename_table = "ALTER TABLE %(old_table)s RENAME TO %(new_table)s"
36 sql_delete_table = "DROP TABLE %(table)s CASCADE"
37
38 sql_create_column = "ALTER TABLE %(table)s ADD COLUMN %(column)s %(definition)s"
39 sql_alter_column = "ALTER TABLE %(table)s %(changes)s"
40 sql_alter_column_no_default = "ALTER COLUMN %(column)s DROP DEFAULT"
41 sql_delete_column = "ALTER TABLE %(table)s DROP COLUMN %(column)s CASCADE"
42 sql_rename_column = (
43 "ALTER TABLE %(table)s RENAME COLUMN %(old_column)s TO %(new_column)s"
44 )
45
46 def __init__(
47 self,
48 connection: DatabaseConnection,
49 atomic: bool = True,
50 collect_sql: bool = False,
51 ):
52 self.connection = connection
53 self.collect_sql = collect_sql
54 self.atomic_migration = atomic and not collect_sql
55 # `atomic_migration` goes False under collect_sql=True (we don't open
56 # a real transaction for preview), but the collected SQL should still
57 # reflect a real atomic run. Track the user's `atomic` intent
58 # separately so the SET LOCAL prelude is emitted in the atomic=True
59 # preview case, and skipped in the atomic=False case (where SET LOCAL
60 # would be a no-op with WARNING outside a transaction block).
61 self._atomic_intent = atomic
62
63 # State-managing methods
64
65 def __enter__(self) -> Self:
66 self.executed_sql: list[str] = []
67 if self.atomic_migration:
68 self.atomic = atomic()
69 self.atomic.__enter__()
70 return self
71
72 def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
73 if self.atomic_migration:
74 self.atomic.__exit__(exc_type, exc_value, traceback)
75
76 # Core utility functions
77
78 def execute(
79 self,
80 sql: str,
81 params: tuple[Any, ...] | list[Any] | None = (),
82 *,
83 set_timeouts: bool = True,
84 ) -> None:
85 """Execute the given SQL statement, with optional parameters.
86
87 When ``set_timeouts`` is True (default), ``SET LOCAL lock_timeout`` and
88 ``SET LOCAL statement_timeout`` are prepended to the SQL so DDL fails
89 fast if it can't acquire its lock or if a blocking statement runs
90 longer than configured. Values come from ``POSTGRES_MIGRATION_*``
91 settings. ``RunSQL(no_timeout=True)`` passes ``set_timeouts=False`` as
92 an escape hatch for long-running data migrations.
93 """
94 sql_str = sql
95
96 # Merge the query client-side, as PostgreSQL won't do it server-side.
97 if params is not None:
98 sql_str = self.connection.compose_sql(sql_str, params)
99 params = None
100
101 # SET LOCAL only works inside a transaction block. Skip the prelude
102 # when the editor was opened with atomic=False (e.g. a migration that
103 # needs to issue CONCURRENTLY via RunSQL) — otherwise Postgres would
104 # silently WARN and ignore the timeouts. Users of non-atomic
105 # migrations manage timeouts explicitly in their RunSQL if needed.
106 if set_timeouts and self._atomic_intent:
107 sql_str = (
108 build_timeout_set_clauses(
109 lock_timeout=plain_settings.POSTGRES_MIGRATION_LOCK_TIMEOUT,
110 statement_timeout=plain_settings.POSTGRES_MIGRATION_STATEMENT_TIMEOUT,
111 )
112 + sql_str
113 )
114
115 # Log the command we're running, then run it
116 logger.debug("Schema SQL executed", extra={"sql": sql_str, "params": params})
117
118 # Track executed SQL for display in migration output
119 self.executed_sql.append(sql_str)
120
121 if self.collect_sql:
122 return
123
124 with self.connection.cursor() as cursor:
125 cursor.execute(sql_str, params)
126
127 def table_sql(self, model: type[Model]) -> tuple[str, list[Any]]:
128 """Take a model and return its table definition."""
129 column_sqls = []
130 params = []
131 for field in model._model_meta.local_fields:
132 definition, extra_params = self.column_sql(
133 model, field, include_default=field.has_persistent_column_default()
134 )
135 if definition is None:
136 continue
137 # Autoincrement SQL (e.g. GENERATED BY DEFAULT AS IDENTITY).
138 col_type_suffix = field.db_type_suffix()
139 if col_type_suffix:
140 definition += f" {col_type_suffix}"
141 if extra_params:
142 params.extend(extra_params)
143 # FK constraints are handled by convergence, not during table creation.
144 # Add the SQL to our big list.
145 column_sqls.append(f"{quote_name(field.column)} {definition}")
146 # Constraints are not created inline — they're managed by convergence.
147 sql = self.sql_create_table % {
148 "table": quote_name(model.model_options.db_table),
149 "definition": ", ".join(col for col in column_sqls if col),
150 }
151 return sql, params
152
153 # Field <-> database mapping functions
154
155 def _iter_column_sql(
156 self,
157 column_db_type: str,
158 params: list[Any],
159 model: type[Model],
160 field: ColumnField,
161 include_default: bool,
162 ) -> Generator[str]:
163 yield column_db_type
164 null = field.allow_null
165 # Include a default value, if requested.
166 if include_default:
167 db_default_expr = field.get_db_default_expression()
168 if db_default_expr is not None:
169 # Expression defaults are inlined into the DDL — they render
170 # as parameter-free SQL and become the column's persistent
171 # DEFAULT.
172 yield f"DEFAULT {self._compile_expression(db_default_expr)}"
173 else:
174 default_value = self.effective_default(field)
175 if default_value is not None:
176 yield "DEFAULT %s"
177 params.append(default_value)
178
179 if not null:
180 yield "NOT NULL"
181 else:
182 yield "NULL"
183
184 if field.primary_key:
185 yield "PRIMARY KEY"
186
187 def column_sql(
188 self, model: type[Model], field: Field, include_default: bool = False
189 ) -> tuple[str | None, list[Any] | None]:
190 """
191 Return the column definition for a field. The field must already have
192 had set_attributes_from_name() called.
193 """
194 # Get the column's type and use that as the basis of the SQL.
195 column_db_type = field.db_type()
196 # Check for fields that aren't actually columns (e.g. M2M).
197 if column_db_type is None:
198 return None, None
199 assert isinstance(field, ColumnField)
200 params: list[Any] = []
201 return (
202 " ".join(
203 # This appends to the params being returned.
204 self._iter_column_sql(
205 column_db_type,
206 params,
207 model,
208 field,
209 include_default,
210 )
211 ),
212 params,
213 )
214
215 def _compile_expression(self, expression: Any) -> str:
216 """Compile a DB-default expression (Now, GenRandomUUID) for inlining into DDL."""
217 return compile_database_default_sql(expression)
218
219 def effective_default(self, field: Field) -> Any:
220 """Return a field's declared literal DEFAULT value, prepared for the
221 database. Returns None for fields without a user-declared default —
222 expression defaults take the `get_db_default_expression()` path
223 instead."""
224 from plain.postgres.fields.base import DefaultableField
225
226 if not isinstance(field, DefaultableField) or not field.has_default():
227 return None
228 return field.get_db_prep_save(field.get_default(), self.connection)
229
230 # Actions
231
232 def create_model(self, model: type[Model]) -> None:
233 """Create a table for the given model."""
234 sql, params = self.table_sql(model)
235 # Prevent using [] as params, in the case a literal '%' is used in the
236 # definition.
237 self.execute(sql, params or None)
238
239 def delete_model(self, model: type[Model]) -> None:
240 """Delete a model from the database."""
241 self.execute(
242 self.sql_delete_table
243 % {
244 "table": quote_name(model.model_options.db_table),
245 }
246 )
247
248 def alter_db_table(
249 self, model: type[Model], old_db_table: str, new_db_table: str
250 ) -> None:
251 """Rename the table a model points to."""
252 if old_db_table == new_db_table:
253 return
254 self.execute(
255 self.sql_rename_table
256 % {
257 "old_table": quote_name(old_db_table),
258 "new_table": quote_name(new_db_table),
259 }
260 )
261
262 def add_field(self, model: type[Model], field: Field) -> None:
263 """
264 Create a field on a model. Usually involves adding a column, but may
265 involve adding a table instead (for M2M fields).
266 """
267 # Get the column's definition
268 definition, params = self.column_sql(
269 model,
270 field,
271 include_default=field.has_persistent_column_default(),
272 )
273 # It might not actually have a column behind it
274 if definition is None:
275 return
276 if col_type_suffix := field.db_type_suffix():
277 definition += f" {col_type_suffix}"
278 # FK constraints are handled by convergence, not inline during add_field.
279 # Build the SQL and run it
280 sql = self.sql_create_column % {
281 "table": quote_name(model.model_options.db_table),
282 "column": quote_name(field.column),
283 "definition": definition,
284 }
285 self.execute(sql, params)
286
287 def remove_field(self, model: type[Model], field: Field) -> None:
288 """
289 Remove a field from a model. Usually involves deleting a column,
290 but for M2Ms may involve deleting a table.
291 """
292 # It might not actually have a column behind it
293 if field.db_type() is None:
294 return
295 # FK constraints are dropped automatically by CASCADE on DROP COLUMN.
296 # Delete the column
297 sql = self.sql_delete_column % {
298 "table": quote_name(model.model_options.db_table),
299 "column": quote_name(field.column),
300 }
301 self.execute(sql)
302
303 def alter_field(
304 self,
305 model: type[Model],
306 old_field: Field,
307 new_field: Field,
308 ) -> None:
309 """
310 Allow a field's type, uniqueness, nullability, default, column,
311 constraints, etc. to be modified.
312 `old_field` is required to compute the necessary changes.
313 """
314 if not self._field_should_be_altered(old_field, new_field):
315 return
316 # Ensure this field is even column-based
317 old_type = old_field.db_type()
318 new_type = new_field.db_type()
319 if (old_type is None and not isinstance(old_field, RelatedField)) or (
320 new_type is None and not isinstance(new_field, RelatedField)
321 ):
322 raise ValueError(
323 f"Cannot alter field {old_field} into {new_field} - they do not properly define "
324 "db_type (are you using a badly-written custom field?)",
325 )
326 elif (
327 old_type is None
328 and new_type is None
329 and isinstance(old_field, RelatedField)
330 and isinstance(old_field.remote_field, ManyToManyRel)
331 and isinstance(new_field, RelatedField)
332 and isinstance(new_field.remote_field, ManyToManyRel)
333 ):
334 # Both sides have through models; this is a no-op.
335 return
336 elif old_type is None or new_type is None:
337 raise ValueError(
338 f"Cannot alter field {old_field} into {new_field} - they are not compatible types "
339 "(you cannot alter to or from M2M fields, or add or remove "
340 "through= on M2M fields)"
341 )
342
343 assert isinstance(old_field, ColumnField)
344 assert isinstance(new_field, ColumnField)
345 self._alter_field(
346 model,
347 old_field,
348 new_field,
349 old_type,
350 new_type,
351 )
352
353 def _alter_field(
354 self,
355 model: type[Model],
356 old_field: ColumnField,
357 new_field: ColumnField,
358 old_type: str,
359 new_type: str,
360 ) -> None:
361 """Column rename + column type change.
362
363 Also drops the existing expression DEFAULT when the column type is
364 changing (Postgres rejects the cast otherwise). Nullability and
365 column DEFAULT reconciliation are convergence-managed — see
366 ``plain.postgres.convergence``.
367 """
368 if old_field.column != new_field.column:
369 self.execute(
370 self._rename_field_sql(
371 model.model_options.db_table, old_field, new_field
372 )
373 )
374 # Postgres rejects ALTER COLUMN TYPE when the existing expression DEFAULT
375 # can't cast to the new type. Drop it first; convergence re-applies it.
376 if old_field.db_returning and old_type != new_type:
377 self.execute(
378 self.sql_alter_column
379 % {
380 "table": quote_name(model.model_options.db_table),
381 "changes": self.sql_alter_column_no_default
382 % {"column": quote_name(new_field.column)},
383 }
384 )
385 if (
386 old_type != new_type
387 or old_field.db_type_suffix() != new_field.db_type_suffix()
388 ):
389 type_sql, type_params = self._alter_column_type_sql(
390 old_field, new_field, new_type
391 )
392 self.execute(
393 self.sql_alter_column
394 % {
395 "table": quote_name(model.model_options.db_table),
396 "changes": type_sql,
397 },
398 type_params,
399 )
400
401 def _alter_column_type_sql(
402 self,
403 old_field: Field,
404 new_field: Field,
405 new_type: str,
406 ) -> tuple[str, list[Any]]:
407 """Return an ``(sql, params)`` ALTER COLUMN TYPE fragment."""
408 sql = "ALTER COLUMN %(column)s TYPE %(type)s"
409 # USING cast when the base data type changed (e.g. varchar → int),
410 # not just a parameter like max_length.
411 if old_field.unqualified_db_type() != new_field.unqualified_db_type():
412 sql += " USING %(column)s::%(type)s"
413 return (
414 sql
415 % {
416 "column": quote_name(new_field.column),
417 "type": new_type,
418 },
419 [],
420 )
421
422 def _field_should_be_altered(
423 self, old_field: Field, new_field: Field, ignore: set[str] | None = None
424 ) -> bool:
425 ignore = ignore or set()
426 _, old_path, old_args, old_kwargs = old_field.deconstruct()
427 _, new_path, new_args, new_kwargs = new_field.deconstruct()
428 # Don't alter when:
429 # - changing only a field name
430 # - changing an attribute that doesn't affect the schema
431 # - changing an attribute in the provided set of ignored attributes
432 for attr in ignore.union(old_field.non_migration_attrs):
433 old_kwargs.pop(attr, None)
434 for attr in ignore.union(new_field.non_migration_attrs):
435 new_kwargs.pop(attr, None)
436 return quote_name(old_field.column) != quote_name(new_field.column) or (
437 old_path,
438 old_args,
439 old_kwargs,
440 ) != (new_path, new_args, new_kwargs)
441
442 def _rename_field_sql(self, table: str, old_field: Field, new_field: Field) -> str:
443 return self.sql_rename_column % {
444 "table": quote_name(table),
445 "old_column": quote_name(old_field.column),
446 "new_column": quote_name(new_field.column),
447 }