mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Do not decompose in functionalization/proxy tensor if autograd wouldn't have decomposed (#164939)
This fixes AOTAutograd rms_norm not being bitwise equivalent to eager, because it avoids a decomposition. You can force the decomposition by having the decomposition in the dispatch table, but if eager mode wouldn't have decomposed (because it went to the fused one), we now default to preserving the fused call by default. This largely reverts https://github.com/pytorch/pytorch/pull/103275/ for view ops. This means that in inference mode we could hit the wrong C++ kernel; if this occurs we should just SymInt'ify the C++ kernel. Another neat side effect of this change is that Inductor's generated kernels for rms_norm now have rms_norm in their name. Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/164939 Approved by: https://github.com/bdhirsh
This commit is contained in:
committed by
PyTorch MergeBot
parent
344e6365a0
commit
a6fa4f9c28
@ -1,11 +1,12 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import functools
|
||||
import warnings
|
||||
from collections import deque
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, overload, Protocol, Union
|
||||
from typing import Optional, overload, Protocol, TYPE_CHECKING, Union
|
||||
from typing_extensions import TypeIs
|
||||
|
||||
import torch
|
||||
@ -20,6 +21,10 @@ from torch._C import (
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
|
||||
# TODO: Limitations and things about enable_torch_dispatch_mode we should fix before exposing it:
|
||||
# - We need a better user-facing api for _DisableTorchDispatch that
|
||||
# is able to selectively disable __torch_dispatch__ of a particular class.
|
||||
@ -414,7 +419,7 @@ class TensorWithFlatten(Protocol):
|
||||
@overload
|
||||
def to(
|
||||
self,
|
||||
device: Optional["torch._prims_common.DeviceLikeType"] = None,
|
||||
device: Optional[torch._prims_common.DeviceLikeType] = None,
|
||||
dtype: Optional[torch.types._dtype] = None,
|
||||
non_blocking: bool = False,
|
||||
copy: bool = False,
|
||||
@ -682,6 +687,61 @@ def get_alias_info(func) -> SchemaInfo:
|
||||
return schema_info
|
||||
|
||||
|
||||
def autograd_would_have_decomposed(
|
||||
func: torch._ops.OpOverload, flat_args: Sequence[Union[torch.Tensor, object]]
|
||||
) -> bool:
|
||||
"""
|
||||
Suppose that an operator has CompositeImplicitAutograd decomp registered.
|
||||
Would autograd have used this decomposition? It will only use it if there
|
||||
isn't an explicit backend registration for the device as well. This function
|
||||
will tell if this would have occurred.
|
||||
|
||||
Why do we need to apply these decompositions later? When inference mode is
|
||||
on, the autograd key is bypassed entirely, so a lower level mode cannot rely
|
||||
on the decomposition have been applied. It's easy to accidentally never apply
|
||||
the decomposition, resulting in an operator showing up in a graph that
|
||||
is unexpected.
|
||||
|
||||
Why do we need to AVOID applying the decomposition when autograd wouldn't
|
||||
have decomposed? If autograd doesn't decompose, this means in eager mode
|
||||
we would have run the fused kernel. It must be possible to trace this
|
||||
fused kernel directly into the graph for fidelity with eager (NB: a user
|
||||
has the option of then further decomposing at proxy tensor mode via
|
||||
decomposition table, but we must preserve it to proxy mode to have the
|
||||
choice.)
|
||||
|
||||
Why does functionalization need to also perform the test here? This is
|
||||
because some CompositeImplicitAutograd decompositions are not functional.
|
||||
If we are eventually going to decompose, we need to do this while we can
|
||||
still turn functionalization back on, so those decompositions get functionalized.
|
||||
So an early decomposition in functionalization may still be necessary. Note that
|
||||
if proxy tensor decomposition process could turn functionalization back on, this
|
||||
wouldn't be necessary, and maybe that is a useful thing to do anyway because
|
||||
the decomposition table is user specified and a user could violate the functional
|
||||
decomp requirement with a bad decomp. If this happened, then you could always
|
||||
pass through functionalization.
|
||||
"""
|
||||
has_backend_registration = False
|
||||
for a in flat_args:
|
||||
if isinstance(a, torch.Tensor):
|
||||
backend_key = torch._C._parse_dispatch_key(
|
||||
torch._C._dispatch_key_for_device(a.device.type)
|
||||
)
|
||||
assert backend_key is not None
|
||||
# TODO: use func.has_kernel_for_dispatch_key(backend_key)
|
||||
# but this one checks py_impl and CompositeImplicitAutograd
|
||||
# incorrectly shows up as has backend reg here
|
||||
has_backend_registration = torch._C._dispatch_has_kernel_for_dispatch_key(
|
||||
func.name(), backend_key
|
||||
)
|
||||
|
||||
# in theory we should take all backend keys and take the highest priority one
|
||||
# to properly mimic the dispatcher,
|
||||
# this just grabs the first tensor and takes its device key
|
||||
break
|
||||
return not has_backend_registration
|
||||
|
||||
|
||||
# See NOTE[SchemaInfo int_tags] above.
|
||||
_TORCH_TAG_INPLACE_VIEW_INT = int(torch.Tag.inplace_view) # type: ignore[call-overload]
|
||||
|
||||
|
Reference in New Issue
Block a user