Do not decompose in functionalization/proxy tensor if autograd wouldn't have decomposed (#164939)

This fixes AOTAutograd rms_norm not being bitwise equivalent to
eager, because it avoids a decomposition.  You can force the
decomposition by having the decomposition in the dispatch table,
but if eager mode wouldn't have decomposed (because it went to the fused
one), we now default to preserving the fused call by default.

This largely reverts https://github.com/pytorch/pytorch/pull/103275/ for view ops. This means that in inference mode we could hit the wrong C++ kernel; if this occurs we should just SymInt'ify the C++ kernel.

Another neat side effect of this change is that Inductor's generated kernels for rms_norm now have rms_norm in their name.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164939
Approved by: https://github.com/bdhirsh
ghstack dependencies: #164573
This commit is contained in:
Edward Z. Yang
2025-10-08 16:29:15 -07:00
committed by PyTorch MergeBot
parent e532f62e0d
commit d40a9bfb8d
10 changed files with 163 additions and 26 deletions

View File

@ -1255,11 +1255,10 @@ class DecompOneOffTests(TestCase):
)
# check RMSNorm was fused with sinh
self.assertTrue("triton_per_fused__fused_rms_norm_sinh" in generated_codes[0])
self.assertTrue(
"triton_per_fused_add_mean_mul_pow_rsqrt_sinh" in generated_codes[0]
)
self.assertTrue(
"triton_per_fused__fused_rms_norm_backward_cosh_mul" in generated_codes[1]
"triton_per_fused__fused_rms_norm__fused_rms_norm_backward_cosh_mul"
in generated_codes[1]
)