Do not decompose in functionalization/proxy tensor if autograd wouldn't have decomposed (#164939)

This fixes AOTAutograd rms_norm not being bitwise equivalent to
eager, because it avoids a decomposition.  You can force the
decomposition by having the decomposition in the dispatch table,
but if eager mode wouldn't have decomposed (because it went to the fused
one), we now default to preserving the fused call by default.

This largely reverts https://github.com/pytorch/pytorch/pull/103275/ for view ops. This means that in inference mode we could hit the wrong C++ kernel; if this occurs we should just SymInt'ify the C++ kernel.

Another neat side effect of this change is that Inductor's generated kernels for rms_norm now have rms_norm in their name.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164939
Approved by: https://github.com/bdhirsh
This commit is contained in:
Edward Z. Yang
2025-10-09 13:07:34 -07:00
committed by PyTorch MergeBot
parent 344e6365a0
commit a6fa4f9c28
13 changed files with 181 additions and 39 deletions

View File

@ -1,11 +1,12 @@
# mypy: allow-untyped-defs
from __future__ import annotations
import contextlib
import functools
import warnings
from collections import deque
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Optional, overload, Protocol, Union
from typing import Optional, overload, Protocol, TYPE_CHECKING, Union
from typing_extensions import TypeIs
import torch
@ -20,6 +21,10 @@ from torch._C import (
)
if TYPE_CHECKING:
from collections.abc import Sequence
# TODO: Limitations and things about enable_torch_dispatch_mode we should fix before exposing it:
# - We need a better user-facing api for _DisableTorchDispatch that
# is able to selectively disable __torch_dispatch__ of a particular class.
@ -414,7 +419,7 @@ class TensorWithFlatten(Protocol):
@overload
def to(
self,
device: Optional["torch._prims_common.DeviceLikeType"] = None,
device: Optional[torch._prims_common.DeviceLikeType] = None,
dtype: Optional[torch.types._dtype] = None,
non_blocking: bool = False,
copy: bool = False,
@ -682,6 +687,61 @@ def get_alias_info(func) -> SchemaInfo:
return schema_info
def autograd_would_have_decomposed(
func: torch._ops.OpOverload, flat_args: Sequence[Union[torch.Tensor, object]]
) -> bool:
"""
Suppose that an operator has CompositeImplicitAutograd decomp registered.
Would autograd have used this decomposition? It will only use it if there
isn't an explicit backend registration for the device as well. This function
will tell if this would have occurred.
Why do we need to apply these decompositions later? When inference mode is
on, the autograd key is bypassed entirely, so a lower level mode cannot rely
on the decomposition have been applied. It's easy to accidentally never apply
the decomposition, resulting in an operator showing up in a graph that
is unexpected.
Why do we need to AVOID applying the decomposition when autograd wouldn't
have decomposed? If autograd doesn't decompose, this means in eager mode
we would have run the fused kernel. It must be possible to trace this
fused kernel directly into the graph for fidelity with eager (NB: a user
has the option of then further decomposing at proxy tensor mode via
decomposition table, but we must preserve it to proxy mode to have the
choice.)
Why does functionalization need to also perform the test here? This is
because some CompositeImplicitAutograd decompositions are not functional.
If we are eventually going to decompose, we need to do this while we can
still turn functionalization back on, so those decompositions get functionalized.
So an early decomposition in functionalization may still be necessary. Note that
if proxy tensor decomposition process could turn functionalization back on, this
wouldn't be necessary, and maybe that is a useful thing to do anyway because
the decomposition table is user specified and a user could violate the functional
decomp requirement with a bad decomp. If this happened, then you could always
pass through functionalization.
"""
has_backend_registration = False
for a in flat_args:
if isinstance(a, torch.Tensor):
backend_key = torch._C._parse_dispatch_key(
torch._C._dispatch_key_for_device(a.device.type)
)
assert backend_key is not None
# TODO: use func.has_kernel_for_dispatch_key(backend_key)
# but this one checks py_impl and CompositeImplicitAutograd
# incorrectly shows up as has backend reg here
has_backend_registration = torch._C._dispatch_has_kernel_for_dispatch_key(
func.name(), backend_key
)
# in theory we should take all backend keys and take the highest priority one
# to properly mimic the dispatcher,
# this just grabs the first tensor and takes its device key
break
return not has_backend_registration
# See NOTE[SchemaInfo int_tags] above.
_TORCH_TAG_INPLACE_VIEW_INT = int(torch.Tag.inplace_view) # type: ignore[call-overload]