Revert "Do not decompose in functionalization/proxy tensor if autograd wouldn't have decomposed (#164939)"

This reverts commit a6fa4f9c283971c0fb6f60a89674a1f35370ac79.

Reverted https://github.com/pytorch/pytorch/pull/164939 on behalf of https://github.com/izaitsevfb due to introduces numeric issues internally, see [D84326613](https://www.internalfb.com/diff/D84326613) ([comment](https://github.com/pytorch/pytorch/pull/164939#issuecomment-3392203314))
This commit is contained in:
PyTorch MergeBot
2025-10-10 20:21:12 +00:00
parent 306b344a18
commit 5c3fe9fb30
13 changed files with 39 additions and 181 deletions

View File

@ -1024,22 +1024,8 @@ def gen_functionalization_registration(
) -> list[str]:
@with_native_function
def emit_registration_helper(f: NativeFunction) -> str:
if f.has_composite_implicit_autograd_kernel:
metadata = composite_implicit_autograd_index.get_kernel(f)
assert metadata is not None
native_api_name = metadata.kernel
sig = NativeSignature(f.func, symint=metadata.supports_symint())
# Note [Composite view ops in the functionalization pass]
# We don't need to worry about implemententing functionalization kernels for views with
# CompositeImplicitAutograd kernels, because we can just decompose them into their base operators.
# We can't just opt the entire Functionalization dispatch key into the composite keyset though,
# because we don't want to decompose non-view ops that are composite, like `at::ones`.
registration_str = (
f"static_cast<{sig.ptr_type()}>(at::native::{native_api_name})"
)
else:
# non-composite view ops (and inplace ops) get a normal registration.
registration_str = f"TORCH_FN(functionalization::{wrapper_name(f.func)})"
assert not f.has_composite_implicit_autograd_kernel
registration_str = f"TORCH_FN(functionalization::{wrapper_name(f.func)})"
return f'm.impl("{f.func.name}", {registration_str});'
# Don't generate kernels in mobile build
@ -1052,8 +1038,12 @@ def gen_functionalization_registration(
if str(g.view.func.name) == "lift_fresh":
return []
view_str = []
view_str.append(emit_registration_helper(g.view))
if g.view_inplace is not None:
if not g.view.has_composite_implicit_autograd_kernel:
view_str.append(emit_registration_helper(g.view))
if (
g.view_inplace is not None
and not g.view_inplace.has_composite_implicit_autograd_kernel
):
assert g.view_inplace.is_view_op
view_str.append(emit_registration_helper(g.view_inplace))
return view_str