v0.148.0
  1from __future__ import annotations
  2
  3from collections.abc import Generator
  4from typing import TYPE_CHECKING, Any
  5
  6if TYPE_CHECKING:
  7    from typing import Self
  8
  9
 10from plain.logs import get_framework_logger
 11from plain.postgres.ddl import compile_database_default_sql
 12from plain.postgres.dialect import build_timeout_set_clauses, quote_name
 13from plain.postgres.fields import Field
 14from plain.postgres.fields.base import ColumnField
 15from plain.postgres.fields.related import RelatedField
 16from plain.postgres.fields.reverse_related import ManyToManyRel
 17from plain.postgres.transaction import atomic
 18from plain.runtime import settings as plain_settings
 19
 20if TYPE_CHECKING:
 21    from plain.postgres.base import Model
 22    from plain.postgres.connection import DatabaseConnection
 23    from plain.postgres.fields import Field
 24
 25logger = get_framework_logger()
 26
 27
 28class DatabaseSchemaEditor:
 29    """
 30    Responsible for emitting schema-changing statements to PostgreSQL - model
 31    creation/removal/alteration, field renaming, index management, and so on.
 32    """
 33
 34    sql_create_table = "CREATE TABLE %(table)s (%(definition)s)"
 35    sql_rename_table = "ALTER TABLE %(old_table)s RENAME TO %(new_table)s"
 36    sql_delete_table = "DROP TABLE %(table)s CASCADE"
 37
 38    sql_create_column = "ALTER TABLE %(table)s ADD COLUMN %(column)s %(definition)s"
 39    sql_alter_column = "ALTER TABLE %(table)s %(changes)s"
 40    sql_alter_column_no_default = "ALTER COLUMN %(column)s DROP DEFAULT"
 41    sql_delete_column = "ALTER TABLE %(table)s DROP COLUMN %(column)s CASCADE"
 42    sql_rename_column = (
 43        "ALTER TABLE %(table)s RENAME COLUMN %(old_column)s TO %(new_column)s"
 44    )
 45
 46    def __init__(
 47        self,
 48        connection: DatabaseConnection,
 49        atomic: bool = True,
 50        collect_sql: bool = False,
 51    ):
 52        self.connection = connection
 53        self.collect_sql = collect_sql
 54        self.atomic_migration = atomic and not collect_sql
 55        # `atomic_migration` goes False under collect_sql=True (we don't open
 56        # a real transaction for preview), but the collected SQL should still
 57        # reflect a real atomic run. Track the user's `atomic` intent
 58        # separately so the SET LOCAL prelude is emitted in the atomic=True
 59        # preview case, and skipped in the atomic=False case (where SET LOCAL
 60        # would be a no-op with WARNING outside a transaction block).
 61        self._atomic_intent = atomic
 62
 63    # State-managing methods
 64
 65    def __enter__(self) -> Self:
 66        self.executed_sql: list[str] = []
 67        if self.atomic_migration:
 68            self.atomic = atomic()
 69            self.atomic.__enter__()
 70        return self
 71
 72    def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
 73        if self.atomic_migration:
 74            self.atomic.__exit__(exc_type, exc_value, traceback)
 75
 76    # Core utility functions
 77
 78    def execute(
 79        self,
 80        sql: str,
 81        params: tuple[Any, ...] | list[Any] | None = (),
 82        *,
 83        set_timeouts: bool = True,
 84    ) -> None:
 85        """Execute the given SQL statement, with optional parameters.
 86
 87        When ``set_timeouts`` is True (default), ``SET LOCAL lock_timeout`` and
 88        ``SET LOCAL statement_timeout`` are prepended to the SQL so DDL fails
 89        fast if it can't acquire its lock or if a blocking statement runs
 90        longer than configured. Values come from ``POSTGRES_MIGRATION_*``
 91        settings. ``RunSQL(no_timeout=True)`` passes ``set_timeouts=False`` as
 92        an escape hatch for long-running data migrations.
 93        """
 94        sql_str = sql
 95
 96        # Merge the query client-side, as PostgreSQL won't do it server-side.
 97        if params is not None:
 98            sql_str = self.connection.compose_sql(sql_str, params)
 99            params = None
