Do not decompose in functionalization/proxy tensor if autograd wouldn't have decomposed (#164939)

This fixes AOTAutograd rms_norm not being bitwise equivalent to
eager, because it avoids a decomposition.  You can force the
decomposition by having the decomposition in the dispatch table,
but if eager mode wouldn't have decomposed (because it went to the fused
one), we now default to preserving the fused call by default.

This largely reverts https://github.com/pytorch/pytorch/pull/103275/ for view ops. This means that in inference mode we could hit the wrong C++ kernel; if this occurs we should just SymInt'ify the C++ kernel.

Another neat side effect of this change is that Inductor's generated kernels for rms_norm now have rms_norm in their name.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164939
Approved by: https://github.com/bdhirsh
This commit is contained in:
Edward Z. Yang
2025-10-10 13:25:49 -07:00
committed by PyTorch MergeBot
parent d73416642f
commit de8d81275a
13 changed files with 181 additions and 39 deletions

View File

@ -1024,8 +1024,22 @@ def gen_functionalization_registration(
) -> list[str]:
@with_native_function
def emit_registration_helper(f: NativeFunction) -> str:
assert not f.has_composite_implicit_autograd_kernel
registration_str = f"TORCH_FN(functionalization::{wrapper_name(f.func)})"
if f.has_composite_implicit_autograd_kernel:
metadata = composite_implicit_autograd_index.get_kernel(f)
assert metadata is not None
native_api_name = metadata.kernel
sig = NativeSignature(f.func, symint=metadata.supports_symint())
# Note [Composite view ops in the functionalization pass]
# We don't need to worry about implemententing functionalization kernels for views with
# CompositeImplicitAutograd kernels, because we can just decompose them into their base operators.
# We can't just opt the entire Functionalization dispatch key into the composite keyset though,
# because we don't want to decompose non-view ops that are composite, like `at::ones`.
registration_str = (
f"static_cast<{sig.ptr_type()}>(at::native::{native_api_name})"
)
else:
# non-composite view ops (and inplace ops) get a normal registration.
registration_str = f"TORCH_FN(functionalization::{wrapper_name(f.func)})"
return f'm.impl("{f.func.name}", {registration_str});'
# Don't generate kernels in mobile build
@ -1038,12 +1052,8 @@ def gen_functionalization_registration(
if str(g.view.func.name) == "lift_fresh":
return []
view_str = []
if not g.view.has_composite_implicit_autograd_kernel:
view_str.append(emit_registration_helper(g.view))
if (
g.view_inplace is not None
and not g.view_inplace.has_composite_implicit_autograd_kernel
):
view_str.append(emit_registration_helper(g.view))
if g.view_inplace is not None:
assert g.view_inplace.is_view_op
view_str.append(emit_registration_helper(g.view_inplace))
return view_str