Do not decompose in functionalization/proxy tensor if autograd wouldn't have decomposed (#164939)

This fixes AOTAutograd rms_norm not being bitwise equivalent to
eager, because it avoids a decomposition.  You can force the
decomposition by having the decomposition in the dispatch table,
but if eager mode wouldn't have decomposed (because it went to the fused
one), we now default to preserving the fused call by default.

This largely reverts https://github.com/pytorch/pytorch/pull/103275/ for view ops. This means that in inference mode we could hit the wrong C++ kernel; if this occurs we should just SymInt'ify the C++ kernel.

Another neat side effect of this change is that Inductor's generated kernels for rms_norm now have rms_norm in their name.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164939
Approved by: https://github.com/bdhirsh
ghstack dependencies: #164573
This commit is contained in:
Edward Z. Yang
2025-10-08 16:29:15 -07:00
committed by PyTorch MergeBot
parent e532f62e0d
commit d40a9bfb8d
10 changed files with 163 additions and 26 deletions

View File

@ -52,9 +52,7 @@ constexpr DispatchKeySet math_dispatch_keyset = backend_dispatch_keyset |
// where we would like to support composite implicit kernels but not
// explicit kernels therefore we manually add the key to the
// math_dispatch_keyset
DispatchKeySet{DispatchKey::NestedTensor} |
// Functionalize should always reuse CompositeImplicit decomps.
DispatchKeySet{DispatchKey::Functionalize};
DispatchKeySet{DispatchKey::NestedTensor};
constexpr DispatchKeySet nested_dispatch_keyset =
DispatchKeySet(

View File

@ -7207,7 +7207,6 @@ metadata incorrectly.
aot_eager = torch.compile(backend="aot_eager")(fn)(x)
self.assertEqual(eager, aot_eager, atol=0, rtol=0)
@unittest.expectedFailure
@unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable")
def test_rms_norm(self):
# Only CUDA rms norm fails to be decomposed

View File

@ -1255,11 +1255,10 @@ class DecompOneOffTests(TestCase):
)
# check RMSNorm was fused with sinh
self.assertTrue("triton_per_fused__fused_rms_norm_sinh" in generated_codes[0])
self.assertTrue(
"triton_per_fused_add_mean_mul_pow_rsqrt_sinh" in generated_codes[0]
)
self.assertTrue(
"triton_per_fused__fused_rms_norm_backward_cosh_mul" in generated_codes[1]
"triton_per_fused__fused_rms_norm__fused_rms_norm_backward_cosh_mul"
in generated_codes[1]
)

View File

