diff --git a/test/test_fake_tensor.py b/test/test_fake_tensor.py index fee53c03601d..a894f16b757a 100644 --- a/test/test_fake_tensor.py +++ b/test/test_fake_tensor.py @@ -1951,6 +1951,60 @@ class FakeTensorDispatchCache(TestCase): extract_tensor_metadata(res4), ) + + @unittest.skipIf(not RUN_CUDA, "requires cuda") + def test_wrapper_tensor_subclass_different_device(self): + class DifferentDeviceTensor(torch.Tensor): + @staticmethod + def __new__(cls, a): + kwargs = {} + kwargs["strides"] = a.stride() + kwargs["storage_offset"] = a.storage_offset() + kwargs["device"] = torch.device("cpu") + kwargs["layout"] = a.layout + kwargs["requires_grad"] = a.requires_grad + kwargs["dtype"] = a.dtype + out = torch.Tensor._make_wrapper_subclass(cls, a.size(), **kwargs) + return out + + def __init__(self, a): + self.inner_tensor = a + + def __repr__(self): + return f"DifferentDeviceTensor({repr(self.inner_tensor)})" + + def __tensor_flatten__(self): + return ["inner_tensor"], None + + @staticmethod + def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride): + assert meta is None + return DifferentDeviceTensor(inner_tensors["inner_tensor"]) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + if kwargs is None: + kwargs = {} + args = pytree.tree_map_only(DifferentDeviceTensor, lambda x: x.inner_tensor, args) + kwargs = pytree.tree_map_only(DifferentDeviceTensor, lambda x: x.inner_tensor, kwargs) + # Returns unwrapped tensor + return func(*args, **kwargs) + + a = torch.ones(2, 2, 768, device="cuda") + wrapped_a = DifferentDeviceTensor(a) + + # Outer Tensor is on cpu, inner is on cuda + self.assertTrue(wrapped_a.is_cpu) + self.assertFalse(wrapped_a.inner_tensor.is_cpu) + + with FakeTensorMode() as fake_mode: + fake_wrapped_a = fake_mode.from_tensor(wrapped_a) + + self.assertTrue(fake_wrapped_a.is_cpu) + assert isinstance(fake_wrapped_a, DifferentDeviceTensor) + self.assertFalse(fake_wrapped_a.inner_tensor.is_cpu) + + def test_cache_tuple_outputs(self): """ Test to check that ops with tuple outputs work. diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py index 090a18ad9f55..1f6ad75cecdc 100644 --- a/torch/_subclasses/fake_tensor.py +++ b/torch/_subclasses/fake_tensor.py @@ -354,14 +354,20 @@ class FakeTensorConverter: maybe_memo = self._get_memo(t) if maybe_memo is not None: return maybe_memo - existing_device = t.device # not yet supported in metatensors if t.is_quantized: raise UnsupportedFakeTensorException("quantized nyi in meta tensors") if type(t) is torch.nn.Parameter: assert not make_constant - def mk_fake_tensor(make_meta_t: Callable[[], object]) -> FakeTensor: + constant = t if make_constant else None + + # This callback is used by both subclass and inner tensors. Require the + # caller to explicitly specify the device in case outer and inner tensors + # have different devices. + def mk_fake_tensor( + make_meta_t: Callable[[], object], device: torch.device + ) -> FakeTensor: # NB: don't use in_kernel_invocation_manager. to # ensure FakeTensor can internally do constant computation # as necessary. Invocation manager is "more correct" as @@ -373,16 +379,16 @@ class FakeTensorConverter: return FakeTensor( fake_mode, make_meta_t(), - existing_device, + device, # TODO: callback might be used in recursive contexts, in # which case using t is wrong! BUG! - constant=t if make_constant else None, + constant=constant, ) out = self.meta_converter( t, shape_env=shape_env, - callback=mk_fake_tensor, + callback=mk_fake_tensor, # type: ignore[arg-type] source=source, symbolic_context=symbolic_context, trace=trace, diff --git a/torch/_subclasses/meta_utils.py b/torch/_subclasses/meta_utils.py index 6815f3e5ef9a..cea2067db0b8 100644 --- a/torch/_subclasses/meta_utils.py +++ b/torch/_subclasses/meta_utils.py @@ -2,6 +2,7 @@ from __future__ import annotations import contextlib import dataclasses +import functools import typing import warnings import weakref @@ -709,7 +710,7 @@ class MetaConverter(Generic[_TensorT]): def meta_storage( self, s: MetaStorageDesc, - callback: Callable[[Callable[[], torch.Tensor]], _TensorT], + callback: Callable[[Callable], _TensorT], ) -> torch.UntypedStorage: # If we are fakeifying a tensor that has a secretly-zero-sized storage, # Need to make sure to resize the meta storage too. @@ -734,7 +735,9 @@ class MetaConverter(Generic[_TensorT]): return typing.cast(_TensorT, t) @classmethod - def _identity_callable(cls, t: Callable[[], torch.Tensor]) -> _TensorT: + def _identity_callable( + cls, t: Callable[[], torch.Tensor], device: torch.device + ) -> _TensorT: return cls._checked_cast_tensor_t(t()) @classmethod @@ -756,10 +759,11 @@ class MetaConverter(Generic[_TensorT]): self, t: MetaTensorDesc, shape_env: Optional[ShapeEnv], - callback: Callable[[Callable[[], torch.Tensor]], _TensorT], + callback: Callable[[Callable], _TensorT], source: Optional[Source], symbolic_context: Optional[SymbolicContext], ) -> _TensorT: + callback = functools.partial(callback, device=t.device) # type: ignore[call-arg] if source is None: from torch._dynamo.source import ConstantSource @@ -905,7 +909,7 @@ class MetaConverter(Generic[_TensorT]): symbolic_context: Optional[ torch.fx.experimental.symbolic_shapes.SymbolicContext ], - callback: Callable[[Callable[[], torch.Tensor]], _TensorT], + callback: Callable[[Callable], _TensorT], source: torch._guards.Source, ) -> _TensorT: # We are hitting plain meta_desc tensor so actually @@ -933,12 +937,15 @@ class MetaConverter(Generic[_TensorT]): ) current_source = AttrSource(source, attr) + inner_callback = functools.partial( + callback, device=meta_tensor_desc.device # type: ignore[call-arg] + ) new_empty_tensor = _empty_create_subclass( meta_tensor_desc, meta_tensor_desc.size, meta_tensor_desc.stride, current_context, - callback, + inner_callback, current_source, ) inner_tensors[attr] = new_empty_tensor @@ -975,7 +982,7 @@ class MetaConverter(Generic[_TensorT]): t: MetaTensorDesc, source: torch._guards.Source, shape_env: Optional[torch.fx.experimental.symbolic_shapes.ShapeEnv], - callback: Callable[[Callable[[], torch.Tensor]], _TensorT], + callback: Callable[[Callable], _TensorT], ) -> torch.fx.experimental.symbolic_shapes.SymbolicContext: from torch._dynamo.source import AttrSource from torch.fx.experimental.symbolic_shapes import ( @@ -1137,7 +1144,7 @@ class MetaConverter(Generic[_TensorT]): shape_env: Optional[ torch.fx.experimental.symbolic_shapes.ShapeEnv ] = shape_env, - callback: Callable[[Callable[[], torch.Tensor]], _TensorT] = callback, # type: ignore[assignment] + callback: Callable[[Callable], _TensorT] = callback, # type: ignore[assignment] ) -> torch.Tensor: # It's possible to close over an undefined tensor (e.g. NJT's lengths). if visited_t is None: @@ -1723,7 +1730,7 @@ class MetaConverter(Generic[_TensorT]): t: torch.Tensor, shape_env: Optional[ShapeEnv] = None, *, - callback: Optional[Callable[[Callable[[], torch.Tensor]], _TensorT]] = None, + callback: Optional[Callable[[Callable], _TensorT]] = None, source: Optional[Source] = None, symbolic_context: Optional[SymbolicContext] = None, # Controls whether or not we should dump the tensor metadata to structured logs @@ -1731,9 +1738,9 @@ class MetaConverter(Generic[_TensorT]): # we don't want to dump info again from AOTAutograd, it is redundant. trace: bool = True, ) -> _TensorT: - callback_: Callable[[Callable[[], torch.Tensor]], _TensorT] + callback_: Callable[[Callable], _TensorT] if callback is None: - callback_ = self._identity_callable + callback_ = self._identity_callable # type: ignore[assignment] else: callback_ = callback # TODO: zero tensors? We appear to have eliminated them by