Fixes#156052 and #156444.
This PR setup the privateuseone key in Python to be used as a python backend for pytorch.
Meaning that, after calling `setup_privateuseone_for_python_backend('npy')`, one can use a subclass to with that device to hold arbitrary python data as "device data" and use `torch.library` to register ops that takes that Tensor.
Changes done in this PR:
1. Register an vanilla Device Guard: I extended NoOpDeviceGuard to have allow device index of 0 and to not raise errors when event related functions are accessed. If I don't do those, when calling backward I would get errors. (CPU backend uses NoOpDeviceGuard just fine, although there seems to be special treatment of CPU in the autograd engine.
2. Tensor subclass allows not having `__torch_dispatch__` if the device is not CUDA or CPU. The comment of the check suggests it was to avoid segfault when calling into ops that expects a storage. Here we have a different device so will not call into those ops.
3. python function that invokes the other incantations to setup the privateusekey backend.
This took inspiration of https://github.com/bdhirsh/pytorch_open_registration_example and https://github.com/tinygrad/tinygrad/blob/master/extra/torch_backend/wrapped_tensor.cpp; great thanks to @bdhirsh and @geohot.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157859
Approved by: https://github.com/albanD
Fixes#156052 and #156444.
This PR setup the privateuseone key in Python to be used as a python backend for pytorch.
Meaning that, after calling `setup_privateuseone_for_python_backend('npy')`, one can use a subclass to with that device to hold arbitrary python data as "device data" and use `torch.library` to register ops that takes that Tensor.
Changes done in this PR:
1. Register an vanilla Device Guard: I extended NoOpDeviceGuard to have allow device index of 0 and to not raise errors when event related functions are accessed. If I don't do those, when calling backward I would get errors. (CPU backend uses NoOpDeviceGuard just fine, although there seems to be special treatment of CPU in the autograd engine.
2. Tensor subclass allows not having `__torch_dispatch__` if the device is not CUDA or CPU. The comment of the check suggests it was to avoid segfault when calling into ops that expects a storage. Here we have a different device so will not call into those ops.
3. python function that invokes the other incantations to setup the privateusekey backend.
This took inspiration of https://github.com/bdhirsh/pytorch_open_registration_example and https://github.com/tinygrad/tinygrad/blob/master/extra/torch_backend/wrapped_tensor.cpp; great thanks to @bdhirsh and @geohot.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157859
Approved by: https://github.com/albanD
Fixes#120648
During issue scrubbing I could not repro these failing tests, so reenabling them to close out the issue
### Test
Original repro command:
```
PYTORCH_TEST_WITH_DYNAMO=1 pytest test/test_openmp.py -v -k test_one_thread
```
Now results in
```
platform linux -- Python 3.12.11, pytest-8.4.1, pluggy-1.6.0 -- /home/lucaskabela/.conda/envs/pytorch-3.12/bin/python3.12
cachedir: .pytest_cache
hypothesis profile 'default'
rootdir: /home/lucaskabela/pytorch
configfile: pytest.ini
plugins: hypothesis-6.138.0
collected 2 items / 1 deselected / 1 selected
Running 1 items in this shard
test/test_openmp.py::TestOpenMP_ParallelFor::test_one_thread PASSED [3.6874s] [100%]
===================================================== 1 passed, 1 deselected in 6.07s =====================================================
```
And:
```
PYTORCH_TEST_WITH_DYNAMO=1 python test/test_openmp.py TestOpenMP_ParallelFor.test_one_thread
```
```
PYTORCH_TEST_WITH_DYNAMO=1 python test/test_sort_and_select.py TestSortAndSelectCPU.test_sort_overflow_cpu_int16
```
Both result in:
```
.
----------------------------------------------------------------------
Ran 1 test in 0.003s
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160562
Approved by: https://github.com/zou3519
Implement traceable config patching for Dynamo: enables restricted patching of Dynamo config where user can use a context manager/decorator to change tracing behavior for parts of the code.
The new `dont_skip_tracing` decorator/context manager for ignoring most trace rules is easily implemented with this more generic traceable config patching feature.
Implementation:
- Create a new specialized context manager class representing a wrapper around torch._dynamo.config.patch
- Dynamo doesn't trace into the context manager but updates config at compile time
- Correctness is based on our correctness for handling supported context managers
- Implementation is inspired by how `GradModeVariable` is implemented.
Previous attempts: https://github.com/pytorch/pytorch/pull/148736 (decorator-only global approach) and https://github.com/pytorch/pytorch/pull/149439 (decorator-only traceback approach)
See https://docs.google.com/document/d/1vWNwKL_jpg-PLopifcaSa338wks3GqSVF4GHRguybGg/edit?tab=t.0 for more details on implementation - including previous approaches.
NOTE: this PR fixes a bug where skipped code objects were not tracked by convert_frame.py, leading to cases where code objects would be automatically skipped even after `torch._dynamo.reset()`. This exposed some latent dynamo-wrapped test failures in CI that previously passed in CI but not locally.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150586
Approved by: https://github.com/jansel, https://github.com/zou3519, https://github.com/anijain2305
Implement traceable config patching for Dynamo: enables restricted patching of Dynamo config where user can use a context manager/decorator to change tracing behavior for parts of the code.
The new `dont_skip_tracing` decorator/context manager for ignoring most trace rules is easily implemented with this more generic traceable config patching feature.
Implementation:
- Create a new specialized context manager class representing a wrapper around torch._dynamo.config.patch
- Dynamo doesn't trace into the context manager but updates config at compile time
- Correctness is based on our correctness for handling supported context managers
- Implementation is inspired by how `GradModeVariable` is implemented.
Previous attempts: https://github.com/pytorch/pytorch/pull/148736 (decorator-only global approach) and https://github.com/pytorch/pytorch/pull/149439 (decorator-only traceback approach)
See https://docs.google.com/document/d/1vWNwKL_jpg-PLopifcaSa338wks3GqSVF4GHRguybGg/edit?tab=t.0 for more details on implementation - including previous approaches.
NOTE: this PR fixes a bug where skipped code objects were not tracked by convert_frame.py, leading to cases where code objects would be automatically skipped even after `torch._dynamo.reset()`. This exposed some latent dynamo-wrapped test failures in CI that previously passed in CI but not locally.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150586
Approved by: https://github.com/jansel, https://github.com/zou3519, https://github.com/anijain2305
This allows Configs to handle setting their defaults (or overriding
themselves) via environment variables.
The environment variables are resolved at install time (which is usually
import time). This is done 1) to avoid any race conditions between
threads etc..., but 2) to help encourage people to just go modify the
configs directly, vs overriding environment variables to change
pytorch behaviour.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138956
Approved by: https://github.com/ezyang
ghstack dependencies: #138766
Construct frame localsplus in 3.12+ using our own simplified way rather than copypasting from CPython.
This is necessary for 3.13 since we can no longer generate frame `f_locals` before executing the interpreter frame.
We also enable this for 3.12 since the `f_locals` construction between 3.12 and 3.13 is the same, so we can test for correctness with 3.12.
This is also one of the first steps to completing https://github.com/pytorch/pytorch/issues/93753 - we will implement simplified f_locals generation of previous Python versions in the future.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129185
Approved by: https://github.com/jansel
More details further down, but first a more high-level description of "how do we functionalize storage resizing"
Today, dynamo converts `param.untyped_storage().resize_(x)` calls that it sees from fsdp into a custom op, `ops.inductor.resize_storage_bytes_(x)`
So given this setup, there are 3 main cases that I think we want to handle:
(1) graph input starts with a real storage size, gets resized down to zero in the graph
(2) graph input starts with 0 storage size, gets resized up in the graph
(3) graph input starts with 0 storage size, gets resized up and used in some compute, then resized back down to 0
For case (1) we need to emit a `resize_storage_bytes_` at the end of the graph, similar to how we emit `copy_()` for data mutations.
For case (2), we need to emit a `resize_storage_bytes_` in the graph, and we **also** need to emit a `copy_()` (the input had its storage resized up, and filled in with data, which is we need to reflect as an input mutation)
For case (3), the net effect is that the input had no data on entry and exit of the function, so we don't need to emit any mutable ops in the end of the graph.
The main thing to call out is that: we need to write a functionalization rule for `resize_storage_byte_`, (`FunctionalTensorWrapper::storage_resize_()`) and this rule actually does very little. We would like to **not** emit any new ops in the graph (like say, a functional resize op). Instead, we should expect / rely on the fact that any resize up will be immediately followed by a `copy_()`/`foreach_copy_`/`out=` op, that will fill in the data of the tensor. So `FunctionalTensor` can temporarily live in a state where its data is invalid, until the `x.copy_(y)` "updates" its data with the new tensor.
So effectively, all that this rule does is:
(1) it stores metadata on the storage, indicating that the tensor was resized, as well as the updated storage size. We need this info in AOTAutograd, so it knows whether to emit a mutable resize_() op in the graph epilogue
(2) There is also a corner case: if we are resizing down to zero, but our tensor had **previously** had a zero size storage, then we update `value_` to point to the original value of the tensor. The reason this seems safe is because if we have a zero storage sized tensor `x`, and we resize it up, use it in some compute, resize it back down to zero, and use it somewhere, we would want the functional version of this code to use the original `x` after the second resize. For FSDP, this is important because we end up saving parameters (graph inputs) for backward, and we want to make sure that the thing we save (and the output to the forward graph) is the original, zero-storage-sized parameter, and not the "version 2" of the parameter after the first resize_()
I think a good order to look at changes in this PR would be:
(1) `test_aotdispatch.py` shows the 3 main cases I focused on as well as the expected functionalized graphs
(2) In `FunctionalStorageImpl.h/cpp`, I had to add a notion of "original base", and "original/curr_size". The first is so I can re-use the zero-size tensor after multiple resizes, and the second is so I can tell in AOTAutograd whether any resizes canceled each other out into a no-op
(3) FunctionalTensorWrapper.h/cpp has the new resize functionalizion rule + some extra utils
(4) `_functorch/_autograd`: the main changes in this folder were around adding the logic at trace-time to detect when we need to put a resize_() in the graph. I also have some assertions to check that any inputs that experience storage resizing will **always be in the graph** and not the opaque epilogue, and I also limited the resize_() mutation case so that you can only ever start with zero storage, or end with zero storage (you can't do e.g. `torch.ones(2).storage().resize_(3)`), and banned it on tensor subclasses
(5) `fake_tensor.py`/`meta_utils.py`: we now need to be able to fakeify tensors with zero storage, so I added a quick version of it in meta_utils.py. This also.. has ramifications for fake tensor caching that I need to fix (include the storage size on the cache key, maybe?)
------------------
This PR subsumes https://github.com/pytorch/pytorch/pull/120971.
This PR is enough to **almost** get a simple ppFSDP forward pass tracing with a functionalized resize_() properly. It also attempts to do the updated version from @jansel, where we don't have any notion of `resize_()` in the graph at all, post functionalization. It would probably be good to test it with @yf225 's FSDP changes, and see how many of the FX passes it allows us to remove. I think that in theory, it should allow us to remove all FX passes that affect the forward graph / partitioner, **except** the one that forces views to be recomputed in the backward (more details below).
There are a few things worth calling out:
(1) failed attempt at functionalizing `aten.copy_()`. I originally wanted to get a version takes these operations:
```
param.storage().resize_(all_gather_size)
param.copy_(all_gather_buffer)
out = aten.matmul(param, param)
```
and functionalizes them into:
```
out = aten.matmul(all_gather_buffer, all_gather_buffer)
```
This would involve getting functionalization to turn `x.copy_(y)` into a giant no-op that just returns `y`. Unfortunately, we can't actually do this in a reasonable way within functionalization (instead, there's a functional `aten.copy` in the graph - see the test case graph expecttest for details). Why? In order for that transformation to be safe, `x` and `y` need to have the same metadata. However, it's possible for `x` and `y` to be subclasses of different types. This is not something we can easily tell from within functionalization, and would be a layering violation. So for now I'm leaving it to downstream code to optimize away the `aten.copy` (this is already the case today, so I think inductor can handle this)
(2) The forward doesn't **actually** run successfully in this PR (see the `assertRaisesRegex` in the test). Why?
The final forward graph looks like this:
```
def forward(self, primals_1, primals_2):
_foreach_copy = torch.ops.aten._foreach_copy.default([primals_1], [primals_2]); primals_2 = None
getitem = _foreach_copy[0]; _foreach_copy = None
mm = torch.ops.aten.mm.default(getitem, getitem); getitem = None
t_1 = torch.ops.aten.t.default(primals_1); primals_1 = None
return [mm, t_1]
```
Where `primals_1` starts out as a secretly-zero-storage-size parameter, and gets resized up and back down within the forward (these are functionalized away).
Importantly, the matmul happy on the result of the `foreach_copy`, **but** the activation that we save for backward (`t_1`) is the result of transposing the **original parameter** (the zero-storage-size param). This is exactly the optimization in fsdp that allows us to have good peak memory usage.
The problem is that the min-cut partitioner decides to save `t_1` for backward. Running this code in eager breaks, because the kernel for `aten.permute(x)` is not happy when `x` has secretly-zero-sized-storage.
The real problem here is that in eager mode the `permute` kernel runs during the backward, after backward hooks have properly resized the saved activation. Here, we are running the transpose in the forward.
One option would be to turn off the checks in our view kernels and allow them to work on zero-storage-sized tensors, which feels pretty bad. Another option is to tweak the partitioner (or use one of Will's FX passes) to force the partitioner to not save views for backward, and allow the views to be recomputed in the backward. This seems kind of silly, but is also probably harmless.
(3) The backward is still broken. To be fair, this issue is pretty separable from "functionalizing storage resize calls", and can be fixed later (either by a real fix to our tracing infra, or via another hacky FX pass). More description of this problem is described at issue (8) of my PR description in https://github.com/pytorch/pytorch/pull/120971
(4) I only added support for "full graph" resizing: basically, the limited case where a param starts with zero storage size, and gets resized up and back down. I think we can add support for the graph break case, but I think we can keep that add-on separate from this PR unless we need it immediately. I also added asserts so we should fail loudly when we hit this case
(5) I have a change to FakeTensor creation when inputs have zero storage size that.. is probably ok. But I also removed FakeTensor caching on view ops, which I probably need to fix before I can land this PR
(6) I added a notion of "original_base" to `FunctionalStorageImpl`. More details are in the comments, but my rational for this was that we basically need it to ensure that autograd saves the **original**, zero-storage-sized param for backward, after resizing up and back down
(7) I had to update our eager kernels for `aten.copy` and `aten._foreach_copy`, to handle the case where the `self` argument has secretly-zero-storage. Inductor can probably generate correct code for this case, but we need these ops to work properly in this situation for the `aot_eager` backend to do the right thing
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122434
Approved by: https://github.com/jansel
When fakifying a grad tracking tensor, if the level is -2 (sentinel
value) we can just unwrap the grad tensor and return a fake version of
it. In this PR, we update the `assert_metadata_eq` to not compare if
the grad tensor and the unwrapped ones are leafs or not, as this may
not be always true.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122728
Approved by: https://github.com/zou3519
Fixes#118795
This is a graph breaking partial fix for #120914. We still need -actual- module parametrization tracing support, but at least it doesn't blow up hard now.
**Background**: Module parametrization injects a property as the module parameter attribute that calls a `nn.Module` whose forward takes in a module parameter and returns a reparametrized module parameter.
Example:
```
class MyParametrization(nn.Module):
def forward(X):
# This reparametrization just negates the original parameter value
return -X
m = nn.Linear(...)
p = MyParametrization()
register_parametrization(m, "weight", p)
# Accessing the "weight" attribute will invoke p's forward() on m's original weight and return the output as the new weight.
# m.weight here is now an injected property that does the above instead of an actual Parameter.
# This property is defined in torch/nn/utils/parametrize.py.
m.weight
# NB: Parametrization changes the module type (e.g. torch.nn.utils.parametrize.ParametrizedLinear)
print(type(m))
```
**Problem 1**: Dynamo has special tracing rules for things in `torch.nn`. Parametrizing a module changes the type of the module and the parametrized attribute, so now these rules wrongly affect tracing here. To fix this:
* For parametrized modules, call `convert_to_unspecialized()` to restart analysis where Dynamo starts inlining the module.
**Problem 2**: The issue seen in #118795 is that Dynamo will see a dynamically constructed tensor when `m.weight` is called and introduce that to its `tensor_weakref_to_sizes_strides` cache during fake-ification. This tensor is also made to be a graph input, since it's a module parameter. When guards are created for this module parameter input, the logic calls `m.weight` again and tries to look the result up in the cache, but this is a different tensor now, giving the `KeyError` symptom. To fix this:
* Replace Dynamo's `tensor_weakref_to_sizes_strides` cache with a `input_source_to_sizes_strides` cache.
* This cache was originally introduced in #100128.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121041
Approved by: https://github.com/anijain2305
This removes the duplicate handling of comparison ops between symbolic_convert and bultin and refactors the handling to use the binop infrastructure. This change regresses overheads a bit, but this is fixed in the next PR.
New test skips are variants of `type(e) is np.ndarray` previously falling back to eager.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122043
Approved by: https://github.com/anijain2305
ghstack dependencies: #122039
**Summary:**
This commit simplifies the existing decomposition hierarchy
of batch norm ops by adding a single, backend agnostic op:
`batch_norm_with_update`. The existing hierarchy looks like:
```
aten.batch_norm ->
aten._batch_norm_impl_index ->
[
aten.native_batch_norm ->
aten._native_batch_norm_legit (export only) ->
_batch_norm_legit_cpu/cuda (kernels, export only) ->
_batch_norm_cpu/cuda (kernels)
] OR
[ aten.cudnn_batch_norm ] OR
[ aten.miopen_batch_norm ]
```
Aside from complexity, an important problem with the
above decomposition hierarchy is cuda numerics in
export flows. We observed significantly worse convergence
when training a mobilenetv2-like model when using the
`_batch_norm_cuda` kernel instead of the `cudnn_batch_norm`
kernel. This means users who export their models on CPU
first then move the models to cuda later may silently
see worse accuracies even when cudnn is installed,
because they are using the worse kernel. This issue is
summarized in https://github.com/pytorch/pytorch/issues/111384.
Instead, the new hierarchy proposed by consolidating
existing batch norm ops will look like:
```
aten.batch_norm ->
aten.batch_norm_with_update ->
[ _batch_norm_cpu (kernel) ] OR
[ _batch_norm_cuda (kernel) ] OR
[ cudnn_batch_norm (kernel) ] OR
[ miopen_batch_norm (kernel) ]
```
The new op `batch_norm_with_update` hides backend
implementation details and automatically picks the right
kernel based on what is installed. This commit also adds
the following variants to this op:
```
batch_norm_with_update_functional
batch_norm_with_update.out
batch_norm_no_update
batch_norm_no_update.out
batch_norm_backward
```
Note that this commit only adds this op and its variants,
but does not actually change the decomps to produce these
ops in the graph. This will be done after the 2 week FC
window, and the ops used in the old stack is planned to
be removed after the 6 month BC window.
Test Plan: `OpInfo` tests for `batch_norm_with_update`.
Reviewers: albanD, bdhirsh
Subscribers: albanD, bdhirsh, supriyar
Tasks: https://github.com/pytorch/pytorch/issues/111384
Differential Revision: [D54805279](https://our.internmc.facebook.com/intern/diff/D54805279)
Co-authored-by: Tugsbayasgalan Manlaibaatar <tmanlaibaatar@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116092
Approved by: https://github.com/bdhirsh, https://github.com/albanD
Summary:
The profiling, even when disabled, takes up about 1.5% cpu for a model I'm looking into.
This patch just splits into with/without profile runs.
The potential downside is that now the script can't enable profiling in itself. It doesn't seem to be used anywhere. If that's a crusial usecase, we can do something about it but ideally we wouldn't.
Test Plan:
Link with profiles:
https://fburl.com/scuba/strobelight_services/ihxsl7pj
```
buck2 run fbcode//caffe2/test/cpp/jit:jit
```
Reviewed By: zhxchen17
Differential Revision: D54066589
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121404
Approved by: https://github.com/zhxchen17
**Summary:**
This commit simplifies the existing decomposition hierarchy
of batch norm ops by adding a single, backend agnostic op:
`batch_norm_with_update`. The existing hierarchy looks like:
```
aten.batch_norm ->
aten._batch_norm_impl_index ->
[
aten.native_batch_norm ->
aten._native_batch_norm_legit (export only) ->
_batch_norm_legit_cpu/cuda (kernels, export only) ->
_batch_norm_cpu/cuda (kernels)
] OR
[ aten.cudnn_batch_norm ] OR
[ aten.miopen_batch_norm ]
```
Aside from complexity, an important problem with the
above decomposition hierarchy is cuda numerics in
export flows. We observed significantly worse convergence
when training a mobilenetv2-like model when using the
`_batch_norm_cuda` kernel instead of the `cudnn_batch_norm`
kernel. This means users who export their models on CPU
first then move the models to cuda later may silently
see worse accuracies even when cudnn is installed,
because they are using the worse kernel. This issue is
summarized in https://github.com/pytorch/pytorch/issues/111384.
Instead, the new hierarchy proposed by consolidating
existing batch norm ops will look like:
```
aten.batch_norm ->
aten.batch_norm_with_update ->
[ _batch_norm_cpu (kernel) ] OR
[ _batch_norm_cuda (kernel) ] OR
[ cudnn_batch_norm (kernel) ] OR
[ miopen_batch_norm (kernel) ]
```
The new op `batch_norm_with_update` hides backend
implementation details and automatically picks the right
kernel based on what is installed. This commit also adds
the following variants to this op:
```
batch_norm_with_update_functional
batch_norm_with_update.out
batch_norm_no_update
batch_norm_no_update.out
batch_norm_backward
```
Note that this commit only adds this op and its variants,
but does not actually change the decomps to produce these
ops in the graph. This will be done after the 2 week FC
window, and the ops used in the old stack is planned to
be removed after the 6 month BC window.
Test Plan: `OpInfo` tests for `batch_norm_with_update`.
Reviewers: albanD, bdhirsh
Subscribers: albanD, bdhirsh, supriyar
Tasks: https://github.com/pytorch/pytorch/issues/111384
Co-authored-by: Tugsbayasgalan Manlaibaatar <tmanlaibaatar@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116092
Approved by: https://github.com/bdhirsh, https://github.com/albanD
These are just tests that I noticed passed on current main
Running:
```
PYTORCH_TEST_WITH_DYNAMO=1 pytest test/dynamo/test_dynamic_shapes.py test/dynamo/test_compile.py -k 'test_export_decomp_dynamic_shapes or test_export_dynamic_dim_cleanup_dynamic_shapes or test_export_multi_dynamic_dim_constraint_dynamic_shapes or test_export_multi_dynamic_dim_unsafe_relationship_dynamic_shapes or test_export_no_raise_dynamic_shapes or test_export_preserve_constraints_as_metadata_scalar_dynamic_shapes or test_export_raise_on_relationship_dynamic_shapes or test_exported_graph_serialization_dynamic_shapes or test_retracibility_dict_container_inp_out_dynamic_shapes or test_retracibility_nested_list_out_dynamic_shapes or test_exception_table_e2e_2_dynamic_shapes or test_exception_table_e2e_dynamic_shapes or test_exception_table_parsing_dynamic_shapes or test_inference_mode_dynamic_shapes or test_inplace_view_on_graph_input_dynamic_shapes or test_numpy_torch_operators_dynamic_shapes or test_py311_jump_offset_dynamic_shapes or test_lazy_module_no_cls_to_become_dynamic_shapes or test_batchnorm_e2e_dynamic_shapes or test_functools_wraps_dynamic_shapes or test_jit_trace_errors_dynamic_shapes or test_multi_import_dynamic_shapes or test_requires_grad_guards_with_grad_mode2_dynamic_shapese or test_dynamo_signatures'
```
BEFORE: `1 failed, 1 passed, 22 skipped, 1372 deselected`
AFTER: `24 passed, 1372 deselected`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121378
Approved by: https://github.com/oulgen
**Summary:**
This commit simplifies the existing decomposition hierarchy
of batch norm ops by adding a single, backend agnostic op:
`batch_norm_with_update`. The existing hierarchy looks like:
```
aten.batch_norm ->
aten._batch_norm_impl_index ->
[
aten.native_batch_norm ->
aten._native_batch_norm_legit (export only) ->
_batch_norm_legit_cpu/cuda (kernels, export only) ->
_batch_norm_cpu/cuda (kernels)
] OR
[ aten.cudnn_batch_norm ] OR
[ aten.miopen_batch_norm ]
```
Aside from complexity, an important problem with the
above decomposition hierarchy is cuda numerics in
export flows. We observed significantly worse convergence
when training a mobilenetv2-like model when using the
`_batch_norm_cuda` kernel instead of the `cudnn_batch_norm`
kernel. This means users who export their models on CPU
first then move the models to cuda later may silently
see worse accuracies even when cudnn is installed,
because they are using the worse kernel. This issue is
summarized in https://github.com/pytorch/pytorch/issues/111384.
Instead, the new hierarchy proposed by consolidating
existing batch norm ops will look like:
```
aten.batch_norm ->
aten.batch_norm_with_update ->
[ _batch_norm_cpu (kernel) ] OR
[ _batch_norm_cuda (kernel) ] OR
[ cudnn_batch_norm (kernel) ] OR
[ miopen_batch_norm (kernel) ]
```
The new op `batch_norm_with_update` hides backend
implementation details and automatically picks the right
kernel based on what is installed. This commit also adds
the following variants to this op:
```
batch_norm_with_update_functional
batch_norm_with_update.out
batch_norm_no_update
batch_norm_no_update.out
batch_norm_backward
```
Note that this commit only adds this op and its variants,
but does not actually change the decomps to produce these
ops in the graph. This will be done after the 2 week FC
window, and the ops used in the old stack is planned to
be removed after the 6 month BC window.
Test Plan: `OpInfo` tests for `batch_norm_with_update`.
Reviewers: albanD, bdhirsh
Subscribers: albanD, bdhirsh, supriyar
Tasks: https://github.com/pytorch/pytorch/issues/111384
Co-authored-by: Tugsbayasgalan Manlaibaatar <tmanlaibaatar@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116092
Approved by: https://github.com/bdhirsh, https://github.com/albanD