208 Commits

Author SHA1 Message Date
8ae4f82243 [aotd] Support HOP effects in backward (#132638)
Support of effectful operations in backward:

1/ AOTD collects metadata from forward fn only, so we can have usage of effectful ops in backward, that were not used in forward => Allowing tokens discovery during joint function .

FunctionalTensorMode holds _tokens, in Joint function after tracing forward we memoize _tokens as `_tokens_forward_output`.

2/ Tokens are added as primals inputs (forward) in EffectTokensWrapper.
Tokens that will be used in backward are in partitioner saved values. We do not have control on which positions they are saved in forward outputs.

2/ If new tokens discovered in backward after tracing joint_fn, the result graph will be manually added in the end of primals.
_aot_autograd/utils.py

3/ All effectful ops during backward are marked with 'must_be_in_backward' partitioner_tag, to prevent partiitoner to place them in forward.

For that functional_tensor_mode got new optional state `self._effects_partitioner_tag` for effectful ops, to set after tracing forward.

There are additional changes in partitioner to improve functionality of 'must_be_in_backward'

4/ Unlift tokens now should run for both forward and backward.
- As saved for backward tokens are placed on non static places - we identify input and output tokens to erase, by input and output of `with_effects` operation
- In forward we can have input tokens, discovered in backward, that are not used in with_effects ops in forward, but saved for backward. We identify them by position in forward inputs.

5/ Adding aot debug logging for graphs before unlifting and before adding additional primal for backward tokens.

Tests:
```
python test/higher_order_ops/test_with_effects.py
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132638
Approved by: https://github.com/bdhirsh
2024-08-23 15:30:58 +00:00
d95aedf5fd [BE] typing for decorators - fx/_compatibility (part 1) (#134202)
Part of #134054.

This corresponds to the pytorch mypy changes from D61493706. Updating takes so
long and touches so many files that it's impossible to land as a whole without conflicting with some other intermediate change.
So landing these 'type: ignore' for pytorch in advance of them actually being needed.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134202
Approved by: https://github.com/Skylion007
2024-08-22 17:07:33 +00:00
57625bacea [partitioner] Fix must_be_in_backward corner cases (#134002)
Preparation PR for https://github.com/pytorch/pytorch/pull/132638

"must_be_in_backward" fails the partitioner, if partitioner picks this node as saved_values.

The fix is to prevent partitioner to pick those nodes during nodes classification.

It's hard to make a test without making effectful ops in backward "must_be_in_backward", which will be testing this ( https://github.com/pytorch/pytorch/pull/132638 )

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134002
Approved by: https://github.com/bdhirsh
ghstack dependencies: #134003
2024-08-21 15:58:49 +00:00
5cb05a82b4 [BC breaking] move benchmarking + prefer inductor path (#132827)
move benchmarking out of `torch._inductor.runtime.runtime_utils` and into `torch._inductor.runtime.benchmarking`, and prefer this path over directly accessing Triton's benchmarking

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132827
Approved by: https://github.com/eellison
2024-08-08 00:47:45 +00:00
a74e5abda4 Fix issues in activation_memory_budget for float8 (#132687)
Summary:
When using activation_memory_budget for float8 training, two issues were noticed:

- When `aggressive_options` (https://fburl.com/code/m1yoskxw) is called , all fp8 gemms (the scaled_mm op) are saved for recomputation.
- After adding "scaled_mm" in the `compute_intensive_ops`, we got the next error from `estimate_runtime`: `mat2 must be col_major` from `meta_scaled_mm`.
To fix it, modified `materialize_arg` to also include the stride of the original tensor.

Test Plan: Run float8 training with `activation_memory_budget`.

Differential Revision: D60777297

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132687
Approved by: https://github.com/Chillee
2024-08-05 23:01:35 +00:00
fd4b649e6c [BE]: Simplify some list comps to generators C419 (#132578)
Simplifies some list comprehensions to generator which is more efficient. Automatically applied diffs for the most part with ruff

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132578
Approved by: https://github.com/ezyang
2024-08-04 17:46:26 +00:00
85f19ce14a Support meta["val"] that is a dict, for triton kernels and for the partitioner (#132466)
Internally there's a model that's using memory_budget with the partitioner, and using custom triton kernels. The partitioner fails when encountering the triton ops because they don't have `meta["val"]`. This PR adds `meta["val"]`  to these fx graph nodes and then adds handling for `meta["val"]` being a dict in the partitioner.

Differential Revision: [D60627813](https://our.internmc.facebook.com/intern/diff/D60627813)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132466
Approved by: https://github.com/zou3519
ghstack dependencies: #132356
2024-08-02 23:24:29 +00:00
7d8b95e8fb [easy] more debug in partitioner assert (#132456)
Print the name of the node that didn't have good meta['val']. An internal model is failing with this assert, we need this info to debug further.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132456
Approved by: https://github.com/Chillee
2024-08-02 05:07:01 +00:00
e7eeee473c [BE][Easy][14/19] enforce style for empty lines in import segments in torch/_[a-c]*/ and torch/_[e-h]*/ and torch/_[j-z]*/ (#129765)
See https://github.com/pytorch/pytorch/pull/129751#issue-2380881501. Most changes are auto-generated by linter.

You can review these PRs via:

```bash
git diff --ignore-all-space --ignore-blank-lines HEAD~1
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129765
Approved by: https://github.com/ezyang
2024-07-31 10:42:50 +00:00
945bf78894 Revert "[BE] typing for decorators - fx/_compatibility (#131568)"
This reverts commit 193f62fde91ee20deb5ddcd9ff4593cd78d74c64.

Reverted https://github.com/pytorch/pytorch/pull/131568 on behalf of https://github.com/clee2000 due to same as https://github.com/pytorch/pytorch/pull/131572#issuecomment-2254328359 but I clicked the wrong link by accident.  This is where it actually starts ([comment](https://github.com/pytorch/pytorch/pull/131568#issuecomment-2254330781))
2024-07-28 03:43:39 +00:00
115994fea2 [aotd] Align partitioner graph output type to tuple (#131759)
Brian debugged the difference of the output type for inference and train graph.
Partitioner sometimes return list output type.

After this PR it will always return tuple.

Potentially there can be some new graphs inside tests that will be landed between this PR ci jobs finish and landing.
This could be easily fixed with fast-forward fix on:
```
EXPECTTEST_ACCEPT=1 python test/test.py
```

Adding ciflows/periodic to minimize this probability

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131759
Approved by: https://github.com/ezyang, https://github.com/bdhirsh
2024-07-26 09:46:29 +00:00
193f62fde9 [BE] typing for decorators - fx/_compatibility (#131568)
See #131429

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131568
Approved by: https://github.com/justinchuby, https://github.com/oulgen, https://github.com/zou3519
2024-07-25 22:24:19 +00:00
58b8704f28 [aot] Keep backward mutations in backward (#129130)
https://github.com/pytorch/pytorch/issues/127561

Mutations of inputs in backward are emitted manually, after joint_fn tracing.
With default partitioner logic they will be moved to "forward" graph, as this is operation on forward inputs.

To keep those mutations in backward:
- Introduce "subgraph" node key, that can be specified with contextmanager. When we do manual `copy_` in backward on forward input - we know that his is for backward - set subgraph="backward"

In partitioner:
Introducing optional argument subgraph, to filter out nodes with specified subgraph (node_subgraph) and not to add them to subgraph if node_subgraph is different.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129130
Approved by: https://github.com/Chillee
2024-07-25 20:02:25 +00:00
120fdf7ee2 Revert "[aota] Needs autograd if an input requires_grad, agnostic to enable_grad (#128890)"
This reverts commit e98135d1ad2f999fec649ecd21b35f3d5676be43.

Reverted https://github.com/pytorch/pytorch/pull/128890 on behalf of https://github.com/zou3519 due to broke trunk tests, probably a landrace ([comment](https://github.com/pytorch/pytorch/pull/128890#issuecomment-2236790805))
2024-07-18 14:58:25 +00:00
e98135d1ad [aota] Needs autograd if an input requires_grad, agnostic to enable_grad (#128890)
Reland of:  https://github.com/pytorch/pytorch/pull/128016

Summary from previous PR:
We assume only two possible mutually exclusive scenarios:

Running compiled region for training (Any of inputs has requires_grad)

Produced differentiable outputs should have requires_grad.
Running compiled region for inference (None of inputs has requires_grad)

All outputs do not have requires_grad.
Even if user runs the region under no_grad(), but has an input Tensor with requires_grad - we go Training scenario (1).

With current state that means:
1/ needs_autograd should not check torch.is_grad_enabled(), only that any of inputs requires_grad
2/ if needs_autograd => trace_joint (We are in training scenario 1.) => always run compiled region under with.enable_grad()

Changes in partitioner?

Inference and Training graphs had difference in return container, list/tuple.
The changes in partitioner are done to unify and return always tuple.
As a result - some changes in test_aotdispatch.py for graph contents list -> tuple.

Why was revert?

There was a regression of hf_Reformer model on inference.
```
TORCHINDUCTOR_FX_GRAPH_CACHE=0 python benchmarks/dynamo/torchbench.py --performance --inference --bfloat16 --backend inductor --device cuda --only hf_Reformer --cold-start-latency --use-eval-mode
```

Because one of the compiled graphs contained outputs, which are aliases to the inputs that are nn.Parameter(requires_grad=True).

Even if inference bencharmsk torchbench runs inside with` torch.no_grad()` - alias (specifically for hf_Reformer - expand) ops preserve requires_grad.

As a result we started compiling training graph instead of inference.

Fix for view ops:

If we have outputs, that are aliases to inputs that requires_grad, those outputs requires grad is not a reason to generate training graph.

This is handled in aot_autograd.py, where output_and_mutation_safe are calculated.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128890
Approved by: https://github.com/bdhirsh
2024-07-18 08:27:53 +00:00
973037be6a [BE][Easy] apply autofix for ruff rules unnecessary-collection-call (C408): list() / tuple() / dict() (#130199)
This PR changes the empty collection factory call to Python literals:

- `list()` -> `[]`
- `tuple()` -> `()`
- `dict()` -> `{}`

The Python literals are more performant and safer. For example, the bytecode for building an empty dictionary:

```bash
$ python3 -m dis - <<EOS
import collections

d1 = {}
d2 = dict()

dict = collections.OrderedDict
d3 = dict()
EOS
```

```text
  0           0 RESUME                   0

  1           2 LOAD_CONST               0 (0)
              4 LOAD_CONST               1 (None)
              6 IMPORT_NAME              0 (collections)
              8 STORE_NAME               0 (collections)

  3          10 BUILD_MAP                0
             12 STORE_NAME               1 (d1)

  4          14 PUSH_NULL
             16 LOAD_NAME                2 (dict)
             18 CALL                     0
             26 STORE_NAME               3 (d2)

  6          28 LOAD_NAME                0 (collections)
             30 LOAD_ATTR                8 (OrderedDict)
             50 STORE_NAME               2 (dict)

  7          52 PUSH_NULL
             54 LOAD_NAME                2 (dict)
             56 CALL                     0
             64 STORE_NAME               5 (d3)
             66 RETURN_CONST             1 (None)
```

The dict literal `{}` only has one bytecode `BUILD_MAP`, while the factory call `dict()` has three `PUSH_NULL + LOAD_NAME + CALL`. Also, the factory call is not safe if users override the `dict` name in `locals` or `globals` (see the example of replacing with `OrderedDict` above).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130199
Approved by: https://github.com/malfet
2024-07-11 17:30:28 +00:00
b81767161e Revert "[aota] Needs autograd if an input requires_grad, agnostic to enable_grad (#128890)"
This reverts commit 08d5423d339ac4b302f8ae6b63b334e032104753.

Reverted https://github.com/pytorch/pytorch/pull/128890 on behalf of https://github.com/clee2000 due to broke inductor/test_flex_attention https://github.com/pytorch/pytorch/actions/runs/9879109008/job/27286339304 08d5423d33 test was not run on PR due to bad TD ([comment](https://github.com/pytorch/pytorch/pull/128890#issuecomment-2221368245))
2024-07-10 20:22:24 +00:00
08d5423d33 [aota] Needs autograd if an input requires_grad, agnostic to enable_grad (#128890)
Reland of:  https://github.com/pytorch/pytorch/pull/128016

Summary from previous PR:
We assume only two possible mutually exclusive scenarios:

Running compiled region for training (Any of inputs has requires_grad)

Produced differentiable outputs should have requires_grad.
Running compiled region for inference (None of inputs has requires_grad)

All outputs do not have requires_grad.
Even if user runs the region under no_grad(), but has an input Tensor with requires_grad - we go Training scenario (1).

With current state that means:
1/ needs_autograd should not check torch.is_grad_enabled(), only that any of inputs requires_grad
2/ if needs_autograd => trace_joint (We are in training scenario 1.) => always run compiled region under with.enable_grad()

Changes in partitioner?

Inference and Training graphs had difference in return container, list/tuple.
The changes in partitioner are done to unify and return always tuple.
As a result - some changes in test_aotdispatch.py for graph contents list -> tuple.

Why was revert?

There was a regression of hf_Reformer model on inference.
```
TORCHINDUCTOR_FX_GRAPH_CACHE=0 python benchmarks/dynamo/torchbench.py --performance --inference --bfloat16 --backend inductor --device cuda --only hf_Reformer --cold-start-latency --use-eval-mode
```

Because one of the compiled graphs contained outputs, which are aliases to the inputs that are nn.Parameter(requires_grad=True).

Even if inference bencharmsk torchbench runs inside with` torch.no_grad()` - alias (specifically for hf_Reformer - expand) ops preserve requires_grad.

As a result we started compiling training graph instead of inference.

Fix for view ops:

If we have outputs, that are aliases to inputs that requires_grad, those outputs requires grad is not a reason to generate training graph.

This is handled in aot_autograd.py, where output_and_mutation_safe are calculated.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128890
Approved by: https://github.com/bdhirsh
2024-07-10 17:56:32 +00:00
6c2a8b6b38 [Ez][BE]: Enable new stable ruff rules (#129825)
Applies a bunch of new ruff lint rules that are now stable. Some of these improve efficiency or readability. Since I already did passes on the codebase for these when they were in preview, there should be relatively few changes to the codebase. This is just more for future hardening of it.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129825
Approved by: https://github.com/XuehaiPan, https://github.com/jansel, https://github.com/malfet
2024-07-02 14:47:10 +00:00
7bb558fd6e add _flash_attention_forward and _efficient_attention_forward to compute intensive ops in partitioner (#129533)
Avoid recompute of SDPA during the backward.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129533
Approved by: https://github.com/drisspg
2024-06-27 00:49:00 +00:00
575bc1e3af [Reopen #114036] Allow "must recompute" in torch.compile + selective checkpointing (SAC) (#129295)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129295
Approved by: https://github.com/Chillee
2024-06-25 23:47:08 +00:00
a2b1673dfb [Horace's PR #126446] Prevent partitioner from ever saving views (#129039)
Most work is done by Horace in https://github.com/pytorch/pytorch/issues/126446, this PR just additionally adds the config for it.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129039
Approved by: https://github.com/Chillee
2024-06-19 23:21:16 +00:00
ea614fb2b1 Flip default value for mypy disallow_untyped_defs [2/11] (#127839)
See #127836 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127839
Approved by: https://github.com/oulgen
2024-06-08 18:23:08 +00:00
310f80995b Added memory budget to partitioner (#126320)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126320
Approved by: https://github.com/shunting314
2024-06-08 05:52:40 +00:00
128952625b Revert "Added memory budget to partitioner (#126320)"
This reverts commit 2184cdd29128a924583e4702489177f83fb8270a.

Reverted https://github.com/pytorch/pytorch/pull/126320 on behalf of https://github.com/ZainRizvi due to The new test_ac.py fails on ROCm machines ([comment](https://github.com/pytorch/pytorch/pull/126320#issuecomment-2155141886))
2024-06-07 16:15:03 +00:00
2184cdd291 Added memory budget to partitioner (#126320)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126320
Approved by: https://github.com/shunting314
2024-06-06 20:32:29 +00:00
2d47385f0f [BE]: Enable ruff TCH rules and autofixes for better imports (#127688)
Automated fixes to put imports that are only used in type hints into TYPE_CHECKING imports. This also enables the RUFF TCH rules which will automatically apply autofixes to move imports in and out of TYPE_CHECKING blocks as needed in the future, this will make the initial PyTorch import faster and will reduce cyclic dependencies.

Co-authored-by: Xuehai Pan <XuehaiPan@pku.edu.cn>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127688
Approved by: https://github.com/XuehaiPan, https://github.com/ezyang, https://github.com/malfet
2024-06-06 16:55:58 +00:00
85172fbe84 Back out "Prevent partitioner from ever saving views (#126446)" (#127316)
Summary: Revert "Prevent partitioner from ever saving views (#126446)" due to a torchinductor failure on CU Training Framework tests.

Reviewed By: Chillee

Differential Revision: D57868343

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127316
Approved by: https://github.com/Chillee
2024-05-29 00:29:44 +00:00
d4ec18bdad Prevent partitioner from ever saving views (#126446)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126446
Approved by: https://github.com/anijain2305
ghstack dependencies: #126615
2024-05-22 17:28:46 +00:00
0f37fd06d9 Revert "Prevent partitioner from ever saving views (#126446)"
This reverts commit da2292ce6b37028746bf5beeae04442eef1e803d.

Reverted https://github.com/pytorch/pytorch/pull/126446 on behalf of https://github.com/DanilBaibak due to Break internal build ([comment](https://github.com/pytorch/pytorch/pull/126615#issuecomment-2124169157))
2024-05-22 08:23:40 +00:00
da2292ce6b Prevent partitioner from ever saving views (#126446)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126446
Approved by: https://github.com/anijain2305
ghstack dependencies: #126615
2024-05-20 23:40:56 +00:00
e3230f87aa Cached required_fw_nodes creation (#126613)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126613
Approved by: https://github.com/anijain2305
2024-05-19 01:48:52 +00:00
f9a7033194 Refactor partitioner and clean it up (#126318)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126318
Approved by: https://github.com/anijain2305
2024-05-17 06:15:00 +00:00
1a27e24ff5 Make inductor scheduler graph extension configurable (#125578)
This patch makes the inductor scheduler graph extension configurable.
It enables ease of debugging by changing the graph format (dot, png, etc.).

Particularly, it's very convenient to work with the graph interactively using tools like https://github.com/tintinweb/vscode-interactive-graphviz

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125578
Approved by: https://github.com/Chillee
2024-05-17 04:19:23 +00:00
a5c93a6899 Speed up _extract_graph_with_inputs_outputs (#125937)
_extract_graph_with_inputs_outputs() does membership testing on the input nodes but often that collection is a list so the test is O(n).  Ensure it's a set before looping over all the nodes.

This change speeds up the internal repro (D57090987) by about 18%:
before:
```
708.88user 15.86system 12:16.19elapsed 98%CPU (0avgtext+0avgdata 12898628maxresident)k
0inputs+91968outputs (3major+3532970minor)pagefaults 0swaps
```
after:
```
583.39user 15.98system 10:10.11elapsed 98%CPU (0avgtext+0avgdata 12895108maxresident)k
0inputs+87488outputs (4major+3374582minor)pagefaults 0swaps
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125937
Approved by: https://github.com/oulgen, https://github.com/anijain2305
2024-05-11 00:20:39 +00:00
445a0c01da Retry: Low mem max_pool2d_with_indices (#122832)
Based on #105687

The low memory path does not need to strictly return the int8 offsets
instead the offset to index computation can be separated from the
inner function of the max pool lowering. The partitioner can then choose
to move the offset to index computation into the backward pass.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122832
Approved by: https://github.com/peterbell10, https://github.com/eellison
2024-05-08 19:37:08 +00:00
1dd42e42c4 [BE]: Try TCH autofixes on torch/ (#125536)
Tries TCH autofixes and see what breaks

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125536
Approved by: https://github.com/ezyang
2024-05-05 23:13:59 +00:00
2f3b0befed [BE]: Apply ruff FURB 118. (#124743)
Replaces various lambdas with operator.itemgetter which is more efficient (as it's a builtin function). Particularly useful for when lambdas are used as 'key' functions.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124743
Approved by: https://github.com/albanD, https://github.com/malfet
2024-04-26 14:34:52 +00:00
29cc293725 [BE]: FURB142 - Remove set mutations. Use set update (#124551)
Uses set mutation methods instead of manually reimplementing (update, set_difference etc).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124551
Approved by: https://github.com/ezyang
2024-04-21 14:12:33 +00:00
1d6c5972c1 [BE]: Optimize min/max/sum comprehensions C419 (#123960)
Automatic fixes that replaces certain list comprehensions with generator ones where appropriate so that they are immediately consumed. This is preview functionality in ruff for rule C419 and it was automatically applied.

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123960
Approved by: https://github.com/malfet
2024-04-12 23:54:15 +00:00
73f0ecc1ac [BE] UFMT directory torch/_functorch (#123723)
Part of #123062

- #123062

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123723
Approved by: https://github.com/Skylion007
2024-04-12 08:04:51 +00:00
5aab2b9acf Use graph.find_nodes in functorch (#122258)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122258
Approved by: https://github.com/jansel
ghstack dependencies: #121565, #122255, #122256, #122257
2024-04-07 18:51:22 +00:00
b9c9f037d1 Added some checkpointing tests (#122848)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122848
Approved by: https://github.com/anijain2305
2024-03-29 03:49:19 +00:00
0348773655 Forward fix for subtly breaking AC with compile in the case of stacked (#122841)
checkpoint layers separated by recomputable op
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122841
Approved by: https://github.com/anijain2305
ghstack dependencies: #122686, #122688, #121692
2024-03-27 23:23:04 +00:00
a54ea7bbd8 Made several changes to min-cut partitioner that allow it to recompute more things (#121692)
Perf results
<img width="862" alt="image" src="https://github.com/pytorch/pytorch/assets/6355099/8d44e633-8941-46a6-8e7d-806330a8c890">

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121692
Approved by: https://github.com/shunting314, https://github.com/eellison
ghstack dependencies: #122686, #122688
2024-03-27 22:45:52 +00:00
773ae817f7 Batch Norm Consolidation (#116092)
**Summary:**

This commit simplifies the existing decomposition hierarchy
of batch norm ops by adding a single, backend agnostic op:
`batch_norm_with_update`. The existing hierarchy looks like:

```
aten.batch_norm ->
aten._batch_norm_impl_index ->
[
  aten.native_batch_norm ->
  aten._native_batch_norm_legit (export only) ->
  _batch_norm_legit_cpu/cuda (kernels, export only) ->
  _batch_norm_cpu/cuda (kernels)
] OR
[ aten.cudnn_batch_norm ] OR
[ aten.miopen_batch_norm ]
```

Aside from complexity, an important problem with the
above decomposition hierarchy is cuda numerics in
export flows. We observed significantly worse convergence
when training a mobilenetv2-like model when using the
`_batch_norm_cuda` kernel instead of the `cudnn_batch_norm`
kernel. This means users who export their models on CPU
first then move the models to cuda later may silently
see worse accuracies even when cudnn is installed,
because they are using the worse kernel. This issue is
summarized in https://github.com/pytorch/pytorch/issues/111384.

Instead, the new hierarchy proposed by consolidating
existing batch norm ops will look like:

```
aten.batch_norm ->
aten.batch_norm_with_update ->
[ _batch_norm_cpu (kernel) ] OR
[ _batch_norm_cuda (kernel) ] OR
[ cudnn_batch_norm (kernel) ] OR
[ miopen_batch_norm (kernel) ]
```

The new op `batch_norm_with_update` hides backend
implementation details and automatically picks the right
kernel based on what is installed. This commit also adds
the following variants to this op:

```
batch_norm_with_update_functional
batch_norm_with_update.out
batch_norm_no_update
batch_norm_no_update.out
batch_norm_backward
```

Note that this commit only adds this op and its variants,
but does not actually change the decomps to produce these
ops in the graph. This will be done after the 2 week FC
window, and the ops used in the old stack is planned to
be removed after the 6 month BC window.

Test Plan: `OpInfo` tests for `batch_norm_with_update`.

Reviewers: albanD, bdhirsh

Subscribers: albanD, bdhirsh, supriyar

Tasks: https://github.com/pytorch/pytorch/issues/111384

Differential Revision: [D54805279](https://our.internmc.facebook.com/intern/diff/D54805279)
Co-authored-by: Tugsbayasgalan Manlaibaatar <tmanlaibaatar@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116092
Approved by: https://github.com/bdhirsh, https://github.com/albanD
2024-03-18 21:01:30 +00:00
fd0dbcd891 Revert "Batch Norm Consolidation (#116092)"
This reverts commit 7b4f70eda519ccd7f28de17689edd43c52743bc9.

Reverted https://github.com/pytorch/pytorch/pull/116092 on behalf of https://github.com/osalpekar due to Causes build failure in //caffe2:aten-hip (AMD build) target. See [D54707318](https://www.internalfb.com/diff/D54707318) for more details, may require internal build system changes to resolve. ([comment](https://github.com/pytorch/pytorch/pull/116092#issuecomment-1989542965))
2024-03-11 22:22:41 +00:00
7b4f70eda5 Batch Norm Consolidation (#116092)
**Summary:**

This commit simplifies the existing decomposition hierarchy
of batch norm ops by adding a single, backend agnostic op:
`batch_norm_with_update`. The existing hierarchy looks like:

```
aten.batch_norm ->
aten._batch_norm_impl_index ->
[
  aten.native_batch_norm ->
  aten._native_batch_norm_legit (export only) ->
  _batch_norm_legit_cpu/cuda (kernels, export only) ->
  _batch_norm_cpu/cuda (kernels)
] OR
[ aten.cudnn_batch_norm ] OR
[ aten.miopen_batch_norm ]
```

Aside from complexity, an important problem with the
above decomposition hierarchy is cuda numerics in
export flows. We observed significantly worse convergence
when training a mobilenetv2-like model when using the
`_batch_norm_cuda` kernel instead of the `cudnn_batch_norm`
kernel. This means users who export their models on CPU
first then move the models to cuda later may silently
see worse accuracies even when cudnn is installed,
because they are using the worse kernel. This issue is
summarized in https://github.com/pytorch/pytorch/issues/111384.

Instead, the new hierarchy proposed by consolidating
existing batch norm ops will look like:

```
aten.batch_norm ->
aten.batch_norm_with_update ->
[ _batch_norm_cpu (kernel) ] OR
[ _batch_norm_cuda (kernel) ] OR
[ cudnn_batch_norm (kernel) ] OR
[ miopen_batch_norm (kernel) ]
```

The new op `batch_norm_with_update` hides backend
implementation details and automatically picks the right
kernel based on what is installed. This commit also adds
the following variants to this op:

```
batch_norm_with_update_functional
batch_norm_with_update.out
batch_norm_no_update
batch_norm_no_update.out
batch_norm_backward
```

Note that this commit only adds this op and its variants,
but does not actually change the decomps to produce these
ops in the graph. This will be done after the 2 week FC
window, and the ops used in the old stack is planned to
be removed after the 6 month BC window.

Test Plan: `OpInfo` tests for `batch_norm_with_update`.

Reviewers: albanD, bdhirsh

Subscribers: albanD, bdhirsh, supriyar

Tasks: https://github.com/pytorch/pytorch/issues/111384

Co-authored-by: Tugsbayasgalan Manlaibaatar <tmanlaibaatar@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116092
Approved by: https://github.com/bdhirsh, https://github.com/albanD
2024-03-08 15:07:15 +00:00
b529c19bdf Revert "Batch Norm Consolidation (#116092)"
This reverts commit 5680f565d5b7d4aa412a3988d3d91ca4c5679303.

Reverted https://github.com/pytorch/pytorch/pull/116092 on behalf of https://github.com/jeffdaily due to broke ROCm, PR signal was clean but trunk was not, the merge should have been blocked but wasn't ([comment](https://github.com/pytorch/pytorch/pull/116092#issuecomment-1981373237))
2024-03-06 17:10:01 +00:00
5680f565d5 Batch Norm Consolidation (#116092)
**Summary:**

This commit simplifies the existing decomposition hierarchy
of batch norm ops by adding a single, backend agnostic op:
`batch_norm_with_update`. The existing hierarchy looks like:

```
aten.batch_norm ->
aten._batch_norm_impl_index ->
[
  aten.native_batch_norm ->
  aten._native_batch_norm_legit (export only) ->
  _batch_norm_legit_cpu/cuda (kernels, export only) ->
  _batch_norm_cpu/cuda (kernels)
] OR
[ aten.cudnn_batch_norm ] OR
[ aten.miopen_batch_norm ]
```

Aside from complexity, an important problem with the
above decomposition hierarchy is cuda numerics in
export flows. We observed significantly worse convergence
when training a mobilenetv2-like model when using the
`_batch_norm_cuda` kernel instead of the `cudnn_batch_norm`
kernel. This means users who export their models on CPU
first then move the models to cuda later may silently
see worse accuracies even when cudnn is installed,
because they are using the worse kernel. This issue is
summarized in https://github.com/pytorch/pytorch/issues/111384.

Instead, the new hierarchy proposed by consolidating
existing batch norm ops will look like:

```
aten.batch_norm ->
aten.batch_norm_with_update ->
[ _batch_norm_cpu (kernel) ] OR
[ _batch_norm_cuda (kernel) ] OR
[ cudnn_batch_norm (kernel) ] OR
[ miopen_batch_norm (kernel) ]
```

The new op `batch_norm_with_update` hides backend
implementation details and automatically picks the right
kernel based on what is installed. This commit also adds
the following variants to this op:

```
batch_norm_with_update_functional
batch_norm_with_update.out
batch_norm_no_update
batch_norm_no_update.out
batch_norm_backward
```

Note that this commit only adds this op and its variants,
but does not actually change the decomps to produce these
ops in the graph. This will be done after the 2 week FC
window, and the ops used in the old stack is planned to
be removed after the 6 month BC window.

Test Plan: `OpInfo` tests for `batch_norm_with_update`.

Reviewers: albanD, bdhirsh

Subscribers: albanD, bdhirsh, supriyar

Tasks: https://github.com/pytorch/pytorch/issues/111384

Co-authored-by: Tugsbayasgalan Manlaibaatar <tmanlaibaatar@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116092
Approved by: https://github.com/bdhirsh, https://github.com/albanD
2024-03-06 04:50:46 +00:00