@ -404,6 +404,7 @@ def _core_aten_decompositions_post_autograd() -> dict[
aten.max_unpool3d,
aten.mish,
aten.mish_,
aten.mish_backward,
aten.mse_loss,
aten.mse_loss_backward,
aten.multi_margin_loss,
@ -419,6 +420,7 @@ def _core_aten_decompositions_post_autograd() -> dict[
aten.native_dropout_backward,
aten.native_group_norm_backward,
aten.native_layer_norm_backward,
aten._fused_rms_norm,
aten._fused_rms_norm_backward,
aten.new_empty,
aten.new_full,
@ -475,6 +477,7 @@ def _core_aten_decompositions_post_autograd() -> dict[
aten.silu,
aten.silu_,
aten.silu_backward.grad_input,
aten.silu_backward,
aten.sinc,
aten.sinc_,
aten.slice_backward,

View File

@ -1757,6 +1757,58 @@ def native_layer_norm_backward_out(
return grad_input
@register_decomposition(aten._fused_rms_norm.default)
def _fused_rms_norm(
input: Tensor,
normalized_shape: list[int],
weight: Optional[Tensor],
eps: Optional[float],
) -> tuple[Tensor, Tensor]:
dims_to_reduce: list[int] = []
for i in range(len(normalized_shape)):
dims_to_reduce.append(input.dim() - i - 1)
# upcast is needed for fp16 and bf16
computation_dtype = utils.get_computation_dtype(input.dtype)
upcasted_input = input.to(computation_dtype)
# computation_dtype would be one of [Double, Float, ComplexFloat, ComplexDouble]
if eps is None:
if computation_dtype in (torch.float32, torch.complex64):
eps_val = sys.float_info.epsilon
else:
eps_val = sys.float_info.epsilon
else:
eps_val = eps
rqrst_input = torch.rsqrt(
# NB: don't inplace here, will violate functional IR invariant
torch.pow(upcasted_input, 2).mean(dim=dims_to_reduce, keepdim=True).add(eps_val)
)
upcasted_result = upcasted_input.mul(rqrst_input)
if weight is not None:
upcasted_result = upcasted_result.mul(weight)
# NB: nested should be dead here, just here for fidelity
is_nested = input.is_nested or (weight is not None and weight.is_nested)
memory_format = utils.suggest_memory_format(input)
is_channels_last = memory_format in (
torch.channels_last,
torch.channels_last_3d,
)
if not is_nested and not is_channels_last:
upcasted_result = upcasted_result.contiguous()
rqrst_input = rqrst_input.contiguous()
# Cast normalized result back to original input type
result = upcasted_result.type_as(input)
return result, rqrst_input
@register_decomposition(aten._fused_rms_norm_backward.default)
def _fused_rms_norm_backward(
grad_out: Tensor,

View File

@ -15,6 +15,7 @@ from torch._subclasses.meta_utils import is_sparse_any
from torch.utils._python_dispatch import (
_detect_infra_mode,
_disable_infra_mode,
autograd_would_have_decomposed,
return_and_correct_aliasing,
TorchDispatchMode,
)
@ -409,8 +410,13 @@ class FunctionalTensorMode(TorchDispatchMode):
return False
return True
# in normal torch.compile IR, we decompose functional composite ops
return True
# in normal torch.compile IR, we only decompose an op if autograd
# would have decomposed it (NB: autograd may have been skipped if
# we are in inference mode)
# TODO: the flatten here can potentially be deduped with the
# unwrapping pytree_map later
flat_args_kwargs, _ = pytree.tree_flatten((args, kwargs))
return autograd_would_have_decomposed(func, flat_args_kwargs)
if (
func not in FunctionalTensor.metadata_fns

View File

@ -21,6 +21,10 @@ backends are ready, this list allows opt-in one at a time.
PRESERVED_ATEN_CIA_OPS = {
torch.ops.aten.upsample_bilinear2d.vec,
torch.ops.aten.upsample_nearest2d.vec,
# NB: don't use the C++ decomp, because it is not functional!
torch.ops.aten.silu_backward.default,
torch.ops.aten.mish_backward.default,
torch.ops.aten._fused_rms_norm.default,
}

View File

@ -63,6 +63,7 @@ from torch.utils._python_dispatch import (
_disable_infra_mode,
_push_mode,
_unset_infra_mode,
autograd_would_have_decomposed,
TorchDispatchMode,
)
from torch.utils._stats import count
@ -908,11 +909,16 @@ def proxy_call(
return r
# For pre-autograd tracing, we do not want to run CompositeImplicit decomps.
if not pre_dispatch and func not in [
torch.ops.aten.size.default,
torch.ops.aten.stride.default,
torch.ops.aten.storage_offset.default,
]:
if (
not pre_dispatch
and func
not in [
torch.ops.aten.size.default,
torch.ops.aten.stride.default,
torch.ops.aten.storage_offset.default,
]
and autograd_would_have_decomposed(func, flat_args_kwargs)
):
with proxy_mode:
r = func.decompose(*args, **kwargs)
if r is not NotImplemented:

View File

@ -1,11 +1,12 @@
# mypy: allow-untyped-defs
from __future__ import annotations
import contextlib
import functools
import warnings
from collections import deque
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Optional, overload, Protocol, Union
from typing import Optional, overload, Protocol, TYPE_CHECKING, Union
from typing_extensions import TypeIs
import torch
@ -20,6 +21,10 @@ from torch._C import (
)
if TYPE_CHECKING:
from collections.abc import Sequence
# TODO: Limitations and things about enable_torch_dispatch_mode we should fix before exposing it:
# - We need a better user-facing api for _DisableTorchDispatch that
# is able to selectively disable __torch_dispatch__ of a particular class.
@ -414,7 +419,7 @@ class TensorWithFlatten(Protocol):
@overload
def to(
self,
device: Optional["torch._prims_common.DeviceLikeType"] = None,
device: Optional[torch._prims_common.DeviceLikeType] = None,
dtype: Optional[torch.types._dtype] = None,
non_blocking: bool = False,
copy: bool = False,
@ -682,6 +687,61 @@ def get_alias_info(func) -> SchemaInfo:
return schema_info
def autograd_would_have_decomposed(
func: torch._ops.OpOverload, flat_args: Sequence[Union[torch.Tensor, object]]
) -> bool:
"""
Suppose that an operator has CompositeImplicitAutograd decomp registered.
Would autograd have used this decomposition? It will only use it if there
isn't an explicit backend registration for the device as well. This function
will tell if this would have occurred.
Why do we need to apply these decompositions later? When inference mode is
on, the autograd key is bypassed entirely, so a lower level mode cannot rely
on the decomposition have been applied. It's easy to accidentally never apply
the decomposition, resulting in an operator showing up in a graph that
is unexpected.
Why do we need to AVOID applying the decomposition when autograd wouldn't
have decomposed? If autograd doesn't decompose, this means in eager mode
we would have run the fused kernel. It must be possible to trace this
fused kernel directly into the graph for fidelity with eager (NB: a user
has the option of then further decomposing at proxy tensor mode via
decomposition table, but we must preserve it to proxy mode to have the
choice.)
Why does functionalization need to also perform the test here? This is
because some CompositeImplicitAutograd decompositions are not functional.
If we are eventually going to decompose, we need to do this while we can
still turn functionalization back on, so those decompositions get functionalized.
So an early decomposition in functionalization may still be necessary. Note that
if proxy tensor decomposition process could turn functionalization back on, this
wouldn't be necessary, and maybe that is a useful thing to do anyway because
the decomposition table is user specified and a user could violate the functional
decomp requirement with a bad decomp. If this happened, then you could always
pass through functionalization.
"""
has_backend_registration = False
for a in flat_args:
if isinstance(a, torch.Tensor):
backend_key = torch._C._parse_dispatch_key(
torch._C._dispatch_key_for_device(a.device.type)
)
assert backend_key is not None
# TODO: use func.has_kernel_for_dispatch_key(backend_key)
# but this one checks py_impl and CompositeImplicitAutograd
# incorrectly shows up as has backend reg here
has_backend_registration = torch._C._dispatch_has_kernel_for_dispatch_key(
func.name(), backend_key
)
# in theory we should take all backend keys and take the highest priority one
# to properly mimic the dispatcher,
# this just grabs the first tensor and takes its device key
break
return not has_backend_registration
# See NOTE[SchemaInfo int_tags] above.
_TORCH_TAG_INPLACE_VIEW_INT = int(torch.Tag.inplace_view) # type: ignore[call-overload]

View File

@ -1024,8 +1024,22 @@ def gen_functionalization_registration(
) -> list[str]:
@with_native_function
def emit_registration_helper(f: NativeFunction) -> str:
assert not f.has_composite_implicit_autograd_kernel
registration_str = f"TORCH_FN(functionalization::{wrapper_name(f.func)})"
if f.has_composite_implicit_autograd_kernel:
metadata = composite_implicit_autograd_index.get_kernel(f)
assert metadata is not None
native_api_name = metadata.kernel
sig = NativeSignature(f.func, symint=metadata.supports_symint())
# Note [Composite view ops in the functionalization pass]
# We don't need to worry about implemententing functionalization kernels for views with
# CompositeImplicitAutograd kernels, because we can just decompose them into their base operators.
# We can't just opt the entire Functionalization dispatch key into the composite keyset though,
# because we don't want to decompose non-view ops that are composite, like `at::ones`.
registration_str = (
f"static_cast<{sig.ptr_type()}>(at::native::{native_api_name})"
)
else:
# non-composite view ops (and inplace ops) get a normal registration.
registration_str = f"TORCH_FN(functionalization::{wrapper_name(f.func)})"
return f'm.impl("{f.func.name}", {registration_str});'
# Don't generate kernels in mobile build
@ -1038,12 +1052,8 @@ def gen_functionalization_registration(
if str(g.view.func.name) == "lift_fresh":
return []
view_str = []
if not g.view.has_composite_implicit_autograd_kernel:
view_str.append(emit_registration_helper(g.view))
if (
g.view_inplace is not None
and not g.view_inplace.has_composite_implicit_autograd_kernel
):
view_str.append(emit_registration_helper(g.view))
if g.view_inplace is not None:
assert g.view_inplace.is_view_op
view_str.append(emit_registration_helper(g.view_inplace))
return view_str