1from __future__ import annotations
2
3from collections.abc import Callable
4from contextlib import nullcontext
5from typing import TYPE_CHECKING, Any
6
7from ..transaction import atomic
8from .loader import MigrationLoader
9from .migration import Migration
10from .recorder import MigrationRecorder
11from .state import ProjectState
12
13if TYPE_CHECKING:
14 from plain.postgres.connection import DatabaseConnection
15
16
17class MigrationExecutor:
18 """
19 End-to-end migration execution - load migrations and run them up or down
20 to a specified set of targets.
21 """
22
23 def __init__(
24 self,
25 connection: DatabaseConnection,
26 progress_callback: Callable[..., Any] | None = None,
27 ) -> None:
28 self.connection = connection
29 self.loader = MigrationLoader(self.connection)
30 self.recorder = MigrationRecorder(self.connection)
31 self.progress_callback = progress_callback
32
33 def migration_plan(
34 self, targets: list[tuple[str, str]], clean_start: bool = False
35 ) -> list[Migration]:
36 """
37 Given a set of targets, return a list of Migration instances.
38 """
39 plan = []
40 if clean_start:
41 applied = {}
42 else:
43 applied_source = self.loader.applied_migrations or {}
44 applied = dict(applied_source)
45 for target in targets:
46 for migration in self.loader.graph.forwards_plan(target):
47 if migration not in applied:
48 plan.append(self.loader.graph.nodes[migration])
49 applied[migration] = self.loader.graph.nodes[migration]
50 return plan
51
52 def _create_project_state(
53 self, with_applied_migrations: bool = False
54 ) -> ProjectState:
55 """
56 Create a project state including all the applications without
57 migrations and applied migrations if with_applied_migrations=True.
58 """
59 state = ProjectState(real_packages=self.loader.unmigrated_packages)
60 if with_applied_migrations:
61 # Create the forwards plan Plain would follow on an empty database
62 full_plan = self.migration_plan(
63 self.loader.graph.leaf_nodes(), clean_start=True
64 )
65 applied_source = self.loader.applied_migrations or {}
66 applied_migrations = {
67 self.loader.graph.nodes[key]
68 for key in applied_source
69 if key in self.loader.graph.nodes
70 }
71 for migration in full_plan:
72 if migration in applied_migrations:
73 migration.mutate_state(state, preserve=False)
74 return state
75
76 def migrate(
77 self,
78 targets: list[tuple[str, str]],
79 plan: list[Migration] | None = None,
80 state: ProjectState | None = None,
81 fake: bool = False,
82 atomic_batch: bool = False,
83 ) -> ProjectState:
84 """
85 Migrate the database up to the given targets.
86
87 Plain first needs to create all project states before a migration is
88 (un)applied and in a second step run all the database operations.
89
90 atomic_batch: Whether to run all migrations in a single transaction.
91 """
92 # The plain_migrations table must be present to record applied
93 # migrations, but don't create it if there are no migrations to apply.
94 if plan == []:
95 if not self.recorder.has_table():
96 return self._create_project_state(with_applied_migrations=False)
97 else:
98 self.recorder.ensure_schema()
99
100 if plan is None:
101 plan = self.migration_plan(targets)
102 # Create the forwards plan Plain would follow on an empty database
103 full_plan = self.migration_plan(
104 self.loader.graph.leaf_nodes(), clean_start=True
105 )
106
107 if not plan:
108 if state is None:
109 # The resulting state should include applied migrations.
110 state = self._create_project_state(with_applied_migrations=True)
111 else:
112 if state is None:
113 # The resulting state should still include applied migrations.
114 state = self._create_project_state(with_applied_migrations=True)
115
116 migrations_to_run = set(plan)
117
118 # Choose context manager based on atomic_batch
119 batch_context = atomic if (atomic_batch and len(plan) > 1) else nullcontext
120
121 with batch_context():
122 for migration in full_plan:
123 if not migrations_to_run:
124 # We remove every migration that we applied from these sets so
125 # that we can bail out once the last migration has been applied
126 # and don't always run until the very end of the migration
127 # process.
128 break
129 if migration in migrations_to_run:
130 if "models_registry" not in state.__dict__:
131 state.models_registry # Render all -- performance critical
132 state = self.apply_migration(state, migration, fake=fake)
133 migrations_to_run.remove(migration)
134
135 self.check_replacements()
136
137 assert state is not None
138 return state
139
140 def apply_migration(
141 self, state: ProjectState, migration: Migration, fake: bool = False
142 ) -> ProjectState:
143 """Run a migration forwards."""
144 if self.progress_callback:
145 self.progress_callback("apply_start", migration=migration, fake=fake)
146 if not fake:
147 with self.connection.schema_editor(
148 atomic=migration.atomic
149 ) as schema_editor:
150 state = migration.apply(
151 state, schema_editor, operation_callback=self.progress_callback
152 )
153 self.record_migration(migration)
154 else:
155 self.record_migration(migration)
156 if self.progress_callback:
157 self.progress_callback("apply_success", migration=migration, fake=fake)
158 return state
159
160 def record_migration(self, migration: Migration) -> None:
161 # For replacement migrations, record individual statuses
162 if migration.replaces:
163 for package_label, name in migration.replaces:
164 self.recorder.record_applied(package_label, name)
165 else:
166 self.recorder.record_applied(migration.package_label, migration.name)
167
168 def check_replacements(self) -> None:
169 """
170 Mark replacement migrations applied if their replaced set all are.
171
172 Do this unconditionally on every migrate, rather than just when
173 migrations are applied or unapplied, to correctly handle the case
174 when a new squash migration is pushed to a deployment that already had
175 all its replaced migrations applied. In this case no new migration will
176 be applied, but the applied state of the squashed migration must be
177 maintained.
178 """
179 applied = self.recorder.applied_migrations()
180 for key, migration in self.loader.replacements.items():
181 all_applied = all(m in applied for m in migration.replaces)
182 if all_applied and key not in applied:
183 self.recorder.record_applied(*key)