mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "Do not decompose in functionalization/proxy tensor if autograd wouldn't have decomposed (#164939)"
This reverts commit a6fa4f9c283971c0fb6f60a89674a1f35370ac79. Reverted https://github.com/pytorch/pytorch/pull/164939 on behalf of https://github.com/izaitsevfb due to introduces numeric issues internally, see [D84326613](https://www.internalfb.com/diff/D84326613) ([comment](https://github.com/pytorch/pytorch/pull/164939#issuecomment-3392203314))
This commit is contained in:
@ -202,7 +202,6 @@ supported:
|
||||
- select_backward
|
||||
- _trilinear
|
||||
- linalg_pinv.atol_rtol_tensor
|
||||
- svd
|
||||
- logsumexp.out
|
||||
symint:
|
||||
- empty.memory_format
|
||||
|
@ -52,7 +52,9 @@ constexpr DispatchKeySet math_dispatch_keyset = backend_dispatch_keyset |
|
||||
// where we would like to support composite implicit kernels but not
|
||||
// explicit kernels therefore we manually add the key to the
|
||||
// math_dispatch_keyset
|
||||
DispatchKeySet{DispatchKey::NestedTensor};
|
||||
DispatchKeySet{DispatchKey::NestedTensor} |
|
||||
// Functionalize should always reuse CompositeImplicit decomps.
|
||||
DispatchKeySet{DispatchKey::Functionalize};
|
||||
|
||||
constexpr DispatchKeySet nested_dispatch_keyset =
|
||||
DispatchKeySet(
|
||||
|
@ -7207,6 +7207,7 @@ metadata incorrectly.
|
||||
aot_eager = torch.compile(backend="aot_eager")(fn)(x)
|
||||
self.assertEqual(eager, aot_eager, atol=0, rtol=0)
|
||||
|
||||
@unittest.expectedFailure
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable")
|
||||
def test_rms_norm(self):
|
||||
# Only CUDA rms norm fails to be decomposed
|
||||
|
@ -85,7 +85,6 @@ def init_lists():
|
||||
"linalg_inv_ex",
|
||||
"linalg_pinv.atol_rtol_tensor",
|
||||
"logsumexp",
|
||||
"svd",
|
||||
}
|
||||
# For some ops, we don't support all variants. Here we use formatted_name
|
||||
# to uniquely identify the variant.
|
||||
@ -221,15 +220,20 @@ class TestLazyOpInfo(TestCase):
|
||||
torch._lazy.wait_device_ops()
|
||||
prefix = "aten" if op.name in FALLBACK_LIST else "lazy"
|
||||
symint_suffix = "_symint" if op.name in HAS_SYMINT_SUFFIX else ""
|
||||
metrics = remove_suffixes(torch._lazy.metrics.counter_names())
|
||||
cands = [f"{prefix}::{op.name}{symint_suffix}"]
|
||||
# check aliases
|
||||
for alias in op.aliases:
|
||||
cands.append(f"{prefix}::{alias.name}{symint_suffix}")
|
||||
|
||||
self.assertTrue(
|
||||
any(c in metrics for c in cands), f"none of {cands} not found in {metrics}"
|
||||
found = f"{prefix}::{op.name}{symint_suffix}" in remove_suffixes(
|
||||
torch._lazy.metrics.counter_names()
|
||||
)
|
||||
# check aliases
|
||||
if not found:
|
||||
for alias in op.aliases:
|
||||
alias_found = (
|
||||
f"{prefix}::{alias.name}{symint_suffix}"
|
||||
in remove_suffixes(torch._lazy.metrics.counter_names())
|
||||
)
|
||||
found = found or alias_found
|
||||
if found:
|
||||
break
|
||||
self.assertTrue(found)
|
||||
|
||||
@ops(
|
||||
[
|
||||
|
@ -1255,10 +1255,11 @@ class DecompOneOffTests(TestCase):
|
||||
)
|
||||
|
||||
# check RMSNorm was fused with sinh
|
||||
self.assertTrue("triton_per_fused__fused_rms_norm_sinh" in generated_codes[0])
|
||||
self.assertTrue(
|
||||
"triton_per_fused__fused_rms_norm__fused_rms_norm_backward_cosh_mul"
|
||||
in generated_codes[1]
|
||||
"triton_per_fused_add_mean_mul_pow_rsqrt_sinh" in generated_codes[0]
|
||||
)
|
||||
self.assertTrue(
|
||||
"triton_per_fused__fused_rms_norm_backward_cosh_mul" in generated_codes[1]
|
||||
)
|
||||
|
||||
|
||||
|
@ -404,7 +404,6 @@ def _core_aten_decompositions_post_autograd() -> dict[
|
||||
aten.max_unpool3d,
|
||||
aten.mish,
|
||||
aten.mish_,
|
||||
aten.mish_backward,
|
||||
aten.mse_loss,
|
||||
aten.mse_loss_backward,
|
||||
aten.multi_margin_loss,
|
||||
@ -420,7 +419,6 @@ def _core_aten_decompositions_post_autograd() -> dict[
|
||||
aten.native_dropout_backward,
|
||||
aten.native_group_norm_backward,
|
||||
aten.native_layer_norm_backward,
|
||||
aten._fused_rms_norm,
|
||||
aten._fused_rms_norm_backward,
|
||||
aten.new_empty,
|
||||
aten.new_full,
|
||||
@ -477,7 +475,6 @@ def _core_aten_decompositions_post_autograd() -> dict[
|
||||
aten.silu,
|
||||
aten.silu_,
|
||||
aten.silu_backward.grad_input,
|
||||
aten.silu_backward,
|
||||
aten.sinc,
|
||||
aten.sinc_,
|
||||
aten.slice_backward,
|
||||
|
@ -1757,58 +1757,6 @@ def native_layer_norm_backward_out(
|
||||
return grad_input
|
||||
|
||||
|
||||
@register_decomposition(aten._fused_rms_norm.default)
|
||||
def _fused_rms_norm(
|
||||
input: Tensor,
|
||||
normalized_shape: list[int],
|
||||
weight: Optional[Tensor],
|
||||
eps: Optional[float],
|
||||
) -> tuple[Tensor, Tensor]:
|
||||
dims_to_reduce: list[int] = []
|
||||
for i in range(len(normalized_shape)):
|
||||
dims_to_reduce.append(input.dim() - i - 1)
|
||||
|
||||
# upcast is needed for fp16 and bf16
|
||||
computation_dtype = utils.get_computation_dtype(input.dtype)
|
||||
upcasted_input = input.to(computation_dtype)
|
||||
|
||||
# computation_dtype would be one of [Double, Float, ComplexFloat, ComplexDouble]
|
||||
if eps is None:
|
||||
if computation_dtype in (torch.float32, torch.complex64):
|
||||
eps_val = sys.float_info.epsilon
|
||||
else:
|
||||
eps_val = sys.float_info.epsilon
|
||||
else:
|
||||
eps_val = eps
|
||||
|
||||
rqrst_input = torch.rsqrt(
|
||||
# NB: don't inplace here, will violate functional IR invariant
|
||||
torch.pow(upcasted_input, 2).mean(dim=dims_to_reduce, keepdim=True).add(eps_val)
|
||||
)
|
||||
|
||||
upcasted_result = upcasted_input.mul(rqrst_input)
|
||||
|
||||
if weight is not None:
|
||||
upcasted_result = upcasted_result.mul(weight)
|
||||
|
||||
# NB: nested should be dead here, just here for fidelity
|
||||
is_nested = input.is_nested or (weight is not None and weight.is_nested)
|
||||
memory_format = utils.suggest_memory_format(input)
|
||||
is_channels_last = memory_format in (
|
||||
torch.channels_last,
|
||||
torch.channels_last_3d,
|
||||
)
|
||||
|
||||
if not is_nested and not is_channels_last:
|
||||
upcasted_result = upcasted_result.contiguous()
|
||||
rqrst_input = rqrst_input.contiguous()
|
||||
|
||||
# Cast normalized result back to original input type
|
||||
result = upcasted_result.type_as(input)
|
||||
|
||||
return result, rqrst_input
|
||||
|
||||
|
||||
@register_decomposition(aten._fused_rms_norm_backward.default)
|
||||
def _fused_rms_norm_backward(
|
||||
grad_out: Tensor,
|
||||
|
@ -15,7 +15,6 @@ from torch._subclasses.meta_utils import is_sparse_any
|
||||
from torch.utils._python_dispatch import (
|
||||
_detect_infra_mode,
|
||||
_disable_infra_mode,
|
||||
autograd_would_have_decomposed,
|
||||
return_and_correct_aliasing,
|
||||
TorchDispatchMode,
|
||||
)
|
||||
@ -410,13 +409,8 @@ class FunctionalTensorMode(TorchDispatchMode):
|
||||
return False
|
||||
return True
|
||||
|
||||
# in normal torch.compile IR, we only decompose an op if autograd
|
||||
# would have decomposed it (NB: autograd may have been skipped if
|
||||
# we are in inference mode)
|
||||
# TODO: the flatten here can potentially be deduped with the
|
||||
# unwrapping pytree_map later
|
||||
flat_args_kwargs, _ = pytree.tree_flatten((args, kwargs))
|
||||
return autograd_would_have_decomposed(func, flat_args_kwargs)
|
||||
# in normal torch.compile IR, we decompose functional composite ops
|
||||
return True
|
||||
|
||||
if (
|
||||
func not in FunctionalTensor.metadata_fns
|
||||
|
@ -466,14 +466,6 @@ at::Tensor LazyNativeFunctions::linalg_pinv(
|
||||
linalg_pinv, atol_rtol_tensor)>::call(self, atol, rtol, hermitian);
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor, at::Tensor> LazyNativeFunctions::svd(
|
||||
const at::Tensor& self,
|
||||
bool some,
|
||||
bool compute_uv) {
|
||||
return at::functionalization::functionalize_aten_op<ATEN_OP(svd)>::call(
|
||||
self, some, compute_uv);
|
||||
}
|
||||
|
||||
// functionalize_aten_op can't handle out= ops directly.
|
||||
// Instead, we can call the composite kernel from core, and copy and mutations
|
||||
// back to the inputs.
|
||||
|
@ -21,10 +21,6 @@ backends are ready, this list allows opt-in one at a time.
|
||||
PRESERVED_ATEN_CIA_OPS = {
|
||||
torch.ops.aten.upsample_bilinear2d.vec,
|
||||
torch.ops.aten.upsample_nearest2d.vec,
|
||||
# NB: don't use the C++ decomp, because it is not functional!
|
||||
torch.ops.aten.silu_backward.default,
|
||||
torch.ops.aten.mish_backward.default,
|
||||
torch.ops.aten._fused_rms_norm.default,
|
||||
}
|
||||
|
||||
|
||||
|
@ -63,7 +63,6 @@ from torch.utils._python_dispatch import (
|
||||
_disable_infra_mode,
|
||||
_push_mode,
|
||||
_unset_infra_mode,
|
||||
autograd_would_have_decomposed,
|
||||
TorchDispatchMode,
|
||||
)
|
||||
from torch.utils._stats import count
|
||||
@ -909,16 +908,11 @@ def proxy_call(
|
||||
return r
|
||||
|
||||
# For pre-autograd tracing, we do not want to run CompositeImplicit decomps.
|
||||
if (
|
||||
not pre_dispatch
|
||||
and func
|
||||
not in [
|
||||
torch.ops.aten.size.default,
|
||||
torch.ops.aten.stride.default,
|
||||
torch.ops.aten.storage_offset.default,
|
||||
]
|
||||
and autograd_would_have_decomposed(func, flat_args_kwargs)
|
||||
):
|
||||
if not pre_dispatch and func not in [
|
||||
torch.ops.aten.size.default,
|
||||
torch.ops.aten.stride.default,
|
||||
torch.ops.aten.storage_offset.default,
|
||||
]:
|
||||
with proxy_mode:
|
||||
r = func.decompose(*args, **kwargs)
|
||||
if r is not NotImplemented:
|
||||
|
@ -1,12 +1,11 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import functools
|
||||
import warnings
|
||||
from collections import deque
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, overload, Protocol, TYPE_CHECKING, Union
|
||||
from typing import Optional, overload, Protocol, Union
|
||||
from typing_extensions import TypeIs
|
||||
|
||||
import torch
|
||||
@ -21,10 +20,6 @@ from torch._C import (
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
|
||||
# TODO: Limitations and things about enable_torch_dispatch_mode we should fix before exposing it:
|
||||
# - We need a better user-facing api for _DisableTorchDispatch that
|
||||
# is able to selectively disable __torch_dispatch__ of a particular class.
|
||||
@ -419,7 +414,7 @@ class TensorWithFlatten(Protocol):
|
||||
@overload
|
||||
def to(
|
||||
self,
|
||||
device: Optional[torch._prims_common.DeviceLikeType] = None,
|
||||
device: Optional["torch._prims_common.DeviceLikeType"] = None,
|
||||
dtype: Optional[torch.types._dtype] = None,
|
||||
non_blocking: bool = False,
|
||||
copy: bool = False,
|
||||
@ -687,61 +682,6 @@ def get_alias_info(func) -> SchemaInfo:
|
||||
return schema_info
|
||||
|
||||
|
||||
def autograd_would_have_decomposed(
|
||||
func: torch._ops.OpOverload, flat_args: Sequence[Union[torch.Tensor, object]]
|
||||
) -> bool:
|
||||
"""
|
||||
Suppose that an operator has CompositeImplicitAutograd decomp registered.
|
||||
Would autograd have used this decomposition? It will only use it if there
|
||||
isn't an explicit backend registration for the device as well. This function
|
||||
will tell if this would have occurred.
|
||||
|
||||
Why do we need to apply these decompositions later? When inference mode is
|
||||
on, the autograd key is bypassed entirely, so a lower level mode cannot rely
|
||||
on the decomposition have been applied. It's easy to accidentally never apply
|
||||
the decomposition, resulting in an operator showing up in a graph that
|
||||
is unexpected.
|
||||
|
||||
Why do we need to AVOID applying the decomposition when autograd wouldn't
|
||||
have decomposed? If autograd doesn't decompose, this means in eager mode
|
||||
we would have run the fused kernel. It must be possible to trace this
|
||||
fused kernel directly into the graph for fidelity with eager (NB: a user
|
||||
has the option of then further decomposing at proxy tensor mode via
|
||||
decomposition table, but we must preserve it to proxy mode to have the
|
||||
choice.)
|
||||
|
||||
Why does functionalization need to also perform the test here? This is
|
||||
because some CompositeImplicitAutograd decompositions are not functional.
|
||||
If we are eventually going to decompose, we need to do this while we can
|
||||
still turn functionalization back on, so those decompositions get functionalized.
|
||||
So an early decomposition in functionalization may still be necessary. Note that
|
||||
if proxy tensor decomposition process could turn functionalization back on, this
|
||||
wouldn't be necessary, and maybe that is a useful thing to do anyway because
|
||||
the decomposition table is user specified and a user could violate the functional
|
||||
decomp requirement with a bad decomp. If this happened, then you could always
|
||||
pass through functionalization.
|
||||
"""
|
||||
has_backend_registration = False
|
||||
for a in flat_args:
|
||||
if isinstance(a, torch.Tensor):
|
||||
backend_key = torch._C._parse_dispatch_key(
|
||||
torch._C._dispatch_key_for_device(a.device.type)
|
||||
)
|
||||
assert backend_key is not None
|
||||
# TODO: use func.has_kernel_for_dispatch_key(backend_key)
|
||||
# but this one checks py_impl and CompositeImplicitAutograd
|
||||
# incorrectly shows up as has backend reg here
|
||||
has_backend_registration = torch._C._dispatch_has_kernel_for_dispatch_key(
|
||||
func.name(), backend_key
|
||||
)
|
||||
|
||||
# in theory we should take all backend keys and take the highest priority one
|
||||
# to properly mimic the dispatcher,
|
||||
# this just grabs the first tensor and takes its device key
|
||||
break
|
||||
return not has_backend_registration
|
||||
|
||||
|
||||
# See NOTE[SchemaInfo int_tags] above.
|
||||
_TORCH_TAG_INPLACE_VIEW_INT = int(torch.Tag.inplace_view) # type: ignore[call-overload]
|
||||
|
||||
|
@ -1024,22 +1024,8 @@ def gen_functionalization_registration(
|
||||
) -> list[str]:
|
||||
@with_native_function
|
||||
def emit_registration_helper(f: NativeFunction) -> str:
|
||||
if f.has_composite_implicit_autograd_kernel:
|
||||
metadata = composite_implicit_autograd_index.get_kernel(f)
|
||||
assert metadata is not None
|
||||
native_api_name = metadata.kernel
|
||||
sig = NativeSignature(f.func, symint=metadata.supports_symint())
|
||||
# Note [Composite view ops in the functionalization pass]
|
||||
# We don't need to worry about implemententing functionalization kernels for views with
|
||||
# CompositeImplicitAutograd kernels, because we can just decompose them into their base operators.
|
||||
# We can't just opt the entire Functionalization dispatch key into the composite keyset though,
|
||||
# because we don't want to decompose non-view ops that are composite, like `at::ones`.
|
||||
registration_str = (
|
||||
f"static_cast<{sig.ptr_type()}>(at::native::{native_api_name})"
|
||||
)
|
||||
else:
|
||||
# non-composite view ops (and inplace ops) get a normal registration.
|
||||
registration_str = f"TORCH_FN(functionalization::{wrapper_name(f.func)})"
|
||||
assert not f.has_composite_implicit_autograd_kernel
|
||||
registration_str = f"TORCH_FN(functionalization::{wrapper_name(f.func)})"
|
||||
return f'm.impl("{f.func.name}", {registration_str});'
|
||||
|
||||
# Don't generate kernels in mobile build
|
||||
@ -1052,8 +1038,12 @@ def gen_functionalization_registration(
|
||||
if str(g.view.func.name) == "lift_fresh":
|
||||
return []
|
||||
view_str = []
|
||||
view_str.append(emit_registration_helper(g.view))
|
||||
if g.view_inplace is not None:
|
||||
if not g.view.has_composite_implicit_autograd_kernel:
|
||||
view_str.append(emit_registration_helper(g.view))
|
||||
if (
|
||||
g.view_inplace is not None
|
||||
and not g.view_inplace.has_composite_implicit_autograd_kernel
|
||||
):
|
||||
assert g.view_inplace.is_view_op
|
||||
view_str.append(emit_registration_helper(g.view_inplace))
|
||||
return view_str
|
||||
|
Reference in New Issue
Block a user