mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Do not decompose in functionalization/proxy tensor if autograd wouldn't have decomposed (#164939)
This fixes AOTAutograd rms_norm not being bitwise equivalent to eager, because it avoids a decomposition. You can force the decomposition by having the decomposition in the dispatch table, but if eager mode wouldn't have decomposed (because it went to the fused one), we now default to preserving the fused call by default. This largely reverts https://github.com/pytorch/pytorch/pull/103275/ for view ops. This means that in inference mode we could hit the wrong C++ kernel; if this occurs we should just SymInt'ify the C++ kernel. Another neat side effect of this change is that Inductor's generated kernels for rms_norm now have rms_norm in their name. Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/164939 Approved by: https://github.com/bdhirsh
This commit is contained in:
committed by
PyTorch MergeBot
parent
d73416642f
commit
de8d81275a
@ -1024,8 +1024,22 @@ def gen_functionalization_registration(
|
||||
) -> list[str]:
|
||||
@with_native_function
|
||||
def emit_registration_helper(f: NativeFunction) -> str:
|
||||
assert not f.has_composite_implicit_autograd_kernel
|
||||
registration_str = f"TORCH_FN(functionalization::{wrapper_name(f.func)})"
|
||||
if f.has_composite_implicit_autograd_kernel:
|
||||
metadata = composite_implicit_autograd_index.get_kernel(f)
|
||||
assert metadata is not None
|
||||
native_api_name = metadata.kernel
|
||||
sig = NativeSignature(f.func, symint=metadata.supports_symint())
|
||||
# Note [Composite view ops in the functionalization pass]
|
||||
# We don't need to worry about implemententing functionalization kernels for views with
|
||||
# CompositeImplicitAutograd kernels, because we can just decompose them into their base operators.
|
||||
# We can't just opt the entire Functionalization dispatch key into the composite keyset though,
|
||||
# because we don't want to decompose non-view ops that are composite, like `at::ones`.
|
||||
registration_str = (
|
||||
f"static_cast<{sig.ptr_type()}>(at::native::{native_api_name})"
|
||||
)
|
||||
else:
|
||||
# non-composite view ops (and inplace ops) get a normal registration.
|
||||
registration_str = f"TORCH_FN(functionalization::{wrapper_name(f.func)})"
|
||||
return f'm.impl("{f.func.name}", {registration_str});'
|
||||
|
||||
# Don't generate kernels in mobile build
|
||||
@ -1038,12 +1052,8 @@ def gen_functionalization_registration(
|
||||
if str(g.view.func.name) == "lift_fresh":
|
||||
return []
|
||||
view_str = []
|
||||
if not g.view.has_composite_implicit_autograd_kernel:
|
||||
view_str.append(emit_registration_helper(g.view))
|
||||
if (
|
||||
g.view_inplace is not None
|
||||
and not g.view_inplace.has_composite_implicit_autograd_kernel
|
||||
):
|
||||
view_str.append(emit_registration_helper(g.view))
|
||||
if g.view_inplace is not None:
|
||||
assert g.view_inplace.is_view_op
|
||||
view_str.append(emit_registration_helper(g.view_inplace))
|
||||
return view_str
|
||||
|
Reference in New Issue
Block a user