56 Commits

Author SHA1 Message Date
c6329524d8 Revert "Add magic TORCH_MAKE_PYBIND_ENUM_FASTER macro (#163527)"
This reverts commit 50c0550f5a5b1e35885d892081a7d5115d8b4489.

Reverted https://github.com/pytorch/pytorch/pull/163527 on behalf of https://github.com/swolchok due to breaking import torch in debug builds, see #164297 ([comment](https://github.com/pytorch/pytorch/pull/163527#issuecomment-3361919142))
2025-10-02 15:42:42 +00:00
50c0550f5a Add magic TORCH_MAKE_PYBIND_ENUM_FASTER macro (#163527)
See comment on the macro definition. In short, pybind11 3.x
added `py::native_enum`, and also had to add overhead for that new way
to bind enums on the critical path for calling functions that take
regular old `py::enum_`s as arguments (for example, `__eq__`).

Differential Revision: [D82873169](https://our.internmc.facebook.com/intern/diff/D82873169/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163527
Approved by: https://github.com/ezyang
2025-09-26 17:59:22 +00:00
97eb7a281d torchdim Python port (#160236)
The big semantic change (and the reason for this port) is that we no longer monkeypatch Tensor with torchdim's special methods. The new algorithm for handling dispatch is that we first land in `__torch_function__` and we see if a special FCD implementation needs to be dispatch to first, and if there is nothing we fallback to the standard level strategy.

Because there is no longer C binding equivalent of classes, we've condensed _C.Dim and Dim together, and similar for Tensor. This resulted in some bugs as the Python API is sometimes different from the C API. I've attempted to disambiguate these but there may still be mistakes (many early bugs were due to this problem). Dim and DimEntry are especially painful as Dim must abide by Tensor equality semantics, but is pointer equality in C (DimEntry doesn't have this problem). Another difference between C/Python that is subtle is we no longer get implicit conversions from Dim to DimEntry, this also caused some bugs.

Much of the mechanical porting work was done by claude code. I have a separate PR that deletes functorch._C, but it was useful having dim.cpp to point claude at it so I haven't done it in this PR. From a reviewing perspective, I need to re-review that I didn't forget to port anything, some noticeably missing "small" things are patched_dim_method. I am still in progress of carefully doing a side-by-side review of ports; "simplifications" from claude code were also a major source of bugs.

There are two major feature gaps in the implementation:

- DelayedTensor and dot handling are not implemented yet. This should be reasonably easy, just need to do it.  However, for the purposes of sharded propagation it is actually better not to reconstruct matmuls.
- Splitting dimensions with an index like `[x, y]` doesn't work. The problem is that `__getitem__` interprets this as advanced indexing and sends the list to torch.tensor to turn into a tensor, instead of being eligible for `__torch_function__`. I think I might need to hard code a special case for this or something?

Signed-off-by: Edward Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160236
Approved by: https://github.com/zdevito, https://github.com/albanD
2025-09-21 03:01:04 +00:00
ffd58293f7 [dynamo] Guard serialization for FUNCTORCH_STACK_MATCH (#152616)
Make Functorch interpreters serializable most of the time, so that we can save the guards on functorch states.

## Test Cases:

0. torch.compile() without functorch layers present. Guard should fail with any layer being pushed.
1. torch.compile() nested in vmap.
2. torch.compile() nested in grad.
3. torch.compile() nested in jvp + vmap
4. torch.compile() nested functionalize
5. torch.compile() nested in vmap + grad

Differential Revision: [D74008787](https://our.internmc.facebook.com/intern/diff/D74008787/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152616
Approved by: https://github.com/zou3519
ghstack dependencies: #152615
2025-05-05 18:05:56 +00:00
cyy
8fa81a6066 Enable misc-use-internal-linkage check and apply fixes (#148948)
Enables clang-tidy rule [`misc-use-internal-linkage`](https://clang.llvm.org/extra/clang-tidy/checks/misc/use-internal-linkage.html). This new check was introduced in Clang-Tidy 18 and is available due to recent update of Clang-Tidy 19.

The check marks functions and variables used only in the translation unit as static. Therefore undesired symbols are not leaked into other units, more link time optimisations are possible and the resulting binaries may be smaller.

The detected violations were mostly fixed by using static. In other cases, the symbols were indeed consumed by others files, then their declaring headers were included. Still some declarations were wrong and have been fixed.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148948
Approved by: https://github.com/Skylion007
2025-03-12 14:22:56 +00:00
8a81f7a4b6 Refactor functions in functorch for functional (#141808)
As the title stated
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141808
Approved by: https://github.com/Skylion007
2024-11-30 20:15:40 +00:00
6b8e3022f2 Remove c10::optional usages in PyTorch (#139525)
Test Plan: Sandcastle

Reviewed By: swolchok

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139525
Approved by: https://github.com/malfet, https://github.com/Skylion007
2024-11-04 15:35:23 +00:00
cyy
0c0d8c8ff0 [1/N] Fix extra warnings brought by clang-tidy-17 (#137407)
Before we can use clang-tidy-17
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137407
Approved by: https://github.com/Skylion007, https://github.com/aaronenyeshi
2024-10-07 17:53:59 +00:00
a408cfcbf1 [torch.compile] torch.vmap supports dynamic shapes + enable flex attention create_block_mask dynamic shapes (#137163)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137163
Approved by: https://github.com/Chillee
2024-10-04 05:16:04 +00:00
cyy
6b12dc0224 [Reland] [11/N] Use std::nullopt and std::optional (#132622)
Reland of #132396, which was reverted due to dependency reversion.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132622
Approved by: https://github.com/ezyang
2024-08-05 20:36:33 +00:00
e4e3575fb0 Revert "[11/N] Use std::nullopt and std::optional (#132396)"
This reverts commit d7d61904936617a6a43782868d0b1004cb70dfc0.

Reverted https://github.com/pytorch/pytorch/pull/132396 on behalf of https://github.com/ZainRizvi due to Sorry, but this PR has a dependency on another PR (https://github.com/pytorch/pytorch/pull/128898) that has to be reverted ([comment](https://github.com/pytorch/pytorch/pull/132396#issuecomment-2265952528))
2024-08-02 18:49:42 +00:00
cyy
d7d6190493 [11/N] Use std::nullopt and std::optional (#132396)
Follows #132364
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132396
Approved by: https://github.com/ezyang
2024-08-01 14:46:33 +00:00
cyy
f4dcf2ae93 [1/N] Change #include <c10/util/Optional.h> to #include <optional> (#128301)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128301
Approved by: https://github.com/ezyang, https://github.com/r-barnes
2024-07-08 07:03:53 +00:00
846bb30e13 Revert "[1/N] Change #include <c10/util/Optional.h> to #include <optional> (#128301)"
This reverts commit bd72e28314d8d63bb347becb8309f5ac7761c6b5.

Reverted https://github.com/pytorch/pytorch/pull/128301 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it fails XLA build bd72e28314. Please rebase your PR before relanding because I think the failure is hidden by an unrelated broken trunk XLA failure from your current base commit ([comment](https://github.com/pytorch/pytorch/pull/128301#issuecomment-2169035822))
2024-06-15 01:58:20 +00:00
cyy
bd72e28314 [1/N] Change #include <c10/util/Optional.h> to #include <optional> (#128301)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128301
Approved by: https://github.com/ezyang
2024-06-14 23:21:01 +00:00
70dc59c55f Fix perf regression caused by #122074 (#126996)
The original change was about 9.5% slower than then before #122074 .
This improves it to be only about 1.4% slower.

Also touched up some unrelated nits that the linter complained about.

Fixes #126293

Ran torchbench 3 times on each change. Perf values before (stable), after (fix),
and with #122074 backed out (backout):
```
../inductor-tools/scripts/modelbench/inductor_single_run.sh single inference performance torchbench pyhpc_isoneutral_mixing amp first dynamic cpp
stable:
43.948x
45.754x
44.906x

fix:
47.505x
49.987x
47.493x

backout:
48.243x
48.199x
48.192x

../inductor-tools/scripts/modelbench/inductor_single_run.sh single inference performance torchbench pyhpc_equation_of_state amp first static default
stable:
15.224x
13.286x
15.354x

fix:
16.402x
16.370x
16.183x

backout:
16.554x
16.675x
16.787x

../inductor-tools/scripts/modelbench/inductor_single_run.sh single inference performance torchbench lennard_jones float32 first static default
stable:
1.712x
1.651x
1.640x

fix:
1.804x
1.798x
1.792x

backout:
1.864x
1.824x
1.836x
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126996
Approved by: https://github.com/jansel
2024-05-24 04:27:22 +00:00
ed327876f5 [codemod] c10:optional -> std::optional (#126135)
Generated by running the following from PyTorch root:
```
find . -regex ".*\.\(cpp\|h\|cu\|hpp\|cc\|cxx\)$" | grep -v "build/" | xargs -n 50 -P 4 perl -pi -e 's/c10::optional/std::optional/'
```

`c10::optional` is just an alias for `std::optional`. This removes usages of that alias in preparation for eliminating it entirely.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126135
Approved by: https://github.com/Skylion007, https://github.com/malfet, https://github.com/albanD, https://github.com/aaronenyeshi
2024-05-14 19:35:51 +00:00
b23b6e7108 Ensure that vmap is restored properly if an exception is thrown during frame eval (#122074)
We save and restore the DynamicLayerStack during frame eval but since fx graph has no way to express a try/finally we just assume it will happen. If we throw an exception between the push and pop to the stack then we're left in a state that affects following operations poorly.  Make sure that if it's in a bad state we restore it after frame eval.

Repro:
before:
```
$ rm test/dynamo_skips/TestSparseCPU.test_log1p_cpu_uint8
$ rm test/dynamo_expected_failures/FuncTorchHigherOrderOpTests.test_vmap_free_tensor
$ PYTORCH_TEST_WITH_DYNAMO=1 pytest test/jit/test_sparse.py test/dynamo/test_dynamic_shapes.py test/inductor/test_torchinductor_dynamic_shapes.py test/test_sparse.py -k 'test_log1p_cpu_uint8'
============= 1 passed, 8588 deselected in 9.75s =============
$ PYTORCH_TEST_WITH_DYNAMO=1 pytest test/jit/test_sparse.py test/dynamo/test_dynamic_shapes.py test/inductor/test_torchinductor_dynamic_shapes.py test/test_sparse.py -k
'test_vmap_free_tensor_dynamic_shapes or test_log1p_cpu_uint8'
================== short test summary info ===================
FAILED [0.0632s] test/test_sparse.py::TestSparseCPU::test_log1p_cpu_uint8 - AssertionError: "only Tensors of floating point dtype can require gradients"
does not match "You are attempting to call Tensor.requires_grad_() (or perhaps using torch.autograd.functional.* APIs) inside of a function ...
======= 1 failed, 1 skipped, 8587 deselected in 10.99s =======
```
(Note that adding test_vmap_free_tensor_dynamic_shapes causes test_vmap_free_tensor_dynamic_shapes to fail)
after:
```
$ rm test/dynamo_skips/TestSparseCPU.test_log1p_cpu_uint8
$ rm test/dynamo_expected_failures/FuncTorchHigherOrderOpTests.test_vmap_free_tensor
$ PYTORCH_TEST_WITH_DYNAMO=1 pytest test/jit/test_sparse.py test/dynamo/test_dynamic_shapes.py test/inductor/test_torchinductor_dynamic_shapes.py test/test_sparse.py -k 'test_log1p_cpu_uint8'
============= 1 passed, 8588 deselected in 9.89s =============
$ PYTORCH_TEST_WITH_DYNAMO=1 pytest test/jit/test_sparse.py test/dynamo/test_dynamic_shapes.py test/inductor/test_torchinductor_dynamic_shapes.py test/test_sparse.py -k
'test_vmap_free_tensor_dynamic_shapes or test_log1p_cpu_uint8'
======= 1 passed, 1 skipped, 8587 deselected in 11.34s =======
```
(test_vmap_free_tensor_dynamic_shapes passes either way)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122074
Approved by: https://github.com/oulgen
2024-05-07 19:36:52 +00:00
4eaa000acc Teach dynamo about torch.func.jvp (#119926)
List of changes:
- Replace JVP_NESTING by torch._C._functorch.maybe_current_level()
- Remove all increment nesting functions from wrap_fx_proxy_cls
- fwAD.make_dual receives the dual_level as keyword argument
- Add jvp_increment_nesting, set_fwd_grad_enabled and dual_level context managers to dynamo

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119926
Approved by: https://github.com/zou3519
2024-03-22 20:25:47 +00:00
0696db8202 Revert "Teach dynamo about torch.func.jvp (#119926)"
This reverts commit 17489784b635187316c6c856c5fe6b6a28d8a15a.

Reverted https://github.com/pytorch/pytorch/pull/119926 on behalf of https://github.com/peterbell10 due to broken mac jobs on main ([comment](https://github.com/pytorch/pytorch/pull/119926#issuecomment-2010327997))
2024-03-20 18:34:43 +00:00
17489784b6 Teach dynamo about torch.func.jvp (#119926)
List of changes:
- Replace JVP_NESTING by torch._C._functorch.maybe_current_level()
- Remove all increment nesting functions from wrap_fx_proxy_cls
- fwAD.make_dual receives the dual_level as keyword argument
- Add jvp_increment_nesting, set_fwd_grad_enabled and dual_level context managers to dynamo

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119926
Approved by: https://github.com/zou3519
2024-03-20 13:09:19 +00:00
36e5c1dcab Revert "Teach dynamo about torch.func.jvp (#119926)"
This reverts commit edd04b7c16cc6715411119bb7db234a9df59065f.

Reverted https://github.com/pytorch/pytorch/pull/119926 on behalf of https://github.com/jeanschmidt due to lots of breakages in pull jobs, checking if reverting this one will help ([comment](https://github.com/pytorch/pytorch/pull/119926#issuecomment-2007915919))
2024-03-19 18:59:46 +00:00
edd04b7c16 Teach dynamo about torch.func.jvp (#119926)
List of changes:
- Replace JVP_NESTING by torch._C._functorch.maybe_current_level()
- Remove all increment nesting functions from wrap_fx_proxy_cls
- fwAD.make_dual receives the dual_level as keyword argument
- Add jvp_increment_nesting, set_fwd_grad_enabled and dual_level context managers to dynamo

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119926
Approved by: https://github.com/zou3519
2024-03-19 13:06:42 +00:00
3319dbcd23 Update vmap guard to avoid recompilations (#119061)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119061
Approved by: https://github.com/zou3519
2024-02-13 20:50:23 +00:00
80cf0ce153 Enhance torch.vmap support from inside torch.compile (#116050)
This work rewrites vmap support in torch.compile by inlining most of
the frames into the existing FX graph. It also unlocks to PyTorch to
support features that were previously missing, such as keyword args.

Fixes: https://github.com/pytorch/pytorch/issues/114306

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116050
Approved by: https://github.com/zou3519
2024-01-22 17:53:45 +00:00
cyy
20f769544c [12/N] Apply clang-tidy and fix warnings in headers of torch/csrc (#116486)
This PR follows #116751.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116486
Approved by: https://github.com/albanD
2024-01-10 08:48:14 +00:00
0aa50909f3 Revert "[12/N] Apply clang-tidy and fix warnings in headers of torch/csrc (#116486)"
This reverts commit 5aa258eb09d5ecd62aea4d2bd02bbfa5eda0d554.

Reverted https://github.com/pytorch/pytorch/pull/116486 on behalf of https://github.com/izaitsevfb due to Reverting, as it depends on https://github.com/pytorch/pytorch/pull/116353, which has to be reverted ([comment](https://github.com/pytorch/pytorch/pull/116486#issuecomment-1876042948))
2024-01-03 22:18:54 +00:00
cyy
5aa258eb09 [12/N] Apply clang-tidy and fix warnings in headers of torch/csrc (#116486)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116486
Approved by: https://github.com/albanD
2023-12-30 18:38:53 +00:00
165f4f6ccf [PyTorch] Redirect c10::optional to std::optional (#101995)
We have C++17 now!

I am intentionally dropping the `c10::optional<c10::ArrayRef>` size optimization. It was intended to improve dispatch, but thanks to D34602980 / #70864 we don't use `optional<ArrayRef>` in function arguments anymore anyway.

Differential Revision: [D46079028](https://our.internmc.facebook.com/intern/diff/D46079028/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/101995
Approved by: https://github.com/malfet, https://github.com/Skylion007, https://github.com/ezyang
2023-11-30 02:46:41 +00:00
7ce69d5dbe [RELAND] Remove some unnecessary <iostream> includes from headers (#108150)
In almost all cases this is only included for writing the output formatter, which
only uses `std::ostream` so including `<ostream>` is sufficient.

The istream header is ~1000 lines so the difference is non-trivial.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108150
Approved by: https://github.com/albanD, https://github.com/malfet
ghstack dependencies: #108149
2023-09-20 21:55:15 +00:00
b928e08f3d Initial vmap + NT support with unbind fallback (#106786)
PoC demonstrating vmap + NT based on the [design doc](https://docs.google.com/document/d/1dVVk6TOqz93PLTIneU2T3xaxCs9qZ0MaJyCvOAp_bC0). This PR:
* Allows `BatchedTensorImpl`s to contain NTs
* Introduces a `BatchedNestedTensor` dispatch key for NT-specific batching rules
* Provides a batching rule fallback that unbinds the NTs -> performs computation on constituent -> rebinds results into NT

Restrictions:
* Only supports one level of vmap
* Only supports vmapping over dim=0 for NTs
    * For operations with mixed NT / dense inputs, support is also limited to dim=0 for the dense inputs
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106786
Approved by: https://github.com/zou3519
2023-09-07 13:53:20 +00:00
378ffde8c1 Revert "Remove some unnecessary <iostream> includes from headers (#106914)"
This reverts commit a6c29b722772816804d54eed070fbb38450d3e6f.

Reverted https://github.com/pytorch/pytorch/pull/106914 on behalf of https://github.com/izaitsevfb due to Causing metal breakage internally, see D48709279 ([comment](https://github.com/pytorch/pytorch/pull/106914#issuecomment-1696670027))
2023-08-29 02:22:33 +00:00
cyy
054f3f1d8f [3/N] fix clang-tidy warnings in torch/csrc (#108024)
Apply fixes to some found issues by clang-tidy in torch/csrc.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108024
Approved by: https://github.com/Skylion007, https://github.com/albanD, https://github.com/malfet
2023-08-28 18:00:00 +00:00
a6c29b7227 Remove some unnecessary <iostream> includes from headers (#106914)
In almost all cases this is only included for writing the output formatter, which
only uses `std::ostream` so including `<ostream>` is sufficient.

The istream header is ~1000 lines so the difference is non-trivial.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106914
Approved by: https://github.com/lezcano
2023-08-25 18:24:05 +00:00
28dc1a093f Revert "Remove some unnecessary <iostream> includes from headers (#106914)"
This reverts commit 60936e4c296e79f56cac2431a560970bb4529d03.

Reverted https://github.com/pytorch/pytorch/pull/106914 on behalf of https://github.com/ZainRizvi due to Sorry, but this is breaking internal builds. Seems like a lot of internal code depends on some of the removed imports ([comment](https://github.com/pytorch/pytorch/pull/106914#issuecomment-1688605975))
2023-08-22 17:16:48 +00:00
60936e4c29 Remove some unnecessary <iostream> includes from headers (#106914)
In almost all cases this is only included for writing the output formatter, which
only uses `std::ostream` so including `<ostream>` is sufficient.

The istream header is ~1000 lines so the difference is non-trivial.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106914
Approved by: https://github.com/lezcano
2023-08-19 20:21:58 +00:00
f66d5dd788 SymIntify functorch vmap (#101409)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101409
Approved by: https://github.com/Skylion007, https://github.com/zou3519
2023-05-19 03:07:41 +00:00
0247ed27cc Apply Clang-Tidy readability-container-size-empty (#93236)
Not only is this change usually shorter and more readable, it also can yield better performance. size() is not always a constant time operation (such as on LinkedLists), but empty() always is.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/93236
Approved by: https://github.com/malfet
2023-01-29 23:28:19 +00:00
7aaad0b832 Rename flag that enables/disables _SingleLevelFunction for functorch (#92025)
functorch used to have a switch that enables/disables autograd.Function.
That switch now enables/disables torch.autograd.function._SingleLevelFunction, so
I've renamed it accordingly.

We could just delete the switch because users should not be directly
working with torch.autograd.function._SingleLevelFunction. However,
it was useful for debugging when something went wrong when I was
implementing the autograd.Function <> functorch interaction, so I want
to keep it around as a debugging tool for a while since the code is
already there.

Test Plan:
- updated tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92025
Approved by: https://github.com/soulitzer
2023-01-17 13:36:41 +00:00
76a3869fc6 Support functionalization on torch.cond (#89966)
This PR adds functionalization path for torch.cond. As it is the first pass, we only functionalize for very restrictive use cases. We explicitly restrict following:

- Output of each branch aliasing input
- In-place mutation on inputs given to each branch

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89966
Approved by: https://github.com/zou3519
2022-12-22 22:01:47 +00:00
ffa37c9fca Add VmapInterpreter.randomness (in pyfunctorch) provide it in info object (#90789)
This PR:
- adds VmapInterpreter.randomness. This returns the randomness option
the user provided in vmap(..., randomness=...)
- adds randomness in the info object passed to the vmap staticmethod of
autograd.Function. This is so that the user can handle random operations
on their own terms (if randomness="error", and if the autograd.Function
has random operations, then it is the user's responsiblity to raise an
error).

Test Plan:
- updated unittest
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90789
Approved by: https://github.com/samdow, https://github.com/soulitzer
2022-12-17 00:43:43 +00:00
abc54f9314 Revert "Revert "[functorch] Refactor life handle storage (#90317)"" (#90856)
Adds the fix for -Wsign-compare.

See original PR (https://github.com/pytorch/pytorch/pull/90317) for
commit message
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90856
Approved by: https://github.com/samdow
2022-12-15 16:03:16 +00:00
0cd69d7cda Revert "[functorch] Refactor life handle storage (#90317)"
This reverts commit 4d494986af5201a0c487a9b7f3c68cfa6c4e28d0.

Reverted https://github.com/pytorch/pytorch/pull/90317 on behalf of https://github.com/osalpekar due to Causing contbuilds to fail when pytorch is built with -Wsign-compare internally - details in [D42019543](https://www.internalfb.com/diff/D42019543)
2022-12-14 19:08:33 +00:00
4809e838c1 functorch.jvp support for autograd.Function (#90077)
This PR adds functorch.jvp support for autograd.Function. It does so by
adding a jvp rule for custom_function_call.

For a regular PyTorch operation (like at::sin), the VariableType kernel:
- re-dispatches to at::sin
- calls the jvp rule for at::sin

The jvp rule for custom_function_call does just that. It constructs a
new autograd.Function (because the above logic already exists). Inside
the forward, it re-dispatches to custom_function_call. In the jvp rule,
it just calls whatever the jvp rule is supposed to be.

Since this logic is really close to the custom_function_call_grad, I
just put them together.

Test Plan:
- added jvp rules to the autograd.Function in autograd_function_db
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90077
Approved by: https://github.com/albanD, https://github.com/soulitzer
2022-12-14 16:20:53 +00:00
4d494986af [functorch] Refactor life handle storage (#90317)
A "life handle" is a pointer-to-boolean that says whether or not a
TensorWrapper is alive. A TensorWrapper is alive if we are currently
inside of its corresponding transform. An Interpreter is alive if we are
currently inside of its corresponding transform. I.e., for vmap(f)(x),
the BatchedTensor(x, level=1) is alive inside of the execution of f; and
the corresponding VmapInterpreter is alive inside of f.

Previously, there was a global map of level to life handle. It is
possible to get into a state where we have multiple levels that refer to
different Interpreters (if the implementation of an operator calls into
functorch) and that messes up the global map.

This PR changes it so that
- every Interpreter holds a life handle that says if it is alive
- to construct a TensorWrapper, one must either (a) directly pass it a life
handle, or (b) one must create the TensorWrapper when the corresponding
Interpreter is on the stack (and we will automatically grab the life
handle by indexing into the DynamicLayerStack with the level)

(a) is more robust so I changed most of our C++ callsites to do that.
(b) feels a bit hacky to me, but it seems fine for now:
- It'll raise a nice error message if the interpreter isn't on the stack
- all of our Python callsites already follow this convention (we construct
TensorWrappers after pushing the Interpreter onto the stack).

The alternative to (b) is that we always do (a), which we can do in the
future if (b) runs us into any problems.

Test Plan:
- all functorch tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90317
Approved by: https://github.com/samdow
2022-12-13 14:45:18 +00:00
24c3ad7851 Move private forward grad mode helpers to torch.autograd.forward_ad (#90240)
Motivation
- These were previously defined in functorch. They are not
functorch-specific, so I'm moving them to torch.autograd.forward_ad and
the autograd python bindings.
- I need this to avoid some of my cyclic import problems.

Should these be public APIs? Probably. Though this needs discussion, so
punting it to the future.

Test Plan:
- moved the tests of these from test/functorch/test_eager_transforms.py
to test/test_autograd.py
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90240
Approved by: https://github.com/soulitzer
2022-12-13 14:14:02 +00:00
3049d99027 autograd.Function supports vmap staticmethod (#90037)
This PR adds a `vmap` staticmethod to autograd.Function and a
corresponding vmap kernel for custom_function_call. These two items mean
that autograd.Function with a vmap staticmethod can be used with vmap.

```py
class NumpyMul(torch.autograd.Function)
    staticmethod
    def forward(x, y):
        return torch.tensor(to_numpy(x) * to_numpy(y), device=x.device)

    staticmethod
    def setup_context(ctx, outputs, x, y):
        ctx.save_for_backward(x, y)

    staticmethod
    def backward(ctx, grad_output):
        x, y = ctx.saved_tensors
        gx = None
        if isinstance(x, torch.Tensor) and x.requires_grad:
            gx = NumpyMul.apply(grad_output, y)
        gy = None
        if isinstance(y, torch.Tensor) and y.requires_grad:
            gy = NumpyMul.apply(grad_output, x)
        return gx, gy

    staticmethod
    def vmap(info, in_dims, x, y):
        x_bdim, y_bdim = in_dims
        x = x.movedim(x_bdim, -1) if x_bdim else x.unsqueeze(-1)
        y = y.movedim(y_bdim, -1) if y_bdim else y.unsqueeze(-1)
        result = NumpyMul.apply(x, y)
        result = result.movedim(-1, 0)
        return result, 0
```

API Spec
- the staticmethod takes two arguments (info, in_dims) as well as the
unexpanded inputs (x, y).
- If we think about it as `vmap(info, in_dims, *args)`, `in_dims` is a
pytree with the same tree structure as args. It has None if the arg is
not being vmapped over and an integer vmapped dimension index if it is.
- `info` is an object with metadata about the vmap. It currently has one
field, `info.batch_size`. In the future we can extend this by adding
things like the randomness information.
- If there is a single vmap going on, (x, y) are NOT BatchedTensors,
they've already been unpacked.
- We expect the user to return a `(outputs, out_dims)` tuple. `out_dims`
must "broadcast" to the same pytree structure as `outputs`.

Semantics
- vmap(NumpyMul.apply)(x) will apply the vmap staticmethod if there is
one and will never actually run NumpyMul.forward.
- In order for the autograd.Function to support nested vmap (e.g.,
`vmap(vmap(NumpyMul.apply))(x)`, then the vmap staticmethod must call
into operations that vmap understands (i.e. PyTorch operators or more
autograd.Function).

At a high level, this PR:
- adds a vmap rule for custom_function_call

Testing
- Added some tests for in_dims and info
- Added vmap staticmethod to most of the autograd.Function in
autograd_function_db and sent them through functorch's vmap-related
OpInfo tests

Future
- Better error messages if the user gets the return contract wrong. I
didn't include them in this PR because it might involve a refactor of
some of the existing code in functorch/_src/vmap.py that will add
~200LOC to the PR, but LMK if you'd prefer it here.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90037
Approved by: https://github.com/samdow, https://github.com/soulitzer
2022-12-13 14:14:02 +00:00
7342251281 functorch.grad support for autograd.Function (#89860)
Happy to split this PR more if it helps.

This PR adds functorch.grad support for autograd.Function. There's a lot
going on; here is the high level picture and there are more details as
comments in the code.

Mechanism (PyOperator)
- Somehow, autograd.Function needs to dispatch with functorch. This is
necessary because every layer of functorch needs to see the
autograd.Function; grad layers need to preserve the backward pass.
- The mechanism for this is via PyOperator. If functorch transforms are
active, then we wrap the autograd.Function in a `custom_function_call`
PyOperator where we are able to define various rules for functorch
transforms.
- `custom_function_call` has a rule for the functorch grad transform.

autograd.Function changes
- I needed to make some changes to autograd.Function to make this work.
- First, this PR splits autograd.Function into a _SingleLevelFunction
(that works with a single level of functorch transform) and
autograd.Function (which works with multiple levels). This is necessary
because functorch's grad rule needs some way of specifying a backward
pass for that level only.
- This PR changes autograd.Function's apply to eitehr call
`custom_function_call` (if functorch is active) or super().apply (if
functorch isn't active).

Testing
- Most of this PR is just testing. It creates an autograd.Function
OpInfo database that then gets passed to the functorch grad-based tests
(grad, vjp, vjpvjp).
- Since functorch transform tests are autogenerated from OpInfo tests,
this is the easiest way to test various autograd.Function with
functorch.

Future
- jvp and vmap support coming next
- better error message (functorch only supports autograd.Function that
have the optional setup_context staticmethod)
- documentation to come when we remove the feature flag

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89860
Approved by: https://github.com/soulitzer
2022-12-08 19:31:04 +00:00
3bc327993f PyDispatcher integration with functorch (#88785)
This PR teaches PyDispatcher and PyOperator about functorch transforms.
It is important that PyDispatcher/PyOperator dispatch with functorch
transforms, because this is our plan for higher-order operators
(operators that accept functions as arguments). Examples of these
include:
- functorch transforms over the existing cond operator (control flow)
- autograd.Function support for functorch (which I am working towards),
- AOTDispatcher (should be a higher order operator)

Concretely, the problem with teaching PyDispatcher/PyOperator about
functorch is that the stack-based dispatching logic (DynamicLayerStack)
is hidden inside the fallbacks for two dispatch keys
(DynamicLayer{Front, Back}). PyDispatcher doesn't know about C++ boxed
fallbacks, our plan on record for that is that we need to reimplement
all of them in Python (but can call helper functions in C++ to make our
lives easier).

Instead of exposing all of what DynamicLayer{Front, Back} do to python,
this PR takes the approach of re-implementing part of the stack-based
dispatching in Python. The motivation is that this is more sane and
follows what the "ideal" implementation of functorch would have been:
- each transform should be a "mode"
- there should be no TLS dispatch key set hackery. functorch needs to do
this hackery today to re-use VariableType implementations.

This PR:
- exposes the DynamicLayerStack to Python
- The DynamicLayerStack is a stack of Interpreters.
These get exposed to Python as well.
- Interpreters can run operations (Interpreter.process) or lower them to
the next interpreter in the stack (Interpreter.lower)
- To use a PyOperator with functorch transforms, a developer needs to
register a rule for each transform (vmap, grad, jvp, ...).
- The PyOperator API is NOT user-facing. Things like autograd.Function
support for functorch will end up going through the autograd.Function
API.

Question for reviewers:
- Does this design make sense?
- I'm trying to split up the "functorch support for autograd.Function"
work into logical pieces. Would it be better if I didn't? (the full
thing is a bit long - 1000-2000 LOC).

Test Plan:
- new tests that construct PyOperator and compose them with functorch
transforms
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88785
Approved by: https://github.com/samdow, https://github.com/soulitzer
2022-11-16 00:46:59 +00:00
2268a3215c [functorch] add switch to enable autograd.Function (#88784)
This is mostly a debug or "if you know what you're doing" switch for
now. It is not public API.

Test Plan:
- new tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88784
Approved by: https://github.com/samdow, https://github.com/soulitzer
2022-11-16 00:46:59 +00:00