v0.145.1
  1from __future__ import annotations
  2
  3import base64
  4import json
  5from functools import cache
  6from typing import TYPE_CHECKING, Any
  7
  8try:
  9    from cryptography.fernet import Fernet, InvalidToken, MultiFernet
 10    from cryptography.hazmat.primitives import hashes
 11    from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
 12except ImportError:
 13    Fernet = None  # ty: ignore[invalid-assignment]
 14    InvalidToken = None  # ty: ignore[invalid-assignment]
 15    MultiFernet = None  # ty: ignore[invalid-assignment]
 16    hashes = None  # ty: ignore[invalid-assignment]
 17    PBKDF2HMAC = None  # ty: ignore[invalid-assignment]
 18
 19from plain import exceptions, preflight
 20from plain.runtime import settings
 21from plain.utils.encoding import force_bytes
 22
 23from .base import ColumnField
 24
 25if TYPE_CHECKING:
 26    from collections.abc import Callable, Sequence
 27
 28    from plain.postgres.connection import DatabaseConnection
 29    from plain.postgres.lookups import Lookup, Transform
 30    from plain.preflight.results import PreflightResult
 31
 32__all__ = [
 33    "EncryptedTextField",
 34    "EncryptedJSONField",
 35]
 36
 37# Fixed salt for key derivation — changing this would invalidate all encrypted data.
 38# This is not secret; it ensures the derived encryption key is distinct from
 39# keys derived for other purposes (e.g., signing) even from the same SECRET_KEY.
 40_KDF_SALT = b"plain.postgres.fields.encrypted"
 41
 42# Prefix for encrypted values in the database.
 43# Makes encrypted data self-describing and distinguishable from plaintext.
 44_ENCRYPTED_PREFIX = "$fernet$"
 45
 46
 47def _derive_fernet_key(secret: str) -> bytes:
 48    """Derive a Fernet-compatible key from an arbitrary secret string."""
 49    if PBKDF2HMAC is None:
 50        raise ImportError(
 51            "The 'cryptography' package is required to use encrypted fields. "
 52            "Install it with: pip install cryptography"
 53        )
 54    kdf = PBKDF2HMAC(
 55        algorithm=hashes.SHA256(),
 56        length=32,
 57        salt=_KDF_SALT,
 58        iterations=480_000,
 59    )
 60    return base64.urlsafe_b64encode(kdf.derive(force_bytes(secret)))
 61
 62
 63@cache
 64def _get_fernet(secret_key: str, fallbacks: tuple[str, ...]) -> MultiFernet:
 65    """Build a MultiFernet from the given secret key and fallbacks.
 66
 67    The first key is used for encryption.
 68    All keys are used for decryption, enabling key rotation.
 69    Results are cached by (secret_key, fallbacks) so changing SECRET_KEY
 70    (e.g. in tests) produces a new MultiFernet automatically.
 71    """
 72    keys = [_derive_fernet_key(secret_key)]
 73    for fallback in fallbacks:
 74        keys.append(_derive_fernet_key(fallback))
 75    return MultiFernet([Fernet(k) for k in keys])
 76
 77
 78def _encrypt(value: str) -> str:
 79    """Encrypt a string and return a self-describing database value."""
 80    if value == "":
 81        return value
 82    f = _get_fernet(settings.SECRET_KEY, tuple(settings.SECRET_KEY_FALLBACKS))
 83    token = f.encrypt(force_bytes(value))
 84    return _ENCRYPTED_PREFIX + token.decode("ascii")
 85
 86
 87def _decrypt(value: str) -> str:
 88    """Decrypt a self-describing database value back to a string.
 89
 90    Gracefully handles unencrypted values — if the value doesn't have
 91    the encryption prefix, it's returned as-is. This supports gradual
 92    migration from plaintext to encrypted fields.
 93    """
 94    if not value.startswith(_ENCRYPTED_PREFIX):
 95        return value
 96    token = value[len(_ENCRYPTED_PREFIX) :]
 97    f = _get_fernet(settings.SECRET_KEY, tuple(settings.SECRET_KEY_FALLBACKS))
 98    try:
 99        return f.decrypt(token.encode("ascii")).decode("utf-8")
