Compare commits

...

1 Commits

Author SHA1 Message Date
882e5d5107 Add pytree docs 2025-07-16 22:15:07 -07:00
7 changed files with 157 additions and 69 deletions

View File

@ -59,3 +59,4 @@ sphinx-copybutton==0.5.0
sphinx-design==0.4.0
sphinxcontrib-mermaid==1.0.0
myst-parser==0.18.1
myst-nb

View File

@ -62,7 +62,7 @@ extensions = [
"sphinxcontrib.katex",
"sphinx_copybutton",
"sphinx_design",
"myst_parser",
"myst_nb",
"sphinx.ext.linkcode",
"sphinxcontrib.mermaid",
"sphinx_sitemap",

View File

@ -421,67 +421,3 @@ FakeTensor:
```python
FakeTensor(dtype=torch.int, size=[2,], device=CPU)
```
### Pytree-able Types
We define a type “Pytree-able”, if it is either a leaf type or a container type
that contains other Pytree-able types.
Note:
> The concept of pytree is the same as the one documented
> [here](https://jax.readthedocs.io/en/latest/pytrees.html) for JAX:
The following types are defined as **leaf type**:
```{eval-rst}
.. list-table::
:widths: 50 50
:header-rows: 1
* - Type
- Definition
* - Tensor
- :class:`torch.Tensor`
* - Scalar
- Any numerical types from Python, including integral types, floating point types, and zero dimensional tensors.
* - int
- Python int (bound as int64_t in C++)
* - float
- Python float (bound as double in C++)
* - bool
- Python bool
* - str
- Python string
* - ScalarType
- :class:`torch.dtype`
* - Layout
- :class:`torch.layout`
* - MemoryFormat
- :class:`torch.memory_format`
* - Device
- :class:`torch.device`
```
The following types are defined as **container type**:
```{eval-rst}
.. list-table::
:widths: 50 50
:header-rows: 1
* - Type
- Definition
* - Tuple
- Python tuple
* - List
- Python list
* - Dict
- Python dict with Scalar keys
* - NamedTuple
- Python namedtuple
* - Dataclass
- Must be registered through `register_dataclass <https://github.com/pytorch/pytorch/blob/901aa85b58e8f490631ce1db44e6555869a31893/torch/export/__init__.py#L693>`__
* - Custom class
- Any custom class defined with `_register_pytree_node <https://github.com/pytorch/pytorch/blob/901aa85b58e8f490631ce1db44e6555869a31893/torch/utils/_pytree.py#L72>`__
```

View File

@ -132,7 +132,7 @@ Whether a value is static or dynamic depends on its type:
sequence for `dict` and `namedtuple` values) is static.
- The contained elements have these rules applied to them recursively
(basically the
[PyTree](https://jax.readthedocs.io/en/latest/pytrees.html) scheme)
{ref}`PyTree <pytree>` scheme)
with leaves that are either Tensor or primitive types.
- Other *classes* (including data classes) can be registered with PyTree
@ -158,7 +158,7 @@ By default, the types of inputs you can use for your program are:
### Custom Input Types (PyTree)
In addition, you can also define your own (custom) class and use it as an
input type, but you will need to register such a class as a PyTree.
input type, but you will need to register such a class as a {ref}`PyTree <pytree>`.
Here's an example of using an utility to register a dataclass that is used as
an input type.

View File

@ -55,6 +55,7 @@ nn.init
nn.attention
onnx
optim
pytree <pytree>
complex_numbers
ddp_comm_hooks
quantization

150
docs/source/pytree.md Normal file
View File

@ -0,0 +1,150 @@
(pytree)=
# PyTrees
```{warning}
The main PyTree functionality is available through `torch.utils._pytree`. Note
that this is currently a private API and may change in future versions.
```
## Overview
A *pytree* is a nested data structure composed of:
* **Leaves**: non-container objects that cannot be further decomposed (e.g. tensors, ints, floats, etc.)
* **Containers**: collections like list, tuple, dict, namedtuple, etc. that contain other pytrees or leaves
Specifically, the following types are defined as **leaf type**:
* {class}`torch.Tensor` and tensor metadata ({class}`torch.dtype`, {class}`torch.layout`, {class}`torch.device`, {class}`torch.memory_format`)
* Symbolic values: `torch.SymInt`, `torch.SymFloat`, `torch.SymBool`
* Scalars: `int`, `float`, `bool`, `str`
* Python classes
* Anything registered through {func}`torch.utils._pytree.register_constant`
The following types are defined as **container type**:
* Python `list`, `tuple`, `namedtuple`
* Python `dict` with scalar keys
* Python `dataclass` (must be registered through {func}`torch.utils._pytree.register_dataclass`)
* Python classes (must be registered through {func}`torch.utils._pytree.register_dataclass` or {func}`torch.utils._pytree.register_pytree_node`)
## Viewing the PyTree Structure
We can use {func}`torch.utils._pytree.tree_structure` to view the
structure of a pytree. The structure is represented as a
{func}`torch.utils._pytree.TreeSpec` object, which contains the type of the
container, a list of its children `TreeSpec`'s, and any context needed to
represent the pytree.
```{code-cell}
import torch
import torch.utils._pytree as pytree
simple_list = [torch.tensor([1, 2]), torch.tensor([3, 4])]
print("simple_list treespec:", pytree.tree_structure(simple_list))
list_with_dict = [
{'a': torch.tensor([1]), 'b': torch.tensor([2])},
torch.tensor([3])
]
print("list_with_dict treespec:", pytree.tree_structure(simple_dict))
```
## Manipulating PyTrees
We can use {func}`torch.utils._pytree.tree_flatten` to flatten a pytree into a
list of leaves. This function also returns a `TreeSpec` representing the
structure of the pytree. This can be used along with
{func}`torch.utils._pytree.tree_unflatten` to then reconstruct the original pytree.
```{code-cell}
tree = {
'a': torch.tensor([1]),
'b': [2, 3.0],
}
flattened_tree, treespec = pytree.tree_flatten(tree)
print("flattened_tree:", flattened_tree)
print("treespec:", treespec)
```
```{code-cell}
manipulated_tree = [x + 1 for x in flattened_tree]
unflattened_tree = pytree.tree_unflatten(manipulated_tree, treespec)
print("unflattened_tree:", unflattened_tree)
```
We can also simply use {func}`torch.utils._pytree.tree_map` to apply a function
to all leaves in a pytree, or {func}`torch.utils._pytree.tree_map_only` to
specific types of leaves.
```{code-cell}
print("map over all:", pytree.tree_map(lambda x: x + 1, tree))
print("map over tensors:", pytree.tree_map_only(torch.Tensor, lambda x: x + 1, tree))
```
leaves, treespec = tree_flatten(tree)
print("Leaves:", leaves)
print("TreeSpec:", treespec)
import torch
import torch.utils._pytree as pytree
## Custom PyTree Registration
You can register custom types to be treated as PyTree containers. To do so, you
must specify a flatten function that returns a flattened representation of the
pytree, and an unflattening function that takes a flattened representation
returns the original pytree.
```{code-cell}
class Data1:
def __init__(self, a: torch.Tensor, b: tuple[str]):
self.a = a
self.b = b
data = Data1(torch.tensor(3), ("moo",))
print("TreeSpec without registration:", pytree.tree_structure(data))
pytree.register_pytree_node(
Data1,
flatten_fn=lambda x: (x.a, x.b),
unflatten_fn=lambda a, b: Data1(a, b),
)
print("TreeSpec after registration:", pytree.tree_structure(data))
```
If the class is a dataclass, or has the semantics of a dataclass, a simpler
approach is to use {func}`torch.utils._pytree.register_dataclass`.
```{code-cell}
class Data2:
def __init__(self, a: torch.Tensor, b: tuple[str]):
self.a = a
self.b = b
data = Data2(torch.tensor(3), ("moo",))
print("TreeSpec without registration:", pytree.tree_structure(data))
pytree.register_dataclass(Data2, field_names=["a", "b"])
print("TreeSpec after registration:", pytree.tree_structure(data))
```
## API Reference
```{eval-rst}
.. autofunction:: torch.utils._pytree.tree_flatten
.. autofunction:: torch.utils._pytree.tree_flatten_with_path
.. autofunction:: torch.utils._pytree.tree_unflatten
.. autofunction:: torch.utils._pytree.tree_map
.. autofunction:: torch.utils._pytree.tree_map_
.. autofunction:: torch.utils._pytree.tree_map_only
.. autofunction:: torch.utils._pytree.tree_map_with_path
.. autofunction:: torch.utils._pytree.register_pytree_node
.. autofunction:: torch.utils._pytree.register_dataclass
.. autofunction:: torch.utils._pytree.register_constant
.. autofunction:: torch.utils._pytree.tree_structure
.. autoclass:: torch.utils._pytree.TreeSpec
```

View File

@ -290,8 +290,8 @@ def register_dataclass(
Args:
cls: The python type to register. The class must have the semantics of a
dataclass; in particular, it must be constructed by passing the fields
in.
dataclass; in particular, it must be constructed by passing the
fields in.
field_names (Optional[List[str]]): A list of field names that correspond
to the **non-constant data** in this class. This list must contain
all the fields that are used to initialize the class. This argument