Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 35 additions & 3 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ env:
COLUMNS: "100"
FORCE_COLOR: "1"
CLICOLOR_FORCE: "1"
ALLOW_PRERELEASES: "false"
ALLOW_PRERELEASES: "true"

jobs:
build-sdist:
Expand All @@ -78,7 +78,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v6
with:
python-version: "3.9 - 3.14"
python-version: "3.9 - 3.15"
update-environment: true

- name: Upgrade pip
Expand Down Expand Up @@ -155,6 +155,8 @@ jobs:
- "3.13t"
- "3.14"
- "3.14t"
- "3.15"
- "3.15t"
- "pypy-3.11"
exclude:
# Exclude unsupported Python versions
Expand Down Expand Up @@ -184,6 +186,11 @@ jobs:
runner: macos-latest
platform: ios
archs: "arm64_iphoneos"
- python-version: "3.15"
target:
runner: macos-latest
platform: ios
archs: "arm64_iphoneos"
# iOS Simulator
- python-version: "3.13"
target:
Expand All @@ -195,6 +202,11 @@ jobs:
runner: macos-latest
platform: ios
archs: "arm64_iphonesimulator"
- python-version: "3.15"
target:
runner: macos-latest
platform: ios
archs: "arm64_iphonesimulator"
# Android
- python-version: "3.13"
target:
Expand All @@ -206,6 +218,11 @@ jobs:
runner: ubuntu-latest
platform: android
archs: "arm64_v8a"
- python-version: "3.15"
target:
runner: ubuntu-latest
platform: android
archs: "arm64_v8a"
# Pyodide
- python-version: "3.12"
target:
Expand Down Expand Up @@ -315,6 +332,21 @@ jobs:
platforms: all

- name: Build wheels
if: matrix.python-version != '3.13t'
uses: pypa/cibuildwheel@v4.0.0rc1
env:
CIBW_BUILD: ${{ env.CIBW_BUILD }}
CIBW_PLATFORM: ${{ matrix.target.platform }}
CIBW_ARCHS: ${{ matrix.target.archs }}
CIBW_ENABLE: pypy${{ env.ALLOW_PRERELEASES == 'true' && ' cpython-prerelease' || '' }}
CIBW_ALLOW_EMPTY: ${{ env.ALLOW_PRERELEASES == 'true' }}
with:
package-dir: .
output-dir: wheelhouse
config-file: "{package}/pyproject.toml"

- name: Build wheels (with Python 3.13t)
if: matrix.python-version == '3.13t'
uses: pypa/cibuildwheel@v3.4
env:
CIBW_BUILD: ${{ env.CIBW_BUILD }}
Expand Down Expand Up @@ -388,7 +420,7 @@ jobs:
if: startsWith(github.ref, 'refs/tags/')
uses: actions/setup-python@v6
with:
python-version: "3.9 - 3.14"
python-version: "3.9 - 3.15"
update-environment: true