100    except InvalidToken:
101        raise ValueError(
102            "Could not decrypt field value. The SECRET_KEY (and SECRET_KEY_FALLBACKS) "
103            "may have changed since this data was encrypted."
104        )
105
106
107# isnull is obviously needed. exact is required so that `filter(field=None)`
108# works — the ORM resolves "exact" first and then rewrites None to isnull.
109# Exact lookups on non-None values will silently return no results (since
110# ciphertext is non-deterministic), but blocking exact entirely would break
111# the None/isnull path.
112_ALLOWED_LOOKUPS = {"isnull", "exact"}
113
114
115class EncryptedFieldMixin:
116    """Shared behavior for all encrypted fields.
117
118    Blocks lookups (except isnull and exact) since encrypted values are non-deterministic.
119    Errors at preflight if the field is used in indexes or unique constraints.
120
121    Must be used with Field as a co-base class.
122    """
123
124    # Type hints for attributes provided by Field (the required co-base class)
125    name: str
126    model: Any
127
128    def get_lookup(self, lookup_name: str) -> type[Lookup] | None:
129        if lookup_name not in _ALLOWED_LOOKUPS:
130            return None
131        get_lookup = getattr(super(), "get_lookup")
132        return get_lookup(lookup_name)
133
134    def get_transform(
135        self, lookup_name: str
136    ) -> type[Transform] | Callable[..., Any] | None:
137        return None
138
139    def _check_encrypted_constraints(self) -> list[PreflightResult]:
140        errors: list[PreflightResult] = []
141        if not hasattr(self, "model"):
142            return errors
143
144        field_name = self.name
145
146        for constraint in self.model.model_options.constraints:
147            constraint_fields = getattr(constraint, "fields", ())
148            if field_name in constraint_fields:
149                errors.append(
150                    preflight.PreflightResult(
151                        fix=(
152                            f"'{self.model.__name__}.{field_name}' is an encrypted field "
153                            f"and cannot be used in constraint '{constraint.name}'. "
154                            "Encrypted values are non-deterministic."
155                        ),
156                        obj=self,
157                        id="fields.encrypted_in_constraint",
158                    )
159                )
160
161        for index in self.model.model_options.indexes:
162            index_fields = getattr(index, "fields", ())
163            # Strip ordering prefix (e.g., "-field_name" for descending)
164            stripped_fields = [f.lstrip("-") for f in index_fields]
165            if field_name in stripped_fields:
166                errors.append(
167                    preflight.PreflightResult(
168                        fix=(
169                            f"'{self.model.__name__}.{field_name}' is an encrypted field "
170                            f"and cannot be used in index '{index.name}'. "
171                            "Encrypted values are non-deterministic."
172                        ),
173                        obj=self,
174                        id="fields.encrypted_in_index",
175                    )
176                )
177
178        return errors
179
180
181class EncryptedTextField(EncryptedFieldMixin, ColumnField[str]):
182    """A text field that encrypts its value before storing in the database.
183
184    Values are encrypted using Fernet (AES-128-CBC + HMAC-SHA256) with a key
185    derived from SECRET_KEY. The database column is always ``text`` regardless
186    of max_length, since ciphertext length is unpredictable.
187
188    max_length is enforced on the plaintext value (validation), not on the
189    ciphertext stored in the database.
190    """
191
192    db_type_sql = "text"
193
194    def __init__(
195        self,
196        *,
197        max_length: int | None = None,
198        required: bool = True,
199        allow_null: bool = False,
200        validators: Sequence[Callable[..., Any]] = (),
201    ):
202        # `default` is intentionally not accepted: Fernet encryption is
203        # non-deterministic, so a literal column DEFAULT cannot be expressed.
204        self.max_length = max_length
205        super().__init__(
206            required=required,
207            allow_null=allow_null,
208            validators=validators,
209        )
210
211    def to_python(self, value: Any) -> str | None:
212        if isinstance(value, str) or value is None:
213            return value
214        return str(value)
215
216    def validate(self, value: Any, model_instance: Any) -> None:
217        super().validate(value, model_instance)
218        if (
219            self.max_length is not None
220            and value is not None
221            and len(value) > self.max_length
222        ):
223            raise exceptions.ValidationError(
224                f"Ensure this value has at most {self.max_length} characters (it has {len(value)}).",
225                code="max_length",
226            )
227
228    def get_prep_value(self, value: Any) -> Any:
229        value = super().get_prep_value(value)
230        if value is None:
231            return value
232        return self.to_python(value)
233
234    def get_db_prep_value(
235        self, value: Any, connection: DatabaseConnection, prepared: bool = False
236    ) -> Any:
237        value = super().get_db_prep_value(value, connection, prepared)
238        if value is None:
239            return value
240        return _encrypt(value)
241
242    def from_db_value(
243        self, value: Any, expression: Any, connection: DatabaseConnection
244    ) -> str | None:
245        if value is None:
246            return value
247        return _decrypt(value)
248
249    def deconstruct(self) -> tuple[str | None, str, list[Any], dict[str, Any]]:
250        name, path, args, kwargs = super().deconstruct()
251        if self.max_length is not None:
252            kwargs["max_length"] = self.max_length
253        return name, path, args, kwargs
254
255    def preflight(self, **kwargs: Any) -> list[PreflightResult]:
256        errors = super().preflight(**kwargs)
257        errors.extend(self._check_encrypted_constraints())
258        return errors
259
260
261class EncryptedJSONField(EncryptedFieldMixin, ColumnField):
262    """A JSONField that encrypts its serialized value before storing in the database.
263
264    The JSON value is serialized to a string, encrypted, and stored as text.
265    On read, it's decrypted and deserialized back to a Python object.
266    """
267
268    db_type_sql = "text"
269    empty_strings_allowed = False
270
271    def __init__(
272        self,
273        *,
274        encoder: type[json.JSONEncoder] | None = None,
275        decoder: type[json.JSONDecoder] | None = None,
276        required: bool = True,
277        allow_null: bool = False,
278        validators: Sequence[Callable[..., Any]] = (),
279    ):
280        # `default` is intentionally not accepted: Fernet encryption is
281        # non-deterministic, so a literal column DEFAULT cannot be expressed.
282        if encoder and not callable(encoder):
283            raise ValueError("The encoder parameter must be a callable object.")
284        if decoder and not callable(decoder):
285            raise ValueError("The decoder parameter must be a callable object.")
286        self.encoder = encoder
287        self.decoder = decoder
288        super().__init__(
289            required=required,
290            allow_null=allow_null,
291            validators=validators,
292        )
293
294    def deconstruct(self) -> tuple[str | None, str, list[Any], dict[str, Any]]:
295        name, path, args, kwargs = super().deconstruct()
296        if self.encoder is not None:
297            kwargs["encoder"] = self.encoder
298        if self.decoder is not None:
299            kwargs["decoder"] = self.decoder
300        return name, path, args, kwargs
301
302    def validate(self, value: Any, model_instance: Any) -> None:
303        super().validate(value, model_instance)
304        try:
305            json.dumps(value, cls=self.encoder)
306        except TypeError:
307            raise exceptions.ValidationError(
308                "Value must be valid JSON.",
309                code="invalid",
310                params={"value": value},
311            )
312
313    def get_db_prep_value(
314        self, value: Any, connection: DatabaseConnection, prepared: bool = False
315    ) -> Any:
316        value = super().get_db_prep_value(value, connection, prepared)
317        if value is None:
318            return value
319        json_str = json.dumps(value, cls=self.encoder)
320        return _encrypt(json_str)
321
322    def from_db_value(
323        self, value: Any, expression: Any, connection: DatabaseConnection
324    ) -> Any:
325        if value is None:
326            return value
327        decrypted = _decrypt(value)
328        try:
329            return json.loads(decrypted, cls=self.decoder)
330        except json.JSONDecodeError:
331            raise ValueError(
332                "Encrypted field contains data that is not valid JSON. "
333                "The stored value may be corrupt."
334            )
335
336    def preflight(self, **kwargs: Any) -> list[PreflightResult]:
337        errors = super().preflight(**kwargs)
338        errors.extend(self._check_encrypted_constraints())
339        return errors