1from __future__ import annotations
2
3import base64
4import json
5from functools import cache
6from typing import TYPE_CHECKING, Any
7
8try:
9 from cryptography.fernet import Fernet, InvalidToken, MultiFernet
10 from cryptography.hazmat.primitives import hashes
11 from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
12except ImportError:
13 Fernet = None # ty: ignore[invalid-assignment]
14 InvalidToken = None # ty: ignore[invalid-assignment]
15 MultiFernet = None # ty: ignore[invalid-assignment]
16 hashes = None # ty: ignore[invalid-assignment]
17 PBKDF2HMAC = None # ty: ignore[invalid-assignment]
18
19from plain import exceptions, preflight
20from plain.runtime import settings
21from plain.utils.encoding import force_bytes
22
23from .base import ColumnField
24
25if TYPE_CHECKING:
26 from collections.abc import Callable, Sequence
27
28 from plain.postgres.connection import DatabaseConnection
29 from plain.postgres.lookups import Lookup, Transform
30 from plain.preflight.results import PreflightResult
31
32__all__ = [
33 "EncryptedTextField",
34 "EncryptedJSONField",
35]
36
37# Fixed salt for key derivation — changing this would invalidate all encrypted data.
38# This is not secret; it ensures the derived encryption key is distinct from
39# keys derived for other purposes (e.g., signing) even from the same SECRET_KEY.
40_KDF_SALT = b"plain.postgres.fields.encrypted"
41
42# Prefix for encrypted values in the database.
43# Makes encrypted data self-describing and distinguishable from plaintext.
44_ENCRYPTED_PREFIX = "$fernet$"
45
46
47def _derive_fernet_key(secret: str) -> bytes:
48 """Derive a Fernet-compatible key from an arbitrary secret string."""
49 if PBKDF2HMAC is None:
50 raise ImportError(
51 "The 'cryptography' package is required to use encrypted fields. "
52 "Install it with: pip install cryptography"
53 )
54 kdf = PBKDF2HMAC(
55 algorithm=hashes.SHA256(),
56 length=32,
57 salt=_KDF_SALT,
58 iterations=480_000,
59 )
60 return base64.urlsafe_b64encode(kdf.derive(force_bytes(secret)))
61
62
63@cache
64def _get_fernet(secret_key: str, fallbacks: tuple[str, ...]) -> MultiFernet:
65 """Build a MultiFernet from the given secret key and fallbacks.
66
67 The first key is used for encryption.
68 All keys are used for decryption, enabling key rotation.
69 Results are cached by (secret_key, fallbacks) so changing SECRET_KEY
70 (e.g. in tests) produces a new MultiFernet automatically.
71 """
72 keys = [_derive_fernet_key(secret_key)]
73 for fallback in fallbacks:
74 keys.append(_derive_fernet_key(fallback))
75 return MultiFernet([Fernet(k) for k in keys])
76
77
78def _encrypt(value: str) -> str:
79 """Encrypt a string and return a self-describing database value."""
80 if value == "":
81 return value
82 f = _get_fernet(settings.SECRET_KEY, tuple(settings.SECRET_KEY_FALLBACKS))
83 token = f.encrypt(force_bytes(value))
84 return _ENCRYPTED_PREFIX + token.decode("ascii")
85
86
87def _decrypt(value: str) -> str:
88 """Decrypt a self-describing database value back to a string.
89
90 Gracefully handles unencrypted values — if the value doesn't have
91 the encryption prefix, it's returned as-is. This supports gradual
92 migration from plaintext to encrypted fields.
93 """
94 if not value.startswith(_ENCRYPTED_PREFIX):
95 return value
96 token = value[len(_ENCRYPTED_PREFIX) :]
97 f = _get_fernet(settings.SECRET_KEY, tuple(settings.SECRET_KEY_FALLBACKS))
98 try:
99 return f.decrypt(token.encode("ascii")).decode("utf-8")
100 except InvalidToken:
101 raise ValueError(
102 "Could not decrypt field value. The SECRET_KEY (and SECRET_KEY_FALLBACKS) "
103 "may have changed since this data was encrypted."
104 )
105
106
107# isnull is obviously needed. exact is required so that `filter(field=None)`
108# works — the ORM resolves "exact" first and then rewrites None to isnull.
109# Exact lookups on non-None values will silently return no results (since
110# ciphertext is non-deterministic), but blocking exact entirely would break
111# the None/isnull path.
112_ALLOWED_LOOKUPS = {"isnull", "exact"}
113
114
115class EncryptedFieldMixin:
116 """Shared behavior for all encrypted fields.
117
118 Blocks lookups (except isnull and exact) since encrypted values are non-deterministic.
119 Errors at preflight if the field is used in indexes or unique constraints.
120
121 Must be used with Field as a co-base class.
122 """
123
124 # Type hints for attributes provided by Field (the required co-base class)
125 name: str
126 model: Any
127
128 def get_lookup(self, lookup_name: str) -> type[Lookup] | None:
129 if lookup_name not in _ALLOWED_LOOKUPS:
130 return None
131 get_lookup = getattr(super(), "get_lookup")
132 return get_lookup(lookup_name)
133
134 def get_transform(
135 self, lookup_name: str
136 ) -> type[Transform] | Callable[..., Any] | None:
137 return None
138
139 def _check_encrypted_constraints(self) -> list[PreflightResult]:
140 errors: list[PreflightResult] = []
141 if not hasattr(self, "model"):
142 return errors
143
144 field_name = self.name
145
146 for constraint in self.model.model_options.constraints:
147 constraint_fields = getattr(constraint, "fields", ())
148 if field_name in constraint_fields:
149 errors.append(
150 preflight.PreflightResult(
151 fix=(
152 f"'{self.model.__name__}.{field_name}' is an encrypted field "
153 f"and cannot be used in constraint '{constraint.name}'. "
154 "Encrypted values are non-deterministic."
155 ),
156 obj=self,
157 id="fields.encrypted_in_constraint",
158 )
159 )
160
161 for index in self.model.model_options.indexes:
162 index_fields = getattr(index, "fields", ())
163 # Strip ordering prefix (e.g., "-field_name" for descending)
164 stripped_fields = [f.lstrip("-") for f in index_fields]
165 if field_name in stripped_fields:
166 errors.append(
167 preflight.PreflightResult(
168 fix=(
169 f"'{self.model.__name__}.{field_name}' is an encrypted field "
170 f"and cannot be used in index '{index.name}'. "
171 "Encrypted values are non-deterministic."
172 ),
173 obj=self,
174 id="fields.encrypted_in_index",
175 )
176 )
177
178 return errors
179
180
181class EncryptedTextField(EncryptedFieldMixin, ColumnField[str]):
182 """A text field that encrypts its value before storing in the database.
183
184 Values are encrypted using Fernet (AES-128-CBC + HMAC-SHA256) with a key
185 derived from SECRET_KEY. The database column is always ``text`` regardless
186 of max_length, since ciphertext length is unpredictable.
187
188 max_length is enforced on the plaintext value (validation), not on the
189 ciphertext stored in the database.
190 """
191
192 db_type_sql = "text"
193
194 def __init__(
195 self,
196 *,
197 max_length: int | None = None,
198 required: bool = True,
199 allow_null: bool = False,
200 validators: Sequence[Callable[..., Any]] = (),
201 ):
202 # `default` is intentionally not accepted: Fernet encryption is
203 # non-deterministic, so a literal column DEFAULT cannot be expressed.
204 self.max_length = max_length
205 super().__init__(
206 required=required,
207 allow_null=allow_null,
208 validators=validators,
209 )
210
211 def to_python(self, value: Any) -> str | None:
212 if isinstance(value, str) or value is None:
213 return value
214 return str(value)
215
216 def validate(self, value: Any, model_instance: Any) -> None:
217 super().validate(value, model_instance)
218 if (
219 self.max_length is not None
220 and value is not None
221 and len(value) > self.max_length
222 ):
223 raise exceptions.ValidationError(
224 f"Ensure this value has at most {self.max_length} characters (it has {len(value)}).",
225 code="max_length",
226 )
227
228 def get_prep_value(self, value: Any) -> Any:
229 value = super().get_prep_value(value)
230 if value is None:
231 return value
232 return self.to_python(value)
233
234 def get_db_prep_value(
235 self, value: Any, connection: DatabaseConnection, prepared: bool = False
236 ) -> Any:
237 value = super().get_db_prep_value(value, connection, prepared)
238 if value is None:
239 return value
240 return _encrypt(value)
241
242 def from_db_value(
243 self, value: Any, expression: Any, connection: DatabaseConnection
244 ) -> str | None:
245 if value is None:
246 return value
247 return _decrypt(value)
248
249 def deconstruct(self) -> tuple[str | None, str, list[Any], dict[str, Any]]:
250 name, path, args, kwargs = super().deconstruct()
251 if self.max_length is not None:
252 kwargs["max_length"] = self.max_length
253 return name, path, args, kwargs
254
255 def preflight(self, **kwargs: Any) -> list[PreflightResult]:
256 errors = super().preflight(**kwargs)
257 errors.extend(self._check_encrypted_constraints())
258 return errors
259
260
261class EncryptedJSONField(EncryptedFieldMixin, ColumnField):
262 """A JSONField that encrypts its serialized value before storing in the database.
263
264 The JSON value is serialized to a string, encrypted, and stored as text.
265 On read, it's decrypted and deserialized back to a Python object.
266 """
267
268 db_type_sql = "text"
269 empty_strings_allowed = False
270
271 def __init__(
272 self,
273 *,
274 encoder: type[json.JSONEncoder] | None = None,
275 decoder: type[json.JSONDecoder] | None = None,
276 required: bool = True,
277 allow_null: bool = False,
278 validators: Sequence[Callable[..., Any]] = (),
279 ):
280 # `default` is intentionally not accepted: Fernet encryption is
281 # non-deterministic, so a literal column DEFAULT cannot be expressed.
282 if encoder and not callable(encoder):
283 raise ValueError("The encoder parameter must be a callable object.")
284 if decoder and not callable(decoder):
285 raise ValueError("The decoder parameter must be a callable object.")
286 self.encoder = encoder
287 self.decoder = decoder
288 super().__init__(
289 required=required,
290 allow_null=allow_null,
291 validators=validators,
292 )
293
294 def deconstruct(self) -> tuple[str | None, str, list[Any], dict[str, Any]]:
295 name, path, args, kwargs = super().deconstruct()
296 if self.encoder is not None:
297 kwargs["encoder"] = self.encoder
298 if self.decoder is not None:
299 kwargs["decoder"] = self.decoder
300 return name, path, args, kwargs
301
302 def validate(self, value: Any, model_instance: Any) -> None:
303 super().validate(value, model_instance)
304 try:
305 json.dumps(value, cls=self.encoder)
306 except TypeError:
307 raise exceptions.ValidationError(
308 "Value must be valid JSON.",
309 code="invalid",
310 params={"value": value},
311 )
312
313 def get_db_prep_value(
314 self, value: Any, connection: DatabaseConnection, prepared: bool = False
315 ) -> Any:
316 value = super().get_db_prep_value(value, connection, prepared)
317 if value is None:
318 return value
319 json_str = json.dumps(value, cls=self.encoder)
320 return _encrypt(json_str)
321
322 def from_db_value(
323 self, value: Any, expression: Any, connection: DatabaseConnection
324 ) -> Any:
325 if value is None:
326 return value
327 decrypted = _decrypt(value)
328 try:
329 return json.loads(decrypted, cls=self.decoder)
330 except json.JSONDecodeError:
331 raise ValueError(
332 "Encrypted field contains data that is not valid JSON. "
333 "The stored value may be corrupt."
334 )
335
336 def preflight(self, **kwargs: Any) -> list[PreflightResult]:
337 errors = super().preflight(**kwargs)
338 errors.extend(self._check_encrypted_constraints())
339 return errors