- name: Upgrade pip
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/tests-with-pydebug.yml
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ jobs:
- "3.12"
- "3.13"
- "3.14"
- "3.15"
python-abiflags: ["d", "td"]
exclude:
- python-version: "3.9"
Expand Down
2 changes: 2 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ jobs:
- "3.13t"
- "3.14"
- "3.14t"
- "3.15"
- "3.15t"
- "pypy-3.11"
fail-fast: false
timeout-minutes: 120
Expand Down
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ repos:
hooks:
- id: cpplint
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.15.12
rev: v0.15.13
hooks:
- id: ruff-check
args: [--fix, --exit-non-zero-on-fix]
Expand All @@ -50,7 +50,7 @@ repos:
- id: codespell
additional_dependencies: [".[toml]"]
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v2.0.0
rev: v2.1.0
hooks:
- id: mypy
exclude: |
Expand Down
7 changes: 4 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ OpTree out-of-the-box supports the following Python container types in the globa
- [`collections.defaultdict`](https://docs.python.org/3/library/collections.html#collections.defaultdict)
- [`collections.deque`](https://docs.python.org/3/library/collections.html#collections.deque)
- [`PyStructSequence`](https://docs.python.org/3/c-api/tuple.html#struct-sequence-objects) types created by C API [`PyStructSequence_NewType`](https://docs.python.org/3/c-api/tuple.html#c.PyStructSequence_NewType)
- [`frozendict`](https://docs.python.org/3/library/stdtypes.html#frozendict) (Python 3.15+)
Comment thread
XuehaiPan marked this conversation as resolved.

These types are considered non-leaf nodes in the tree.
Python objects whose type is not registered are treated as leaf nodes.
Expand Down Expand Up @@ -357,7 +358,7 @@ There are several key attributes of the pytree type registry:
> [!WARNING]
> Any `PyTreeSpec` objects created before the unregistration still hold a reference to the old registration. Unflattening such a `PyTreeSpec` will use the **old** `unflatten_func`, not the newly registered one.

3. **Built-in types cannot be re-registered.** The behavior of the types listed in [Built-in PyTree Node Types](#built-in-pytree-node-types) (e.g., key-sorted traversal for `dict` and `collections.defaultdict`) is fixed.
3. **Built-in types cannot be re-registered.** The behavior of the types listed in [Built-in PyTree Node Types](#built-in-pytree-node-types) (e.g., key-sorted traversal for `dict`, `collections.defaultdict`, and `frozendict`) is fixed.

4. **Inherited subclasses are not implicitly registered.** The registry lookup uses `type(obj) is registered_type` rather than `isinstance(obj, registered_type)`. Users need to register the subclasses explicitly. To register all subclasses, it is easy to implement with [`metaclass`](https://docs.python.org/3/reference/datamodel.html#metaclasses) or [`__init_subclass__`](https://docs.python.org/3/reference/datamodel.html#customizing-class-creation), for example:

Expand Down Expand Up @@ -497,7 +498,7 @@ OrderedDict({
The built-in Python dictionary ([`builtins.dict`](https://docs.python.org/3/library/stdtypes.html#dict)) is a mapping whose leaves are its values.
Since [Python 3.7](https://docs.python.org/3/whatsnew/3.7.html), `dict` is guaranteed to be insertion ordered, but the equality operator (`==`) ignores key order.
To ensure [referential transparency](https://en.wikipedia.org/wiki/Referential_transparency) — "equal `dict`" implies "equal ordering of leaves" — the leaves (values) are returned in key-sorted order.
The same applies to [`collections.defaultdict`](https://docs.python.org/3/library/collections.html#collections.defaultdict).
The same applies to [`collections.defaultdict`](https://docs.python.org/3/library/collections.html#collections.defaultdict) and [`frozendict`](https://docs.python.org/3/library/stdtypes.html#frozendict) (Python 3.15+).

```python
>>> optree.tree_flatten({'a': [1, 2], 'b': [3]})
Expand Down Expand Up @@ -562,7 +563,7 @@ False
([3, 1, 2], PyTreeSpec(OrderedDict({'b': [*], 'a': [*, *]})))
```

To flatten [`builtins.dict`](https://docs.python.org/3/library/stdtypes.html#dict) and [`collections.defaultdict`](https://docs.python.org/3/library/collections.html#collections.defaultdict) objects with the insertion order preserved, use the `dict_insertion_ordered` context manager:
To flatten [`builtins.dict`](https://docs.python.org/3/library/stdtypes.html#dict), [`collections.defaultdict`](https://docs.python.org/3/library/collections.html#collections.defaultdict), and [`frozendict`](https://docs.python.org/3/library/stdtypes.html#frozendict) (Python 3.15+) objects with the insertion order preserved, use the `dict_insertion_ordered` context manager:

```python
>>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5}
Expand Down
2 changes: 2 additions & 0 deletions docs/source/ops.rst
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ Tree Reduce Functions
treespec_defaultdict
treespec_deque
treespec_structseq
treespec_frozendict
treespec_from_collection

.. autofunction:: treespec_paths
Expand All @@ -175,4 +176,5 @@ Tree Reduce Functions
.. autofunction:: treespec_defaultdict
.. autofunction:: treespec_deque
.. autofunction:: treespec_structseq
.. autofunction:: treespec_frozendict
.. autofunction:: treespec_from_collection
1 change: 1 addition & 0 deletions docs/source/spelling_wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ eq
fillvalue
fmt
forwardref
frozendict
frozenset
func
functools
Expand Down
1 change: 1 addition & 0 deletions docs/source/treespec.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,5 @@ Check section :ref:`PyTreeSpec Functions` for more detailed documentation.
defaultdict
deque
structseq
frozendict
from_collection
6 changes: 6 additions & 0 deletions include/optree/pymacros.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@ limitations under the License.
# undef OPTREE_HAS_SUBINTERPRETER_SUPPORT
#endif

#if PY_VERSION_HEX >= 0x030F00A7 // Python 3.15.0a7+
# define OPTREE_HAS_FROZENDICT 1
#else
# undef OPTREE_HAS_FROZENDICT
#endif

namespace py = pybind11;

#if !defined(Py_ALWAYS_INLINE)
Expand Down
25 changes: 22 additions & 3 deletions include/optree/pytypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ constexpr py::ssize_t MAX_TYPE_CACHE_SIZE = 4096;
#define PyOrderedDict_Type (reinterpret_cast<PyTypeObject *>(PyOrderedDictTypeObject.ptr()))
#define PyDefaultDict_Type (reinterpret_cast<PyTypeObject *>(PyDefaultDictTypeObject.ptr()))
#define PyDeque_Type (reinterpret_cast<PyTypeObject *>(PyDequeTypeObject.ptr()))
#if defined(OPTREE_HAS_FROZENDICT)
# define PyFrozenDictTypeObject \
(py::reinterpret_borrow<py::object>(reinterpret_cast<PyObject *>(&PyFrozenDict_Type)))
#endif

inline const py::object &ImportOrderedDict() {
PYBIND11_CONSTINIT static py::gil_safe_call_once_and_store<py::object> storage;
Expand Down Expand Up @@ -181,6 +185,14 @@ inline Py_ALWAYS_INLINE void AssertExactDict(const py::handle &object) {
}
}

#if defined(OPTREE_HAS_FROZENDICT)
inline Py_ALWAYS_INLINE void AssertExactFrozenDict(const py::handle &object) {
if (!PyFrozenDict_CheckExact(object.ptr())) [[unlikely]] {
throw py::value_error("Expected an instance of frozendict, got " + PyRepr(object) + ".");
}
}
#endif

inline Py_ALWAYS_INLINE void AssertExactOrderedDict(const py::handle &object) {
if (!py::type::handle_of(object).is(PyOrderedDictTypeObject)) [[unlikely]] {
throw py::value_error("Expected an instance of collections.OrderedDict, got " +
Expand All @@ -198,10 +210,17 @@ inline Py_ALWAYS_INLINE void AssertExactDefaultDict(const py::handle &object) {
inline Py_ALWAYS_INLINE void AssertExactStandardDict(const py::handle &object) {
if (!(PyDict_CheckExact(object.ptr()) ||
py::type::handle_of(object).is(PyOrderedDictTypeObject) ||
py::type::handle_of(object).is(PyDefaultDictTypeObject))) [[unlikely]] {
py::type::handle_of(object).is(PyDefaultDictTypeObject)
#if defined(OPTREE_HAS_FROZENDICT)
|| PyFrozenDict_CheckExact(object.ptr())
#endif
)) [[unlikely]] {
throw py::value_error(
"Expected an instance of dict, collections.OrderedDict, or collections.defaultdict, "
"got " +
"Expected an instance of dict, "
#if defined(OPTREE_HAS_FROZENDICT)
"frozendict, "
#endif
"collections.OrderedDict, or collections.defaultdict, got " +
PyRepr(object) + ".");
}
}
Expand Down
2 changes: 2 additions & 0 deletions include/optree/registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ enum class PyTreeKind : std::uint8_t {
DefaultDict, // A collections.defaultdict
Deque, // A collections.deque
StructSequence, // A PyStructSequence
FrozenDict, // A frozendict (Python 3.15+)
NumKinds, // Number of kinds (placed at the end)
};

Expand All @@ -67,6 +68,7 @@ constexpr PyTreeKind kOrderedDict = PyTreeKind::OrderedDict;
constexpr PyTreeKind kDefaultDict = PyTreeKind::DefaultDict;
constexpr PyTreeKind kDeque = PyTreeKind::Deque;
constexpr PyTreeKind kStructSequence = PyTreeKind::StructSequence;
constexpr PyTreeKind kFrozenDict = PyTreeKind::FrozenDict;
constexpr PyTreeKind kNumPyTreeKinds = PyTreeKind::NumKinds;

// Registry of custom node types.
Expand Down
8 changes: 5 additions & 3 deletions include/optree/treespec.h
Original file line number Diff line number Diff line change
Expand Up @@ -272,9 +272,11 @@ class PyTreeSpec {

// Kind-specific metadata.
// For a NamedTuple/PyStructSequence, contains the tuple type object.
// For a Dict, contains a sorted list of keys.
// For a Dict or FrozenDict, contains a sorted list of keys by default and `original_keys`
// (below) records the original keys; the keys are kept in insertion order instead when
// the `dict_insertion_ordered` context manager is active.
// For a OrderedDict, contains a list of keys.
// For a DefaultDict, contains a tuple of (default_factory, sorted list of keys).
// For a DefaultDict, contains a tuple of (default_factory, keys in sorted/insertion order).
// For a Deque, contains the `maxlen` attribute.
// For a Custom type, contains the metadata returned by the `flatten_func` function.
py::object node_data{};
Expand All @@ -293,7 +295,7 @@ class PyTreeSpec {
// Number of leaf and interior nodes in the subtree rooted at this node.
ssize_t num_nodes = 0;

// For a Dict or DefaultDict, captures the keys in insertion order as `dict[Key, None]`.
// For a Dict/DefaultDict/FrozenDict, captures the insertion order as `dict[Key, None]`.
// Null-default for other node kinds. Used to preserve key order during unflattening.
py::object original_keys{};
};
Expand Down
2 changes: 2 additions & 0 deletions optree/_C.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ GLIBCXX_USE_CXX11_ABI: Final[bool]
MSVC_DISABLE_CONSTEXPR_MUTEX_CONSTRUCTOR: Final[bool]
OPTREE_HAS_SUBINTERPRETER_SUPPORT: Final[bool]
OPTREE_HAS_READ_WRITE_LOCK: Final[bool]
OPTREE_HAS_FROZENDICT: Final[bool]

@final
class InternalError(SystemError): ...
Expand All @@ -75,6 +76,7 @@ class PyTreeKind(enum.IntEnum):
DEFAULTDICT = enum.auto() # a collections.defaultdict
DEQUE = enum.auto() # a collections.deque
STRUCTSEQUENCE = enum.auto() # a PyStructSequence
FROZENDICT = enum.auto() # a frozendict (Python 3.15+)

NUM_KINDS: ClassVar[int]

Expand Down
3 changes: 3 additions & 0 deletions optree/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
treespec_entries,
treespec_entry,
treespec_from_collection,
treespec_frozendict,
treespec_is_leaf,
treespec_is_one_level,
treespec_is_prefix,
Expand Down Expand Up @@ -188,6 +189,7 @@
'treespec_defaultdict',
'treespec_deque',
'treespec_structseq',
'treespec_frozendict',
'treespec_from_collection',
# Accessor
'PyTreeEntry',
Expand Down Expand Up @@ -225,6 +227,7 @@
'structseq_fields',
]


Comment thread
XuehaiPan marked this conversation as resolved.
MAX_RECURSION_DEPTH: int = MAX_RECURSION_DEPTH
"""Maximum recursion depth for pytree traversal.

Expand Down
Loading
Loading