Commit Graph

20 Commits

Author SHA1 Message Date
b11593c31b [8/N] Apply ruff UP035 rule (#165214)
This is follow-up of #164653 to continue applying `UP035` fixes. The purpose is to finally enable this rule.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165214
Approved by: https://github.com/ezyang
2025-10-15 03:18:57 +00:00
11c07c848c [BE][14/16] fix typos in torch/ (torch/fx/) (#156604)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/156604
Approved by: https://github.com/jingsh
ghstack dependencies: #156318, #156320, #156602
2025-07-02 22:55:29 +00:00
a821d69d92 Fix register constant to be usable in exportz (#147533)
Differential Revision: [D69939737](https://our.internmc.facebook.com/intern/diff/D69939737)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147533
Approved by: https://github.com/zou3519
2025-02-25 21:10:47 +00:00
cf05f6a134 [BE]: Improve typing for torch/fx/_pytree.py and torch/utils/_pytree.py (#145173)
Improve type inference in _pytree.py utility functions

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145173
Approved by: https://github.com/bobrenjc93
2025-01-20 22:18:19 +00:00
0b2a3687b9 PEP585 update - torch/fx (#145166)
See #145101 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145166
Approved by: https://github.com/bobrenjc93
2025-01-20 18:11:54 +00:00
53256edff9 [export] Support module inputs for non strict mode. (#143925)
Summary:
Add experimental support for torch.nn.Module as input types.

Before this change, we don't support module inputs but recently we saw some interesting use cases like gpt-fast https://github.com/pytorch-labs/gpt-fast/blob/main/generate.py#L68 where we directly pass in a module input for different variants of the same models.

Since we don't really care about non-param or non-buffer states in non strict mode, we don't care about those either and pretend they are like plain constants during tracing. We treat any module input like a nested container of tensor, and each time we will automatically register a pytree handler for these module types to flatten its state dict into a group of tensors. We will just inline any module method call during tracing like we did for `self` module in export_for_training. This will make input modules' behavior very similar to the training module in typical case, except that we don't record the inputs as parameter or buffers but rather just plain user inputs.

Test Plan: buck run mode/opt caffe2/test:test_export -- -r test_module_input

Differential Revision: D67680827

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143925
Approved by: https://github.com/tugsbayasgalan
2025-01-16 17:30:36 +00:00
f3fce597e9 [BE][Easy][17/19] enforce style for empty lines in import segments in torch/[a-c]*/ and torch/[e-n]*/ (#129769)
See https://github.com/pytorch/pytorch/pull/129751#issue-2380881501. Most changes are auto-generated by linter.

You can review these PRs via:

```bash
git diff --ignore-all-space --ignore-blank-lines HEAD~1
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129769
Approved by: https://github.com/ezyang
2024-08-04 10:24:09 +00:00
038b927590 Flip default value for mypy disallow_untyped_defs [7/11] (#127844)
See #127836 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127844
Approved by: https://github.com/oulgen
ghstack dependencies: #127842, #127843
2024-06-08 18:49:45 +00:00
f7e79299c7 register torch.return_types in torch.fx._pytree (#120027)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120027
Approved by: https://github.com/lezcano, https://github.com/zou3519, https://github.com/XuehaiPan
ghstack dependencies: #119284
2024-02-23 21:52:42 +00:00
42062e2622 [pytree][BE] update treespec is_leaf() access (#116371)
Change `isinstance(treespec, LeafSpec) -> treespec.is_leaf()`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116371
Approved by: https://github.com/zou3519
2024-01-27 11:44:57 +00:00
199e07f108 [pytree][BE] update treespec num_children access (#116370)
Change `len(treespec.children_spes) -> treespec.num_children`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116370
Approved by: https://github.com/Skylion007
2023-12-24 20:54:32 +00:00
2a3d8e50fb [pytree] test aligned API signature for C++ and Python pytree (#112485)
Add tests to ensure the C++ and Python pytree provide the same APIs with identical signatures.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112485
Approved by: https://github.com/zou3519
2023-11-30 17:50:06 +00:00
5e2adc8650 [pytree] align function signature between C++ and Python pytree (#112482)
Change the argument name in C++ and Python pytree APIs. Also add a test to ensure the function signatures are the same in the two implementations.

- #112485

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112482
Approved by: https://github.com/zou3519
2023-11-10 02:37:48 +00:00
66150b29e3 Revert "[pytree] align function signature between C++ and Python pytree (#112482)"
This reverts commit 4893a2814ffb5adeec102c17d71d2f25ba5eeb3c.

Reverted https://github.com/pytorch/pytorch/pull/112482 on behalf of https://github.com/PaliC due to changing _register_pytree_node's signature is bc breaking, please revert the signature and reland ([comment](https://github.com/pytorch/pytorch/pull/112482#issuecomment-1804909926))
2023-11-10 00:59:23 +00:00
4893a2814f [pytree] align function signature between C++ and Python pytree (#112482)
Change the argument name in C++ and Python pytree APIs. Also add a test to ensure the function signatures are the same in the two implementations.

- #112485

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112482
Approved by: https://github.com/zou3519
2023-11-07 01:26:41 +00:00
ac967e9dad [export] Fix tree spec matching behavior. (#109679)
Summary:

Test Plan:
Internal test.
Reviewers:

Subscribers:

Tasks:

Tags:

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/109679
Approved by: https://github.com/angelayi, https://github.com/tugsbayasgalan
2023-09-21 14:24:09 +00:00
78b28e884a Fix error formatting string (#105935)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105935
Approved by: https://github.com/Skylion007
2023-07-26 01:20:19 +00:00
3ce1ebb6fb Apply some safe comprehension optimizations (#94323)
Optimize unnecessary collection cast calls, unnecessary calls to list, tuple, and dict, and simplify calls to the sorted builtin. This should strictly improve speed and improve readability.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94323
Approved by: https://github.com/albanD
2023-02-07 23:53:46 +00:00
52d1ffb789 Teach pytrees about namedtuple (#62292)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62292

This PR adds pytree support for namedtuples. The challenge about namedtuple
is that each namedtuple class is actually different. This PR does the
following:
- it adds a namedtuple flatten/unflatten. The flatten function returns
a context that is the actual type of the namedtuple subclass. The
unflatten function uses that type to reconstruct the namedtuple
- Special cases all pytree logic to consider all namedtuples the same.
This is done by creating a `_get_node_type(pytree)` helper function that
returns `namedtuple` if `pytree` is any namedtuple subclass. The effect
of this is that all namedtuple subclasses will go through the namedtuple
flatten/unflatten functions
- Adds a `_namedtuple_flatten_spec` function for FX pytrees. This function
flattens the namedtuple based on the spec and is equivalent to the
`_tuple_flatten_spec`.

Test Plan
- new tests in test/test_pytree.py and test/test_fx.py

Test Plan: Imported from OSS

Reviewed By: albanD

Differential Revision: D29947302

Pulled By: zou3519

fbshipit-source-id: 19c00665b13546642c315df0f243ad99b8e7ff7c
2021-07-28 06:27:44 -07:00
8d363d37da [FX] Adds PyTree support to FX through concrete_args (#55888)
Summary:
```
class Foo(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, y, x):
        for k in x:
            for v in x[k]:
                v += y
        return x

example_dict = {'x': {'a': [fx.HOLE], 'z': [fx.HOLE, fx.HOLE]}}
new_f = fx.symbolic_trace(Foo(), concrete_args=example_dict)
print(new_f.code)
new_f(torch.randn(5), {'x': {'a': [torch.randn(5)], 'z': [torch.randn(5), torch.randn(5)]}})

fx.symbolic_trace(new_f, concrete_args=example_dict)
```

prints out
```
def forward(self, y, x):
    y, tree_2, tree_3, tree_4 = pytree.tree_flatten([y, x])[0]
    add = tree_2 + y
    add_1 = tree_3 + y
    add_2 = tree_4 + y;  y = None
    return {'a': [tree_2], 'z': [tree_3, tree_4]}
```

Currently, I store `in_spec` as an extra attribute on `fx.Graph`, and then include it when we do the codegen. I'm not sure if this is the right approach - it introduces a divergence between what's in `fx.Graph` and what's in the python code.

Perhaps the best API is something explicit like `fx.Graph.flatten_args`, but that does make calling things a bit ... more verbose.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/55888

Reviewed By: jamesr66a

Differential Revision: D27884694

Pulled By: Chillee

fbshipit-source-id: f9e8a70c63a8df63c9f9bd0a6459255daa5a8df8
2021-05-07 04:48:35 -07:00