diff --git a/CHANGELOG.md b/CHANGELOG.md index ca2e1ae..9839c5d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,10 @@ ## [Unreleased] +### Added + +- Custom encoders can now receive `_parent` and `_sort_keys` parameters to enable proper encoding of nested structures. ([#429](https://github.com/python-poetry/tomlkit/issues/429)) + ## [0.13.3] - 2025-06-05 ### Added diff --git a/tests/test_items.py b/tests/test_items.py index a24c18a..4101a41 100644 --- a/tests/test_items.py +++ b/tests/test_items.py @@ -986,6 +986,120 @@ def encode_decimal(obj): api.unregister_encoder(encode_decimal) +def test_custom_encoders_with_parent_and_sort_keys(): + """Test that custom encoders can receive _parent and _sort_keys parameters.""" + import decimal + + parent_captured = None + sort_keys_captured = None + + @api.register_encoder + def encode_decimal_with_context(obj, _parent=None, _sort_keys=False): + nonlocal parent_captured, sort_keys_captured + if isinstance(obj, decimal.Decimal): + parent_captured = _parent + sort_keys_captured = _sort_keys + return api.float_(str(obj)) + raise TypeError + + # Test with default parameters + result = api.item(decimal.Decimal("1.23")) + assert result.as_string() == "1.23" + assert parent_captured is None + assert sort_keys_captured is False + + # Test with custom parent and sort_keys + parent_captured = None + sort_keys_captured = None + table = api.table() + result = item(decimal.Decimal("4.56"), _parent=table, _sort_keys=True) + assert result.as_string() == "4.56" + assert parent_captured is table + assert sort_keys_captured is True + + api.unregister_encoder(encode_decimal_with_context) + + +def test_custom_encoders_backward_compatibility(): + """Test that old-style custom encoders still work without modification.""" + import decimal + + @api.register_encoder + def encode_decimal_old_style(obj): + # Old style encoder - only accepts obj parameter + if isinstance(obj, decimal.Decimal): + return api.float_(str(obj)) + raise TypeError + + # Should work exactly as before + result = api.item(decimal.Decimal("2.34")) + assert result.as_string() == "2.34" + + # Should work when called from item() with extra parameters + table = api.table() + result = item(decimal.Decimal("5.67"), _parent=table, _sort_keys=True) + assert result.as_string() == "5.67" + + api.unregister_encoder(encode_decimal_old_style) + + +def test_custom_encoders_with_kwargs(): + """Test that custom encoders can use **kwargs to accept additional parameters.""" + import decimal + + kwargs_captured = None + + @api.register_encoder + def encode_decimal_with_kwargs(obj, **kwargs): + nonlocal kwargs_captured + if isinstance(obj, decimal.Decimal): + kwargs_captured = kwargs + return api.float_(str(obj)) + raise TypeError + + # Test with parent and sort_keys passed as kwargs + table = api.table() + result = item(decimal.Decimal("7.89"), _parent=table, _sort_keys=True) + assert result.as_string() == "7.89" + assert kwargs_captured == {"_parent": table, "_sort_keys": True} + + api.unregister_encoder(encode_decimal_with_kwargs) + + +def test_custom_encoders_for_complex_objects(): + """Test custom encoders that need to encode nested structures.""" + + class CustomDict: + def __init__(self, data): + self.data = data + + @api.register_encoder + def encode_custom_dict(obj, _parent=None, _sort_keys=False): + if isinstance(obj, CustomDict): + # Create a table and use item() to convert nested values + table = api.table() + for key, value in obj.data.items(): + # Pass along _parent and _sort_keys when converting nested values + table[key] = item(value, _parent=table, _sort_keys=_sort_keys) + return table + raise TypeError + + # Test with nested structure + custom_obj = CustomDict({"a": 1, "b": {"c": 2, "d": 3}}) + result = item(custom_obj, _sort_keys=True) + + # Should properly format as a table with sorted keys + expected = """a = 1 + +[b] +c = 2 +d = 3 +""" + assert result.as_string() == expected + + api.unregister_encoder(encode_custom_dict) + + def test_no_extra_minus_sign(): doc = parse("a = -1") assert doc.as_string() == "a = -1" diff --git a/tomlkit/api.py b/tomlkit/api.py index 894c6ee..439b437 100644 --- a/tomlkit/api.py +++ b/tomlkit/api.py @@ -5,6 +5,7 @@ from collections.abc import Mapping from typing import IO +from typing import TYPE_CHECKING from typing import Iterable from typing import TypeVar @@ -19,7 +20,6 @@ from tomlkit.items import Date from tomlkit.items import DateTime from tomlkit.items import DottedKey -from tomlkit.items import Encoder from tomlkit.items import Float from tomlkit.items import InlineTable from tomlkit.items import Integer @@ -37,6 +37,12 @@ from tomlkit.toml_document import TOMLDocument +if TYPE_CHECKING: + from tomlkit.items import Encoder + + E = TypeVar("E", bound=Encoder) + + def loads(string: str | bytes) -> TOMLDocument: """ Parses a string into a TOMLDocument. @@ -294,13 +300,22 @@ def comment(string: str) -> Comment: return Comment(Trivia(comment_ws=" ", comment="# " + string)) -E = TypeVar("E", bound=Encoder) - - def register_encoder(encoder: E) -> E: """Add a custom encoder, which should be a function that will be called - if the value can't otherwise be converted. It should takes a single value - and return a TOMLKit item or raise a ``ConvertError``. + if the value can't otherwise be converted. + + The encoder should return a TOMLKit item or raise a ``ConvertError``. + + Example: + @register_encoder + def encode_custom_dict(obj, _parent=None, _sort_keys=False): + if isinstance(obj, CustomDict): + tbl = table() + for key, value in obj.items(): + # Pass along parameters when encoding nested values + tbl[key] = item(value, _parent=tbl, _sort_keys=_sort_keys) + return tbl + raise ConvertError("Not a CustomDict") """ CUSTOM_ENCODERS.append(encoder) return encoder diff --git a/tomlkit/items.py b/tomlkit/items.py index 456dc83..d326d49 100644 --- a/tomlkit/items.py +++ b/tomlkit/items.py @@ -3,6 +3,7 @@ import abc import copy import dataclasses +import inspect import math import re import string @@ -15,7 +16,6 @@ from enum import Enum from typing import TYPE_CHECKING from typing import Any -from typing import Callable from typing import Collection from typing import Iterable from typing import Iterator @@ -38,11 +38,17 @@ if TYPE_CHECKING: + from typing import Protocol + from tomlkit import container + class Encoder(Protocol): + def __call__( + self, __value: Any, _parent: Item | None = None, _sort_keys: bool = False + ) -> Item: ... + ItemT = TypeVar("ItemT", bound="Item") -Encoder = Callable[[Any], "Item"] CUSTOM_ENCODERS: list[Encoder] = [] AT = TypeVar("AT", bound="AbstractTable") @@ -199,7 +205,16 @@ def item(value: Any, _parent: Item | None = None, _sort_keys: bool = False) -> I else: for encoder in CUSTOM_ENCODERS: try: - rv = encoder(value) + # Check if encoder accepts keyword arguments for backward compatibility + sig = inspect.signature(encoder) + if "_parent" in sig.parameters or any( + p.kind == p.VAR_KEYWORD for p in sig.parameters.values() + ): + # New style encoder that can accept additional parameters + rv = encoder(value, _parent=_parent, _sort_keys=_sort_keys) + else: + # Old style encoder that only accepts value + rv = encoder(value) except ConvertError: pass else: