mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-13 20:35:10 +08:00
Update
[ghstack-poisoned]
This commit is contained in:
@ -562,7 +562,7 @@ _autograd_backward_strict_mode_conditional_banned_ops = [
|
||||
|
||||
# Enables caching of dispatches to fake tensors.
|
||||
fake_tensor_cache_enabled = (
|
||||
os.environ.get("TORCH_FAKE_TENSOR_DISPATCH_CACHE", "1") == "1"
|
||||
os.environ.get("TORCH_FAKE_TENSOR_DISPATCH_CACHE", "0") == "1"
|
||||
)
|
||||
|
||||
# Enables cross checking between the fake tensor cache and dispatch.
|
||||
|
||||
@ -1332,8 +1332,7 @@ def can_safe_cast_to(*, cast_to: torch.dtype, cast_from: torch.dtype) -> bool:
|
||||
|
||||
def check_same_dtype(*args):
|
||||
"""
|
||||
Checks that all Tensors in args have the same device and that all Numbers have the
|
||||
same corresponding Python type.
|
||||
Checks that all Tensors in args have the same dtype.
|
||||
|
||||
Raises a RuntimeError when:
|
||||
- args contains an object whose type is not Tensor or Number
|
||||
@ -1346,51 +1345,25 @@ def check_same_dtype(*args):
|
||||
scalar_type = None
|
||||
|
||||
for arg in args:
|
||||
if isinstance(arg, Number):
|
||||
# Scalar type checking is disabled (and may be removed in the future)
|
||||
continue
|
||||
# if scalar_type is None:
|
||||
# scalar_type = type(arg)
|
||||
|
||||
# if scalar_type is not type(arg):
|
||||
# msg = (
|
||||
# "Scalar of type "
|
||||
# + str(type(arg))
|
||||
# + " is not the expected type of "
|
||||
# + str(scalar_type)
|
||||
# + "!"
|
||||
# )
|
||||
# raise RuntimeError(msg)
|
||||
elif isinstance(arg, TensorLike):
|
||||
if isinstance(arg, TensorLike):
|
||||
if full_dtype is None:
|
||||
full_dtype = arg.dtype
|
||||
if scalar_type is None:
|
||||
scalar_type = dtype_to_type(arg.dtype)
|
||||
|
||||
if full_dtype is not arg.dtype:
|
||||
msg = (
|
||||
"Tensor with dtype "
|
||||
+ str(arg.dtype)
|
||||
+ " is not the expected dtype of "
|
||||
+ str(full_dtype)
|
||||
+ "!"
|
||||
)
|
||||
msg = f"Tensor with dtype {arg.dtype} is not the expected dtype of {full_dtype}!"
|
||||
raise RuntimeError(msg)
|
||||
|
||||
arg_type = dtype_to_type(arg.dtype)
|
||||
if arg_type is not scalar_type:
|
||||
msg = (
|
||||
"Tensor with corresponding Python type "
|
||||
+ str(arg_type)
|
||||
+ " is not the expected type of "
|
||||
+ str(scalar_type)
|
||||
+ "!"
|
||||
)
|
||||
msg = f"Tensor with corresponding Python type {arg_type} is not the expected type of {scalar_type}!"
|
||||
raise RuntimeError(msg)
|
||||
elif isinstance(arg, Number):
|
||||
# Can't check a non-tensor dtype.
|
||||
continue
|
||||
else:
|
||||
msg = (
|
||||
"Unexpected type when checking for same dtype, " + str(type(arg)) + "!"
|
||||
)
|
||||
msg = f"Unexpected type when checking for same dtype, {type(arg)}!"
|
||||
raise RuntimeError(msg)
|
||||
|
||||
|
||||
|
||||
@ -15,12 +15,14 @@ import typing
|
||||
import weakref
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from functools import cached_property
|
||||
from typing import Any, Callable, cast, Literal, Optional, TYPE_CHECKING, TypeVar, Union
|
||||
from typing_extensions import Self, TypeGuard
|
||||
from weakref import ReferenceType
|
||||
|
||||
import torch
|
||||
import torch._library.utils as library_utils
|
||||
import torch.utils._cxx_pytree as pytree
|
||||
from torch import SymBool, SymFloat, SymInt, Tensor
|
||||
from torch._C._functorch import is_functorch_wrapped_tensor, is_legacy_batchedtensor
|
||||
from torch._library.fake_class_registry import FakeScriptObject
|
||||
@ -41,12 +43,19 @@ from torch.multiprocessing.reductions import StorageWeakRef
|
||||
from torch.overrides import TorchFunctionMode
|
||||
from torch.types import IntLikeType, py_sym_types
|
||||
from torch.utils._backport_slots import dataclass_slots
|
||||
from torch.utils._cxx_pytree import (
|
||||
KeyPath,
|
||||
keystr,
|
||||
PyTree,
|
||||
tree_map,
|
||||
tree_map_,
|
||||
TreeSpec,
|
||||
)
|
||||
from torch.utils._mode_utils import no_dispatch
|
||||
from torch.utils._python_dispatch import (
|
||||
is_traceable_wrapper_subclass,
|
||||
TorchDispatchMode,
|
||||
)
|
||||
from torch.utils._pytree import KeyPath, keystr, PyTree, tree_map, tree_map_, TreeSpec
|
||||
from torch.utils._stats import count
|
||||
from torch.utils._traceback import CapturedTraceback
|
||||
|
||||
@ -77,7 +86,6 @@ except ValueError as e:
|
||||
|
||||
DimList = list
|
||||
|
||||
pytree = torch.utils._pytree
|
||||
T = TypeVar("T")
|
||||
|
||||
aten = torch._ops.ops.aten
|
||||
@ -147,6 +155,17 @@ def ordered_set(*items: T) -> dict[T, Literal[True]]:
|
||||
return dict.fromkeys(items, True)
|
||||
|
||||
|
||||
# list of ops which can have args(tensor/tensorList) in mixed device
|
||||
_MIXED_DEVICE_FNS = ordered_set(
|
||||
aten._foreach_copy.default,
|
||||
)
|
||||
|
||||
# list of ops not using zero dim cpu tensor logic to align with the eager mode.
|
||||
_BYPASS_ZERO_DIM_CPU_TENSOR_CHECK_FNS = ordered_set(
|
||||
aten.nextafter.default,
|
||||
)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def unset_fake_temporarily() -> Generator[Optional[TorchDispatchMode], None, None]:
|
||||
old = torch._C._unset_dispatch_mode(torch._C._TorchDispatchModeKey.FAKE)
|
||||
@ -884,16 +903,6 @@ class FakeTensor(Tensor):
|
||||
has_scalar_only_inputs = False
|
||||
is_cpu_zero_dim = None
|
||||
|
||||
# list of ops which can have args(tensor/tensorList) in mixed device
|
||||
mixed_device_fns = ordered_set(
|
||||
aten._foreach_copy.default,
|
||||
)
|
||||
|
||||
# list of ops not using zero dim cpu tensor logic to align with the eager mode.
|
||||
bypass_zero_dim_cpu_tensor_check_ops = ordered_set(
|
||||
aten.nextafter.default,
|
||||
)
|
||||
|
||||
def check_cpu_device(device: torch.device) -> bool:
|
||||
return device.type == "cpu"
|
||||
|
||||
@ -918,7 +927,7 @@ class FakeTensor(Tensor):
|
||||
return
|
||||
|
||||
is_bypass_zero_dim_cpu_tensor_check_op = (
|
||||
func in bypass_zero_dim_cpu_tensor_check_ops
|
||||
func in _BYPASS_ZERO_DIM_CPU_TENSOR_CHECK_FNS
|
||||
)
|
||||
|
||||
# mismatching devices !
|
||||
@ -936,7 +945,7 @@ class FakeTensor(Tensor):
|
||||
# on different devices for ex. _foreach_copy, and one of the
|
||||
# device must be cpu in this case we will return from here without
|
||||
# throwing an error
|
||||
if func in mixed_device_fns:
|
||||
if func in _MIXED_DEVICE_FNS:
|
||||
if any(map(check_cpu_device, (common_device, t.device))):
|
||||
return
|
||||
|
||||
@ -1343,7 +1352,7 @@ class FakeTensorMode(TorchDispatchMode):
|
||||
# - We change the torch.tensor ctor contract to never materialize
|
||||
# tensors on device
|
||||
# (see NOTE: [torch.tensor, lift_fresh, and device movement])
|
||||
@property
|
||||
@cached_property
|
||||
def avoid_device_init(self) -> bool:
|
||||
if torch.xpu._is_compiled():
|
||||
assert not torch.cuda._is_compiled()
|
||||
@ -1380,8 +1389,6 @@ class FakeTensorMode(TorchDispatchMode):
|
||||
|
||||
# No-op if FakeTensorMode is already in use
|
||||
def __enter__(self) -> Self:
|
||||
import torch.nested._internal.nested_tensor
|
||||
|
||||
prev_only_lift_cpu_tensors = None
|
||||
if self.avoid_device_init:
|
||||
# See NOTE: [torch.tensor, lift_fresh, and device movement]
|
||||
|
||||
@ -446,6 +446,11 @@ def tree_leaves(
|
||||
)
|
||||
|
||||
|
||||
def arg_tree_leaves(*args: PyTree, **kwargs: PyTree) -> list[Any]:
|
||||
"""Get a flat list of arguments to this function."""
|
||||
return tree_leaves((args, kwargs))
|
||||
|
||||
|
||||
def tree_structure(
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
|
||||
Reference in New Issue
Block a user