mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[pytree] Use OpTree for PyTree manipulation (#93139)
Split from #92679. Use C++-based PyTree implementation. ## Highlights 1. High performance (20x speedup than the pure-Python implementation, 10%-20% overall speedup for `torch.fx`) 2. Multi-input tree-map support 3. Custom tree node registry with namespace isolation Refs: - #65761 - #91323 - #92679 From https://github.com/pytorch/pytorch/issues/65761#issuecomment-1334746366: > ### 0. Out-of-box compatible with JAX's pytree, provides the same interfaces and functions (and more). > > ### 1. High-performance: `optree` has comparable fast tree operations (~0.9x for `dict`s and ~2.5x for `OrderedDict`s) than JAX's pytree and it is 20x faster than `torch.utils._pytree`. > > `optree` implements some common Python container types in C++ (e.g., `OrderedDict`) and achieves 2.5x performance than JAX's pytree. Check out section [Built-in PyTree Node Types](https://github.com/metaopt/optree#built-in-pytree-node-types) and [Benchmark](https://github.com/metaopt/optree#benchmark) for more details. > > | Module | Nodes | OpTree (μs) | JAX XLA (μs) | PyTorch (μs) | DM-Tree (μs) | Speedup (J / O) | Speedup (P / O) | Speedup (D / O) | > | :-------- | ----: | ----------: | -----------: | -----------: | -----------: | --------------: | --------------: | --------------: | > | TinyMLP | 53 | 26.40 | 68.19 | 586.87 | 34.14 | 2.58 | 22.23 | 1.29 | > | AlexNet | 188 | 84.28 | 259.51 | 2182.07 | 125.12 | 3.08 | 25.89 | 1.48 | > | ResNet18 | 698 | 288.57 | 807.27 | 7881.69 | 429.39 | 2.80 | 27.31 | 1.49 | > | ResNet34 | 1242 | 580.75 | 1564.97 | 15082.84 | 819.02 | 2.69 | 25.97 | 1.41 | > | ResNet50 | 1702 | 791.18 | 2081.17 | 20982.82 | 1104.62 | 2.63 | 26.52 | 1.40 | > | ResNet101 | 3317 | 1603.93 | 3939.37 | 40382.14 | 2208.63 | 2.46 | 25.18 | 1.38 | > | ResNet152 | 4932 | 2446.56 | 6267.98 | 56892.36 | 3139.17 | 2.56 | 23.25 | 1.28 | > | ViT-H/14 | 3420 | 1681.48 | 4488.33 | 41703.16 | 2504.86 | 2.67 | 24.80 | 1.49 | > | Swin-B | 2881 | 1565.41 | 4091.10 | 34241.99 | 1936.75 | 2.61 | 21.87 | 1.24 | > | | | | | | **Average** | **2.68** | **24.78** | **1.38** | > > <div align="center"> > <img src="https://user-images.githubusercontent.com/16078332/200494435-fd5bb385-59f7-4811-b520-98bf5763ccf3.png" width="90%" /> > </div> > > ### 2. Namespace Isolation for the PyTree Type Registry > > In addition to the JAX's pytree registry for custom node type registration, `optree` adds `namespace` isolation to the registry. Users can register the same type multiple times for different flatten/unflatten behavior. It also provides module-level isolation for safety reasons. For example, you can add a unique prefix to your namespace to isolate your registry with other modules (e.g., `torch.xxx`, `torch.functorch.xxx`): > > ```python > # Register a Python type into a namespace > import torch > > optree.register_pytree_node( > torch.Tensor, > # (tensor) -> (children, metadata) > flatten_func=lambda tensor: ( > (tensor.cpu().numpy(),), > dict(dtype=tensor.dtype, device=tensor.device, requires_grad=tensor.requires_grad), > ), > # (metadata, children) -> tensor > unflatten_func=lambda metadata, children: torch.tensor(children[0], **metadata), > namespace='torch.torch2numpy', > ) > ``` > > ```python > >>> tree = {'weight': torch.ones(size=(1, 2)).cuda(), 'bias': torch.zeros(size=(2,))} > >>> tree > {'weight': tensor([[1., 1.]], device='cuda:0'), 'bias': tensor([0., 0.])} > > # Flatten without specifying the namespace > >>> tree_flatten(tree) # `torch.Tensor`s are leaf nodes > ([tensor([0., 0.]), tensor([[1., 1.]], device='cuda:0')], PyTreeSpec({'bias': *, 'weight': *})) > > # Flatten with the namespace > >>> leaves, treespec = optree.tree_flatten(tree, namespace='torch.torch2numpy') > >>> leaves, treespec > ( > [array([0., 0.], dtype=float32), array([[1., 1.]], dtype=float32)], > PyTreeSpec( > { > 'bias': CustomTreeNode(Tensor[{'dtype': torch.float32, 'device': device(type='cpu'), 'requires_grad': False}], [*]), > 'weight': CustomTreeNode(Tensor[{'dtype': torch.float32, 'device': device(type='cuda', index=0), 'requires_grad': False}], [*]) > }, > namespace='torch.torch2numpy' > ) > ) > > # `entries` are not defined and use `range(len(children))` > >>> optree.tree_paths(tree, namespace='torch.torch2numpy') > [('bias', 0), ('weight', 0)] > > # Unflatten back to a copy of the original object > >>> optree.tree_unflatten(treespec, leaves) > {'bias': tensor([0., 0.]), 'weight': tensor([[1., 1.]], device='cuda:0')} > ``` > > Check out section [Registering a Container-like Custom Type as Non-leaf Nodes](https://github.com/metaopt/optree#notes-about-the-pytree-type-registry) for more details. > > ### 3. Support both `None` as Non-leaf Node and `None` as Leaf > > In JAX's implementation, `None` is always an internal non-leaf node with an arity 0, which is like an empty tuple. This limits the usage of the JAX's pytree utilities for PyTorch. For example, the `nn.Module` uses `_parameters` and `_buffers` (`OrderedDict[str, Optional[Tensor]]`) to hold the tensors, while the value can be a tensor or `None`. > > `optree` supports both `None` as Non-leaf Node (JAX's default) and `None` as Leaf (PyTorch's default). Check out section [None is Non-leaf Node vs. None is Leaf](https://github.com/metaopt/optree#none-is-non-leaf-node-vs-none-is-leaf) for more details. > > ### 4. Some other improvements and bug fixes > > 1. Adds in-place version of treemap (`tree_map_`), which reduces redundant unflatten operation for better performance. > 2. Adds support for tree flatten and tree map with paths. (useful for `functorch` module extraction). > 3. Improves the JAX's pytree sorting support for `dict`s. > 4. Better string representation `repr(PyTreeSpec)`. > 5. Fixes some bugs for JAX's pytree of hashing, pickle serialization, segmentation fault for infinite recursion, and tree-compose/tree-transpose. From https://github.com/pytorch/pytorch/pull/92679#issuecomment-1398778481: > ```python > # pytree_make_fx_bench.py > import torch > from torch.fx.experimental.proxy_tensor import make_fx > import time > > def f(x): > for _ in range(10000): > x = x+x > return x > > import time > begin = time.time() > out = make_fx(f, tracing_mode="real")(torch.randn(20)) > begin = time.time() > print(f'tracing_mode="real" {time.time() - begin:.2f}') > out = make_fx(f, tracing_mode="fake")(torch.randn(20)) > print(f'tracing_mode="fake" {time.time() - begin:.2f}') > > out = make_fx(f, tracing_mode="symbolic")(torch.randn(20)) > print(f'tracing_mode="symbolic" {time.time() - begin:.2f}') > ``` > > This seems to run around 10-20% faster with the optree implementation: > > ``` > # Optree > python pytree_make_fx_bench.py > tracing_mode="real" 0.00 > tracing_mode="fake" 6.32 > tracing_mode="symbolic" 27.13 > ``` > > ``` > # torch.utils._pytree > python pytree_make_fx_bench.py > tracing_mode="real" 0.00 > tracing_mode="fake" 7.66 > tracing_mode="symbolic" 31.07 > ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/93139 Approved by: https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
8a567bb59d
commit
0bf30c140a
@ -43,6 +43,7 @@ files =
|
||||
tools,
|
||||
torch/profiler/_memory_profiler.py,
|
||||
torch/utils/_pytree.py,
|
||||
torch/utils/pytree.py,
|
||||
torch/utils/benchmark/utils/common.py,
|
||||
torch/utils/benchmark/utils/timer.py,
|
||||
torch/utils/benchmark/utils/valgrind_wrapper
|
||||
|
Reference in New Issue
Block a user