100
101        # SET LOCAL only works inside a transaction block. Skip the prelude
102        # when the editor was opened with atomic=False (e.g. a migration that
103        # needs to issue CONCURRENTLY via RunSQL) — otherwise Postgres would
104        # silently WARN and ignore the timeouts. Users of non-atomic
105        # migrations manage timeouts explicitly in their RunSQL if needed.
106        if set_timeouts and self._atomic_intent:
107            sql_str = (
108                build_timeout_set_clauses(
109                    lock_timeout=plain_settings.POSTGRES_MIGRATION_LOCK_TIMEOUT,
110                    statement_timeout=plain_settings.POSTGRES_MIGRATION_STATEMENT_TIMEOUT,
111                )
112                + sql_str
113            )
114
115        # Log the command we're running, then run it
116        logger.debug("Schema SQL executed", extra={"sql": sql_str, "params": params})
117
118        # Track executed SQL for display in migration output
119        self.executed_sql.append(sql_str)
120
121        if self.collect_sql:
122            return
123
124        with self.connection.cursor() as cursor:
125            cursor.execute(sql_str, params)
126
127    def table_sql(self, model: type[Model]) -> tuple[str, list[Any]]:
128        """Take a model and return its table definition."""
129        column_sqls = []
130        params = []
131        for field in model._model_meta.local_fields:
132            definition, extra_params = self.column_sql(
133                model, field, include_default=field.has_persistent_column_default()
134            )
135            if definition is None:
136                continue
137            # Autoincrement SQL (e.g. GENERATED BY DEFAULT AS IDENTITY).
138            col_type_suffix = field.db_type_suffix()
139            if col_type_suffix:
140                definition += f" {col_type_suffix}"
141            if extra_params:
142                params.extend(extra_params)
143            # FK constraints are handled by convergence, not during table creation.
144            # Add the SQL to our big list.
145            column_sqls.append(f"{quote_name(field.column)} {definition}")
146        # Constraints are not created inline — they're managed by convergence.
147        sql = self.sql_create_table % {
148            "table": quote_name(model.model_options.db_table),
149            "definition": ", ".join(col for col in column_sqls if col),
150        }
151        return sql, params
152
153    # Field <-> database mapping functions
154
155    def _iter_column_sql(
156        self,
157        column_db_type: str,
158        params: list[Any],
159        model: type[Model],
160        field: ColumnField,
161        include_default: bool,
162    ) -> Generator[str]:
163        yield column_db_type
164        null = field.allow_null
165        # Include a default value, if requested.
166        if include_default:
167            db_default_expr = field.get_db_default_expression()
168            if db_default_expr is not None:
169                # Expression defaults are inlined into the DDL — they render
170                # as parameter-free SQL and become the column's persistent
171                # DEFAULT.
172                yield f"DEFAULT {self._compile_expression(db_default_expr)}"
173            else:
174                default_value = self.effective_default(field)
175                if default_value is not None:
176                    yield "DEFAULT %s"
177                    params.append(default_value)
178
179        if not null:
180            yield "NOT NULL"
181        else:
182            yield "NULL"
183
184        if field.primary_key:
185            yield "PRIMARY KEY"
186
187    def column_sql(
188        self, model: type[Model], field: Field, include_default: bool = False
189    ) -> tuple[str | None, list[Any] | None]:
190        """
191        Return the column definition for a field. The field must already have
192        had set_attributes_from_name() called.
193        """
194        # Get the column's type and use that as the basis of the SQL.
195        column_db_type = field.db_type()
196        # Check for fields that aren't actually columns (e.g. M2M).
197        if column_db_type is None:
198            return None, None
199        assert isinstance(field, ColumnField)
200        params: list[Any] = []
201        return (
202            " ".join(
203                # This appends to the params being returned.
204                self._iter_column_sql(
205                    column_db_type,
206                    params,
207                    model,
208                    field,
209                    include_default,
210                )
211            ),
212            params,
213        )
214
215    def _compile_expression(self, expression: Any) -> str:
216        """Compile a DB-default expression (Now, GenRandomUUID) for inlining into DDL."""
217        return compile_database_default_sql(expression)
218
219    def effective_default(self, field: Field) -> Any:
220        """Return a field's declared literal DEFAULT value, prepared for the
221        database. Returns None for fields without a user-declared default —
222        expression defaults take the `get_db_default_expression()` path
223        instead."""
224        from plain.postgres.fields.base import DefaultableField
225
226        if not isinstance(field, DefaultableField) or not field.has_default():
227            return None
228        return field.get_db_prep_save(field.get_default(), self.connection)
229
230    # Actions
231
232    def create_model(self, model: type[Model]) -> None:
233        """Create a table for the given model."""
234        sql, params = self.table_sql(model)
235        # Prevent using [] as params, in the case a literal '%' is used in the
236        # definition.
237        self.execute(sql, params or None)
238
239    def delete_model(self, model: type[Model]) -> None:
240        """Delete a model from the database."""
241        self.execute(
242            self.sql_delete_table
243            % {
244                "table": quote_name(model.model_options.db_table),
245            }
246        )
247
248    def alter_db_table(
249        self, model: type[Model], old_db_table: str, new_db_table: str
250    ) -> None:
251        """Rename the table a model points to."""
252        if old_db_table == new_db_table:
253            return
254        self.execute(
255            self.sql_rename_table
256            % {
257                "old_table": quote_name(old_db_table),
258                "new_table": quote_name(new_db_table),
259            }
260        )
261
262    def add_field(self, model: type[Model], field: Field) -> None:
263        """
264        Create a field on a model. Usually involves adding a column, but may
265        involve adding a table instead (for M2M fields).
266        """
267        # Get the column's definition
268        definition, params = self.column_sql(
269            model,
270            field,
271            include_default=field.has_persistent_column_default(),
272        )
273        # It might not actually have a column behind it
274        if definition is None:
275            return
276        if col_type_suffix := field.db_type_suffix():
277            definition += f" {col_type_suffix}"
278        # FK constraints are handled by convergence, not inline during add_field.
279        # Build the SQL and run it
280        sql = self.sql_create_column % {
281            "table": quote_name(model.model_options.db_table),
282            "column": quote_name(field.column),
283            "definition": definition,
284        }
285        self.execute(sql, params)
286
287    def remove_field(self, model: type[Model], field: Field) -> None:
288        """
289        Remove a field from a model. Usually involves deleting a column,
290        but for M2Ms may involve deleting a table.
291        """
292        # It might not actually have a column behind it
293        if field.db_type() is None:
294            return
295        # FK constraints are dropped automatically by CASCADE on DROP COLUMN.
296        # Delete the column
297        sql = self.sql_delete_column % {
298            "table": quote_name(model.model_options.db_table),
299            "column": quote_name(field.column),
300        }
301        self.execute(sql)
302
303    def alter_field(
304        self,
305        model: type[Model],
306        old_field: Field,
307        new_field: Field,
308    ) -> None:
309        """
310        Allow a field's type, uniqueness, nullability, default, column,
311        constraints, etc. to be modified.
312        `old_field` is required to compute the necessary changes.
313        """
314        if not self._field_should_be_altered(old_field, new_field):
315            return
316        # Ensure this field is even column-based
317        old_type = old_field.db_type()
318        new_type = new_field.db_type()
319        if (old_type is None and not isinstance(old_field, RelatedField)) or (
320            new_type is None and not isinstance(new_field, RelatedField)
321        ):
322            raise ValueError(
323                f"Cannot alter field {old_field} into {new_field} - they do not properly define "
324                "db_type (are you using a badly-written custom field?)",
325            )
326        elif (
327            old_type is None
328            and new_type is None
329            and isinstance(old_field, RelatedField)
330            and isinstance(old_field.remote_field, ManyToManyRel)
331            and isinstance(new_field, RelatedField)
332            and isinstance(new_field.remote_field, ManyToManyRel)
333        ):
334            # Both sides have through models; this is a no-op.
335            return
336        elif old_type is None or new_type is None:
337            raise ValueError(
338                f"Cannot alter field {old_field} into {new_field} - they are not compatible types "
339                "(you cannot alter to or from M2M fields, or add or remove "
340                "through= on M2M fields)"
341            )
342
343        assert isinstance(old_field, ColumnField)
344        assert isinstance(new_field, ColumnField)
345        self._alter_field(
346            model,
347            old_field,
348            new_field,
349            old_type,
350            new_type,
351        )
352
353    def _alter_field(
354        self,
355        model: type[Model],
356        old_field: ColumnField,
357        new_field: ColumnField,
358        old_type: str,
359        new_type: str,
360    ) -> None:
361        """Column rename + column type change.
362
363        Also drops the existing expression DEFAULT when the column type is
364        changing (Postgres rejects the cast otherwise). Nullability and
365        column DEFAULT reconciliation are convergence-managed — see
366        ``plain.postgres.convergence``.
367        """
368        if old_field.column != new_field.column:
369            self.execute(
370                self._rename_field_sql(
371                    model.model_options.db_table, old_field, new_field
372                )
373            )
374        # Postgres rejects ALTER COLUMN TYPE when the existing expression DEFAULT
375        # can't cast to the new type. Drop it first; convergence re-applies it.
376        if old_field.db_returning and old_type != new_type:
377            self.execute(
378                self.sql_alter_column
379                % {
380                    "table": quote_name(model.model_options.db_table),
381                    "changes": self.sql_alter_column_no_default
382                    % {"column": quote_name(new_field.column)},
383                }
384            )
385        if (
386            old_type != new_type
387            or old_field.db_type_suffix() != new_field.db_type_suffix()
388        ):
389            type_sql, type_params = self._alter_column_type_sql(
390                old_field, new_field, new_type
391            )
392            self.execute(
393                self.sql_alter_column
394                % {
395                    "table": quote_name(model.model_options.db_table),
396                    "changes": type_sql,
397                },
398                type_params,
399            )
400
401    def _alter_column_type_sql(
402        self,
403        old_field: Field,
404        new_field: Field,
405        new_type: str,
406    ) -> tuple[str, list[Any]]:
407        """Return an ``(sql, params)`` ALTER COLUMN TYPE fragment."""
408        sql = "ALTER COLUMN %(column)s TYPE %(type)s"
409        # USING cast when the base data type changed (e.g. varchar → int),
410        # not just a parameter like max_length.
411        if old_field.unqualified_db_type() != new_field.unqualified_db_type():
412            sql += " USING %(column)s::%(type)s"
413        return (
414            sql
415            % {
416                "column": quote_name(new_field.column),
417                "type": new_type,
418            },
419            [],
420        )
421
422    def _field_should_be_altered(
423        self, old_field: Field, new_field: Field, ignore: set[str] | None = None
424    ) -> bool:
425        ignore = ignore or set()
426        _, old_path, old_args, old_kwargs = old_field.deconstruct()
427        _, new_path, new_args, new_kwargs = new_field.deconstruct()
428        # Don't alter when:
429        # - changing only a field name
430        # - changing an attribute that doesn't affect the schema
431        # - changing an attribute in the provided set of ignored attributes
432        for attr in ignore.union(old_field.non_migration_attrs):
433            old_kwargs.pop(attr, None)
434        for attr in ignore.union(new_field.non_migration_attrs):
435            new_kwargs.pop(attr, None)
436        return quote_name(old_field.column) != quote_name(new_field.column) or (
437            old_path,
438            old_args,
439            old_kwargs,
440        ) != (new_path, new_args, new_kwargs)
441
442    def _rename_field_sql(self, table: str, old_field: Field, new_field: Field) -> str:
443        return self.sql_rename_column % {
444            "table": quote_name(table),
445            "old_column": quote_name(old_field.column),
446            "new_column": quote_name(new_field.column),
447        }