From de8d81275a3799fae09d5907cb984c71a9b7fe50 Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Fri, 10 Oct 2025 13:25:49 -0700 Subject: [PATCH] Do not decompose in functionalization/proxy tensor if autograd wouldn't have decomposed (#164939) This fixes AOTAutograd rms_norm not being bitwise equivalent to eager, because it avoids a decomposition. You can force the decomposition by having the decomposition in the dispatch table, but if eager mode wouldn't have decomposed (because it went to the fused one), we now default to preserving the fused call by default. This largely reverts https://github.com/pytorch/pytorch/pull/103275/ for view ops. This means that in inference mode we could hit the wrong C++ kernel; if this occurs we should just SymInt'ify the C++ kernel. Another neat side effect of this change is that Inductor's generated kernels for rms_norm now have rms_norm in their name. Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/164939 Approved by: https://github.com/bdhirsh --- aten/src/ATen/native/ts_native_functions.yaml | 1 + c10/core/DispatchKeySet.cpp | 4 +- test/functorch/test_aotdispatch.py | 1 - test/lazy/test_ts_opinfo.py | 22 +++---- test/test_decomp.py | 7 +- torch/_decomp/__init__.py | 3 + torch/_decomp/decompositions.py | 52 +++++++++++++++ torch/_subclasses/functional_tensor.py | 10 ++- .../lazy/ts_backend/ts_native_functions.cpp | 8 +++ torch/export/decomp_utils.py | 4 ++ torch/fx/experimental/proxy_tensor.py | 16 +++-- torch/utils/_python_dispatch.py | 66 ++++++++++++++++++- torchgen/gen_functionalization_type.py | 26 +++++--- 13 files changed, 181 insertions(+), 39 deletions(-) diff --git a/aten/src/ATen/native/ts_native_functions.yaml b/aten/src/ATen/native/ts_native_functions.yaml index 17c9bd4234f3..4ef380704de8 100644 --- a/aten/src/ATen/native/ts_native_functions.yaml +++ b/aten/src/ATen/native/ts_native_functions.yaml @@ -202,6 +202,7 @@ supported: - select_backward - _trilinear - linalg_pinv.atol_rtol_tensor + - svd - logsumexp.out symint: - empty.memory_format diff --git a/c10/core/DispatchKeySet.cpp b/c10/core/DispatchKeySet.cpp index 96ef6b3522ba..72e72f49a5e4 100644 --- a/c10/core/DispatchKeySet.cpp +++ b/c10/core/DispatchKeySet.cpp @@ -52,9 +52,7 @@ constexpr DispatchKeySet math_dispatch_keyset = backend_dispatch_keyset | // where we would like to support composite implicit kernels but not // explicit kernels therefore we manually add the key to the // math_dispatch_keyset - DispatchKeySet{DispatchKey::NestedTensor} | - // Functionalize should always reuse CompositeImplicit decomps. - DispatchKeySet{DispatchKey::Functionalize}; + DispatchKeySet{DispatchKey::NestedTensor}; constexpr DispatchKeySet nested_dispatch_keyset = DispatchKeySet( diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index 080002999964..41b37a687fae 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -7207,7 +7207,6 @@ metadata incorrectly. aot_eager = torch.compile(backend="aot_eager")(fn)(x) self.assertEqual(eager, aot_eager, atol=0, rtol=0) - @unittest.expectedFailure @unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable") def test_rms_norm(self): # Only CUDA rms norm fails to be decomposed diff --git a/test/lazy/test_ts_opinfo.py b/test/lazy/test_ts_opinfo.py index 7c467dc62413..e4652a465d72 100644 --- a/test/lazy/test_ts_opinfo.py +++ b/test/lazy/test_ts_opinfo.py @@ -85,6 +85,7 @@ def init_lists(): "linalg_inv_ex", "linalg_pinv.atol_rtol_tensor", "logsumexp", + "svd", } # For some ops, we don't support all variants. Here we use formatted_name # to uniquely identify the variant. @@ -220,20 +221,15 @@ class TestLazyOpInfo(TestCase): torch._lazy.wait_device_ops() prefix = "aten" if op.name in FALLBACK_LIST else "lazy" symint_suffix = "_symint" if op.name in HAS_SYMINT_SUFFIX else "" - found = f"{prefix}::{op.name}{symint_suffix}" in remove_suffixes( - torch._lazy.metrics.counter_names() - ) + metrics = remove_suffixes(torch._lazy.metrics.counter_names()) + cands = [f"{prefix}::{op.name}{symint_suffix}"] # check aliases - if not found: - for alias in op.aliases: - alias_found = ( - f"{prefix}::{alias.name}{symint_suffix}" - in remove_suffixes(torch._lazy.metrics.counter_names()) - ) - found = found or alias_found - if found: - break - self.assertTrue(found) + for alias in op.aliases: + cands.append(f"{prefix}::{alias.name}{symint_suffix}") + + self.assertTrue( + any(c in metrics for c in cands), f"none of {cands} not found in {metrics}" + ) @ops( [ diff --git a/test/test_decomp.py b/test/test_decomp.py index 610465db4c48..a534b643997b 100644 --- a/test/test_decomp.py +++ b/test/test_decomp.py @@ -1255,11 +1255,10 @@ class DecompOneOffTests(TestCase): ) # check RMSNorm was fused with sinh + self.assertTrue("triton_per_fused__fused_rms_norm_sinh" in generated_codes[0]) self.assertTrue( - "triton_per_fused_add_mean_mul_pow_rsqrt_sinh" in generated_codes[0] - ) - self.assertTrue( - "triton_per_fused__fused_rms_norm_backward_cosh_mul" in generated_codes[1] + "triton_per_fused__fused_rms_norm__fused_rms_norm_backward_cosh_mul" + in generated_codes[1] ) diff --git a/torch/_decomp/__init__.py b/torch/_decomp/__init__.py index c4396932818d..69ef0b901bed 100644 --- a/torch/_decomp/__init__.py +++ b/torch/_decomp/__init__.py @@ -404,6 +404,7 @@ def _core_aten_decompositions_post_autograd() -> dict[ aten.max_unpool3d, aten.mish, aten.mish_, + aten.mish_backward, aten.mse_loss, aten.mse_loss_backward, aten.multi_margin_loss, @@ -419,6 +420,7 @@ def _core_aten_decompositions_post_autograd() -> dict[ aten.native_dropout_backward, aten.native_group_norm_backward, aten.native_layer_norm_backward, + aten._fused_rms_norm, aten._fused_rms_norm_backward, aten.new_empty, aten.new_full, @@ -475,6 +477,7 @@ def _core_aten_decompositions_post_autograd() -> dict[ aten.silu, aten.silu_, aten.silu_backward.grad_input, + aten.silu_backward, aten.sinc, aten.sinc_, aten.slice_backward, diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py index 18c6ac5945e5..597c28ad0029 100644 --- a/torch/_decomp/decompositions.py +++ b/torch/_decomp/decompositions.py @@ -1757,6 +1757,58 @@ def native_layer_norm_backward_out( return grad_input +@register_decomposition(aten._fused_rms_norm.default) +def _fused_rms_norm( + input: Tensor, + normalized_shape: list[int], + weight: Optional[Tensor], + eps: Optional[float], +) -> tuple[Tensor, Tensor]: + dims_to_reduce: list[int] = [] + for i in range(len(normalized_shape)): + dims_to_reduce.append(input.dim() - i - 1) + + # upcast is needed for fp16 and bf16 + computation_dtype = utils.get_computation_dtype(input.dtype) + upcasted_input = input.to(computation_dtype) + + # computation_dtype would be one of [Double, Float, ComplexFloat, ComplexDouble] + if eps is None: + if computation_dtype in (torch.float32, torch.complex64): + eps_val = torch.finfo(torch.float32).eps + else: + eps_val = torch.finfo(torch.float64).eps + else: + eps_val = eps + + rqrst_input = torch.rsqrt( + # NB: don't inplace here, will violate functional IR invariant + torch.pow(upcasted_input, 2).mean(dim=dims_to_reduce, keepdim=True).add(eps_val) + ) + + upcasted_result = upcasted_input.mul(rqrst_input) + + if weight is not None: + upcasted_result = upcasted_result.mul(weight) + + # NB: nested should be dead here, just here for fidelity + is_nested = input.is_nested or (weight is not None and weight.is_nested) + memory_format = utils.suggest_memory_format(input) + is_channels_last = memory_format in ( + torch.channels_last, + torch.channels_last_3d, + ) + + if not is_nested and not is_channels_last: + upcasted_result = upcasted_result.contiguous() + rqrst_input = rqrst_input.contiguous() + + # Cast normalized result back to original input type + result = upcasted_result.type_as(input) + + return result, rqrst_input + + @register_decomposition(aten._fused_rms_norm_backward.default) def _fused_rms_norm_backward( grad_out: Tensor, diff --git a/torch/_subclasses/functional_tensor.py b/torch/_subclasses/functional_tensor.py index 15ed56ddca3c..d3b9ac7858ce 100644 --- a/torch/_subclasses/functional_tensor.py +++ b/torch/_subclasses/functional_tensor.py @@ -15,6 +15,7 @@ from torch._subclasses.meta_utils import is_sparse_any from torch.utils._python_dispatch import ( _detect_infra_mode, _disable_infra_mode, + autograd_would_have_decomposed, return_and_correct_aliasing, TorchDispatchMode, ) @@ -409,8 +410,13 @@ class FunctionalTensorMode(TorchDispatchMode): return False return True - # in normal torch.compile IR, we decompose functional composite ops - return True + # in normal torch.compile IR, we only decompose an op if autograd + # would have decomposed it (NB: autograd may have been skipped if + # we are in inference mode) + # TODO: the flatten here can potentially be deduped with the + # unwrapping pytree_map later + flat_args_kwargs, _ = pytree.tree_flatten((args, kwargs)) + return autograd_would_have_decomposed(func, flat_args_kwargs) if ( func not in FunctionalTensor.metadata_fns diff --git a/torch/csrc/lazy/ts_backend/ts_native_functions.cpp b/torch/csrc/lazy/ts_backend/ts_native_functions.cpp index 1bb720b810f9..f1f69e092591 100644 --- a/torch/csrc/lazy/ts_backend/ts_native_functions.cpp +++ b/torch/csrc/lazy/ts_backend/ts_native_functions.cpp @@ -466,6 +466,14 @@ at::Tensor LazyNativeFunctions::linalg_pinv( linalg_pinv, atol_rtol_tensor)>::call(self, atol, rtol, hermitian); } +std::tuple LazyNativeFunctions::svd( + const at::Tensor& self, + bool some, + bool compute_uv) { + return at::functionalization::functionalize_aten_op::call( + self, some, compute_uv); +} + // functionalize_aten_op can't handle out= ops directly. // Instead, we can call the composite kernel from core, and copy and mutations // back to the inputs. diff --git a/torch/export/decomp_utils.py b/torch/export/decomp_utils.py index a261ce3c8b2c..d3097734c8a3 100644 --- a/torch/export/decomp_utils.py +++ b/torch/export/decomp_utils.py @@ -21,6 +21,10 @@ backends are ready, this list allows opt-in one at a time. PRESERVED_ATEN_CIA_OPS = { torch.ops.aten.upsample_bilinear2d.vec, torch.ops.aten.upsample_nearest2d.vec, + # NB: don't use the C++ decomp, because it is not functional! + torch.ops.aten.silu_backward.default, + torch.ops.aten.mish_backward.default, + torch.ops.aten._fused_rms_norm.default, } diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py index 2bccd906aa93..2e877ff4fa0d 100644 --- a/torch/fx/experimental/proxy_tensor.py +++ b/torch/fx/experimental/proxy_tensor.py @@ -63,6 +63,7 @@ from torch.utils._python_dispatch import ( _disable_infra_mode, _push_mode, _unset_infra_mode, + autograd_would_have_decomposed, TorchDispatchMode, ) from torch.utils._stats import count @@ -908,11 +909,16 @@ def proxy_call( return r # For pre-autograd tracing, we do not want to run CompositeImplicit decomps. - if not pre_dispatch and func not in [ - torch.ops.aten.size.default, - torch.ops.aten.stride.default, - torch.ops.aten.storage_offset.default, - ]: + if ( + not pre_dispatch + and func + not in [ + torch.ops.aten.size.default, + torch.ops.aten.stride.default, + torch.ops.aten.storage_offset.default, + ] + and autograd_would_have_decomposed(func, flat_args_kwargs) + ): with proxy_mode: r = func.decompose(*args, **kwargs) if r is not NotImplemented: diff --git a/torch/utils/_python_dispatch.py b/torch/utils/_python_dispatch.py index fa756892c342..7d844cd3f91b 100644 --- a/torch/utils/_python_dispatch.py +++ b/torch/utils/_python_dispatch.py @@ -1,11 +1,12 @@ # mypy: allow-untyped-defs +from __future__ import annotations + import contextlib import functools import warnings from collections import deque -from collections.abc import Sequence from dataclasses import dataclass -from typing import Optional, overload, Protocol, Union +from typing import Optional, overload, Protocol, TYPE_CHECKING, Union from typing_extensions import TypeIs import torch @@ -20,6 +21,10 @@ from torch._C import ( ) +if TYPE_CHECKING: + from collections.abc import Sequence + + # TODO: Limitations and things about enable_torch_dispatch_mode we should fix before exposing it: # - We need a better user-facing api for _DisableTorchDispatch that # is able to selectively disable __torch_dispatch__ of a particular class. @@ -414,7 +419,7 @@ class TensorWithFlatten(Protocol): @overload def to( self, - device: Optional["torch._prims_common.DeviceLikeType"] = None, + device: Optional[torch._prims_common.DeviceLikeType] = None, dtype: Optional[torch.types._dtype] = None, non_blocking: bool = False, copy: bool = False, @@ -682,6 +687,61 @@ def get_alias_info(func) -> SchemaInfo: return schema_info +def autograd_would_have_decomposed( + func: torch._ops.OpOverload, flat_args: Sequence[Union[torch.Tensor, object]] +) -> bool: + """ + Suppose that an operator has CompositeImplicitAutograd decomp registered. + Would autograd have used this decomposition? It will only use it if there + isn't an explicit backend registration for the device as well. This function + will tell if this would have occurred. + + Why do we need to apply these decompositions later? When inference mode is + on, the autograd key is bypassed entirely, so a lower level mode cannot rely + on the decomposition have been applied. It's easy to accidentally never apply + the decomposition, resulting in an operator showing up in a graph that + is unexpected. + + Why do we need to AVOID applying the decomposition when autograd wouldn't + have decomposed? If autograd doesn't decompose, this means in eager mode + we would have run the fused kernel. It must be possible to trace this + fused kernel directly into the graph for fidelity with eager (NB: a user + has the option of then further decomposing at proxy tensor mode via + decomposition table, but we must preserve it to proxy mode to have the + choice.) + + Why does functionalization need to also perform the test here? This is + because some CompositeImplicitAutograd decompositions are not functional. + If we are eventually going to decompose, we need to do this while we can + still turn functionalization back on, so those decompositions get functionalized. + So an early decomposition in functionalization may still be necessary. Note that + if proxy tensor decomposition process could turn functionalization back on, this + wouldn't be necessary, and maybe that is a useful thing to do anyway because + the decomposition table is user specified and a user could violate the functional + decomp requirement with a bad decomp. If this happened, then you could always + pass through functionalization. + """ + has_backend_registration = False + for a in flat_args: + if isinstance(a, torch.Tensor): + backend_key = torch._C._parse_dispatch_key( + torch._C._dispatch_key_for_device(a.device.type) + ) + assert backend_key is not None + # TODO: use func.has_kernel_for_dispatch_key(backend_key) + # but this one checks py_impl and CompositeImplicitAutograd + # incorrectly shows up as has backend reg here + has_backend_registration = torch._C._dispatch_has_kernel_for_dispatch_key( + func.name(), backend_key + ) + + # in theory we should take all backend keys and take the highest priority one + # to properly mimic the dispatcher, + # this just grabs the first tensor and takes its device key + break + return not has_backend_registration + + # See NOTE[SchemaInfo int_tags] above. _TORCH_TAG_INPLACE_VIEW_INT = int(torch.Tag.inplace_view) # type: ignore[call-overload] diff --git a/torchgen/gen_functionalization_type.py b/torchgen/gen_functionalization_type.py index c396941cf913..1cb681ba19d3 100644 --- a/torchgen/gen_functionalization_type.py +++ b/torchgen/gen_functionalization_type.py @@ -1024,8 +1024,22 @@ def gen_functionalization_registration( ) -> list[str]: @with_native_function def emit_registration_helper(f: NativeFunction) -> str: - assert not f.has_composite_implicit_autograd_kernel - registration_str = f"TORCH_FN(functionalization::{wrapper_name(f.func)})" + if f.has_composite_implicit_autograd_kernel: + metadata = composite_implicit_autograd_index.get_kernel(f) + assert metadata is not None + native_api_name = metadata.kernel + sig = NativeSignature(f.func, symint=metadata.supports_symint()) + # Note [Composite view ops in the functionalization pass] + # We don't need to worry about implemententing functionalization kernels for views with + # CompositeImplicitAutograd kernels, because we can just decompose them into their base operators. + # We can't just opt the entire Functionalization dispatch key into the composite keyset though, + # because we don't want to decompose non-view ops that are composite, like `at::ones`. + registration_str = ( + f"static_cast<{sig.ptr_type()}>(at::native::{native_api_name})" + ) + else: + # non-composite view ops (and inplace ops) get a normal registration. + registration_str = f"TORCH_FN(functionalization::{wrapper_name(f.func)})" return f'm.impl("{f.func.name}", {registration_str});' # Don't generate kernels in mobile build @@ -1038,12 +1052,8 @@ def gen_functionalization_registration( if str(g.view.func.name) == "lift_fresh": return [] view_str = [] - if not g.view.has_composite_implicit_autograd_kernel: - view_str.append(emit_registration_helper(g.view)) - if ( - g.view_inplace is not None - and not g.view_inplace.has_composite_implicit_autograd_kernel - ): + view_str.append(emit_registration_helper(g.view)) + if g.view_inplace is not None: assert g.view_inplace.is_view_op view_str.append(emit_registration_helper(g.view_inplace)) return view_str