Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

## [Unreleased]

### Added

- Custom encoders can now receive `_parent` and `_sort_keys` parameters to enable proper encoding of nested structures. ([#429](https://github.com/python-poetry/tomlkit/issues/429))

## [0.13.3] - 2025-06-05

### Added
Expand Down
114 changes: 114 additions & 0 deletions tests/test_items.py
Original file line number Diff line number Diff line change
Expand Up @@ -986,6 +986,120 @@ def encode_decimal(obj):
api.unregister_encoder(encode_decimal)


def test_custom_encoders_with_parent_and_sort_keys():
"""Test that custom encoders can receive _parent and _sort_keys parameters."""
import decimal

parent_captured = None
sort_keys_captured = None

@api.register_encoder
def encode_decimal_with_context(obj, _parent=None, _sort_keys=False):
nonlocal parent_captured, sort_keys_captured
if isinstance(obj, decimal.Decimal):
parent_captured = _parent
sort_keys_captured = _sort_keys
return api.float_(str(obj))
raise TypeError

# Test with default parameters
result = api.item(decimal.Decimal("1.23"))
assert result.as_string() == "1.23"
assert parent_captured is None
assert sort_keys_captured is False

# Test with custom parent and sort_keys
parent_captured = None
sort_keys_captured = None
table = api.table()
result = item(decimal.Decimal("4.56"), _parent=table, _sort_keys=True)
assert result.as_string() == "4.56"
assert parent_captured is table
assert sort_keys_captured is True

api.unregister_encoder(encode_decimal_with_context)


def test_custom_encoders_backward_compatibility():
"""Test that old-style custom encoders still work without modification."""
import decimal

@api.register_encoder
def encode_decimal_old_style(obj):
# Old style encoder - only accepts obj parameter
if isinstance(obj, decimal.Decimal):
return api.float_(str(obj))
raise TypeError

# Should work exactly as before
result = api.item(decimal.Decimal("2.34"))
assert result.as_string() == "2.34"

# Should work when called from item() with extra parameters
table = api.table()
result = item(decimal.Decimal("5.67"), _parent=table, _sort_keys=True)
assert result.as_string() == "5.67"

api.unregister_encoder(encode_decimal_old_style)


def test_custom_encoders_with_kwargs():
"""Test that custom encoders can use **kwargs to accept additional parameters."""
import decimal

kwargs_captured = None

@api.register_encoder
def encode_decimal_with_kwargs(obj, **kwargs):
nonlocal kwargs_captured
if isinstance(obj, decimal.Decimal):
kwargs_captured = kwargs
return api.float_(str(obj))
raise TypeError

# Test with parent and sort_keys passed as kwargs
table = api.table()
result = item(decimal.Decimal("7.89"), _parent=table, _sort_keys=True)
assert result.as_string() == "7.89"
assert kwargs_captured == {"_parent": table, "_sort_keys": True}

api.unregister_encoder(encode_decimal_with_kwargs)


def test_custom_encoders_for_complex_objects():
"""Test custom encoders that need to encode nested structures."""

class CustomDict:
def __init__(self, data):
self.data = data

@api.register_encoder
def encode_custom_dict(obj, _parent=None, _sort_keys=False):
if isinstance(obj, CustomDict):
# Create a table and use item() to convert nested values
table = api.table()
for key, value in obj.data.items():
# Pass along _parent and _sort_keys when converting nested values
table[key] = item(value, _parent=table, _sort_keys=_sort_keys)
return table
raise TypeError

# Test with nested structure
custom_obj = CustomDict({"a": 1, "b": {"c": 2, "d": 3}})
result = item(custom_obj, _sort_keys=True)

# Should properly format as a table with sorted keys
expected = """a = 1

[b]
c = 2
d = 3
"""
assert result.as_string() == expected

api.unregister_encoder(encode_custom_dict)


def test_no_extra_minus_sign():
doc = parse("a = -1")
assert doc.as_string() == "a = -1"
Expand Down
27 changes: 21 additions & 6 deletions tomlkit/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from collections.abc import Mapping
from typing import IO
from typing import TYPE_CHECKING
from typing import Iterable
from typing import TypeVar

Expand All @@ -19,7 +20,6 @@
from tomlkit.items import Date
from tomlkit.items import DateTime
from tomlkit.items import DottedKey
from tomlkit.items import Encoder
from tomlkit.items import Float
from tomlkit.items import InlineTable
from tomlkit.items import Integer
Expand All @@ -37,6 +37,12 @@
from tomlkit.toml_document import TOMLDocument


if TYPE_CHECKING:
from tomlkit.items import Encoder

E = TypeVar("E", bound=Encoder)


def loads(string: str | bytes) -> TOMLDocument:
"""
Parses a string into a TOMLDocument.
Expand Down Expand Up @@ -294,13 +300,22 @@ def comment(string: str) -> Comment:
return Comment(Trivia(comment_ws=" ", comment="# " + string))


E = TypeVar("E", bound=Encoder)


def register_encoder(encoder: E) -> E:
"""Add a custom encoder, which should be a function that will be called
if the value can't otherwise be converted. It should takes a single value
and return a TOMLKit item or raise a ``ConvertError``.
if the value can't otherwise be converted.

The encoder should return a TOMLKit item or raise a ``ConvertError``.

Example:
@register_encoder
def encode_custom_dict(obj, _parent=None, _sort_keys=False):
if isinstance(obj, CustomDict):
tbl = table()
for key, value in obj.items():
# Pass along parameters when encoding nested values
tbl[key] = item(value, _parent=tbl, _sort_keys=_sort_keys)
return tbl
raise ConvertError("Not a CustomDict")
"""
CUSTOM_ENCODERS.append(encoder)
return encoder
Expand Down
21 changes: 18 additions & 3 deletions tomlkit/items.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import abc
import copy
import dataclasses
import inspect
import math
import re
import string
Expand All @@ -15,7 +16,6 @@
from enum import Enum
from typing import TYPE_CHECKING
from typing import Any
from typing import Callable
from typing import Collection
from typing import Iterable
from typing import Iterator
Expand All @@ -38,11 +38,17 @@


if TYPE_CHECKING:
from typing import Protocol

from tomlkit import container

class Encoder(Protocol):
def __call__(
self, __value: Any, _parent: Item | None = None, _sort_keys: bool = False
) -> Item: ...


ItemT = TypeVar("ItemT", bound="Item")
Encoder = Callable[[Any], "Item"]
CUSTOM_ENCODERS: list[Encoder] = []
AT = TypeVar("AT", bound="AbstractTable")

Expand Down Expand Up @@ -199,7 +205,16 @@ def item(value: Any, _parent: Item | None = None, _sort_keys: bool = False) -> I
else:
for encoder in CUSTOM_ENCODERS:
try:
rv = encoder(value)
# Check if encoder accepts keyword arguments for backward compatibility
sig = inspect.signature(encoder)
if "_parent" in sig.parameters or any(
p.kind == p.VAR_KEYWORD for p in sig.parameters.values()
):
# New style encoder that can accept additional parameters
rv = encoder(value, _parent=_parent, _sort_keys=_sort_keys)
else:
# Old style encoder that only accepts value
rv = encoder(value)
except ConvertError:
pass
else:
Expand Down
Loading