Revert "Do not decompose in functionalization/proxy tensor if autograd wouldn't have decomposed (#164939)"

This reverts commit a6fa4f9c283971c0fb6f60a89674a1f35370ac79.

Reverted https://github.com/pytorch/pytorch/pull/164939 on behalf of https://github.com/izaitsevfb due to introduces numeric issues internally, see [D84326613](https://www.internalfb.com/diff/D84326613) ([comment](https://github.com/pytorch/pytorch/pull/164939#issuecomment-3392203314))
This commit is contained in:
PyTorch MergeBot
2025-10-10 20:21:12 +00:00
parent 306b344a18
commit 5c3fe9fb30
13 changed files with 39 additions and 181 deletions

View File

@ -202,7 +202,6 @@ supported:
- select_backward
- _trilinear
- linalg_pinv.atol_rtol_tensor
- svd
- logsumexp.out
symint:
- empty.memory_format

View File

@ -52,7 +52,9 @@ constexpr DispatchKeySet math_dispatch_keyset = backend_dispatch_keyset |
// where we would like to support composite implicit kernels but not
// explicit kernels therefore we manually add the key to the
// math_dispatch_keyset
DispatchKeySet{DispatchKey::NestedTensor};
DispatchKeySet{DispatchKey::NestedTensor} |
// Functionalize should always reuse CompositeImplicit decomps.
DispatchKeySet{DispatchKey::Functionalize};
constexpr DispatchKeySet nested_dispatch_keyset =
DispatchKeySet(

View File

@ -7207,6 +7207,7 @@ metadata incorrectly.
aot_eager = torch.compile(backend="aot_eager")(fn)(x)
self.assertEqual(eager, aot_eager, atol=0, rtol=0)
@unittest.expectedFailure
@unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable")
def test_rms_norm(self):
# Only CUDA rms norm fails to be decomposed

View File

@ -85,7 +85,6 @@ def init_lists():
"linalg_inv_ex",
"linalg_pinv.atol_rtol_tensor",
"logsumexp",
"svd",
}
# For some ops, we don't support all variants. Here we use formatted_name
# to uniquely identify the variant.
@ -221,15 +220,20 @@ class TestLazyOpInfo(TestCase):
torch._lazy.wait_device_ops()
prefix = "aten" if op.name in FALLBACK_LIST else "lazy"
symint_suffix = "_symint" if op.name in HAS_SYMINT_SUFFIX else ""
metrics = remove_suffixes(torch._lazy.metrics.counter_names())
cands = [f"{prefix}::{op.name}{symint_suffix}"]
# check aliases
for alias in op.aliases:
cands.append(f"{prefix}::{alias.name}{symint_suffix}")
self.assertTrue(
any(c in metrics for c in cands), f"none of {cands} not found in {metrics}"
found = f"{prefix}::{op.name}{symint_suffix}" in remove_suffixes(
torch._lazy.metrics.counter_names()
)
# check aliases
if not found:
for alias in op.aliases:
alias_found = (
f"{prefix}::{alias.name}{symint_suffix}"
in remove_suffixes(torch._lazy.metrics.counter_names())
)
found = found or alias_found
if found:
break
self.assertTrue(found)
@ops(
[

View File

@ -1255,10 +1255,11 @@ class DecompOneOffTests(TestCase):
)
# check RMSNorm was fused with sinh
self.assertTrue("triton_per_fused__fused_rms_norm_sinh" in generated_codes[0])
self.assertTrue(
"triton_per_fused__fused_rms_norm__fused_rms_norm_backward_cosh_mul"
in generated_codes[1]
"triton_per_fused_add_mean_mul_pow_rsqrt_sinh" in generated_codes[0]
)
self.assertTrue(
"triton_per_fused__fused_rms_norm_backward_cosh_mul" in generated_codes[1]
)

View File

@ -404,7 +404,6 @@ def _core_aten_decompositions_post_autograd() -> dict[
aten.max_unpool3d,
aten.mish,
aten.mish_,
aten.mish_backward,
aten.mse_loss,
aten.mse_loss_backward,
aten.multi_margin_loss,
@ -420,7 +419,6 @@ def _core_aten_decompositions_post_autograd() -> dict[
aten.native_dropout_backward,
aten.native_group_norm_backward,
aten.native_layer_norm_backward,
aten._fused_rms_norm,
aten._fused_rms_norm_backward,
aten.new_empty,
aten.new_full,
@ -477,7 +475,6 @@ def _core_aten_decompositions_post_autograd() -> dict[
aten.silu,
aten.silu_,
aten.silu_backward.grad_input,
aten.silu_backward,
aten.sinc,
aten.sinc_,
aten.slice_backward,

View File

@ -1757,58 +1757,6 @@ def native_layer_norm_backward_out(
return grad_input
@register_decomposition(aten._fused_rms_norm.default)
def _fused_rms_norm(
input: Tensor,
normalized_shape: list[int],
weight: Optional[Tensor],
eps: Optional[float],
) -> tuple[Tensor, Tensor]:
dims_to_reduce: list[int] = []
for i in range(len(normalized_shape)):
dims_to_reduce.append(input.dim() - i - 1)
# upcast is needed for fp16 and bf16
computation_dtype = utils.get_computation_dtype(input.dtype)
upcasted_input = input.to(computation_dtype)
# computation_dtype would be one of [Double, Float, ComplexFloat, ComplexDouble]
if eps is None:
if computation_dtype in (torch.float32, torch.complex64):
eps_val = sys.float_info.epsilon
else:
eps_val = sys.float_info.epsilon
else:
eps_val = eps
rqrst_input = torch.rsqrt(
# NB: don't inplace here, will violate functional IR invariant
torch.pow(upcasted_input, 2).mean(dim=dims_to_reduce, keepdim=True).add(eps_val)
)
upcasted_result = upcasted_input.mul(rqrst_input)
if weight is not None:
upcasted_result = upcasted_result.mul(weight)
# NB: nested should be dead here, just here for fidelity
is_nested = input.is_nested or (weight is not None and weight.is_nested)
memory_format = utils.suggest_memory_format(input)
is_channels_last = memory_format in (
torch.channels_last,
torch.channels_last_3d,
)
if not is_nested and not is_channels_last:
upcasted_result = upcasted_result.contiguous()
rqrst_input = rqrst_input.contiguous()
# Cast normalized result back to original input type
result = upcasted_result.type_as(input)
return result, rqrst_input
@register_decomposition(aten._fused_rms_norm_backward.default)
def _fused_rms_norm_backward(
grad_out: Tensor,

View File

@ -15,7 +15,6 @@ from torch._subclasses.meta_utils import is_sparse_any
from torch.utils._python_dispatch import (
_detect_infra_mode,
_disable_infra_mode,
autograd_would_have_decomposed,
return_and_correct_aliasing,
TorchDispatchMode,
)
@ -410,13 +409,8 @@ class FunctionalTensorMode(TorchDispatchMode):
return False
return True
# in normal torch.compile IR, we only decompose an op if autograd
# would have decomposed it (NB: autograd may have been skipped if
# we are in inference mode)
# TODO: the flatten here can potentially be deduped with the
# unwrapping pytree_map later
flat_args_kwargs, _ = pytree.tree_flatten((args, kwargs))
return autograd_would_have_decomposed(func, flat_args_kwargs)
# in normal torch.compile IR, we decompose functional composite ops
return True
if (
func not in FunctionalTensor.metadata_fns

View File

@ -466,14 +466,6 @@ at::Tensor LazyNativeFunctions::linalg_pinv(
linalg_pinv, atol_rtol_tensor)>::call(self, atol, rtol, hermitian);
}
std::tuple<at::Tensor, at::Tensor, at::Tensor> LazyNativeFunctions::svd(
const at::Tensor& self,
bool some,
bool compute_uv) {
return at::functionalization::functionalize_aten_op<ATEN_OP(svd)>::call(
self, some, compute_uv);
}
// functionalize_aten_op can't handle out= ops directly.
// Instead, we can call the composite kernel from core, and copy and mutations
// back to the inputs.

View File

@ -21,10 +21,6 @@ backends are ready, this list allows opt-in one at a time.
PRESERVED_ATEN_CIA_OPS = {
torch.ops.aten.upsample_bilinear2d.vec,
torch.ops.aten.upsample_nearest2d.vec,
# NB: don't use the C++ decomp, because it is not functional!
torch.ops.aten.silu_backward.default,
torch.ops.aten.mish_backward.default,
torch.ops.aten._fused_rms_norm.default,
}

View File

@ -63,7 +63,6 @@ from torch.utils._python_dispatch import (
_disable_infra_mode,
_push_mode,
_unset_infra_mode,
autograd_would_have_decomposed,
TorchDispatchMode,
)
from torch.utils._stats import count
@ -909,16 +908,11 @@ def proxy_call(
return r
# For pre-autograd tracing, we do not want to run CompositeImplicit decomps.
if (
not pre_dispatch
and func
not in [
torch.ops.aten.size.default,
torch.ops.aten.stride.default,
torch.ops.aten.storage_offset.default,
]
and autograd_would_have_decomposed(func, flat_args_kwargs)
):
if not pre_dispatch and func not in [
torch.ops.aten.size.default,
torch.ops.aten.stride.default,
torch.ops.aten.storage_offset.default,
]:
with proxy_mode:
r = func.decompose(*args, **kwargs)
if r is not NotImplemented:

View File

@ -1,12 +1,11 @@
# mypy: allow-untyped-defs
from __future__ import annotations
import contextlib
import functools
import warnings
from collections import deque
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Optional, overload, Protocol, TYPE_CHECKING, Union
from typing import Optional, overload, Protocol, Union
from typing_extensions import TypeIs
import torch
@ -21,10 +20,6 @@ from torch._C import (
)
if TYPE_CHECKING:
from collections.abc import Sequence
# TODO: Limitations and things about enable_torch_dispatch_mode we should fix before exposing it:
# - We need a better user-facing api for _DisableTorchDispatch that
# is able to selectively disable __torch_dispatch__ of a particular class.
@ -419,7 +414,7 @@ class TensorWithFlatten(Protocol):
@overload
def to(
self,
device: Optional[torch._prims_common.DeviceLikeType] = None,
device: Optional["torch._prims_common.DeviceLikeType"] = None,
dtype: Optional[torch.types._dtype] = None,
non_blocking: bool = False,
copy: bool = False,
@ -687,61 +682,6 @@ def get_alias_info(func) -> SchemaInfo:
return schema_info
def autograd_would_have_decomposed(
func: torch._ops.OpOverload, flat_args: Sequence[Union[torch.Tensor, object]]
) -> bool:
"""
Suppose that an operator has CompositeImplicitAutograd decomp registered.
Would autograd have used this decomposition? It will only use it if there
isn't an explicit backend registration for the device as well. This function
will tell if this would have occurred.
Why do we need to apply these decompositions later? When inference mode is
on, the autograd key is bypassed entirely, so a lower level mode cannot rely
on the decomposition have been applied. It's easy to accidentally never apply
the decomposition, resulting in an operator showing up in a graph that
is unexpected.
Why do we need to AVOID applying the decomposition when autograd wouldn't
have decomposed? If autograd doesn't decompose, this means in eager mode
we would have run the fused kernel. It must be possible to trace this
fused kernel directly into the graph for fidelity with eager (NB: a user
has the option of then further decomposing at proxy tensor mode via
decomposition table, but we must preserve it to proxy mode to have the
choice.)
Why does functionalization need to also perform the test here? This is
because some CompositeImplicitAutograd decompositions are not functional.
If we are eventually going to decompose, we need to do this while we can
still turn functionalization back on, so those decompositions get functionalized.
So an early decomposition in functionalization may still be necessary. Note that
if proxy tensor decomposition process could turn functionalization back on, this
wouldn't be necessary, and maybe that is a useful thing to do anyway because
the decomposition table is user specified and a user could violate the functional
decomp requirement with a bad decomp. If this happened, then you could always
pass through functionalization.
"""
has_backend_registration = False
for a in flat_args:
if isinstance(a, torch.Tensor):
backend_key = torch._C._parse_dispatch_key(
torch._C._dispatch_key_for_device(a.device.type)
)
assert backend_key is not None
# TODO: use func.has_kernel_for_dispatch_key(backend_key)
# but this one checks py_impl and CompositeImplicitAutograd
# incorrectly shows up as has backend reg here
has_backend_registration = torch._C._dispatch_has_kernel_for_dispatch_key(
func.name(), backend_key
)
# in theory we should take all backend keys and take the highest priority one
# to properly mimic the dispatcher,
# this just grabs the first tensor and takes its device key
break
return not has_backend_registration
# See NOTE[SchemaInfo int_tags] above.
_TORCH_TAG_INPLACE_VIEW_INT = int(torch.Tag.inplace_view) # type: ignore[call-overload]

View File

@ -1024,22 +1024,8 @@ def gen_functionalization_registration(
) -> list[str]:
@with_native_function
def emit_registration_helper(f: NativeFunction) -> str:
if f.has_composite_implicit_autograd_kernel:
metadata = composite_implicit_autograd_index.get_kernel(f)
assert metadata is not None
native_api_name = metadata.kernel
sig = NativeSignature(f.func, symint=metadata.supports_symint())
# Note [Composite view ops in the functionalization pass]
# We don't need to worry about implemententing functionalization kernels for views with
# CompositeImplicitAutograd kernels, because we can just decompose them into their base operators.
# We can't just opt the entire Functionalization dispatch key into the composite keyset though,
# because we don't want to decompose non-view ops that are composite, like `at::ones`.
registration_str = (
f"static_cast<{sig.ptr_type()}>(at::native::{native_api_name})"
)
else:
# non-composite view ops (and inplace ops) get a normal registration.
registration_str = f"TORCH_FN(functionalization::{wrapper_name(f.func)})"
assert not f.has_composite_implicit_autograd_kernel
registration_str = f"TORCH_FN(functionalization::{wrapper_name(f.func)})"
return f'm.impl("{f.func.name}", {registration_str});'
# Don't generate kernels in mobile build
@ -1052,8 +1038,12 @@ def gen_functionalization_registration(
if str(g.view.func.name) == "lift_fresh":
return []
view_str = []
view_str.append(emit_registration_helper(g.view))
if g.view_inplace is not None:
if not g.view.has_composite_implicit_autograd_kernel:
view_str.append(emit_registration_helper(g.view))
if (
g.view_inplace is not None
and not g.view_inplace.has_composite_implicit_autograd_kernel
):
assert g.view_inplace.is_view_op
view_str.append(emit_registration_helper(g.view_inplace))
return view_str