1from __future__ import annotations
2
3import inspect
4from collections import defaultdict
5from collections.abc import Callable
6from typing import Any
7
8from plain.packages import packages_registry
9from plain.postgres.constraints import UniqueConstraint
10from plain.postgres.db import get_connection
11from plain.postgres.expressions import F, OrderBy
12from plain.postgres.fields.related import ForeignKeyField
13from plain.postgres.registry import ModelsRegistry, models_registry
14from plain.preflight import PreflightCheck, PreflightResult, register_check
15from plain.runtime import settings
16
17
18def _get_app_models() -> list[Any]:
19 """Return models from the user's app packages only (not framework/third-party)."""
20 app_models = []
21 for package_config in packages_registry.get_package_configs():
22 if package_config.name.startswith("app."):
23 app_models.extend(
24 models_registry.get_models(package_label=package_config.package_label)
25 )
26 return app_models
27
28
29def _collect_model_indexes(model: Any) -> list[tuple[str, list[str], bool]]:
30 """Collect (name, fields, is_unique) for non-partial indexes/constraints.
31
32 Partials are skipped for the same reason as in ``_fk_covered_field_names``.
33 """
34 all_indexes: list[tuple[str, list[str], bool]] = []
35
36 for index in model.model_options.indexes:
37 if index.fields and not index.is_partial:
38 fields = [f.lstrip("-") for f in index.fields]
39 all_indexes.append((index.name, fields, False))
40
41 for constraint in model.model_options.constraints:
42 if (
43 isinstance(constraint, UniqueConstraint)
44 and constraint.fields
45 and not constraint.is_partial
46 ):
47 all_indexes.append((constraint.name, list(constraint.fields), True))
48
49 return all_indexes
50
51
52def _bare_column_name(expr: Any) -> str | None:
53 """Return the column name if `expr` resolves to a bare column, else `None`.
54
55 Postgres can range-scan the leading column of an index for `WHERE col = ?`
56 only when that column is a real attribute, not an expression — so a
57 compound leading expression like `Lower("email")` returns `None` here.
58 Sort direction (`F("col").desc()` / `OrderBy(F)`) doesn't affect equality
59 lookups, so we unwrap one layer of `OrderBy` around a bare `F`.
60 """
61 if isinstance(expr, OrderBy):
62 expr = expr.expression
63 if isinstance(expr, F):
64 return expr.name
65 return None
66
67
68def _fk_covered_field_names(model: Any) -> set[str]:
69 """Field names that appear as the leading column of a non-partial index
70 or unique constraint — covering arbitrary FK lookups via the index's
71 leading column.
72
73 Partial indexes/constraints (declared with ``condition=Q(...)``) are
74 excluded: Postgres can only use them for queries whose predicate
75 implies the partial-index predicate, so an FK lookup or cascade
76 delete that doesn't filter by that condition still does a sequential
77 scan. (The narrow ``WHERE fk IS NOT NULL`` case — which Postgres can
78 match to ``WHERE fk = ?`` — is conservatively treated as not
79 covering; users wanting guaranteed FK coverage should add a regular
80 non-partial ``Index(fields=[...])``.) Includes expression-based
81 indexes/constraints whose leading expression is a bare
82 ``F(field_name)``.
83 """
84 covered: set[str] = set()
85
86 def _record_leading(
87 fields: tuple[str, ...] | list[str], expressions: tuple
88 ) -> None:
89 if fields:
90 covered.add(fields[0].lstrip("-"))
91 elif expressions:
92 name = _bare_column_name(expressions[0])
93 if name is not None:
94 covered.add(name)
95
96 for index in model.model_options.indexes:
97 if not index.is_partial:
98 _record_leading(index.fields, index.expressions)
99
100 for constraint in model.model_options.constraints:
101 if isinstance(constraint, UniqueConstraint) and not constraint.is_partial:
102 _record_leading(constraint.fields, constraint.expressions)
103
104 return covered
105
106
107@register_check("postgres.all_models")
108class CheckAllModels(PreflightCheck):
109 """Validates all model definitions for common issues."""
110
111 def run(self) -> list[PreflightResult]:
112 db_table_models = defaultdict(list)
113 # Indexes and constraints share the same Postgres namespace,
114 # so track them together to catch cross-type collisions.
115 relation_names = defaultdict(list)
116 errors = []
117 models = models_registry.get_models()
118 for model in models:
119 db_table_models[model.model_options.db_table].append(
120 model.model_options.label
121 )
122 if not inspect.ismethod(model.preflight):
123 errors.append(
124 PreflightResult(
125 fix=f"The '{model.__name__}.preflight()' class method is currently overridden by {model.preflight!r}.",
126 obj=model,
127 id="postgres.preflight_method_overridden",
128 )
129 )
130 else:
131 errors.extend(model.preflight())
132 for model_index in model.model_options.indexes:
133 relation_names[model_index.name].append(model.model_options.label)
134 for model_constraint in model.model_options.constraints:
135 relation_names[model_constraint.name].append(model.model_options.label)
136 for db_table, model_labels in db_table_models.items():
137 if len(model_labels) != 1:
138 model_labels_str = ", ".join(model_labels)
139 errors.append(
140 PreflightResult(
141 fix=f"db_table '{db_table}' is used by multiple models: {model_labels_str}.",
142 obj=db_table,
143 id="postgres.duplicate_db_table",
144 )
145 )
146 for relation_name, model_labels in relation_names.items():
147 if len(model_labels) > 1:
148 unique_models = set(model_labels)
149 single_model = len(unique_models) == 1
150 errors.append(
151 PreflightResult(
152 fix="index/constraint name '{}' is not unique {} {}.".format(
153 relation_name,
154 "for model" if single_model else "among models:",
155 ", ".join(sorted(unique_models)),
156 ),
157 id="postgres.relation_name_not_unique_single"
158 if single_model
159 else "postgres.relation_name_not_unique_multiple",
160 ),
161 )
162 return errors
163
164
165def _check_lazy_references(
166 models_registry: ModelsRegistry, packages_registry: Any
167) -> list[PreflightResult]:
168 """
169 Ensure all lazy (i.e. string) model references have been resolved.
170
171 Lazy references are used in various places throughout Plain, primarily in
172 related fields and model signals. Identify those common cases and provide
173 more helpful error messages for them.
174 """
175 pending_models = set(models_registry._pending_operations)
176
177 # Short circuit if there aren't any errors.
178 if not pending_models:
179 return []
180
181 def extract_operation(
182 obj: Any,
183 ) -> tuple[Callable[..., Any], list[Any], dict[str, Any]]:
184 """
185 Take a callable found in Packages._pending_operations and identify the
186 original callable passed to Packages.lazy_model_operation(). If that
187 callable was a partial, return the inner, non-partial function and
188 any arguments and keyword arguments that were supplied with it.
189
190 obj is a callback defined locally in Packages.lazy_model_operation() and
191 annotated there with a `func` attribute so as to imitate a partial.
192 """
193 operation, args, keywords = obj, [], {}
194 while hasattr(operation, "func"):
195 args.extend(getattr(operation, "args", []))
196 keywords.update(getattr(operation, "keywords", {}))
197 operation = operation.func
198 return operation, args, keywords
199
200 def app_model_error(model_key: tuple[str, str]) -> str:
201 try:
202 packages_registry.get_package_config(model_key[0])
203 model_error = "app '{}' doesn't provide model '{}'".format(*model_key)
204 except LookupError:
205 model_error = f"app '{model_key[0]}' isn't installed"
206 return model_error
207
208 # Here are several functions which return CheckMessage instances for the
209 # most common usages of lazy operations throughout Plain. These functions
210 # take the model that was being waited on as an (package_label, modelname)
211 # pair, the original lazy function, and its positional and keyword args as
212 # determined by extract_operation().
213
214 def field_error(
215 model_key: tuple[str, str],
216 func: Callable[..., Any],
217 args: list[Any],
218 keywords: dict[str, Any],
219 ) -> PreflightResult:
220 error_msg = (
221 "The field %(field)s was declared with a lazy reference "
222 "to '%(model)s', but %(model_error)s."
223 )
224 params = {
225 "model": ".".join(model_key),
226 "field": keywords["field"],
227 "model_error": app_model_error(model_key),
228 }
229 return PreflightResult(
230 fix=error_msg % params,
231 obj=keywords["field"],
232 id="fields.lazy_reference_not_resolvable",
233 )
234
235 def default_error(
236 model_key: tuple[str, str],
237 func: Callable[..., Any],
238 args: list[Any],
239 keywords: dict[str, Any],
240 ) -> PreflightResult:
241 error_msg = (
242 "%(op)s contains a lazy reference to %(model)s, but %(model_error)s."
243 )
244 params = {
245 "op": func,
246 "model": ".".join(model_key),
247 "model_error": app_model_error(model_key),
248 }
249 return PreflightResult(
250 fix=error_msg % params,
251 obj=func,
252 id="postgres.lazy_reference_resolution_failed",
253 )
254
255 # Maps common uses of lazy operations to corresponding error functions
256 # defined above. If a key maps to None, no error will be produced.
257 # default_error() will be used for usages that don't appear in this dict.
258 known_lazy = {
259 ("plain.postgres.fields.related", "resolve_related_class"): field_error,
260 }
261
262 def build_error(
263 model_key: tuple[str, str],
264 func: Callable[..., Any],
265 args: list[Any],
266 keywords: dict[str, Any],
267 ) -> PreflightResult | None:
268 key = (func.__module__, func.__name__) # ty: ignore[unresolved-attribute]
269 error_fn = known_lazy.get(key, default_error)
270 return error_fn(model_key, func, args, keywords) if error_fn else None
271
272 return sorted(
273 filter(
274 None,
275 (
276 build_error(model_key, *extract_operation(func))
277 for model_key in pending_models
278 for func in models_registry._pending_operations[model_key]
279 ),
280 ),
281 key=lambda error: error.fix,
282 )
283
284
285@register_check("postgres.lazy_references")
286class CheckLazyReferences(PreflightCheck):
287 """Ensures all lazy (string) model references have been resolved."""
288
289 def run(self) -> list[PreflightResult]:
290 return _check_lazy_references(models_registry, packages_registry)
291
292
293@register_check("postgres.middleware_installed")
294class CheckMiddlewareInstalled(PreflightCheck):
295 """Errors if `DatabaseConnectionMiddleware` isn't in `MIDDLEWARE`.
296
297 Without it, pooled connections are only released by GC at the end of
298 each request — relying on refcount timing under load is a recipe for
299 pool exhaustion under cyclic refs or delayed finalization.
300 """
301
302 REQUIRED = "plain.postgres.DatabaseConnectionMiddleware"
303
304 def run(self) -> list[PreflightResult]:
305 if self.REQUIRED in settings.MIDDLEWARE:
306 return []
307 return [
308 PreflightResult(
309 fix=(
310 f"Add '{self.REQUIRED}' to MIDDLEWARE so pooled "
311 "database connections are returned at the end of each "
312 "request. Place it first so its after_response runs "
313 "after any middleware that queries the database."
314 ),
315 id="postgres.middleware_not_installed",
316 )
317 ]
318
319
320@register_check("postgres.postgres_version")
321class CheckPostgresVersion(PreflightCheck):
322 """Checks that the PostgreSQL server meets the minimum version requirement."""
323
324 MINIMUM_VERSION = 16
325
326 def run(self) -> list[PreflightResult]:
327 conn = get_connection()
328 conn.ensure_connection()
329 assert conn.connection is not None
330 major, minor = divmod(conn.connection.info.server_version, 10000)
331 if major < self.MINIMUM_VERSION:
332 return [
333 PreflightResult(
334 fix=f"PostgreSQL {self.MINIMUM_VERSION} or later is required (found {major}.{minor}).",
335 id="postgres.postgres_version_too_old",
336 )
337 ]
338 return []
339
340
341@register_check("postgres.database_tables")
342class CheckDatabaseTables(PreflightCheck):
343 """Checks for unknown tables in the database when plain.postgres is available."""
344
345 def run(self) -> list[PreflightResult]:
346 from .introspection import get_unknown_tables
347
348 unknown_tables = get_unknown_tables()
349
350 if not unknown_tables:
351 return []
352
353 table_names = ", ".join(unknown_tables)
354 return [
355 PreflightResult(
356 fix=f"Unknown tables in default database: {table_names}. "
357 "Tables may be from packages/models that have been uninstalled. "
358 "Make sure you have a backup, then run `plain postgres drop-unknown-tables` to remove them.",
359 id="postgres.unknown_database_tables",
360 warning=True,
361 )
362 ]
363
364
365@register_check("postgres.prunable_migrations")
366class CheckPrunableMigrations(PreflightCheck):
367 """Warns about stale migration records in the database."""
368
369 def run(self) -> list[PreflightResult]:
370 # Import here to avoid circular import issues
371 from plain.postgres.migrations.loader import MigrationLoader
372 from plain.postgres.migrations.recorder import MigrationRecorder
373
374 errors = []
375
376 # Load migrations from disk and database
377 conn = get_connection()
378 loader = MigrationLoader(conn, ignore_no_migrations=True)
379 recorder = MigrationRecorder(conn)
380 recorded_migrations = recorder.applied_migrations()
381
382 # disk_migrations should not be None after MigrationLoader initialization,
383 # but check to satisfy type checker
384 if loader.disk_migrations is None:
385 return errors
386
387 # Find all prunable migrations (recorded but not on disk)
388 all_prunable = [
389 migration
390 for migration in recorded_migrations
391 if migration not in loader.disk_migrations
392 ]
393
394 if not all_prunable:
395 return errors
396
397 # Separate into existing packages vs orphaned packages
398 existing_packages = set(loader.migrated_packages)
399 prunable_existing: list[tuple[str, str]] = []
400 prunable_orphaned: list[tuple[str, str]] = []
401
402 for migration in all_prunable:
403 package, name = migration
404 if package in existing_packages:
405 prunable_existing.append(migration)
406 else:
407 prunable_orphaned.append(migration)
408
409 # Build the warning message
410 total_count = len(all_prunable)
411 message_parts = [
412 f"Found {total_count} stale migration record{'s' if total_count != 1 else ''} in the database."
413 ]
414
415 if prunable_existing:
416 existing_list = ", ".join(
417 f"{pkg}.{name}" for pkg, name in prunable_existing[:3]
418 )
419 if len(prunable_existing) > 3:
420 existing_list += f" (and {len(prunable_existing) - 3} more)"
421 message_parts.append(f"From existing packages: {existing_list}.")
422
423 if prunable_orphaned:
424 orphaned_list = ", ".join(
425 f"{pkg}.{name}" for pkg, name in prunable_orphaned[:3]
426 )
427 if len(prunable_orphaned) > 3:
428 orphaned_list += f" (and {len(prunable_orphaned) - 3} more)"
429 message_parts.append(f"From removed packages: {orphaned_list}.")
430
431 message_parts.append("Run 'plain migrations prune' to review and remove them.")
432
433 errors.append(
434 PreflightResult(
435 fix=" ".join(message_parts),
436 id="postgres.prunable_migrations",
437 warning=True,
438 )
439 )
440
441 return errors
442
443
444@register_check("postgres.missing_fk_indexes")
445class CheckMissingFKIndexes(PreflightCheck):
446 """Warns about foreign key fields without index coverage."""
447
448 def run(self) -> list[PreflightResult]:
449 results = []
450
451 for model in _get_app_models():
452 covered_fields = _fk_covered_field_names(model)
453
454 for field in model._model_meta.local_fields:
455 if (
456 isinstance(field, ForeignKeyField)
457 and not field.primary_key
458 and field.name not in covered_fields
459 ):
460 results.append(
461 PreflightResult(
462 fix=f"Foreign key '{field.name}' has no index coverage. "
463 f"Add an Index on [\"{field.name}\"] or a constraint with '{field.name}' as the first field.",
464 obj=f"{model.model_options.label}.{field.name}",
465 id="postgres.missing_fk_index",
466 warning=True,
467 )
468 )
469
470 return results
471
472
473@register_check("postgres.duplicate_indexes")
474class CheckDuplicateIndexes(PreflightCheck):
475 """Warns about indexes redundant with other indexes or constraints.
476
477 Catches both prefix-redundancy (a 1-column index shadowed by a wider
478 composite) and exact-column duplicates (an `Index(fields=["x"])` that
479 duplicates a same-column `UniqueConstraint`).
480 """
481
482 def run(self) -> list[PreflightResult]:
483 results = []
484
485 for model in _get_app_models():
486 all_indexes = _collect_model_indexes(model)
487
488 flagged: set[str] = set()
489 for i, idx_a in enumerate(all_indexes):
490 for idx_b in all_indexes[i + 1 :]:
491 for shorter, longer in [(idx_a, idx_b), (idx_b, idx_a)]:
492 s_name, s_fields, s_unique = shorter
493 l_name, l_fields, l_unique = longer
494
495 if s_name in flagged:
496 continue
497
498 is_prefix_dup = (
499 not s_unique
500 and len(s_fields) < len(l_fields)
501 and l_fields[: len(s_fields)] == s_fields
502 )
503 is_exact_dup = (
504 s_fields == l_fields
505 and not s_unique
506 and (l_unique or s_name > l_name)
507 )
508
509 if not (is_prefix_dup or is_exact_dup):
510 continue
511
512 if is_prefix_dup:
513 fix = (
514 f"Index '{s_name}' on [{', '.join(s_fields)}] "
515 f"is redundant with '{l_name}' on [{', '.join(l_fields)}]. "
516 f"The longer index covers the same queries."
517 )
518 elif l_unique:
519 fix = (
520 f"Index '{s_name}' on [{', '.join(s_fields)}] "
521 f"is redundant with '{l_name}' on the same columns. "
522 f"The unique-backed index already covers these queries."
523 )
524 else:
525 fix = (
526 f"Index '{s_name}' on [{', '.join(s_fields)}] "
527 f"is an exact duplicate of '{l_name}'. "
528 f"Drop one of them."
529 )
530
531 results.append(
532 PreflightResult(
533 fix=fix,
534 obj=model.model_options.label,
535 id="postgres.duplicate_index",
536 warning=True,
537 )
538 )
539 flagged.add(s_name)
540
541 return results