mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "Do not decompose in functionalization/proxy tensor if autograd wouldn't have decomposed (#164939)"
This reverts commit a6fa4f9c283971c0fb6f60a89674a1f35370ac79. Reverted https://github.com/pytorch/pytorch/pull/164939 on behalf of https://github.com/izaitsevfb due to introduces numeric issues internally, see [D84326613](https://www.internalfb.com/diff/D84326613) ([comment](https://github.com/pytorch/pytorch/pull/164939#issuecomment-3392203314))
This commit is contained in:
@ -1,12 +1,11 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import functools
|
||||
import warnings
|
||||
from collections import deque
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, overload, Protocol, TYPE_CHECKING, Union
|
||||
from typing import Optional, overload, Protocol, Union
|
||||
from typing_extensions import TypeIs
|
||||
|
||||
import torch
|
||||
@ -21,10 +20,6 @@ from torch._C import (
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
|
||||
# TODO: Limitations and things about enable_torch_dispatch_mode we should fix before exposing it:
|
||||
# - We need a better user-facing api for _DisableTorchDispatch that
|
||||
# is able to selectively disable __torch_dispatch__ of a particular class.
|
||||
@ -419,7 +414,7 @@ class TensorWithFlatten(Protocol):
|
||||
@overload
|
||||
def to(
|
||||
self,
|
||||
device: Optional[torch._prims_common.DeviceLikeType] = None,
|
||||
device: Optional["torch._prims_common.DeviceLikeType"] = None,
|
||||
dtype: Optional[torch.types._dtype] = None,
|
||||
non_blocking: bool = False,
|
||||
copy: bool = False,
|
||||
@ -687,61 +682,6 @@ def get_alias_info(func) -> SchemaInfo:
|
||||
return schema_info
|
||||
|
||||
|
||||
def autograd_would_have_decomposed(
|
||||
func: torch._ops.OpOverload, flat_args: Sequence[Union[torch.Tensor, object]]
|
||||
) -> bool:
|
||||
"""
|
||||
Suppose that an operator has CompositeImplicitAutograd decomp registered.
|
||||
Would autograd have used this decomposition? It will only use it if there
|
||||
isn't an explicit backend registration for the device as well. This function
|
||||
will tell if this would have occurred.
|
||||
|
||||
Why do we need to apply these decompositions later? When inference mode is
|
||||
on, the autograd key is bypassed entirely, so a lower level mode cannot rely
|
||||
on the decomposition have been applied. It's easy to accidentally never apply
|
||||
the decomposition, resulting in an operator showing up in a graph that
|
||||
is unexpected.
|
||||
|
||||
Why do we need to AVOID applying the decomposition when autograd wouldn't
|
||||
have decomposed? If autograd doesn't decompose, this means in eager mode
|
||||
we would have run the fused kernel. It must be possible to trace this
|
||||
fused kernel directly into the graph for fidelity with eager (NB: a user
|
||||
has the option of then further decomposing at proxy tensor mode via
|
||||
decomposition table, but we must preserve it to proxy mode to have the
|
||||
choice.)
|
||||
|
||||
Why does functionalization need to also perform the test here? This is
|
||||
because some CompositeImplicitAutograd decompositions are not functional.
|
||||
If we are eventually going to decompose, we need to do this while we can
|
||||
still turn functionalization back on, so those decompositions get functionalized.
|
||||
So an early decomposition in functionalization may still be necessary. Note that
|
||||
if proxy tensor decomposition process could turn functionalization back on, this
|
||||
wouldn't be necessary, and maybe that is a useful thing to do anyway because
|
||||
the decomposition table is user specified and a user could violate the functional
|
||||
decomp requirement with a bad decomp. If this happened, then you could always
|
||||
pass through functionalization.
|
||||
"""
|
||||
has_backend_registration = False
|
||||
for a in flat_args:
|
||||
if isinstance(a, torch.Tensor):
|
||||
backend_key = torch._C._parse_dispatch_key(
|
||||
torch._C._dispatch_key_for_device(a.device.type)
|
||||
)
|
||||
assert backend_key is not None
|
||||
# TODO: use func.has_kernel_for_dispatch_key(backend_key)
|
||||
# but this one checks py_impl and CompositeImplicitAutograd
|
||||
# incorrectly shows up as has backend reg here
|
||||
has_backend_registration = torch._C._dispatch_has_kernel_for_dispatch_key(
|
||||
func.name(), backend_key
|
||||
)
|
||||
|
||||
# in theory we should take all backend keys and take the highest priority one
|
||||
# to properly mimic the dispatcher,
|
||||
# this just grabs the first tensor and takes its device key
|
||||
break
|
||||
return not has_backend_registration
|
||||
|
||||
|
||||
# See NOTE[SchemaInfo int_tags] above.
|
||||
_TORCH_TAG_INPLACE_VIEW_INT = int(torch.Tag.inplace_view) # type: ignore[call-overload]
|
||||
|
||||
|
Reference in New Issue
Block a user