Commit Graph

81 Commits

Author SHA1 Message Date
5c3fe9fb30 Revert "Do not decompose in functionalization/proxy tensor if autograd wouldn't have decomposed (#164939)"
This reverts commit a6fa4f9c283971c0fb6f60a89674a1f35370ac79.

Reverted https://github.com/pytorch/pytorch/pull/164939 on behalf of https://github.com/izaitsevfb due to introduces numeric issues internally, see [D84326613](https://www.internalfb.com/diff/D84326613) ([comment](https://github.com/pytorch/pytorch/pull/164939#issuecomment-3392203314))
2025-10-10 20:21:12 +00:00
a6fa4f9c28 Do not decompose in functionalization/proxy tensor if autograd wouldn't have decomposed (#164939)
This fixes AOTAutograd rms_norm not being bitwise equivalent to
eager, because it avoids a decomposition.  You can force the
decomposition by having the decomposition in the dispatch table,
but if eager mode wouldn't have decomposed (because it went to the fused
one), we now default to preserving the fused call by default.

This largely reverts https://github.com/pytorch/pytorch/pull/103275/ for view ops. This means that in inference mode we could hit the wrong C++ kernel; if this occurs we should just SymInt'ify the C++ kernel.

Another neat side effect of this change is that Inductor's generated kernels for rms_norm now have rms_norm in their name.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164939
Approved by: https://github.com/bdhirsh
2025-10-10 00:15:00 +00:00
06d86e58d0 Revert "Do not decompose in functionalization/proxy tensor if autograd wouldn't have decomposed (#164939)"
This reverts commit d40a9bfb8da0dc1ac1e6e56b33a25979112874de.

Reverted https://github.com/pytorch/pytorch/pull/164939 on behalf of https://github.com/pytorch-auto-revert due to Reverted automatically by pytorch's autorevert, to avoid this behaviour add the tag autorevert: disable ([comment](https://github.com/pytorch/pytorch/pull/164939#issuecomment-3385056722))
2025-10-09 09:50:59 +00:00
d40a9bfb8d Do not decompose in functionalization/proxy tensor if autograd wouldn't have decomposed (#164939)
This fixes AOTAutograd rms_norm not being bitwise equivalent to
eager, because it avoids a decomposition.  You can force the
decomposition by having the decomposition in the dispatch table,
but if eager mode wouldn't have decomposed (because it went to the fused
one), we now default to preserving the fused call by default.

This largely reverts https://github.com/pytorch/pytorch/pull/103275/ for view ops. This means that in inference mode we could hit the wrong C++ kernel; if this occurs we should just SymInt'ify the C++ kernel.

Another neat side effect of this change is that Inductor's generated kernels for rms_norm now have rms_norm in their name.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164939
Approved by: https://github.com/bdhirsh
ghstack dependencies: #164573
2025-10-09 04:49:44 +00:00
a43c4c3972 [5/N] Apply ruff UP035 rule (#164423)
Continued code migration to enable ruff `UP035`. Most changes are about moving `Callable` from `typing` to `from collections.abc`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164423
Approved by: https://github.com/ezyang
2025-10-02 07:31:11 +00:00
7f14b42adf [BE][2/16] fix typos in torch/ (torch/_*/) (#156312)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/156312
Approved by: https://github.com/albanD
2025-07-12 05:47:06 +00:00
e15f4248ad Revert "[BE][2/16] fix typos in torch/ (torch/_*/) (#156312)"
This reverts commit 7a92b5119654c07d15f5c0818e6ae804b01e836c.

Reverted https://github.com/pytorch/pytorch/pull/156312 on behalf of https://github.com/XuehaiPan due to landrace ([comment](https://github.com/pytorch/pytorch/pull/156312#issuecomment-3064672250))
2025-07-12 04:40:52 +00:00
7a92b51196 [BE][2/16] fix typos in torch/ (torch/_*/) (#156312)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/156312
Approved by: https://github.com/albanD
2025-07-12 01:47:22 +00:00
e95e8eed0a mypy 1.16.0 (#155821)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155821
Approved by: https://github.com/ezyang, https://github.com/zou3519
2025-06-14 18:18:43 +00:00
0a7eef140b Add torch.Tensor._make_wrapper_subclass to torch/_C/__init__.pyi (#154022)
Fixes #153790

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154022
Approved by: https://github.com/Skylion007
2025-05-27 14:10:00 +00:00
90ddb33141 [export] specialize for aten.to (#149235)
Changes decomposition behavior of `aten.to` to respect the aliasing/non-aliasing behavior in eager, and to specialize to the input/conversion dtype & device.

Before change: we always decompose `aten.to` into `_to_copy`, regardless of aliasing behavior. This leads us to ban mutations on the result of `_to_copy` when aliased, since we can't guarantee correct program semantics. This meant users had to explicitly call `.clone()` before mutating. In the special cases where we don’t ban mutations (e.g. dtype conversion), we add runtime assertions on the input & conversion dtype/devices in the decomposed program (see https://github.com/pytorch/pytorch/pull/142420).

After change: we decompose to the aliasing/non-aliasing behavior that matches eager, allowing mutations in all cases. We also add dtype/device assertions for all `aten.to` ops, starting in the pre-dispatch graph, basically specializing the program to the dtype/devices.

Differential Revision: D71229547

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149235
Approved by: https://github.com/tugsbayasgalan
2025-04-03 05:20:10 +00:00
fb566c5aea Fix auto_functionalize x inference_mode (#147925)
Fixes #147924

We were using the wrong FunctionalTensorMode to construct
FunctionalTensors. FunctionalTensors modify the FunctionalTensorMode on
construction, so that led to the wrong FunctionalTensorMode being
modified. This PR threads the FunctionalTensorMode through correctly.

Test Plan:
- new test

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147925
Approved by: https://github.com/bdhirsh
2025-02-26 18:05:30 +00:00
db4ce78d46 PEP585: More UP006 fixes (#146392)
This should be the final PR before we can enable RUFF UP006.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146392
Approved by: https://github.com/justinchuby, https://github.com/albanD, https://github.com/Skylion007
2025-02-20 06:18:13 +00:00
f40e013787 Fix aten.to when input is a tensor constant (#146220)
Summary:
Fix aten.to when input is a tensor constant.

In this case, `args_unwrapped` could just be a constant, so not a functional tensor.

Test Plan:
```
buck2 run 'fbcode//mode/dev-nosan' fbcode//caffe2/test:test_export  -- -r  tensor_constant_aten_to
```

Differential Revision: D68984244

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146220
Approved by: https://github.com/JacobSzwejbka
2025-02-01 11:07:33 +00:00
805c4b597a PEP585 update - torch/_higher_order_ops torch/_subclasses torch/backends torch/compiler torch/cuda torch/masked torch/mtia torch/nested (#145202)
See #145101 for details.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145202
Approved by: https://github.com/bobrenjc93
2025-01-20 22:37:26 +00:00
0e1675a89b Relax aten.to restriction (#142420)
Summary: if we have a.to(b), and b has a different dtype with a, then it must be a copy. In this case, we do not need to freeze the tensor. Instead, we use torch.ops.aten._assert_tensor_metadata.default to ensure that a must not have the same dtype as b.

Fixes https://github.com/pytorch/pytorch/issues/139718

Update executorch pin to include https://github.com/pytorch/executorch/pull/7277.

Test Plan:
```
buck2 run 'fbcode//mode/dev-nosan' fbcode//caffe2/test:test_export  -- -r  test_float_conversion
buck2 run 'fbcode//mode/dev-nosan' fbcode//caffe2/test:test_export  -- -r  test_device_to_mutation_float
```

Differential Revision: D66988295

Pull Request resolved: https://github.com/pytorch/pytorch/pull/142420
Approved by: https://github.com/bdhirsh
2025-01-08 18:11:31 +00:00
11c786dcb5 [BE] Make maybe_aliasing_or_mutating proper tag (#131990)
For better tracking, we need to make maybe aliasing/mutating ops with proper tag. We need to special case native_batch_norm because it is not a CIA but has a wrong schema. I guess native_batch_norm will be removed at some point, so until then we just keep it around.

D60347117
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131990
Approved by: https://github.com/bdhirsh
2024-11-24 00:12:49 +00:00
05b6200ccd Do not compute base in export mode (#137760)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137760
Approved by: https://github.com/zou3519, https://github.com/bdhirsh
2024-10-15 19:04:42 +00:00
44653895cc override bool(), is_nonzero for real tensor tracing (#136788)
Fixes bool() and is_nonzero() calls for real tensor tracing, non-strict export

Differential Revision: D63482693

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136788
Approved by: https://github.com/ezyang
2024-10-15 17:13:44 +00:00
9409274bc1 Fix bug in functional tensor decomp (#136600)
Summary: Previously we had a very bad bug where we don't allow any decomp on CIA. This never mattered before because we never had to actually push CIA decomp to Python key level in export.

Test Plan: CI

Differential Revision: D63363749

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136600
Approved by: https://github.com/bdhirsh
2024-09-25 17:37:50 +00:00
31715be72a [BE]: Update mypy to 1.11.2 (#133816)
Updates mypy to 1.11.1 to improve type inference

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133816
Approved by: https://github.com/ezyang
2024-09-16 19:44:11 +00:00
3117f2cf67 Revert "[BE]: Update mypy to 1.11.2 (#133816)"
This reverts commit 55299cfc223fa838aadd8d6d6fa3ed541fa5acd1.

Reverted https://github.com/pytorch/pytorch/pull/133816 on behalf of https://github.com/jeanschmidt due to seems to have broken https://github.com/pytorch/pytorch/actions/runs/10865710499/job/30155699792 on main ([comment](https://github.com/pytorch/pytorch/pull/133816#issuecomment-2352377684))
2024-09-16 09:11:16 +00:00
55299cfc22 [BE]: Update mypy to 1.11.2 (#133816)
Updates mypy to 1.11.1 to improve type inference

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133816
Approved by: https://github.com/ezyang
2024-09-14 21:40:36 +00:00
ba6e0f31ab Remove cycle dependency by localizing the import. (#135926)
Summary:
Since https://www.internalfb.com/diff/D62215095 landed there has been many silence errors due to the dependency between functional_tensor and config.

```
 File "/tmp/torch_deploy_zip5YRJC1/torch_python_modules.zip/torch/export/__init__.py", line 64, in <module>
  File "/tmp/torch_deploy_zip5YRJC1/torch_python_modules.zip/torch/export/dynamic_shapes.py", line 23, in <module>
  File "/tmp/torch_deploy_zip5YRJC1/torch_python_modules.zip/torch/export/exported_program.py", line 26, in <module>
  File "/tmp/torch_deploy_zip5YRJC1/torch_python_modules.zip/torch/_higher_order_ops/__init__.py", line 1, in <module>
  File "/tmp/torch_deploy_zip5YRJC1/torch_python_modules.zip/torch/_higher_order_ops/cond.py", line 6, in <module>
  File "/tmp/torch_deploy_zip5YRJC1/torch_python_modules.zip/torch/_subclasses/functional_tensor.py", line 9, in <module>
  File "/tmp/torch_deploy_zip5YRJC1/torch_python_modules.zip/torch/_inductor/config.py", line 44, in <module>
```

https://fburl.com/logarithm/ol5kx0ee
complaining about a cycle dependency

this fix it.

Test Plan: buck test multipy/runtime:test_deploy_embedded_cuda_interp_without_cuda_available -- --run-disabled TorchpyTest.AcquireMultipleSessionsInDifferentPackages

Reviewed By: aorenste

Differential Revision: D62616765

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135926
Approved by: https://github.com/aorenste, https://github.com/oulgen, https://github.com/Skylion007
2024-09-13 15:05:41 +00:00
66dd4577b1 Track base of FunctionalTensor in inference mode. (#135141)
The idea behind the tracking is the following, whenever we see a tensor if the tensors is a root tensors (does not have any view metas ) when we consider is as the base of the all the tensors that shares its storage.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135141
Approved by: https://github.com/zou3519
2024-09-06 00:10:25 +00:00
c8ab9b06a2 Redesign custom op functionlaization for better re-inplace (#134409)
- The new implementation (auto_functionalized_v2) is enabled by default but can be disable
 using an inductor flag.
- In export mode the old implementation is used.

**Motiviation**
Previous functionalization fails to re-inplace arguments when they are view over other tensors.
see issue https://github.com/pytorch/pytorch/issues/131192
The new functionalization is easier to re-inplace for views.

**A) Functionalizations pass**
consider a program:

```

func(t)
    x = t[0]
    y = t[1]
    foo(x, y) # custom operator with x, y mutable
    return (x, y, t)
```

- To functionalize `foo` we generate a function that operates on the base tensors of the inputs;  (x.base() and y.base())
and record how to regenerates the views out of the base for argument x by recording ```ViewInfo=(x.base(), x.size(), x.stride, x,storage_offset())```

- Due to some limitations on the torch.export arguments format, we have to generate alot of arguments, but this is something we can simplify in the future, for the example above we get the following function.

   ```
   auto_functionalized = torch.ops.higher_order.auto_functionalized(torch.ops.mylib.foo.default,
     _x_base_index = 0, _x_size = (), _x_stride = (), _x_storage_offset = 0 ,
     _y_base_index = 0,_y_size = (), _y_stride = (), _y_storage_offset = 1   ,
     _all_bases = [arg0_1])
   ```
 -  In the code above:
        - _all_bases[t]: refers to a unique set of bases for all foo arguments.
        - for each argument x we have _x_base_index, _x_size, _x_stride, _x_storage_offset that can be used to (1)  regenerate x from _all_bases[_x_base_index] or a copy of a the base.

-  the output of auto_functionalized is foo output , followed by x tensors one for each base in  _all_bases, that is a copy of the base tensor after observing the mutations of the all the arguments that are views of that base.

-  for each use of a base in _all_bases or a view of it , that are after the call to foo, replace it with a view of the new output

 for the function above after functionalization we get :
 ```
    def forward(self, arg0_1: "f32[2][1]cpu"):
        auto_functionalized = torch.ops.higher_order.auto_functionalized(torch.ops.mylib.foo.default, _x_base_index = 0, _x_size = (), _x_stride = (), _x_storage_offset = 0, _y_base_index = 0, _y_size = (), _y_stride = (), _y_storage_offset = 1, _all_bases = [arg0_1])
        getitem_1: "f32[2][1]cpu" = auto_functionalized[1];  auto_functionalized = None
        copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, getitem_1);  arg0_1 = copy_ = None

        # No stacktrace found for following nodes
        select_2: "f32[][]cpu" = torch.ops.aten.select.int(getitem_1, 0, 0)
        select_3: "f32[][]cpu" = torch.ops.aten.select.int(getitem_1, 0, 1);  getitem_1 = None
        return (select_2, select_3)
```

**B) Semantics of  auto_functionalize**
The new semantics of auto_functionalize is as the following:
1. For each base in all_bases, copy the base and create all_bases copies. (if a base is inplaced we do not need to copy it)
2. For each arg, regenerate the arg from the copy of its base using the view information above.
3. return the original foo output followed by the new bases.

**C) Re-inplace pass**
since auto_functionalize not copy the bases, what we actually inplace is the bases.
 (run just like before but on the beses instead of args).

1. For each base b in _all_bases check if there is any use of base (or its aliases/views) after auto_functionalize (before its overwritten with a copy) if there is not any, then inplace it (avoid copying it in step 1 above).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134409
Approved by: https://github.com/zou3519
2024-09-04 17:08:58 +00:00
92e38a476f preserve aten::to device in export training (#134622)
Summary:
With training IR, we cannot rely on trapping `to()` in `FunctionalTensor` because the regular decomposition kicks it first, and that can cause it to be optimized away.

So instead we preserve it until we functionalize, and then replace it explicitly with `_to_copy()`.

Test Plan: expected test failures go away

Differential Revision: D61883878

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134622
Approved by: https://github.com/zhxchen17, https://github.com/tugsbayasgalan
2024-08-29 14:53:30 +00:00
1a0d00f1f4 [traced-graph][sparse] enable to_dense() for compressed (#133371)
Fixes https://github.com/pytorch/pytorch/issues/133174

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133371
Approved by: https://github.com/ezyang
2024-08-24 20:33:23 +00:00
8ae4f82243 [aotd] Support HOP effects in backward (#132638)
Support of effectful operations in backward:

1/ AOTD collects metadata from forward fn only, so we can have usage of effectful ops in backward, that were not used in forward => Allowing tokens discovery during joint function .

FunctionalTensorMode holds _tokens, in Joint function after tracing forward we memoize _tokens as `_tokens_forward_output`.

2/ Tokens are added as primals inputs (forward) in EffectTokensWrapper.
Tokens that will be used in backward are in partitioner saved values. We do not have control on which positions they are saved in forward outputs.

2/ If new tokens discovered in backward after tracing joint_fn, the result graph will be manually added in the end of primals.
_aot_autograd/utils.py

3/ All effectful ops during backward are marked with 'must_be_in_backward' partitioner_tag, to prevent partiitoner to place them in forward.

For that functional_tensor_mode got new optional state `self._effects_partitioner_tag` for effectful ops, to set after tracing forward.

There are additional changes in partitioner to improve functionality of 'must_be_in_backward'

4/ Unlift tokens now should run for both forward and backward.
- As saved for backward tokens are placed on non static places - we identify input and output tokens to erase, by input and output of `with_effects` operation
- In forward we can have input tokens, discovered in backward, that are not used in with_effects ops in forward, but saved for backward. We identify them by position in forward inputs.

5/ Adding aot debug logging for graphs before unlifting and before adding additional primal for backward tokens.

Tests:
```
python test/higher_order_ops/test_with_effects.py
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132638
Approved by: https://github.com/bdhirsh
2024-08-23 15:30:58 +00:00
4af4910b1a Reland "Construct NJT without graph breaks" (#133196)
This reverts commit 154d40ca488e6979ce9c2de89d8a35b53129ebea.

and adds changes from https://github.com/pytorch/pytorch/pull/133061

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133196
Approved by: https://github.com/ezyang
ghstack dependencies: #133145
2024-08-14 01:11:13 +00:00
05de2b2d0f Revert "Construct NJT without graph breaks" (#133145)
This reverts commit 911154271309667b55dfb963ec6384bd0048019b.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133145
Approved by: https://github.com/YuqingJ
2024-08-10 03:11:16 +00:00
f50621989b Construct NJT without graph breaks (#130292)
Combines contributions from https://github.com/pytorch/pytorch/pull/130505

Some context can be found in this large comment block:

a5b64d39fd/test/dynamo/test_subclasses.py (L1667-L1681)

Changes in this PR
- For each tensor fakified, check the nested int registry in eager, and eagerly symbolicize if that tensor has already been associated with nested int in eager.
- Adds a separate counter stored on FakeTensorMode as a fake analog to _tensor_id_counter (which keeps track of unique tensors). This counter is initialized to the global eager tensor id counter upon creation of the FakeTensorMode, and needs to be reset when the same FakeTensorMode is reused to trace again (in this PR, we piggyback on the epoch incrementing logic).
- (refactor) Today, we store FakeTensor -> symbolic nested int in the global registry. With this PR, symbolic nested int is stored directly on the FakeTensor. (Eager still caches nested int in the registry, though we should avoid this at some point.)

Basically unchanged, but worth noting:
- `__tensor_unflatten__` is still responsible for determining whether we should cache for now. The logic is somewhat simplified.
- to_copy is still using the trick of updating two different tensors in the registry to point to the same nested int. This is kind of broken, but we try to leave it as is, and plan a better fix with the UnionFind stack.

Differential Revision: [D60406772](https://our.internmc.facebook.com/intern/diff/D60406772)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130292
Approved by: https://github.com/bdhirsh
ghstack dependencies: #131916, #131803
2024-08-06 17:03:39 +00:00
a8490a0762 [traced-graph][sparse] propagate sparsity in fx graph (#131920)
This PR proceeds with implementing the feature request #117188 by generalizing more cases that already work with COO to work with the compressed sparse formats as well.

Feature request:
https://github.com/pytorch/pytorch/issues/117188

Rebranch of older PRs (for history):
https://github.com/pytorch/pytorch/pull/131474
https://github.com/pytorch/pytorch/pull/128549

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131920
Approved by: https://github.com/ezyang
2024-08-05 15:49:53 +00:00
997f64af38 fastpath FunctionalTensor sizes() (#132084)
Another attempt at fast-pathing sizes() in FunctionalTensor, since it appears to improve compile time perf by up to ~10%. See the investigation from https://github.com/pytorch/pytorch/issues/125977#issuecomment-2122915602.

After looking at some failing tests locally I realized that we need to manually handle metadata mutations now, since the previous "smarter" size dispatch was handling the updates

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132084
Approved by: https://github.com/ezyang
2024-08-01 21:09:22 +00:00
93979e7063 Skip frame if torch dispatch mode enabled (#131828)
Fixes https://github.com/pytorch/pytorch/issues/105929

We now skip frames if a dispatch mode is enabled.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131828
Approved by: https://github.com/bdhirsh, https://github.com/anijain2305
2024-08-01 19:06:20 +00:00
e7eeee473c [BE][Easy][14/19] enforce style for empty lines in import segments in torch/_[a-c]*/ and torch/_[e-h]*/ and torch/_[j-z]*/ (#129765)
See https://github.com/pytorch/pytorch/pull/129751#issue-2380881501. Most changes are auto-generated by linter.

You can review these PRs via:

```bash
git diff --ignore-all-space --ignore-blank-lines HEAD~1
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129765
Approved by: https://github.com/ezyang
2024-07-31 10:42:50 +00:00
f093cd4086 Fix custom ops warning during export (#130623)
Fixes https://github.com/pytorch/pytorch/issues/130588

The problem was we were warning on all custom ops, not just ones marked
as CompositeImplicitAutograd. This PR changes the warning to just warn
on CompositeImplicitAutograd ops.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130623
Approved by: https://github.com/williamwen42
2024-07-12 21:34:29 +00:00
e019540c9e Revert "Fix the SDPA AOT export issue (#130164)"
This reverts commit 1927c406844affbfe3496d5cbc31d4ebe11c8bfb.

Reverted https://github.com/pytorch/pytorch/pull/130164 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it is breaking ExecuTorch tests in trunk 1927c40684 ([comment](https://github.com/pytorch/pytorch/pull/130164#issuecomment-2211667777))
2024-07-06 05:59:49 +00:00
1927c40684 Fix the SDPA AOT export issue (#130164)
Summary:
## Context
TL;DR: aot_export failed for SDPA memory efficient backend when using `inference_mode`

The CMF AOTI lowering started to fail on the trunk. We have the script (https://fburl.com/code/kfk64i5s) to reproduce the issue quickly (log: P1469307638). By bisecting the stack, we found the issue starting from the D58701607

## Root Cause
In the `inference_mode()`,
the `aten::scaled_dot_product_attention` was not decomposed before the `functionalization` and the op it-self was an out-place op, so the `functionalization` doesn't make change and then was decomposed into `masked_fill_.`, then decomposed to the `copy_`
So it's `aten::sdpa` --- (functionalization) ---> `aten::sdpa` --- (decompose) ---> `masked_fill_` --- (decompose) ---> `copy_` ---> failure

In the `torch.no_grad()`,
`aten::sdpa` was decomposed before `functionalization`, so the story is
`aten::sdpa` --- (decompose) ---> `masked_fill_` --- (functionalization) ---> `masked_fill` --- (decompose) ---> `out-place ops` ---> good

## How to fix
Long-term:
The issue was tracked in the ticket (https://github.com/pytorch/pytorch/issues/129418). The long-term fix could be we do one more round of `functionalization` after the `decompose`, like

`aten::sdpa` --- (functionalization) ---> `aten::sdpa` --- (decompose) ---> `masked_fill_` --- (functionalization) ---> `masked_fill` ---> good

Short-term:
It would be a big change I guess. To unblock the production use-case, I marked the `aten::sdpa` should be decomposed in this diff

Test Plan:
local repro works now

buck run mode/opt scripts/sijiac/prototypes:sdpa_aoti

Differential Revision: D59385876

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130164
Approved by: https://github.com/zou3519
2024-07-06 00:57:47 +00:00
90f6043368 Don't decompose functional composite ops in export inference IR (#128077)
Recently we decided to split export IR into two different IRs (training vs inference). In the inference IR, one major change we decided to introduce was we wanted to keep the composite ops that user specified in the IR. This PR does that by overriding the CompositeImplicitAutograd decomp in export inference path.

Differential Revision: [D58701607](https://our.internmc.facebook.com/intern/diff/D58701607)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128077
Approved by: https://github.com/bdhirsh
2024-06-26 23:07:55 +00:00
ba19ed9a1a FunctionalTensor: dispatch metadata directly to inner tensor (#127927)
Fixes https://github.com/pytorch/pytorch/issues/127374

The error in the linked repro is:
```
AssertionError: Please convert all Tensors to FakeTensors first or instantiate FakeTensorMode with 'allow_non_fake_inputs'. Found in aten.sym_storage_offset.default(_to_functional_tensor(FakeTensor(..., device='cuda:0', size=(16, 4), dtype=torch.uint8),
       device='cuda:0'))
```

Where we hit FakeTensor.__torch_dispatch__, but our input is a C++ `FunctionalTensorWrapper`.

What should actually have happened is that the call to `aten.sym_storage_offset` hits the `Functionalize` dispatch key, which should remove the `FunctionalTensorWrapper`  and redispatch. I spent some time debugging and haven't actually figured out why this isn't happening. Instead, this PR just skips that step completely, and asks `FunctionalTensor` to directly unwrap the C++ `FunctionalTensorWrapper` when querying tensor metadata.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127927
Approved by: https://github.com/tugsbayasgalan
2024-06-15 00:08:44 +00:00
afe15d2d2f Flip default value for mypy disallow_untyped_defs [3/11] (#127840)
See #127836 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127840
Approved by: https://github.com/oulgen
2024-06-08 18:28:01 +00:00
f9dda37a74 [export] Cover more cases to copy tensor conversions. (#125628)
Summary:
Previously we tried to convert all .to() calls to to_copy in the graph, now some user reports that other methods like .float() is not covered: https://github.com/pytorch/PiPPy/issues/1104#issuecomment-2093352734

I think fundemantally .float() should look similar to .to() in export and this diff tries to expand the coverage of the tensor conversion methods here.

Test Plan: buck run mode/opt caffe2/test:test_export -- -r float_conversion

Differential Revision: D56951634

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125628
Approved by: https://github.com/tugsbayasgalan
2024-05-15 15:50:21 +00:00
d7fe3c4123 [RELAND] Switch default behavoir of export IR to be predispatch (#125860)
This PR switches export IR from aot-dispatch to pre-dispatch IR.

**What is pre-dispatch IR and why should you care?**

Currently the default IR returned by torch.export can contain only functional ATen operators after ALL pytorch dispatcher decompositions (for example, CompositeImplicitAutograd) run.

In contrast, pre-dispatch IR refers to an IR that can contain all functional ATen operators (i.e., not just from the core subset), before any decomposition happens, as well as operators that manipulate autograd state. Pre-dispatch IR closely resembles eager PyTorch computation, but is still functional and serializable by torch.export. As a result:

You can train the pre-dispatch IR in eager mode as the IR contains necessary information for the autograd engine to automatically generate a backward graph.
You can write sound graph transformations more easily as the IR is functional.
Since it is an ATen IR, it is still normalized. For example, torch.add has multiple overloads, but aten.add.Tensor is unique in this IR.
If you want to get the core aten IR out of torch.export, you will need to:
```
ep = torch.export.export(M(), inputs)
ep_for_core_aten = ep.run_decompositions()
```

Differential Revision: [D57172986](https://our.internmc.facebook.com/intern/diff/D57172986)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125860
Approved by: https://github.com/zhxchen17
2024-05-10 17:36:53 +00:00
02ed2992d9 [export] Capture tensor.to() under export. (#123732)
Summary: We use to skip tensor.to() during tracing when the device is the same. This will bring some performance improvement in eager but making graph capture losing the semantics from original model. In this diff, we add an additional condition to skip the fast path when we don't have actual data inside a tensor, which is the case when we're using FakeTensor / FunctionalTensor to trace the model. This won't have perf impact on previous eager models while making sure we can capture the _to_copy() node in the graph.

Test Plan: buck test mode/opt caffe2/test:test_export -- -r device_to

Differential Revision: D55969674

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123732
Approved by: https://github.com/angelayi, https://github.com/tugsbayasgalan
2024-04-24 23:12:19 +00:00
674e15ae07 Back out "Switch to predispatch" (#124860)
Summary:
Original commit changeset: 1f155b3a0bfc

Original Phabricator Diff: D56273267

Test Plan: CI

Differential Revision: D56526505

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124860
Approved by: https://github.com/angelayi
2024-04-24 17:28:33 +00:00
c933af2709 Switch to predispatch (#123573)
This PR switches export IR from aot-dispatch to pre-dispatch IR.

**What is pre-dispatch IR and why should you care?**

Currently the default IR returned by torch.export can contain only functional ATen operators after ALL pytorch dispatcher decompositions (for example, CompositeImplicitAutograd) run.

In contrast, pre-dispatch IR refers to an IR that can contain all functional ATen operators (i.e., not just from the core subset), before any decomposition happens, as well as operators that manipulate autograd state. Pre-dispatch IR closely resembles eager PyTorch computation, but is still functional and serializable by torch.export. As a result:
- You can train the pre-dispatch IR in eager mode as the IR contains necessary information for the autograd engine to automatically generate a backward graph.
- You can write sound graph transformations more easily as the IR is functional.
- Since it is an ATen IR, it is still normalized. For example, torch.add has multiple overloads, but aten.add.Tensor is unique in this IR.

If you want to get the core aten IR out of `torch.export`, you will need to:
```
ep = torch.export.export(M(), inputs)
ep_for_core_aten = ep.run_decompositions()
```

Differential Revision: [D56273267](https://our.internmc.facebook.com/intern/diff/D56273267)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123573
Approved by: https://github.com/gmagogsfm
2024-04-24 00:51:09 +00:00
f4e2a226aa ScoreMod API (#121845)
# Summary

This PR adds a new higher-order_op: `templated_attention`.  This op is designed to extend the functionality of torch.nn.fucntional.scaled_dot_product_attention.  PyTorch has efficient pre-written fused-attention kernels. However, users want to modify how scores are computed (a substep inside attention) -- this traditionally requires the user to write their own attention kernel. One such modification to attention scores that is not currently supported by the top level SDPA op is:[ Attention with Linear Biases (ALiBi](https://arxiv.org/abs/2108.12409)).

This higher-order op will instead accept a callable( 'score_mod') function that is through torch.compile will be used to create an efficient attention kernel instantiation.

### Details

This HOP utilizes the existing fx and HOP infra to capture and convert the User `score-mod` function and convert to an FX graph module. Inductor then consumes this HOP that has a `ir.Subgraph` input. It will inline this lowered subgraph into a triton kernel which performs fused attention with the modification to the scores matrix inlined.

### API

The API for a score_mod function should be as follows:

```Python
def score_mod(score: torch.Tensor, batch: torch.Tensor, head: torch.Tensor, token_1: torch.Tensor, token_kv: torch.Tensor) -> torch.Tensor
```

This function receives five parameters:

- `score`: A scalar tensor representing the attention score, with the same data type and device as the query, key, and value tensors.
- `batch`, `head`, `seq_len_q`, `seq_len_kv`: Scalar tensors indicating the batch index, head index, query index, and key/value index, respectively, with torch.int data type and located on the same device as the score tensor.

Consider inputs query, key, value of shapes (2, 4, 16, 8), leading to an intermediate attention score matrix of shape (2, 4, 16, 16)

The score_mod function will be vectorized over each element of this matrix. For instance, modifying the score at the position corresponding to the 0th batch, 2nd head, between the 8th query and the 9th key element, would be invoked as:

```Python
score_mod(score[0,2,8,9], torch.tensor(0), torch.tensor(2), torch.tensor(8), torch.tensor(9))
```

### Examples
```Python
import torch
from torch.nn.attention.templated_attention import templated_attention

torch.manual_seed(0)

# Lets create some input tensors
# The input tensor has shape (batch_size, num_heads, seq_len, head_dim)
query = torch.randn(8, 8, 2048, 64, device="cuda", dtype=torch.float32)
key = torch.randn(8, 8, 2048, 64, device="cuda", dtype=torch.float32)
value = torch.randn(8, 8, 2048, 64, device="cuda", dtype=torch.float32)

# Lets create a fun new score_modification! I will call this
# Checkerboard. It will reduce the score for neighboring tokens (1 step apart)
# in the sequence. And increase the score for tokens 2 steps apart. For everything
# else, the score will remain the same.

def checkerboard(score, batch, head, token_q, token_kv):
    score = torch.where(torch.abs(token_kv - token_q) == 1, score * 0.5, score)
    score = torch.where(torch.abs(token_kv - token_q) == 2, score * 2.0, score)
    return score

# Lets call templated_attention with this new score modification
output = templated_attention(query, key, value, score_mod=checkerboard)

compiled_templated_attention = torch.compile(templated_attention)
out_compiled = compiled_templated_attention(query, key, value, score_mod=checkerboard)

torch.testing.assert_close(output, out_compiled, atol=2e-2, rtol=2e-2)
```

### Future Work
- This PR is currently only forward only. However the triton kernel for backwards where score_modifications to not rely on external buffers has been explored here: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/flash/flash_attention.py
- Kernel Improvements; There are has been some larger updates to the fused attention implementation that Triton uses in its tutorials. The implementation of this kernel is based on a prior version and should be updated.
- We may want to unify this API under the top level SDPA API and leave that as a follow up once this is more stable
- Should we error on CPU?
- There are some issues with dynamic shapes
- Capturing of free variables and lifting to inputs to the subgraph is not working correctly today

### Performance
Comparisons generated by this benchmark:

| Type    |   Speedup |   batch_size |   num_heads |   q_seq_len |   k_seq_len |   head_dim | score_mod     | dtype          |
|---------|-----------|--------------|-------------|-------------|-------------|------------|---------------|----------------|
| Average |     5.412 |              |             |             |             |            |               |                |
| Max     |     8.882 |           16 |          16 |        4096 |        4096 |         64 | relative_bias | torch.bfloat16 |
| Min     |     3.645 |            8 |          16 |         512 |         512 |         64 | causal_mask   | torch.bfloat16 |
| Min     |     0.345 |            1 |          16 |        1024 |        1024 |         64 | pathological  | torch.bfloat16 |

For reference

| Configuration                                 | Forward Time (µ seconds) | Backend          | Speedup |
|-----------------------------------------------|--------------------------|------------------|---------|
| Fastest Config in Sweep (`8 16 4096 4096 64 relative_bias torch.bfloat16`) | 3608                   | Templated Attention                | 1.0  |
| Compiled SDPA (No Mask)                       | 9928                   | Math             | 2.75x   |
| Compiled SDPA (With Mask)                     | 11898                    | Math             | 3.29x   |
| Compiled SDPA (With Mask) | 8704                      | Memory Efficient Attention | 2.42x   |
| Compiled SDPA (No Mask) | 2548                     | FlashAttention2 | 0.706x   |

The speedups are measuring compiled templated attention speed versus different calls to torch.nn.functional.sdpa

<details>

<summary> FULL PERFORMANCE SWEEP NUMBERS </summary>

|   batch_size |   num_heads |   q_seq_len |   k_seq_len |   head_dim | score_mod     | dtype          |   eager_time |   compiled_time |   speedup |
|--------------|-------------|-------------|-------------|------------|---------------|----------------|--------------|-----------------|-----------|
|            1 |          16 |         512 |         512 |         64 | causal_mask   | torch.bfloat16 |      331.444 |          67.221 |     4.931 |
|            1 |          16 |         512 |         512 |         64 | relative_bias | torch.bfloat16 |      335.300 |          64.187 |     5.224 |
|            1 |          16 |         512 |         512 |         64 | head_bias     | torch.bfloat16 |      352.039 |          63.806 |     5.517 |
|            1 |          16 |         512 |         512 |         64 | pathological  | torch.bfloat16 |      371.699 |         711.349 |     0.523 |
|            1 |          16 |        1024 |        1024 |         64 | causal_mask   | torch.bfloat16 |      333.488 |          86.455 |     3.857 |
|            1 |          16 |        1024 |        1024 |         64 | relative_bias | torch.bfloat16 |      322.363 |          82.469 |     3.909 |
|            1 |          16 |        1024 |        1024 |         64 | head_bias     | torch.bfloat16 |      349.967 |          82.233 |     4.256 |
|            1 |          16 |        1024 |        1024 |         64 | pathological  | torch.bfloat16 |      486.359 |        1412.453 |     0.344 |
|            1 |          16 |        4096 |        4096 |         64 | causal_mask   | torch.bfloat16 |     2794.597 |         551.188 |     5.070 |
|            1 |          16 |        4096 |        4096 |         64 | relative_bias | torch.bfloat16 |     3965.150 |         513.101 |     7.728 |
|            1 |          16 |        4096 |        4096 |         64 | head_bias     | torch.bfloat16 |     2408.013 |         504.759 |     4.771 |
|            1 |          16 |        4096 |        4096 |         64 | pathological  | torch.bfloat16 |     6850.531 |       16733.675 |     0.409 |
|            8 |          16 |         512 |         512 |         64 | causal_mask   | torch.bfloat16 |      441.939 |         123.576 |     3.576 |
|            8 |          16 |         512 |         512 |         64 | relative_bias | torch.bfloat16 |      560.379 |         116.710 |     4.801 |
|            8 |          16 |         512 |         512 |         64 | head_bias     | torch.bfloat16 |      421.172 |         115.825 |     3.636 |
|            8 |          16 |         512 |         512 |         64 | pathological  | torch.bfloat16 |      994.492 |        2132.806 |     0.466 |
|            8 |          16 |        1024 |        1024 |         64 | causal_mask   | torch.bfloat16 |     1436.430 |         309.495 |     4.641 |
|            8 |          16 |        1024 |        1024 |         64 | relative_bias | torch.bfloat16 |     1892.216 |         290.186 |     6.521 |
|            8 |          16 |        1024 |        1024 |         64 | head_bias     | torch.bfloat16 |     1360.665 |         282.956 |     4.809 |
|            8 |          16 |        1024 |        1024 |         64 | pathological  | torch.bfloat16 |     3525.532 |        8359.702 |     0.422 |
|            8 |          16 |        4096 |        4096 |         64 | causal_mask   | torch.bfloat16 |    22026.839 |        3864.604 |     5.700 |
|            8 |          16 |        4096 |        4096 |         64 | relative_bias | torch.bfloat16 |    31262.746 |        3609.551 |     8.661 |
|            8 |          16 |        4096 |        4096 |         64 | head_bias     | torch.bfloat16 |    20219.079 |        3480.402 |     5.809 |
|            8 |          16 |        4096 |        4096 |         64 | pathological  | torch.bfloat16 |    54654.647 |      116652.357 |     0.469 |
|           16 |          16 |         512 |         512 |         64 | causal_mask   | torch.bfloat16 |      820.606 |         188.683 |     4.349 |
|           16 |          16 |         512 |         512 |         64 | relative_bias | torch.bfloat16 |     1058.362 |         179.295 |     5.903 |
|           16 |          16 |         512 |         512 |         64 | head_bias     | torch.bfloat16 |      784.372 |         175.714 |     4.464 |
|           16 |          16 |         512 |         512 |         64 | pathological  | torch.bfloat16 |     1890.792 |        4212.877 |     0.449 |
|           16 |          16 |        1024 |        1024 |         64 | causal_mask   | torch.bfloat16 |     2781.830 |         557.017 |     4.994 |
|           16 |          16 |        1024 |        1024 |         64 | relative_bias | torch.bfloat16 |     3694.050 |         525.249 |     7.033 |
|           16 |          16 |        1024 |        1024 |         64 | head_bias     | torch.bfloat16 |     2634.164 |         507.613 |     5.189 |
|           16 |          16 |        1024 |        1024 |         64 | pathological  | torch.bfloat16 |     6959.917 |       15331.116 |     0.454 |
|           16 |          16 |        4096 |        4096 |         64 | causal_mask   | torch.bfloat16 |    43889.096 |        7582.018 |     5.789 |
|           16 |          16 |        4096 |        4096 |         64 | relative_bias | torch.bfloat16 |    62784.293 |        7075.846 |     8.873 |
|           16 |          16 |        4096 |        4096 |         64 | head_bias     | torch.bfloat16 |    40308.606 |        6829.587 |     5.902 |
|           16 |          16 |        4096 |        4096 |         64 | pathological  | torch.bfloat16 |   108892.137 |      233090.953 |     0.467 |
</details>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121845
Approved by: https://github.com/Chillee, https://github.com/zou3519
2024-04-06 01:10:44 +00:00
557e7c9c16 Add some type hints to functions and update a few spelling mistakes (#123015)
# Summary
While working on this PR: https://github.com/pytorch/pytorch/pull/121845
I found that these type hints made my ide/ noob experience easier to reason about

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123015
Approved by: https://github.com/Skylion007
2024-03-30 21:15:01 +00:00
c81c9ba472 Disallow {FakeTensor,FunctionalTensor}.data_ptr (#122514)
This PR:
- disallows FakeTensor.data_ptr when it is called inside PT2 or fx tracing.
- disallows FunctionalTensor.data_ptr (python FunctionalTensor is only used in
  PT2)

The motivation behind this is that the leading cause of segfaults when
using custom ops with PT2 is calling .data_ptr on FunctionalTensor or
FakeTensor.

This change is BC-breaking. If your code broke as a result of this, it's
because there was a bug in it (these .data_ptr should never be
accessed!). You can either fix the bug (recommended) or get the previous
behavior back with:
```
from torch._subclasses.fake_tensor import FakeTensor
from torch._subclasses.functional_tensor import FunctionalTensor

data_ptr = 0 if isinstance(tensor, (FakeTensor, FunctionalTensor)) else tensor.data_ptr()
```

Test Plan:
- existing tests

Differential Revision: [D55366199](https://our.internmc.facebook.com/intern/diff/D55366199)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122514
Approved by: https://github.com/ezyang, https://github.com/albanD, https://github.com/yifuwang, https://github.com/kurtamohler
2024-03-26 23:55:42 +00:00