Changes in this PR:
1. Add `is_structseq` and `is_structseq_class` functions to determine a object or a class is PyStructSequence.
2. Add a generic class `structseq` which can be used as the registration key for PyStructSequence types like `namedtuple` for Named Tuple types.
3. Change `is_namedtuple` to accept subclasses of namedtuple to be namedtuple. Before this PR, only namedtuple class directly created by `collections.namedtuple` or `typing.NamedTuple` were namedtuple classes while their subclasses were not. This PR makes `is_namedtuple` return true for subclasses of namedtuple class.
Resolves#75982. New tests are included in this PR.
- #75982
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113257
Approved by: https://github.com/zou3519
Changes in this PR:
1. Add `is_structseq` and `is_structseq_class` functions to determine a object or a class is PyStructSequence.
2. Add a generic class `structseq` which can be used as the registration key for PyStructSequence types like `namedtuple` for Named Tuple types.
3. Change `is_namedtuple` to accept subclasses of namedtuple to be namedtuple. Before this PR, only namedtuple class directly created by `collections.namedtuple` or `typing.NamedTuple` were namedtuple classes while their subclasses were not. This PR makes `is_namedtuple` return true for subclasses of namedtuple class.
Resolves#75982. New tests are included in this PR.
- #75982
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113257
Approved by: https://github.com/zou3519
Changes in this PR:
1. Add `is_structseq` and `is_structseq_class` functions to determine a object or a class is PyStructSequence.
2. Add a generic class `structseq` which can be used as the registration key for PyStructSequence types like `namedtuple` for Named Tuple types.
3. Change `is_namedtuple` to accept subclasses of namedtuple to be namedtuple. Before this PR, only namedtuple class directly created by `collections.namedtuple` or `typing.NamedTuple` were namedtuple classes while their subclasses were not. This PR makes `is_namedtuple` return true for subclasses of namedtuple class.
Resolves#75982. New tests are included in this PR.
- #75982
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113257
Approved by: https://github.com/zou3519
As title, this enables `nonstrict_trace`-ed function to take in object
whose type has been `pytree.register_constant`-ed, as long as the object
existed outside the `torch.compile` region. This also forces Dynamo to
emit a `EQUALS_MATCH` guard on the object.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148007
Approved by: https://github.com/zou3519
ghstack dependencies: #148385
This PR:
- adds pytree.register_constant for registering a class to be treated as
a constant by torch.compile/torch.fx
- adds a very barebones flat_apply HOP. This should be sufficient to get
mark_traceable working. A lot more work is necessary to get the custom
operator case working (when make_fx sees a custom operator with PyTree
arg types, it needs to emit a call to the flat_apply HOP).
- I expect the flat_apply HOP to change a lot, I want to ship this in
the current state to unblock the mark_traceable and custom ops
work.
Test Plan:
- It's kind of difficult to test the barebones flat_apply HOP "works" so
I added a really simple test.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146060
Approved by: https://github.com/StrongerXi, https://github.com/yanboliang
ghstack dependencies: #146059
Use `typing_extensions.deprecated` for deprecation annotation if possible. Otherwise, add `category=FutureWarning` to `warnings.warn("message")` if the category is missing.
Note that only warnings that their messages contain `[Dd]eprecat(ed|ion)` are updated in this PR.
Resolves#126888
- #126888
This PR is split from PR #126898.
- #126898
------
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127689
Approved by: https://github.com/Skylion007
Use `typing_extensions.deprecated` for deprecation annotation if possible. Otherwise, add `category=FutureWarning` to `warnings.warn("message")` if the category is missing.
Note that only warnings that their messages contain `[Dd]eprecat(ed|ion)` are updated in this PR.
UPDATE: Use `FutureWarning` instead of `DeprecationWarning`.
Resolves#126888
- #126888
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126898
Approved by: https://github.com/albanD
Summary:
Previously we were serializing namedtuple treespecs incorrectly:
```python
Point = namedtuple("Point", ["x", "y"])
p = Point(1, 2)
flat, spec = pytree.tree_flatten(p)
print(flat) # [1, 2]
print(spec) # TreeSpec(type=namedtuple, context=Point, children=[*, *])
dumped_spec = pytree.treespec_dumps(spec)
print(dumped_spec)
"""
We only serialize the name of the class and the fields of the namedtuple:
TreeSpec {
type='collections.namedtuple',
context={class_name='Point', class_fields={'x', 'y'}},
children=[Leaf, Leaf]
}
"""
reconstructed_spec = pytree.treespec_loads(dumped_spec)
print(reconstructed_spec)
"""
When we load, we create a new namedtuple class containing the same fields as before,
but the is class is now a completely different class than the original one:
TreeSpec(type=namedtuple, context=torch.utils._pytree.Point, children=[*, *])
"""
spec == reconstructed_spec # False
```
So, we introduce a new API called `pytree._register_namedtuple` where users can pass in the serialized name for each namedtuple class:
```python
Point = namedtuple("Point", ["x", "y"])
pytree._register_namedtuple(Point, "Point")
p = Point(1, 2)
flat, spec = pytree.tree_flatten(p)
print(flat) # [1, 2]
print(spec) # TreeSpec(type=namedtuple, context=Point, children=[*, *])
dumped_spec = pytree.treespec_dumps(spec)
print(dumped_spec)
"""
TreeSpec {
type='collections.namedtuple',
context='Point',
children=[Leaf, Leaf]
}
"""
reconstructed_spec = pytree.treespec_loads(dumped_spec)
print(reconstructed_spec) # TreeSpec(type=namedtuple, context=Point, children=[*, *])
spec == reconstructed_spec # True
```
Test Plan: `python test/test_pytree.py`
Differential Revision: D55771058
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123388
Approved by: https://github.com/zou3519
In many places in the code we use `tree_map_only((SymInt, SymBool, SymFloat), foo)` but with nested ints, it is possible to have SymInts that are non-symbolic, so we may want to do something like `tree_map_only(is_symbolic, foo)` instead.
Alternative: wrap nested int SymNodes with something other than SymInt.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119974
Approved by: https://github.com/zou3519
ghstack dependencies: #119661
Simplifies and optimizes dict construction using the `fromkeys` classmethod ctor. This also makes it really obvious when all the keys will have the same static value, which could be a bug if unintentional. It is also significantly faster than using a dict comprehension. The rule is in preview, but I am adding a forward fix for when it becomes stable.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118637
Approved by: https://github.com/albanD
This PR introduces a key path API to pytrees, drawing direct inspiration from JAX's [key path API](https://jax.readthedocs.io/en/latest/jax-101/05.1-pytrees.html#key-paths).
I added the 3 APIs described there, and a registry of `flatten_with_keys` fns for each node type, which is a version of `flatten` that also returns `KeyEntry`s describing how to access values from the original pytree.
Current use cases for this API:
- Folks would like to do argument traversal over input pytrees to do verification and compatibility enforcement. Keypaths are useful for this—https://fburl.com/code/06p7zrvr is a handrolled pass doing basically the same thing but probably more fragilely.
- In export non-strict mode, we need to figure out a way to track sources for pytree inputs. In strict mode, dynamo handles this for us, but we'd like a decoupled component to handle this when we're not using dynamo.
I'm sure there are places it would be useful.
Some design notes:
- I only implemented the API for the Python pytree impl. optree has some differences in how their keypath APIs are designed (see https://github.com/pytorch/pytorch/issues/113378 for discussion). I have some issues with the proposed typed_path solution in that discussion and prefer JAX's API, but we can hash that out separately.
- The way folks register a `flatten_with_keys` fn is through a new kwarg to `register_pytree_node`. This follows how we do serialization fns, although the list of additional arguments is getting unwieldy.
- My impl handles pytrees with an undefined `flatten_with_keys` fn is different from JAX. I will raise an error, JAX creates a fallback keyentry.
Differential Revision: [D52547850](https://our.internmc.facebook.com/intern/diff/D52547850/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116786
Approved by: https://github.com/voznesenskym
Changes:
1. Add `_private_register_pytree_node` API in both C++ and Python pytree. In C++ pytree, the API will only register pytree node for C++ pytree. In Python pytree, the API will only register pytree node for Python pytree.
2. Do not allow registering a type as pytree node twice in the Python pytree.
3. Add thread lock to the Python pytree node register API.
4. The old `_register_pytree_node` API will call the `_private_register_pytree_node` API and raise a deprecation warning.
5. Add a new `register_pytree_node` API to register node type in both C++ and Python implementations.
6. Add tests to ensure a warning will be raised when the old private function is called.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112111
Approved by: https://github.com/zou3519
Changes:
1. Add `_private_register_pytree_node` API in both C++ and Python pytree. In C++ pytree, the API will only register pytree node for C++ pytree. In Python pytree, the API will only register pytree node for Python pytree.
2. Do not allow registering a type as pytree node twice in the Python pytree.
3. Add thread lock to the Python pytree node register API.
4. The old `_register_pytree_node` API will call the `_private_register_pytree_node` API and raise a deprecation warning.
5. Add a new `register_pytree_node` API to register node type in both C++ and Python implementations.
6. Add tests to ensure a warning will be raised when the old private function is called.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112111
Approved by: https://github.com/zou3519
Reorganized the two C++ and Python pytree submodules into a subpackage. I think this would be easier to implement the abstract `PyTreeAPI` class with two implementations. And it will be much easier for the user to switch between the two implementations.
Before:
```text
torch
├── utils
│ ├── _pytree.py
│ ├── _cxx_pytree.py
│ ...
...
```
After:
```text
torch
├── utils
│ ├── _pytree
│ │ ├── __init__.py
│ │ └── api
│ │ ├── __init__.py
│ │ ├── cxx.py
│ │ └── python.py
│ ...
...
```
The `torch.utils._pytree` module will import all APIs from `torch.utils._pytree.api.python`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112278
Approved by: https://github.com/zou3519
ghstack dependencies: #112111
Reorganized the two C++ and Python pytree submodules into a subpackage. I think this would be easier to implement the abstract `PyTreeAPI` class with two implementations. And it will be much easier for the user to switch between the two implementations.
Before:
```text
torch
├── utils
│ ├── _pytree.py
│ ├── _cxx_pytree.py
│ ...
...
```
After:
```text
torch
├── utils
│ ├── _pytree
│ │ ├── __init__.py
│ │ └── api
│ │ ├── __init__.py
│ │ ├── cxx.py
│ │ └── python.py
│ ...
...
```
The `torch.utils._pytree` module will import all APIs from `torch.utils._pytree.api.python`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112278
Approved by: https://github.com/zou3519
ghstack dependencies: #112111
Previously we added a change which required users to pass in a serialized name if they want to serialize a pytree so that the serialized name does not depend on the python environment. However this is currently breaking AOTInductor benchmark tests as AOTInductor will serialize the pytree into the .so for flattening/unflattening the inputs. However, the registration for those pytree types in the AOTInductor benchmarks are in the huggingface repo, so I'm not sure what's a good fix for now.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112748
Approved by: https://github.com/zhxchen17, https://github.com/malfet
Custom classes that are serialized with pytree are serialized by default with `f”{class.__module__}.{class.__name__}”`. This is a dependency from our serialized program directly into the outer Python environment. If a user moves the class to a different directory, the serialized program will be unable to be loaded. So, we will require users to pass in an FQN if they want to serialize their custom treespec type.
Differential Revision: [D50886366](https://our.internmc.facebook.com/intern/diff/D50886366)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112428
Approved by: https://github.com/suo