Summary:
Previously we were serializing namedtuple treespecs incorrectly:
```python
Point = namedtuple("Point", ["x", "y"])
p = Point(1, 2)
flat, spec = pytree.tree_flatten(p)
print(flat) # [1, 2]
print(spec) # TreeSpec(type=namedtuple, context=Point, children=[*, *])
dumped_spec = pytree.treespec_dumps(spec)
print(dumped_spec)
"""
We only serialize the name of the class and the fields of the namedtuple:
TreeSpec {
type='collections.namedtuple',
context={class_name='Point', class_fields={'x', 'y'}},
children=[Leaf, Leaf]
}
"""
reconstructed_spec = pytree.treespec_loads(dumped_spec)
print(reconstructed_spec)
"""
When we load, we create a new namedtuple class containing the same fields as before,
but the is class is now a completely different class than the original one:
TreeSpec(type=namedtuple, context=torch.utils._pytree.Point, children=[*, *])
"""
spec == reconstructed_spec # False
```
So, we introduce a new API called `pytree._register_namedtuple` where users can pass in the serialized name for each namedtuple class:
```python
Point = namedtuple("Point", ["x", "y"])
pytree._register_namedtuple(Point, "Point")
p = Point(1, 2)
flat, spec = pytree.tree_flatten(p)
print(flat) # [1, 2]
print(spec) # TreeSpec(type=namedtuple, context=Point, children=[*, *])
dumped_spec = pytree.treespec_dumps(spec)
print(dumped_spec)
"""
TreeSpec {
type='collections.namedtuple',
context='Point',
children=[Leaf, Leaf]
}
"""
reconstructed_spec = pytree.treespec_loads(dumped_spec)
print(reconstructed_spec) # TreeSpec(type=namedtuple, context=Point, children=[*, *])
spec == reconstructed_spec # True
```
Test Plan: `python test/test_pytree.py`
Differential Revision: D55771058
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123388
Approved by: https://github.com/zou3519
In many places in the code we use `tree_map_only((SymInt, SymBool, SymFloat), foo)` but with nested ints, it is possible to have SymInts that are non-symbolic, so we may want to do something like `tree_map_only(is_symbolic, foo)` instead.
Alternative: wrap nested int SymNodes with something other than SymInt.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119974
Approved by: https://github.com/zou3519
ghstack dependencies: #119661
Simplifies and optimizes dict construction using the `fromkeys` classmethod ctor. This also makes it really obvious when all the keys will have the same static value, which could be a bug if unintentional. It is also significantly faster than using a dict comprehension. The rule is in preview, but I am adding a forward fix for when it becomes stable.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118637
Approved by: https://github.com/albanD
This PR introduces a key path API to pytrees, drawing direct inspiration from JAX's [key path API](https://jax.readthedocs.io/en/latest/jax-101/05.1-pytrees.html#key-paths).
I added the 3 APIs described there, and a registry of `flatten_with_keys` fns for each node type, which is a version of `flatten` that also returns `KeyEntry`s describing how to access values from the original pytree.
Current use cases for this API:
- Folks would like to do argument traversal over input pytrees to do verification and compatibility enforcement. Keypaths are useful for this—https://fburl.com/code/06p7zrvr is a handrolled pass doing basically the same thing but probably more fragilely.
- In export non-strict mode, we need to figure out a way to track sources for pytree inputs. In strict mode, dynamo handles this for us, but we'd like a decoupled component to handle this when we're not using dynamo.
I'm sure there are places it would be useful.
Some design notes:
- I only implemented the API for the Python pytree impl. optree has some differences in how their keypath APIs are designed (see https://github.com/pytorch/pytorch/issues/113378 for discussion). I have some issues with the proposed typed_path solution in that discussion and prefer JAX's API, but we can hash that out separately.
- The way folks register a `flatten_with_keys` fn is through a new kwarg to `register_pytree_node`. This follows how we do serialization fns, although the list of additional arguments is getting unwieldy.
- My impl handles pytrees with an undefined `flatten_with_keys` fn is different from JAX. I will raise an error, JAX creates a fallback keyentry.
Differential Revision: [D52547850](https://our.internmc.facebook.com/intern/diff/D52547850/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116786
Approved by: https://github.com/voznesenskym
Changes:
1. Add `_private_register_pytree_node` API in both C++ and Python pytree. In C++ pytree, the API will only register pytree node for C++ pytree. In Python pytree, the API will only register pytree node for Python pytree.
2. Do not allow registering a type as pytree node twice in the Python pytree.
3. Add thread lock to the Python pytree node register API.
4. The old `_register_pytree_node` API will call the `_private_register_pytree_node` API and raise a deprecation warning.
5. Add a new `register_pytree_node` API to register node type in both C++ and Python implementations.
6. Add tests to ensure a warning will be raised when the old private function is called.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112111
Approved by: https://github.com/zou3519
Changes:
1. Add `_private_register_pytree_node` API in both C++ and Python pytree. In C++ pytree, the API will only register pytree node for C++ pytree. In Python pytree, the API will only register pytree node for Python pytree.
2. Do not allow registering a type as pytree node twice in the Python pytree.
3. Add thread lock to the Python pytree node register API.
4. The old `_register_pytree_node` API will call the `_private_register_pytree_node` API and raise a deprecation warning.
5. Add a new `register_pytree_node` API to register node type in both C++ and Python implementations.
6. Add tests to ensure a warning will be raised when the old private function is called.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112111
Approved by: https://github.com/zou3519
Reorganized the two C++ and Python pytree submodules into a subpackage. I think this would be easier to implement the abstract `PyTreeAPI` class with two implementations. And it will be much easier for the user to switch between the two implementations.
Before:
```text
torch
├── utils
│ ├── _pytree.py
│ ├── _cxx_pytree.py
│ ...
...
```
After:
```text
torch
├── utils
│ ├── _pytree
│ │ ├── __init__.py
│ │ └── api
│ │ ├── __init__.py
│ │ ├── cxx.py
│ │ └── python.py
│ ...
...
```
The `torch.utils._pytree` module will import all APIs from `torch.utils._pytree.api.python`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112278
Approved by: https://github.com/zou3519
ghstack dependencies: #112111
Reorganized the two C++ and Python pytree submodules into a subpackage. I think this would be easier to implement the abstract `PyTreeAPI` class with two implementations. And it will be much easier for the user to switch between the two implementations.
Before:
```text
torch
├── utils
│ ├── _pytree.py
│ ├── _cxx_pytree.py
│ ...
...
```
After:
```text
torch
├── utils
│ ├── _pytree
│ │ ├── __init__.py
│ │ └── api
│ │ ├── __init__.py
│ │ ├── cxx.py
│ │ └── python.py
│ ...
...
```
The `torch.utils._pytree` module will import all APIs from `torch.utils._pytree.api.python`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112278
Approved by: https://github.com/zou3519
ghstack dependencies: #112111
Previously we added a change which required users to pass in a serialized name if they want to serialize a pytree so that the serialized name does not depend on the python environment. However this is currently breaking AOTInductor benchmark tests as AOTInductor will serialize the pytree into the .so for flattening/unflattening the inputs. However, the registration for those pytree types in the AOTInductor benchmarks are in the huggingface repo, so I'm not sure what's a good fix for now.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112748
Approved by: https://github.com/zhxchen17, https://github.com/malfet
Custom classes that are serialized with pytree are serialized by default with `f”{class.__module__}.{class.__name__}”`. This is a dependency from our serialized program directly into the outer Python environment. If a user moves the class to a different directory, the serialized program will be unable to be loaded. So, we will require users to pass in an FQN if they want to serialize their custom treespec type.
Differential Revision: [D50886366](https://our.internmc.facebook.com/intern/diff/D50886366)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112428
Approved by: https://github.com/suo
Part of #109684
- #109684
Changes:
- Add new functions `tree_structure`, `tree_leaves`, `tree_map_` and `tree_map_only_` to Python pytree.
- Extract reusable tests for pytree to `TestGenericPytree`.
- Change `treespec_dumps` and `treespec_loads` in C++ pytree to call Python pytree and use JSON string as serialization type.
- Rename `torch.utils.pytree` -> `torch.utils._cxx_pytree`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110395
Approved by: https://github.com/zou3519
Fix: #107315
This PR enables dynamo to trace through the `pytree` API by inlining its functions. In
order to do so, a few details of `pytree` had to be changed.
In summary, this PR:
- Introduces `TreeSpecVariable` for representing `TreeSpec` instances
- Specializes `<type>.__bases__` call, returning a `TupleVariable`
- Enables the call to `id` builtin function for every variable that implements
`as_python_constant` method
- Specializes `ConstantVariable.call_method` for its (un)flatten functions
- Implements `UserDefinedObjectVariable.as_python_constant`
- Modifies `pytree` by:
- Make `SUPPORTED_NODES` a map of ids (instead of types) to `NodeDef`
- Removed `functools.wraps` function, since it can't be inlined
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108533
Approved by: https://github.com/ezyang, https://github.com/voznesenskym
ghstack dependencies: #109201
Fixes https://github.com/pytorch/pytorch/pull/102577#issuecomment-1650905536
Serializing to json is more stable, and renamed the API:
```
# Takes in a treespec and returns the serialized treespec as a string. Also optionally takes in a protocol version number.
def treespec_dumps(treespec: TreeSpec, protocol: Optional[int] = None) -> str:
# Takes in a serialized treespec and outputs a TreeSpec
def treespec_loads(data: str) -> TreeSpec:
```
If users want to register their own serialization format for a given pytree, they can go through the `_register_treespec_serializer` API which optionally takes in a `getstate` and `setstate` function.
```
_register_treespec_serializer(type_, *, getstate, setstate)
# Takes in the context, and outputs a json-dumpable context
def getstate(context: Context) -> DumpableContext:
# Takes in a json-dumpable context, and reconstructs the original context
def setstate(dumpable_context: DumpableContext) -> Context:
```
We will serialize to the following dataclass, and then json.dump this it to string.
```
class TreeSpec
type: Optional[str] # a string name of the type. null for the case of a LeafSpec
context: Optional[Any] # optional, a json dumpable format of the context
children_specs: List[TreeSpec],
}
```
If no getstate/setstate function is registered, we will by default serialize the context using `json.dumps/loads`. We will also serialize the type through `f"{typ.__module__}.{typ.__name__}"`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106116
Approved by: https://github.com/zou3519
The current thing indents based on the length of the previous line, which is totally unreadable if, e.g. the treespec is a dict with a lot of keys, since all the keys will go on a ginormous line and everything after will be super indented.
Fix the indentation at 2, which is much more compact.
Fixes #ISSUE_NUMBER
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103945
Approved by: https://github.com/zou3519
1. Made TreeSpec into a dataclass.
2. In `__repr__`, recursively transformed TreeSpec into dictionaries and then pretty-printed it.
Fixes#46538. Hi, @ezyang. this PR is for the TreeSpec `__repr__` refactor we discussed.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86546
Approved by: https://github.com/ezyang
All of the torch.return_types.* are these special things "structseq"
that subclass tuple but have a different constructor from tuple :(.
This PR iterates through all of torch.return_types.* and adds a pytree
registration for them.
Test Plan:
- add tests for max and min which return torch.return_types.max, and
torch.return_types.min, respectively. There's not an easy way to
"get all torch ops that return a return_types object".
Fixes https://github.com/pytorch/pytorch/issues/75218
Pull Request resolved: https://github.com/pytorch/pytorch/pull/75915
Approved by: https://github.com/ezyang, https://github.com/kshitij12345
Summary:
Following triage review discussion, it would be best for these tests to not be triaged high priority by automation, but by the triagers in the oncall.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/74555
Reviewed By: albanD
Differential Revision: D35099202
Pulled By: janeyx99
fbshipit-source-id: 657a0317141de3a598476a6f601ec26cc26231b1
(cherry picked from commit 057519cb2494d0f9a0b169f359ac87ba9e89f088)
Summary:
Action following https://github.com/pytorch/pytorch/issues/66232
This change does require some context: there were several suggestions regarding what to do about this group of tests: tests that are core and crucial to all of PyTorch and are too broad to be owned by one team.
1. Let's add a "module: core" and put people behind it! This idea sounds appealing unless you are one of the people backing the label. From talking to albanD among others, this idea of putting all these core tests on the shoulder of a few people or one team isn't super fair and I have not yet found anyone willing to take on this job.
2. Taking advantage of the fact that we already have a triaging oncall that takes turns triaging issues, we can leave these tests essentially unlabeled and allow the oncall to triage these tests. Since these tests are crucial to PyTorch, we'll add the "high priority" label to mark them different from other unowned tests (see https://github.com/pytorch/pytorch/issues/67552).
3. I _could_ still create an unbacked label "module: core" and attribute these tests there, but I don't like the idea of creating a facade that the tests are "triaged" to a label when no one is actually taking a look.
Now we could potentially break these tests down into smaller files so that each piece _could_ be owned by a team, but 1. I don't know if this is currently feasible and 2. This approach does not prevent that from happening in the future.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/67553
Reviewed By: albanD
Differential Revision: D32025004
Pulled By: janeyx99
fbshipit-source-id: 1fb1aa4c27e305695ab6e80ae3d02f90519939c0
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62292
This PR adds pytree support for namedtuples. The challenge about namedtuple
is that each namedtuple class is actually different. This PR does the
following:
- it adds a namedtuple flatten/unflatten. The flatten function returns
a context that is the actual type of the namedtuple subclass. The
unflatten function uses that type to reconstruct the namedtuple
- Special cases all pytree logic to consider all namedtuples the same.
This is done by creating a `_get_node_type(pytree)` helper function that
returns `namedtuple` if `pytree` is any namedtuple subclass. The effect
of this is that all namedtuple subclasses will go through the namedtuple
flatten/unflatten functions
- Adds a `_namedtuple_flatten_spec` function for FX pytrees. This function
flattens the namedtuple based on the spec and is equivalent to the
`_tuple_flatten_spec`.
Test Plan
- new tests in test/test_pytree.py and test/test_fx.py
Test Plan: Imported from OSS
Reviewed By: albanD
Differential Revision: D29947302
Pulled By: zou3519
fbshipit-source-id: 19c00665b13546642c315df0f243ad99b8e7ff7c
Summary:
```
class Foo(nn.Module):
def __init__(self):
super().__init__()
def forward(self, y, x):
for k in x:
for v in x[k]:
v += y
return x
example_dict = {'x': {'a': [fx.HOLE], 'z': [fx.HOLE, fx.HOLE]}}
new_f = fx.symbolic_trace(Foo(), concrete_args=example_dict)
print(new_f.code)
new_f(torch.randn(5), {'x': {'a': [torch.randn(5)], 'z': [torch.randn(5), torch.randn(5)]}})
fx.symbolic_trace(new_f, concrete_args=example_dict)
```
prints out
```
def forward(self, y, x):
y, tree_2, tree_3, tree_4 = pytree.tree_flatten([y, x])[0]
add = tree_2 + y
add_1 = tree_3 + y
add_2 = tree_4 + y; y = None
return {'a': [tree_2], 'z': [tree_3, tree_4]}
```
Currently, I store `in_spec` as an extra attribute on `fx.Graph`, and then include it when we do the codegen. I'm not sure if this is the right approach - it introduces a divergence between what's in `fx.Graph` and what's in the python code.
Perhaps the best API is something explicit like `fx.Graph.flatten_args`, but that does make calling things a bit ... more verbose.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/55888
Reviewed By: jamesr66a
Differential Revision: D27884694
Pulled By: Chillee
fbshipit-source-id: f9e8a70c63a8df63c9f9bd0a6459255daa5a8df8
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/46288
This "broadcasts" `pytree` to have the same structure as `spec`
and then flattens it.
I find it hard to describe what that does in words, so here's an example:
- Broadcasting 1 to have the same structure as [0, [0, 0]] would
return [1, [1, 1]]. Further flattening it gives us [1, 1, 1].
- Broadcasting [1, 2] to have the same structure as [0, [0, 0]] would
return [1, [2, 2]]. Further flattening it gives us [1, 2, 2].
What is this used for?
----------------------
The next PR up in the stack uses this helper function to allow vmap to
accept nested data structures. `vmap(fn, in_dims)(*inputs)` allows the
user to specify in_dims with a tree structure that is a sub-graph of
that of `inputs` (where both contain the root of the tree).
For example, one can do `vmap(fn, in_dims=0)(x, y, z)`. `in_dims` is 0
and inputs is (x, y, z). We would like to broadcast in_dims up to the
structure of inputs to get (0, 0, 0).
Another example, is `vmap(fn, in_dims=(0, 1))(x, [y, z])`. `in_dims` is
(0, 1) and inputs is (x, [y, z]). We would like to broadcast in_dims up
to the structure of inputs to get (0, [1, 1]); this value of in_dims is
used to say "let's vmap over dim 0 for x and dim 1 for y and z".
Test Plan
---------
New tests.
Test Plan: Imported from OSS
Reviewed By: heitorschueroff
Differential Revision: D24392891
Pulled By: zou3519
fbshipit-source-id: 6f494d8b6359582f1b4ab6b8dd6a956d8bfe8ed4
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/46287
This adds a lightweight `pytree` implementation that is similar to and
inspired by JAX pytrees, tensorflow.nest, deepmind/tree,
TorchBeast's TensorNest, etc.
A *pytree* is Python nested data structure. It is a tree in the sense
that nodes are Python collections (e.g., list, tuple, dict) and the leaves
are Python values. Furthermore, a pytree should not contain reference
cycles.
This PR:
- adds support for flattening and unflattening nested Python list/dict/tuples
Context: nested Tensor inputs for vmap
--------------------------------------
Right now, vmap is restricted to taking in flat lists of tensors. This
is because vmap needs to be able to convert every tensor in the input
that is being vmapped over into a BatchedTensor.
With a pytree library, we can simply flatten the input data structure
(returning the leaves), map all of the Tensors in the flat input to
BatchedTensors, and unflatten the flat list of BatchedTensors into a new
input. Or equivalently, with a `tree_map` function, we can map a nested
python data structure containing Tensors into one containing
BatchedTensors.
Future work
-----------
In some future PRs, we'll add nested input support for vmap. The
prerequisites for that are:
- a `broadcast_to(small, big)` that broadcasts `small` up to `big`.
This is for handling the in_dims to vmap: the in_dims structure must
be compatible with the structure of the inputs.
Test Plan
---------
- New tests in test/test_pytree.py
Test Plan: Imported from OSS
Reviewed By: heitorschueroff
Differential Revision: D24392890
Pulled By: zou3519
fbshipit-source-id: 7daf7430c5a38354e7d203a72882bd7a9b24cfb1