v0.145.1
  1from __future__ import annotations
  2
  3import inspect
  4from collections import defaultdict
  5from collections.abc import Callable
  6from typing import Any
  7
  8from plain.packages import packages_registry
  9from plain.postgres.constraints import UniqueConstraint
 10from plain.postgres.db import get_connection
 11from plain.postgres.expressions import F, OrderBy
 12from plain.postgres.fields.related import ForeignKeyField
 13from plain.postgres.registry import ModelsRegistry, models_registry
 14from plain.preflight import PreflightCheck, PreflightResult, register_check
 15from plain.runtime import settings
 16
 17
 18def _get_app_models() -> list[Any]:
 19    """Return models from the user's app packages only (not framework/third-party)."""
 20    app_models = []
 21    for package_config in packages_registry.get_package_configs():
 22        if package_config.name.startswith("app."):
 23            app_models.extend(
 24                models_registry.get_models(package_label=package_config.package_label)
 25            )
 26    return app_models
 27
 28
 29def _collect_model_indexes(model: Any) -> list[tuple[str, list[str], bool]]:
 30    """Collect (name, fields, is_unique) for non-partial indexes/constraints.
 31
 32    Partials are skipped for the same reason as in ``_fk_covered_field_names``.
 33    """
 34    all_indexes: list[tuple[str, list[str], bool]] = []
 35
 36    for index in model.model_options.indexes:
 37        if index.fields and not index.is_partial:
 38            fields = [f.lstrip("-") for f in index.fields]
 39            all_indexes.append((index.name, fields, False))
 40
 41    for constraint in model.model_options.constraints:
 42        if (
 43            isinstance(constraint, UniqueConstraint)
 44            and constraint.fields
 45            and not constraint.is_partial
 46        ):
 47            all_indexes.append((constraint.name, list(constraint.fields), True))
 48
 49    return all_indexes
 50
 51
 52def _bare_column_name(expr: Any) -> str | None:
 53    """Return the column name if `expr` resolves to a bare column, else `None`.
 54
 55    Postgres can range-scan the leading column of an index for `WHERE col = ?`
 56    only when that column is a real attribute, not an expression — so a
 57    compound leading expression like `Lower("email")` returns `None` here.
 58    Sort direction (`F("col").desc()` / `OrderBy(F)`) doesn't affect equality
 59    lookups, so we unwrap one layer of `OrderBy` around a bare `F`.
 60    """
 61    if isinstance(expr, OrderBy):
 62        expr = expr.expression
 63    if isinstance(expr, F):
 64        return expr.name
 65    return None
 66
 67
 68def _fk_covered_field_names(model: Any) -> set[str]:
 69    """Field names that appear as the leading column of a non-partial index
 70    or unique constraint — covering arbitrary FK lookups via the index's
 71    leading column.
 72
 73    Partial indexes/constraints (declared with ``condition=Q(...)``) are
 74    excluded: Postgres can only use them for queries whose predicate
 75    implies the partial-index predicate, so an FK lookup or cascade
 76    delete that doesn't filter by that condition still does a sequential
 77    scan. (The narrow ``WHERE fk IS NOT NULL`` case — which Postgres can
 78    match to ``WHERE fk = ?`` — is conservatively treated as not
 79    covering; users wanting guaranteed FK coverage should add a regular
 80    non-partial ``Index(fields=[...])``.) Includes expression-based
 81    indexes/constraints whose leading expression is a bare
 82    ``F(field_name)``.
 83    """
 84    covered: set[str] = set()
 85
 86    def _record_leading(
 87        fields: tuple[str, ...] | list[str], expressions: tuple
 88    ) -> None:
 89        if fields:
 90            covered.add(fields[0].lstrip("-"))
 91        elif expressions:
 92            name = _bare_column_name(expressions[0])
 93            if name is not None:
 94                covered.add(name)
 95
 96    for index in model.model_options.indexes:
 97        if not index.is_partial:
 98            _record_leading(index.fields, index.expressions)
 99
