[ghstack-poisoned]
This commit is contained in:
Benjamin Glass
2025-08-15 19:00:29 +00:00
parent 44b8ff2055
commit 7969c080f6
4 changed files with 38 additions and 53 deletions

View File

@ -562,7 +562,7 @@ _autograd_backward_strict_mode_conditional_banned_ops = [
# Enables caching of dispatches to fake tensors.
fake_tensor_cache_enabled = (
os.environ.get("TORCH_FAKE_TENSOR_DISPATCH_CACHE", "1") == "1"
os.environ.get("TORCH_FAKE_TENSOR_DISPATCH_CACHE", "0") == "1"
)
# Enables cross checking between the fake tensor cache and dispatch.

View File

@ -1332,8 +1332,7 @@ def can_safe_cast_to(*, cast_to: torch.dtype, cast_from: torch.dtype) -> bool:
def check_same_dtype(*args):
"""
Checks that all Tensors in args have the same device and that all Numbers have the
same corresponding Python type.
Checks that all Tensors in args have the same dtype.
Raises a RuntimeError when:
- args contains an object whose type is not Tensor or Number
@ -1346,51 +1345,25 @@ def check_same_dtype(*args):
scalar_type = None
for arg in args:
if isinstance(arg, Number):
# Scalar type checking is disabled (and may be removed in the future)
continue
# if scalar_type is None:
# scalar_type = type(arg)
# if scalar_type is not type(arg):
# msg = (
# "Scalar of type "
# + str(type(arg))
# + " is not the expected type of "
# + str(scalar_type)
# + "!"
# )
# raise RuntimeError(msg)
elif isinstance(arg, TensorLike):
if isinstance(arg, TensorLike):
if full_dtype is None:
full_dtype = arg.dtype
if scalar_type is None:
scalar_type = dtype_to_type(arg.dtype)
if full_dtype is not arg.dtype:
msg = (
"Tensor with dtype "
+ str(arg.dtype)
+ " is not the expected dtype of "
+ str(full_dtype)
+ "!"
)
msg = f"Tensor with dtype {arg.dtype} is not the expected dtype of {full_dtype}!"
raise RuntimeError(msg)
arg_type = dtype_to_type(arg.dtype)
if arg_type is not scalar_type:
msg = (
"Tensor with corresponding Python type "
+ str(arg_type)
+ " is not the expected type of "
+ str(scalar_type)
+ "!"
)
msg = f"Tensor with corresponding Python type {arg_type} is not the expected type of {scalar_type}!"
raise RuntimeError(msg)
elif isinstance(arg, Number):
# Can't check a non-tensor dtype.
continue
else:
msg = (
"Unexpected type when checking for same dtype, " + str(type(arg)) + "!"
)
msg = f"Unexpected type when checking for same dtype, {type(arg)}!"
raise RuntimeError(msg)

View File

@ -15,12 +15,14 @@ import typing
import weakref
from collections import defaultdict
from dataclasses import dataclass
from functools import cached_property
from typing import Any, Callable, cast, Literal, Optional, TYPE_CHECKING, TypeVar, Union
from typing_extensions import Self, TypeGuard
from weakref import ReferenceType
import torch
import torch._library.utils as library_utils
import torch.utils._cxx_pytree as pytree
from torch import SymBool, SymFloat, SymInt, Tensor
from torch._C._functorch import is_functorch_wrapped_tensor, is_legacy_batchedtensor
from torch._library.fake_class_registry import FakeScriptObject
@ -41,12 +43,19 @@ from torch.multiprocessing.reductions import StorageWeakRef
from torch.overrides import TorchFunctionMode
from torch.types import IntLikeType, py_sym_types
from torch.utils._backport_slots import dataclass_slots
from torch.utils._cxx_pytree import (
KeyPath,
keystr,
PyTree,
tree_map,
tree_map_,
TreeSpec,
)
from torch.utils._mode_utils import no_dispatch
from torch.utils._python_dispatch import (
is_traceable_wrapper_subclass,
TorchDispatchMode,
)
from torch.utils._pytree import KeyPath, keystr, PyTree, tree_map, tree_map_, TreeSpec
from torch.utils._stats import count
from torch.utils._traceback import CapturedTraceback
@ -77,7 +86,6 @@ except ValueError as e:
DimList = list
pytree = torch.utils._pytree
T = TypeVar("T")
aten = torch._ops.ops.aten
@ -147,6 +155,17 @@ def ordered_set(*items: T) -> dict[T, Literal[True]]:
return dict.fromkeys(items, True)
# list of ops which can have args(tensor/tensorList) in mixed device
_MIXED_DEVICE_FNS = ordered_set(
aten._foreach_copy.default,
)
# list of ops not using zero dim cpu tensor logic to align with the eager mode.
_BYPASS_ZERO_DIM_CPU_TENSOR_CHECK_FNS = ordered_set(
aten.nextafter.default,
)
@contextlib.contextmanager
def unset_fake_temporarily() -> Generator[Optional[TorchDispatchMode], None, None]:
old = torch._C._unset_dispatch_mode(torch._C._TorchDispatchModeKey.FAKE)
@ -884,16 +903,6 @@ class FakeTensor(Tensor):
has_scalar_only_inputs = False
is_cpu_zero_dim = None
# list of ops which can have args(tensor/tensorList) in mixed device
mixed_device_fns = ordered_set(
aten._foreach_copy.default,
)
# list of ops not using zero dim cpu tensor logic to align with the eager mode.
bypass_zero_dim_cpu_tensor_check_ops = ordered_set(
aten.nextafter.default,
)
def check_cpu_device(device: torch.device) -> bool:
return device.type == "cpu"
@ -918,7 +927,7 @@ class FakeTensor(Tensor):
return
is_bypass_zero_dim_cpu_tensor_check_op = (
func in bypass_zero_dim_cpu_tensor_check_ops
func in _BYPASS_ZERO_DIM_CPU_TENSOR_CHECK_FNS
)
# mismatching devices !
@ -936,7 +945,7 @@ class FakeTensor(Tensor):
# on different devices for ex. _foreach_copy, and one of the
# device must be cpu in this case we will return from here without
# throwing an error
if func in mixed_device_fns:
if func in _MIXED_DEVICE_FNS:
if any(map(check_cpu_device, (common_device, t.device))):
return
@ -1343,7 +1352,7 @@ class FakeTensorMode(TorchDispatchMode):
# - We change the torch.tensor ctor contract to never materialize
# tensors on device
# (see NOTE: [torch.tensor, lift_fresh, and device movement])
@property
@cached_property
def avoid_device_init(self) -> bool:
if torch.xpu._is_compiled():
assert not torch.cuda._is_compiled()
@ -1380,8 +1389,6 @@ class FakeTensorMode(TorchDispatchMode):
# No-op if FakeTensorMode is already in use
def __enter__(self) -> Self:
import torch.nested._internal.nested_tensor
prev_only_lift_cpu_tensors = None
if self.avoid_device_init:
# See NOTE: [torch.tensor, lift_fresh, and device movement]

View File

@ -446,6 +446,11 @@ def tree_leaves(
)
def arg_tree_leaves(*args: PyTree, **kwargs: PyTree) -> list[Any]:
"""Get a flat list of arguments to this function."""
return tree_leaves((args, kwargs))
def tree_structure(
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,