Do not decompose in functionalization/proxy tensor if autograd wouldn't have decomposed (#164939)

This fixes AOTAutograd rms_norm not being bitwise equivalent to
eager, because it avoids a decomposition.  You can force the
decomposition by having the decomposition in the dispatch table,
but if eager mode wouldn't have decomposed (because it went to the fused
one), we now default to preserving the fused call by default.

This largely reverts https://github.com/pytorch/pytorch/pull/103275/ for view ops. This means that in inference mode we could hit the wrong C++ kernel; if this occurs we should just SymInt'ify the C++ kernel.

Another neat side effect of this change is that Inductor's generated kernels for rms_norm now have rms_norm in their name.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164939
Approved by: https://github.com/bdhirsh
ghstack dependencies: #164573
This commit is contained in:
Edward Z. Yang
2025-10-08 16:29:15 -07:00
committed by PyTorch MergeBot
parent e532f62e0d
commit d40a9bfb8d
10 changed files with 163 additions and 26 deletions

View File

@ -404,6 +404,7 @@ def _core_aten_decompositions_post_autograd() -> dict[
aten.max_unpool3d,
aten.mish,
aten.mish_,
aten.mish_backward,
aten.mse_loss,
aten.mse_loss_backward,
aten.multi_margin_loss,
@ -419,6 +420,7 @@ def _core_aten_decompositions_post_autograd() -> dict[
aten.native_dropout_backward,
aten.native_group_norm_backward,
aten.native_layer_norm_backward,
aten._fused_rms_norm,
aten._fused_rms_norm_backward,
aten.new_empty,
aten.new_full,
@ -475,6 +477,7 @@ def _core_aten_decompositions_post_autograd() -> dict[
aten.silu,
aten.silu_,
aten.silu_backward.grad_input,
aten.silu_backward,
aten.sinc,
aten.sinc_,
aten.slice_backward,

View File

@ -1757,6 +1757,58 @@ def native_layer_norm_backward_out(
return grad_input
@register_decomposition(aten._fused_rms_norm.default)
def _fused_rms_norm(
input: Tensor,
normalized_shape: list[int],
weight: Optional[Tensor],
eps: Optional[float],
) -> tuple[Tensor, Tensor]:
dims_to_reduce: list[int] = []
for i in range(len(normalized_shape)):
dims_to_reduce.append(input.dim() - i - 1)
# upcast is needed for fp16 and bf16
computation_dtype = utils.get_computation_dtype(input.dtype)
upcasted_input = input.to(computation_dtype)
# computation_dtype would be one of [Double, Float, ComplexFloat, ComplexDouble]
if eps is None:
if computation_dtype in (torch.float32, torch.complex64):
eps_val = sys.float_info.epsilon
else:
eps_val = sys.float_info.epsilon
else:
eps_val = eps
rqrst_input = torch.rsqrt(
# NB: don't inplace here, will violate functional IR invariant
torch.pow(upcasted_input, 2).mean(dim=dims_to_reduce, keepdim=True).add(eps_val)
)
upcasted_result = upcasted_input.mul(rqrst_input)
if weight is not None:
upcasted_result = upcasted_result.mul(weight)
# NB: nested should be dead here, just here for fidelity
is_nested = input.is_nested or (weight is not None and weight.is_nested)
memory_format = utils.suggest_memory_format(input)
is_channels_last = memory_format in (
torch.channels_last,
torch.channels_last_3d,
)
if not is_nested and not is_channels_last:
upcasted_result = upcasted_result.contiguous()
rqrst_input = rqrst_input.contiguous()
# Cast normalized result back to original input type
result = upcasted_result.type_as(input)
return result, rqrst_input
@register_decomposition(aten._fused_rms_norm_backward.default)
def _fused_rms_norm_backward(
grad_out: Tensor,