100    for constraint in model.model_options.constraints:
101        if isinstance(constraint, UniqueConstraint) and not constraint.is_partial:
102            _record_leading(constraint.fields, constraint.expressions)
103
104    return covered
105
106
107@register_check("postgres.all_models")
108class CheckAllModels(PreflightCheck):
109    """Validates all model definitions for common issues."""
110
111    def run(self) -> list[PreflightResult]:
112        db_table_models = defaultdict(list)
113        # Indexes and constraints share the same Postgres namespace,
114        # so track them together to catch cross-type collisions.
115        relation_names = defaultdict(list)
116        errors = []
117        models = models_registry.get_models()
118        for model in models:
119            db_table_models[model.model_options.db_table].append(
120                model.model_options.label
121            )
122            if not inspect.ismethod(model.preflight):
123                errors.append(
124                    PreflightResult(
125                        fix=f"The '{model.__name__}.preflight()' class method is currently overridden by {model.preflight!r}.",
126                        obj=model,
127                        id="postgres.preflight_method_overridden",
128                    )
129                )
130            else:
131                errors.extend(model.preflight())
132            for model_index in model.model_options.indexes:
133                relation_names[model_index.name].append(model.model_options.label)
134            for model_constraint in model.model_options.constraints:
135                relation_names[model_constraint.name].append(model.model_options.label)
136        for db_table, model_labels in db_table_models.items():
137            if len(model_labels) != 1:
138                model_labels_str = ", ".join(model_labels)
139                errors.append(
140                    PreflightResult(
141                        fix=f"db_table '{db_table}' is used by multiple models: {model_labels_str}.",
142                        obj=db_table,
143                        id="postgres.duplicate_db_table",
144                    )
145                )
146        for relation_name, model_labels in relation_names.items():
147            if len(model_labels) > 1:
148                unique_models = set(model_labels)
149                single_model = len(unique_models) == 1
150                errors.append(
151                    PreflightResult(
152                        fix="index/constraint name '{}' is not unique {} {}.".format(
153                            relation_name,
154                            "for model" if single_model else "among models:",
155                            ", ".join(sorted(unique_models)),
156                        ),
157                        id="postgres.relation_name_not_unique_single"
158                        if single_model
159                        else "postgres.relation_name_not_unique_multiple",
160                    ),
161                )
162        return errors
163
164
165def _check_lazy_references(
166    models_registry: ModelsRegistry, packages_registry: Any
167) -> list[PreflightResult]:
168    """
169    Ensure all lazy (i.e. string) model references have been resolved.
170
171    Lazy references are used in various places throughout Plain, primarily in
172    related fields and model signals. Identify those common cases and provide
173    more helpful error messages for them.
174    """
175    pending_models = set(models_registry._pending_operations)
176
177    # Short circuit if there aren't any errors.
178    if not pending_models:
179        return []
180
181    def extract_operation(
182        obj: Any,
183    ) -> tuple[Callable[..., Any], list[Any], dict[str, Any]]:
184        """
185        Take a callable found in Packages._pending_operations and identify the
186        original callable passed to Packages.lazy_model_operation(). If that
187        callable was a partial, return the inner, non-partial function and
188        any arguments and keyword arguments that were supplied with it.
189
190        obj is a callback defined locally in Packages.lazy_model_operation() and
191        annotated there with a `func` attribute so as to imitate a partial.
192        """
193        operation, args, keywords = obj, [], {}
194        while hasattr(operation, "func"):
195            args.extend(getattr(operation, "args", []))
196            keywords.update(getattr(operation, "keywords", {}))
197            operation = operation.func
198        return operation, args, keywords
199
200    def app_model_error(model_key: tuple[str, str]) -> str:
201        try:
202            packages_registry.get_package_config(model_key[0])
203            model_error = "app '{}' doesn't provide model '{}'".format(*model_key)
204        except LookupError:
205            model_error = f"app '{model_key[0]}' isn't installed"
206        return model_error
207
208    # Here are several functions which return CheckMessage instances for the
209    # most common usages of lazy operations throughout Plain. These functions
210    # take the model that was being waited on as an (package_label, modelname)
211    # pair, the original lazy function, and its positional and keyword args as
212    # determined by extract_operation().
213
214    def field_error(
215        model_key: tuple[str, str],
216        func: Callable[..., Any],
217        args: list[Any],
218        keywords: dict[str, Any],
219    ) -> PreflightResult:
220        error_msg = (
221            "The field %(field)s was declared with a lazy reference "
222            "to '%(model)s', but %(model_error)s."
223        )
224        params = {
225            "model": ".".join(model_key),
226            "field": keywords["field"],
227            "model_error": app_model_error(model_key),
228        }
229        return PreflightResult(
230            fix=error_msg % params,
231            obj=keywords["field"],
232            id="fields.lazy_reference_not_resolvable",
233        )
234
235    def default_error(
236        model_key: tuple[str, str],
237        func: Callable[..., Any],
238        args: list[Any],
239        keywords: dict[str, Any],
240    ) -> PreflightResult:
241        error_msg = (
242            "%(op)s contains a lazy reference to %(model)s, but %(model_error)s."
243        )
244        params = {
245            "op": func,
246            "model": ".".join(model_key),
247            "model_error": app_model_error(model_key),
248        }
249        return PreflightResult(
250            fix=error_msg % params,
251            obj=func,
252            id="postgres.lazy_reference_resolution_failed",
253        )
254
255    # Maps common uses of lazy operations to corresponding error functions
256    # defined above. If a key maps to None, no error will be produced.
257    # default_error() will be used for usages that don't appear in this dict.
258    known_lazy = {
259        ("plain.postgres.fields.related", "resolve_related_class"): field_error,
260    }
261
262    def build_error(
263        model_key: tuple[str, str],
264        func: Callable[..., Any],
265        args: list[Any],
266        keywords: dict[str, Any],
267    ) -> PreflightResult | None:
268        key = (func.__module__, func.__name__)  # ty: ignore[unresolved-attribute]
269        error_fn = known_lazy.get(key, default_error)
270        return error_fn(model_key, func, args, keywords) if error_fn else None
271
272    return sorted(
273        filter(
274            None,
275            (
276                build_error(model_key, *extract_operation(func))
277                for model_key in pending_models
278                for func in models_registry._pending_operations[model_key]
279            ),
280        ),
281        key=lambda error: error.fix,
282    )
283
284
285@register_check("postgres.lazy_references")
286class CheckLazyReferences(PreflightCheck):
287    """Ensures all lazy (string) model references have been resolved."""
288
289    def run(self) -> list[PreflightResult]:
290        return _check_lazy_references(models_registry, packages_registry)
291
292
293@register_check("postgres.middleware_installed")
294class CheckMiddlewareInstalled(PreflightCheck):
295    """Errors if `DatabaseConnectionMiddleware` isn't in `MIDDLEWARE`.
296
297    Without it, pooled connections are only released by GC at the end of
298    each request — relying on refcount timing under load is a recipe for
299    pool exhaustion under cyclic refs or delayed finalization.
300    """
301
302    REQUIRED = "plain.postgres.DatabaseConnectionMiddleware"
303
304    def run(self) -> list[PreflightResult]:
305        if self.REQUIRED in settings.MIDDLEWARE:
306            return []
307        return [
308            PreflightResult(
309                fix=(
310                    f"Add '{self.REQUIRED}' to MIDDLEWARE so pooled "
311                    "database connections are returned at the end of each "
312                    "request. Place it first so its after_response runs "
313                    "after any middleware that queries the database."
314                ),
315                id="postgres.middleware_not_installed",
316            )
317        ]
318
319
320@register_check("postgres.postgres_version")
321class CheckPostgresVersion(PreflightCheck):
322    """Checks that the PostgreSQL server meets the minimum version requirement."""
323
324    MINIMUM_VERSION = 16
325
326    def run(self) -> list[PreflightResult]:
327        conn = get_connection()
328        conn.ensure_connection()
329        assert conn.connection is not None
330        major, minor = divmod(conn.connection.info.server_version, 10000)
331        if major < self.MINIMUM_VERSION:
332            return [
333                PreflightResult(
334                    fix=f"PostgreSQL {self.MINIMUM_VERSION} or later is required (found {major}.{minor}).",
335                    id="postgres.postgres_version_too_old",
336                )
337            ]
338        return []
339
340
341@register_check("postgres.database_tables")
342class CheckDatabaseTables(PreflightCheck):
343    """Checks for unknown tables in the database when plain.postgres is available."""
344
345    def run(self) -> list[PreflightResult]:
346        from .introspection import get_unknown_tables
347
348        unknown_tables = get_unknown_tables()
349
350        if not unknown_tables:
351            return []
352
353        table_names = ", ".join(unknown_tables)
354        return [
355            PreflightResult(
356                fix=f"Unknown tables in default database: {table_names}. "
357                "Tables may be from packages/models that have been uninstalled. "
358                "Make sure you have a backup, then run `plain postgres drop-unknown-tables` to remove them.",
359                id="postgres.unknown_database_tables",
360                warning=True,
361            )
362        ]
363
364
365@register_check("postgres.prunable_migrations")
366class CheckPrunableMigrations(PreflightCheck):
367    """Warns about stale migration records in the database."""
368
369    def run(self) -> list[PreflightResult]:
370        # Import here to avoid circular import issues
371        from plain.postgres.migrations.loader import MigrationLoader
372        from plain.postgres.migrations.recorder import MigrationRecorder
373
374        errors = []
375
376        # Load migrations from disk and database
377        conn = get_connection()
378        loader = MigrationLoader(conn, ignore_no_migrations=True)
379        recorder = MigrationRecorder(conn)
380        recorded_migrations = recorder.applied_migrations()
381
382        # disk_migrations should not be None after MigrationLoader initialization,
383        # but check to satisfy type checker
384        if loader.disk_migrations is None:
385            return errors
386
387        # Find all prunable migrations (recorded but not on disk)
388        all_prunable = [
389            migration
390            for migration in recorded_migrations
391            if migration not in loader.disk_migrations
392        ]
393
394        if not all_prunable:
395            return errors
396
397        # Separate into existing packages vs orphaned packages
398        existing_packages = set(loader.migrated_packages)
399        prunable_existing: list[tuple[str, str]] = []
400        prunable_orphaned: list[tuple[str, str]] = []
401
402        for migration in all_prunable:
403            package, name = migration
404            if package in existing_packages:
405                prunable_existing.append(migration)
406            else:
407                prunable_orphaned.append(migration)
408
409        # Build the warning message
410        total_count = len(all_prunable)
411        message_parts = [
412            f"Found {total_count} stale migration record{'s' if total_count != 1 else ''} in the database."
413        ]
414
415        if prunable_existing:
416            existing_list = ", ".join(
417                f"{pkg}.{name}" for pkg, name in prunable_existing[:3]
418            )
419            if len(prunable_existing) > 3:
420                existing_list += f" (and {len(prunable_existing) - 3} more)"
421            message_parts.append(f"From existing packages: {existing_list}.")
422
423        if prunable_orphaned:
424            orphaned_list = ", ".join(
425                f"{pkg}.{name}" for pkg, name in prunable_orphaned[:3]
426            )
427            if len(prunable_orphaned) > 3:
428                orphaned_list += f" (and {len(prunable_orphaned) - 3} more)"
429            message_parts.append(f"From removed packages: {orphaned_list}.")
430
431        message_parts.append("Run 'plain migrations prune' to review and remove them.")
432
433        errors.append(
434            PreflightResult(
435                fix=" ".join(message_parts),
436                id="postgres.prunable_migrations",
437                warning=True,
438            )
439        )
440
441        return errors
442
443
444@register_check("postgres.missing_fk_indexes")
445class CheckMissingFKIndexes(PreflightCheck):
446    """Warns about foreign key fields without index coverage."""
447
448    def run(self) -> list[PreflightResult]:
449        results = []
450
451        for model in _get_app_models():
452            covered_fields = _fk_covered_field_names(model)
453
454            for field in model._model_meta.local_fields:
455                if (
456                    isinstance(field, ForeignKeyField)
457                    and not field.primary_key
458                    and field.name not in covered_fields
459                ):
460                    results.append(
461                        PreflightResult(
462                            fix=f"Foreign key '{field.name}' has no index coverage. "
463                            f"Add an Index on [\"{field.name}\"] or a constraint with '{field.name}' as the first field.",
464                            obj=f"{model.model_options.label}.{field.name}",
465                            id="postgres.missing_fk_index",
466                            warning=True,
467                        )
468                    )
469
470        return results
471
472
473@register_check("postgres.duplicate_indexes")
474class CheckDuplicateIndexes(PreflightCheck):
475    """Warns about indexes redundant with other indexes or constraints.
476
477    Catches both prefix-redundancy (a 1-column index shadowed by a wider
478    composite) and exact-column duplicates (an `Index(fields=["x"])` that
479    duplicates a same-column `UniqueConstraint`).
480    """
481
482    def run(self) -> list[PreflightResult]:
483        results = []
484
485        for model in _get_app_models():
486            all_indexes = _collect_model_indexes(model)
487
488            flagged: set[str] = set()
489            for i, idx_a in enumerate(all_indexes):
490                for idx_b in all_indexes[i + 1 :]:
491                    for shorter, longer in [(idx_a, idx_b), (idx_b, idx_a)]:
492                        s_name, s_fields, s_unique = shorter
493                        l_name, l_fields, l_unique = longer
494
495                        if s_name in flagged:
496                            continue
497
498                        is_prefix_dup = (
499                            not s_unique
500                            and len(s_fields) < len(l_fields)
501                            and l_fields[: len(s_fields)] == s_fields
502                        )
503                        is_exact_dup = (
504                            s_fields == l_fields
505                            and not s_unique
506                            and (l_unique or s_name > l_name)
507                        )
508
509                        if not (is_prefix_dup or is_exact_dup):
510                            continue
511
512                        if is_prefix_dup:
513                            fix = (
514                                f"Index '{s_name}' on [{', '.join(s_fields)}] "
515                                f"is redundant with '{l_name}' on [{', '.join(l_fields)}]. "
516                                f"The longer index covers the same queries."
517                            )
518                        elif l_unique:
519                            fix = (
520                                f"Index '{s_name}' on [{', '.join(s_fields)}] "
521                                f"is redundant with '{l_name}' on the same columns. "
522                                f"The unique-backed index already covers these queries."
523                            )
524                        else:
525                            fix = (
526                                f"Index '{s_name}' on [{', '.join(s_fields)}] "
527                                f"is an exact duplicate of '{l_name}'. "
528                                f"Drop one of them."
529                            )
530
531                        results.append(
532                            PreflightResult(
533                                fix=fix,
534                                obj=model.model_options.label,
535                                id="postgres.duplicate_index",
536                                warning=True,
537                            )
538                        )
539                        flagged.add(s_name)
540
541        return results