mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "Do not decompose in functionalization/proxy tensor if autograd wouldn't have decomposed (#164939)"
This reverts commit a6fa4f9c283971c0fb6f60a89674a1f35370ac79. Reverted https://github.com/pytorch/pytorch/pull/164939 on behalf of https://github.com/izaitsevfb due to introduces numeric issues internally, see [D84326613](https://www.internalfb.com/diff/D84326613) ([comment](https://github.com/pytorch/pytorch/pull/164939#issuecomment-3392203314))
This commit is contained in:
@ -1024,22 +1024,8 @@ def gen_functionalization_registration(
|
||||
) -> list[str]:
|
||||
@with_native_function
|
||||
def emit_registration_helper(f: NativeFunction) -> str:
|
||||
if f.has_composite_implicit_autograd_kernel:
|
||||
metadata = composite_implicit_autograd_index.get_kernel(f)
|
||||
assert metadata is not None
|
||||
native_api_name = metadata.kernel
|
||||
sig = NativeSignature(f.func, symint=metadata.supports_symint())
|
||||
# Note [Composite view ops in the functionalization pass]
|
||||
# We don't need to worry about implemententing functionalization kernels for views with
|
||||
# CompositeImplicitAutograd kernels, because we can just decompose them into their base operators.
|
||||
# We can't just opt the entire Functionalization dispatch key into the composite keyset though,
|
||||
# because we don't want to decompose non-view ops that are composite, like `at::ones`.
|
||||
registration_str = (
|
||||
f"static_cast<{sig.ptr_type()}>(at::native::{native_api_name})"
|
||||
)
|
||||
else:
|
||||
# non-composite view ops (and inplace ops) get a normal registration.
|
||||
registration_str = f"TORCH_FN(functionalization::{wrapper_name(f.func)})"
|
||||
assert not f.has_composite_implicit_autograd_kernel
|
||||
registration_str = f"TORCH_FN(functionalization::{wrapper_name(f.func)})"
|
||||
return f'm.impl("{f.func.name}", {registration_str});'
|
||||
|
||||
# Don't generate kernels in mobile build
|
||||
@ -1052,8 +1038,12 @@ def gen_functionalization_registration(
|
||||
if str(g.view.func.name) == "lift_fresh":
|
||||
return []
|
||||
view_str = []
|
||||
view_str.append(emit_registration_helper(g.view))
|
||||
if g.view_inplace is not None:
|
||||
if not g.view.has_composite_implicit_autograd_kernel:
|
||||
view_str.append(emit_registration_helper(g.view))
|
||||
if (
|
||||
g.view_inplace is not None
|
||||
and not g.view_inplace.has_composite_implicit_autograd_kernel
|
||||
):
|
||||
assert g.view_inplace.is_view_op
|
||||
view_str.append(emit_registration_helper(g.view_inplace))
|
||||
return view_str
|
||||
|
Reference in New Issue
Block a user