73 Commits

Author SHA1 Message Date
2fa0520a64 [BE][pytree] cleanup parameterized pytree tests (#160842)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160842
Approved by: https://github.com/Skylion007
2025-09-05 20:15:29 +00:00
30293b8b5e Preserve Enum types during torch.export serialization and deserialization (#154821)
Fixes #154674

Addresses an issue where `torch.export` does not correctly preserve Python `Enum` types during the save/load round-trip. Previously, Enum inputs were serialized by value only, causing their type to be lost after deserialization.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154821
Approved by: https://github.com/XuehaiPan, https://github.com/Skylion007, https://github.com/yushangdi, https://github.com/angelayi
2025-06-08 17:30:31 +00:00
60fe0922f6 [pytree] Register normal class to register_dataclass (#147752)
Fixes https://github.com/pytorch/pytorch/pull/147532#discussion_r1964365330

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147752
Approved by: https://github.com/zou3519
2025-04-01 23:28:20 +00:00
a10b765bf1 [pytree] add APIs to determine a class is a namedtuple or PyStructSequence (#113257)
Changes in this PR:

1. Add `is_structseq` and `is_structseq_class` functions to determine a object or a class is PyStructSequence.
2. Add a generic class `structseq` which can be used as the registration key for PyStructSequence types like `namedtuple` for Named Tuple types.
3. Change `is_namedtuple` to accept subclasses of namedtuple to be namedtuple. Before this PR, only namedtuple class directly created by `collections.namedtuple` or `typing.NamedTuple` were namedtuple classes while their subclasses were not. This PR makes `is_namedtuple` return true for subclasses of namedtuple class.

Resolves #75982. New tests are included in this PR.

- #75982

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113257
Approved by: https://github.com/zou3519
2025-04-01 10:40:43 +00:00
f9b4856989 Revert "[pytree] add APIs to determine a class is a namedtuple or PyStructSequence (#113257)"
This reverts commit c95a6b416b4d1b830535f82e2719c055d077cbad.

Reverted https://github.com/pytorch/pytorch/pull/113257 on behalf of https://github.com/ZainRizvi due to Sorry but this is breaking internally. @zou3519 can you please help land this internally? See the sigmoid tests in D71198793 for details. To validate the fixes internally, you can follow the instructions here: https://fburl.com/fixing-ghfirst-reverts ([comment](https://github.com/pytorch/pytorch/pull/113257#issuecomment-2725982539))
2025-03-14 23:13:34 +00:00
c95a6b416b [pytree] add APIs to determine a class is a namedtuple or PyStructSequence (#113257)
Changes in this PR:

1. Add `is_structseq` and `is_structseq_class` functions to determine a object or a class is PyStructSequence.
2. Add a generic class `structseq` which can be used as the registration key for PyStructSequence types like `namedtuple` for Named Tuple types.
3. Change `is_namedtuple` to accept subclasses of namedtuple to be namedtuple. Before this PR, only namedtuple class directly created by `collections.namedtuple` or `typing.NamedTuple` were namedtuple classes while their subclasses were not. This PR makes `is_namedtuple` return true for subclasses of namedtuple class.

Resolves #75982. New tests are included in this PR.

- #75982

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113257
Approved by: https://github.com/zou3519
2025-03-14 08:50:30 +00:00
ebd087e4b5 Revert "[pytree] add APIs to determine a class is a namedtuple or PyStructSequence (#113257)"
This reverts commit f08146b67bab331f7bdc9fa247f526f6e60a7190.

Reverted https://github.com/pytorch/pytorch/pull/113257 on behalf of https://github.com/jovianjaison due to breaking internal builds ([comment](https://github.com/pytorch/pytorch/pull/113257#issuecomment-2711299830))
2025-03-10 17:19:21 +00:00
098494e9cb [dynamo] allow global import from collections import deque in user code (#148676)
See https://github.com/pytorch/pytorch/pull/148669#discussion_r1983462218 for more details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148676
Approved by: https://github.com/jansel
2025-03-10 13:14:05 +00:00
19a39a7a06 Revert "[dynamo] allow global import from collections import deque in user code (#148676)"
This reverts commit 685fb377131cc684633dc5471e77038988db53f6.

Reverted https://github.com/pytorch/pytorch/pull/148676 on behalf of https://github.com/malfet due to Looks like it broke ROCM, see f1444f006c/1(default%2C%201&mergeLF=true ([comment](https://github.com/pytorch/pytorch/pull/148676#issuecomment-2709057326))
2025-03-09 20:42:03 +00:00
685fb37713 [dynamo] allow global import from collections import deque in user code (#148676)
See https://github.com/pytorch/pytorch/pull/148669#discussion_r1983462218 for more details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148676
Approved by: https://github.com/jansel
2025-03-09 09:35:29 +00:00
f08146b67b [pytree] add APIs to determine a class is a namedtuple or PyStructSequence (#113257)
Changes in this PR:

1. Add `is_structseq` and `is_structseq_class` functions to determine a object or a class is PyStructSequence.
2. Add a generic class `structseq` which can be used as the registration key for PyStructSequence types like `namedtuple` for Named Tuple types.
3. Change `is_namedtuple` to accept subclasses of namedtuple to be namedtuple. Before this PR, only namedtuple class directly created by `collections.namedtuple` or `typing.NamedTuple` were namedtuple classes while their subclasses were not. This PR makes `is_namedtuple` return true for subclasses of namedtuple class.

Resolves #75982. New tests are included in this PR.

- #75982

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113257
Approved by: https://github.com/zou3519
2025-03-06 18:59:02 +00:00
097b0d372a [pytree] fix previously failed dynamo tests (#148669)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148669
Approved by: https://github.com/zou3519
2025-03-06 17:59:29 +00:00
ad9a10aff0 [dynamo] Make nonstrict_trace work with some pytree.register_constant-ed instances (#148007)
As title, this enables `nonstrict_trace`-ed function to take in object
whose type has been `pytree.register_constant`-ed, as long as the object
existed outside the `torch.compile` region. This also forces Dynamo to
emit a `EQUALS_MATCH` guard on the object.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148007
Approved by: https://github.com/zou3519
ghstack dependencies: #148385
2025-03-05 21:28:26 +00:00
9abaaad6a8 [pytree][Easy] preserve dict keys in insertion order in CXX pytree (#130140)
`optree` and JAX pytree traversal the `dict` in sorted key ordering (see [Key Ordering for Dictionaries](https://github.com/metaopt/optree#key-ordering-for-dictionaries)). While in PyTorch Python pytree, we traversal the `dict` in insertion order. See also:

- #114392

This aligns the behavior of CXX pytree with Python pytree.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130140
Approved by: https://github.com/zou3519
2025-02-12 16:41:49 +00:00
0f768c7866 Barebones flat_apply HOP (#146060)
This PR:
- adds pytree.register_constant for registering a class to be treated as
  a constant by torch.compile/torch.fx
- adds a very barebones flat_apply HOP. This should be sufficient to get
  mark_traceable working. A lot more work is necessary to get the custom
  operator case working (when make_fx sees a custom operator with PyTree
  arg types, it needs to emit a call to the flat_apply HOP).
- I expect the flat_apply HOP to change a lot, I want to ship this in
  the current state to unblock the mark_traceable and custom ops
  work.

Test Plan:
- It's kind of difficult to test the barebones flat_apply HOP "works" so
  I added a really simple test.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146060
Approved by: https://github.com/StrongerXi, https://github.com/yanboliang
ghstack dependencies: #146059
2025-02-01 16:17:48 +00:00
373606928b Add torch.utils._pytree.register_dataclass (#146059)
This is an API that registers a dataclass as a pytree node.
It directly calls torch.export.register_dataclass, but we should
eventually inline that implementation here. I want to use this API for
something in compile and feel weird calling
torch.export.register_dataclass.

Test Plan:
- tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146059
Approved by: https://github.com/StrongerXi, https://github.com/angelayi, https://github.com/yanboliang
2025-02-01 16:17:48 +00:00
f013cfee38 [TreeSpec] Support enum in defaultdict (#144235)
Summary: Followup from D66269157, add support for enum in defaultdict.

Test Plan: Added unit test

Differential Revision: D67832100

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144235
Approved by: https://github.com/henrylhtsang, https://github.com/houseroad
2025-01-07 00:10:46 +00:00
d8c8ba2440 Fix unused Python variables in test/[e-z]* (#136964)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136964
Approved by: https://github.com/justinchuby, https://github.com/albanD
2024-12-18 23:02:30 +00:00
7edeb1005a [dynamo][pytree][2/N] make CXX pytree traceable: tree_flatten / tree_unflatten / tree_structure (#137398)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137398
Approved by: https://github.com/jansel
2024-12-12 18:05:25 +00:00
30d907c6fb When serializing treespec context, support enum as well (#141525)
Following https://github.com/pytorch/pytorch/pull/102716, per @angelayi's suggestion.

Note that in general enum as an input is not supported.

repro:
```
class TestEnum(enum.Enum):
    A = auto()
    B = auto()

    @staticmethod
    def from_string(s):
        return TestEnum[s.upper()]

class M(torch.nn.Module):
    def forward(self, x, en):
        return x.clone()

input1 = (
    torch.rand(10, device="cuda"),
    {TestEnum.A: torch.rand(10, device="cuda")},
)
inputs = [input1]
model = M().cuda()

_ = model(*input1)

ep = torch.export.export(model, input1, strict=False)
path = torch._inductor.aot_compile(ep.module(), input1)
```

Differential Revision: D66269157
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141525
Approved by: https://github.com/angelayi
2024-12-04 03:08:50 +00:00
4226ed1585 [BE] Format uncategorized Python files with ruff format (#132576)
Remove patterns `**`, `test/**`, and `torch/**` in `tools/linter/adapters/pyfmt_linter.py` and run `lintrunner`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132576
Approved by: https://github.com/ezyang, https://github.com/Skylion007
ghstack dependencies: #132574
2024-08-04 17:13:31 +00:00
480ae51f85 [pytree] Only import optree if it's used (#131478)
torch.utils._pytree imports optree if it's available. Instead, we change
it to if it gets used. The motivation for this is better isolation.

Test Plan:
- new tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131478
Approved by: https://github.com/albanD
2024-07-24 00:10:49 +00:00
ba48cf6535 [BE][Easy][6/19] enforce style for empty lines in import segments in test/ (#129757)
See https://github.com/pytorch/pytorch/pull/129751#issue-2380881501. Most changes are auto-generated by linter.

You can review these PRs via:

```bash
git diff --ignore-all-space --ignore-blank-lines HEAD~1
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129757
Approved by: https://github.com/ezyang
2024-07-17 06:42:37 +00:00
67ef2683d9 [BE] wrap deprecated function/class with typing_extensions.deprecated (#127689)
Use `typing_extensions.deprecated` for deprecation annotation if possible. Otherwise, add `category=FutureWarning` to `warnings.warn("message")` if the category is missing.

Note that only warnings that their messages contain `[Dd]eprecat(ed|ion)` are updated in this PR.

Resolves #126888

- #126888

This PR is split from PR #126898.

- #126898

------

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127689
Approved by: https://github.com/Skylion007
2024-06-02 12:30:43 +00:00
033e733021 Revert "[BE] wrap deprecated function/class with typing_extensions.deprecated (#126898)"
This reverts commit 749a132fb0a8325cbad4734a563aa459ca611991.

Reverted https://github.com/pytorch/pytorch/pull/126898 on behalf of https://github.com/fbgheith due to switching typing-extensions=4.3.0 to 4.9.0 causes internal failure ([comment](https://github.com/pytorch/pytorch/pull/126898#issuecomment-2142884456))
2024-05-31 19:47:24 +00:00
749a132fb0 [BE] wrap deprecated function/class with typing_extensions.deprecated (#126898)
Use `typing_extensions.deprecated` for deprecation annotation if possible. Otherwise, add `category=FutureWarning` to `warnings.warn("message")` if the category is missing.

Note that only warnings that their messages contain `[Dd]eprecat(ed|ion)` are updated in this PR.

UPDATE: Use `FutureWarning` instead of `DeprecationWarning`.

Resolves #126888

- #126888

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126898
Approved by: https://github.com/albanD
2024-05-29 12:09:27 +00:00
1be2126ff6 [pytree] Fix namedtuple serialization (#123388)
Summary:
Previously we were serializing namedtuple treespecs incorrectly:
```python
Point = namedtuple("Point", ["x", "y"])
p = Point(1, 2)
flat, spec = pytree.tree_flatten(p)

print(flat)  # [1, 2]
print(spec)  # TreeSpec(type=namedtuple, context=Point, children=[*, *])

dumped_spec = pytree.treespec_dumps(spec)
print(dumped_spec)
"""
We only serialize the name of the class and the fields of the namedtuple:

TreeSpec {
  type='collections.namedtuple',
  context={class_name='Point', class_fields={'x', 'y'}},
  children=[Leaf, Leaf]
}
"""

reconstructed_spec = pytree.treespec_loads(dumped_spec)
print(reconstructed_spec)
"""
When we load, we create a new namedtuple class containing the same fields as before,
but the is class is now a completely different class than the original one:

TreeSpec(type=namedtuple, context=torch.utils._pytree.Point, children=[*, *])
"""

spec == reconstructed_spec  # False
```

So, we introduce a new API called `pytree._register_namedtuple` where users can pass in the serialized name for each namedtuple class:
```python
Point = namedtuple("Point", ["x", "y"])
pytree._register_namedtuple(Point, "Point")

p = Point(1, 2)
flat, spec = pytree.tree_flatten(p)

print(flat)  # [1, 2]
print(spec)  # TreeSpec(type=namedtuple, context=Point, children=[*, *])

dumped_spec = pytree.treespec_dumps(spec)
print(dumped_spec)
"""
TreeSpec {
  type='collections.namedtuple',
  context='Point',
  children=[Leaf, Leaf]
}
"""

reconstructed_spec = pytree.treespec_loads(dumped_spec)
print(reconstructed_spec)  # TreeSpec(type=namedtuple, context=Point, children=[*, *])

spec == reconstructed_spec  # True
```

Test Plan: `python test/test_pytree.py`

Differential Revision: D55771058

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123388
Approved by: https://github.com/zou3519
2024-04-08 20:55:19 +00:00
cbbc309cae [pytree][reland] Require pytree serialized_type_name (#120636)
Relanding https://github.com/pytorch/pytorch/pull/119718 as the diff which prevents breakages of torchrec [D53857843](https://www.internalfb.com/diff/D53857843) has landed
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120636
Approved by: https://github.com/avikchaudhuri
2024-02-27 06:53:33 +00:00
be0ee93467 [pytree] support X | Y union type in tree_map_only (#120389)
Follow-up PR for #119974 with some small tweaks.

1. Support `X | Y` union type for Python 3.10+
2. Enable predicate function in `tree_map_only` in CXX pytree.
3. Remove unnecessary function definition.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120389
Approved by: https://github.com/zou3519
2024-02-22 18:17:13 +00:00
2e77629b9f [pytrees] Allow tree_map_only to support predicate function as filter (#119974)
In many places in the code we use `tree_map_only((SymInt, SymBool, SymFloat), foo)` but with nested ints, it is possible to have SymInts that are non-symbolic, so we may want to do something like `tree_map_only(is_symbolic, foo)` instead.

Alternative: wrap nested int SymNodes with something other than SymInt.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119974
Approved by: https://github.com/zou3519
ghstack dependencies: #119661
2024-02-21 21:10:02 +00:00
3f4dd9bfa4 Back out "[pytree] Require serialized_type_name" (#120041)
Summary:
D53785493 breaks apf.rec.ir.tests.ir_export_deserialize_test.IRExportDeserializeTest: test_export_deserialize_ebc failed:

https://www.internalfb.com/sandcastle/workflow/3436246515685789584

Test Plan: buck2 test mode/opt apf/rec/ir/tests:ir_export_deserialize_test

Differential Revision: D53834881

Co-authored-by: Wilson Hong <wilsonhong@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120041
Approved by: https://github.com/ydwu4
2024-02-16 10:02:25 +00:00
b4c7afe101 [pytree] Require serialized_type_name (#119718)
Differential Revision: [D53785493](https://our.internmc.facebook.com/intern/diff/D53785493)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119718
Approved by: https://github.com/suo
2024-02-15 20:32:44 +00:00
1562dae62c [BE]: Apply RUF025 dict.fromkeys preview rule (#118637)
Simplifies and optimizes dict construction using the `fromkeys` classmethod ctor. This also makes it really obvious when all the keys will have the same static value, which could be a bug if unintentional. It is also significantly faster than using a dict comprehension. The rule is in preview, but I am adding a forward fix for when it becomes stable.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118637
Approved by: https://github.com/albanD
2024-01-30 20:46:54 +00:00
suo
e732adf0a7 [pytree] add access api (#117771)
This PR introduces an API to use KeyPaths to actually access values on pytrees.

Differential Revision: [D52881260](https://our.internmc.facebook.com/intern/diff/D52881260/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117771
Approved by: https://github.com/zou3519, https://github.com/XuehaiPan
2024-01-20 04:03:26 +00:00
suo
9448065061 [pytree] add key path api (#116786)
This PR introduces a key path API to pytrees, drawing direct inspiration from JAX's [key path API](https://jax.readthedocs.io/en/latest/jax-101/05.1-pytrees.html#key-paths).

I added the 3 APIs described there, and a registry of `flatten_with_keys` fns for each node type, which is a version of `flatten` that also returns `KeyEntry`s describing how to access values from the original pytree.

Current use cases for this API:
- Folks would like to do argument traversal over input pytrees to do verification and compatibility enforcement. Keypaths are useful for this—https://fburl.com/code/06p7zrvr is a handrolled pass doing basically the same thing but probably more fragilely.
- In export non-strict mode, we need to figure out a way to track sources for pytree inputs. In strict mode, dynamo handles this for us, but we'd like a decoupled component to handle this when we're not using dynamo.

I'm sure there are places it would be useful.

Some design notes:
- I only implemented the API for  the Python pytree impl. optree has some differences in how their keypath APIs are designed (see https://github.com/pytorch/pytorch/issues/113378 for discussion). I have some issues with the proposed typed_path solution in that discussion and prefer JAX's API, but we can hash that out separately.
- The way folks register a `flatten_with_keys` fn is through a new kwarg to `register_pytree_node`. This follows how we do serialization fns, although the list of additional arguments is getting unwieldy.
- My impl handles pytrees with an undefined `flatten_with_keys` fn is different from JAX. I will raise an error, JAX creates a fallback keyentry.

Differential Revision: [D52547850](https://our.internmc.facebook.com/intern/diff/D52547850/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116786
Approved by: https://github.com/voznesenskym
2024-01-17 07:24:35 +00:00
ab1ac43752 [pytree] extend pytree operations with is_leaf prediction function (#116419)
Add an extra `is_leaf` prediction function to pytree operations.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116419
Approved by: https://github.com/zou3519
2024-01-09 19:50:08 +00:00
suo
902807a86d enable pytree tests in fbcode (#116787)
these were not runnable before

Differential Revision: [D52547846](https://our.internmc.facebook.com/intern/diff/D52547846/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116787
Approved by: https://github.com/zou3519
2024-01-09 19:12:43 +00:00
36c6c0c7dc [pytree] expand tree_map to accept multi-inputs (#115642)
Fixes #115419
Fixes #91323
Closes #115549

- #115419
- #91323

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115642
Approved by: https://github.com/vmoens, https://github.com/zou3519
2023-12-14 06:16:42 +00:00
ec124b90b8 [pytree] hardcode values for none_is_leaf and namespace in C++ pytree (#114858)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114858
Approved by: https://github.com/zou3519
2023-12-01 15:01:33 +00:00
d6c0d1b58b [pytree] support collections.deque type for Python pytree (#113256)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113256
Approved by: https://github.com/zou3519
ghstack dependencies: #112485, #113255
2023-12-01 05:12:09 +00:00
2ab2e8e1c0 [pytree] support collections.defaultdict type for Python pytree (#113255)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113255
Approved by: https://github.com/zou3519
ghstack dependencies: #112485
2023-11-30 20:46:25 +00:00
2a3d8e50fb [pytree] test aligned API signature for C++ and Python pytree (#112485)
Add tests to ensure the C++ and Python pytree provide the same APIs with identical signatures.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112485
Approved by: https://github.com/zou3519
2023-11-30 17:50:06 +00:00
89a1fe6966 [pytree] register pytree node type in both C++ pytree and Python pytree (#112111)
Changes:

1. Add `_private_register_pytree_node` API in both C++ and Python pytree. In C++ pytree, the API will only register pytree node for C++ pytree. In Python pytree, the API will only register pytree node for Python pytree.
2. Do not allow registering a type as pytree node twice in the Python pytree.
3. Add thread lock to the Python pytree node register API.
4. The old `_register_pytree_node` API will call the `_private_register_pytree_node` API and raise a deprecation warning.
5. Add a new `register_pytree_node` API to register node type in both C++ and Python implementations.
6. Add tests to ensure a warning will be raised when the old private function is called.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112111
Approved by: https://github.com/zou3519
2023-11-28 11:41:38 +00:00
01366efcc9 Revert "[pytree] register pytree node type in both C++ pytree and Python pytree (#112111)"
This reverts commit 4e4a6ad6ecd71a1aefde3992ecf7f77e37d2e264.

Reverted https://github.com/pytorch/pytorch/pull/112111 on behalf of https://github.com/DanilBaibak due to Break internal build ([comment](https://github.com/pytorch/pytorch/pull/112111#issuecomment-1824099658))
2023-11-23 09:59:32 +00:00
4e4a6ad6ec [pytree] register pytree node type in both C++ pytree and Python pytree (#112111)
Changes:

1. Add `_private_register_pytree_node` API in both C++ and Python pytree. In C++ pytree, the API will only register pytree node for C++ pytree. In Python pytree, the API will only register pytree node for Python pytree.
2. Do not allow registering a type as pytree node twice in the Python pytree.
3. Add thread lock to the Python pytree node register API.
4. The old `_register_pytree_node` API will call the `_private_register_pytree_node` API and raise a deprecation warning.
5. Add a new `register_pytree_node` API to register node type in both C++ and Python implementations.
6. Add tests to ensure a warning will be raised when the old private function is called.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112111
Approved by: https://github.com/zou3519
2023-11-21 19:53:13 +00:00
23e0923c74 Revert "[pytree] reorganize submodule structure for C++ and Python pytree (#112278)"
This reverts commit eeeb40b32717bab75bd7d8f28f8343385688b3ab.

Reverted https://github.com/pytorch/pytorch/pull/112278 on behalf of https://github.com/PaliC due to Reverting this pr as the one under it in the stack is causing regressions in torchrec ([comment](https://github.com/pytorch/pytorch/pull/112278#issuecomment-1806044435))
2023-11-10 16:30:36 +00:00
eeeb40b327 [pytree] reorganize submodule structure for C++ and Python pytree (#112278)
Reorganized the two C++ and Python pytree submodules into a subpackage. I think this would be easier to implement the abstract `PyTreeAPI` class with two implementations. And it will be much easier for the user to switch between the two implementations.

Before:

```text
torch
├── utils
│   ├── _pytree.py
│   ├── _cxx_pytree.py
│   ...
...
```

After:

```text
torch
├── utils
│   ├── _pytree
│   │   ├── __init__.py
│   │   └── api
│   │       ├── __init__.py
│   │       ├── cxx.py
│   │       └── python.py
│   ...
...
```

The `torch.utils._pytree` module will import all APIs from `torch.utils._pytree.api.python`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112278
Approved by: https://github.com/zou3519
ghstack dependencies: #112111
2023-11-10 05:41:32 +00:00
bf452dcde6 Revert "[pytree] reorganize submodule structure for C++ and Python pytree (#112278)"
This reverts commit fa895da968ec6f1ae128ee95fcb96ba9addac8a0.

Reverted https://github.com/pytorch/pytorch/pull/112278 on behalf of https://github.com/PaliC due to in the bottom diff in the stack changing _register_pytree_node's signature is bc breaking, please revert the signature and reland ([comment](https://github.com/pytorch/pytorch/pull/112278#issuecomment-1804870560))
2023-11-10 00:12:52 +00:00
fa895da968 [pytree] reorganize submodule structure for C++ and Python pytree (#112278)
Reorganized the two C++ and Python pytree submodules into a subpackage. I think this would be easier to implement the abstract `PyTreeAPI` class with two implementations. And it will be much easier for the user to switch between the two implementations.

Before:

```text
torch
├── utils
│   ├── _pytree.py
│   ├── _cxx_pytree.py
│   ...
...
```

After:

```text
torch
├── utils
│   ├── _pytree
│   │   ├── __init__.py
│   │   └── api
│   │       ├── __init__.py
│   │       ├── cxx.py
│   │       └── python.py
│   ...
...
```

The `torch.utils._pytree` module will import all APIs from `torch.utils._pytree.api.python`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112278
Approved by: https://github.com/zou3519
ghstack dependencies: #112111
2023-11-08 06:05:39 +00:00
3904b81420 [pytree] Add back a default serialized name (#112748)
Previously we added a change which required users to pass in a serialized name if they want to serialize a pytree so that the serialized name does not depend on the python environment. However this is currently breaking AOTInductor benchmark tests as AOTInductor will serialize the pytree into the .so for flattening/unflattening the inputs. However, the registration for those pytree types in the AOTInductor benchmarks are in the huggingface repo, so I'm not sure what's a good fix for now.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112748
Approved by: https://github.com/zhxchen17, https://github.com/malfet
2023-11-02 22:34:42 +00:00