From e41a0b33ec464100d35e707e118673eda472a90c Mon Sep 17 00:00:00 2001 From: soulitzer Date: Mon, 2 Dec 2024 11:46:44 -0800 Subject: [PATCH] Allow Fakified subclass to have different device for inner and outer tensor (#141839) Previously if a wrapper tensor subclass is fakified, the inner tensors would end up having the same device as the outer tensor. This PR makes it so that inner and outer tensors can have different devices. See OffloadTensor PR https://github.com/pytorch/pytorch/pull/141840/files#diff-3bc0cf540b694f4ec0a3749f78b047456657a53a5657e495ffb68e5970c5fdaaR1955 for an application. A simpler test has been added in this PR. This is technically bc-breaking because now the callback passed to MetaConverter needs to accept an extra argument, but no one external should be using this anyway? Pull Request resolved: https://github.com/pytorch/pytorch/pull/141839 Approved by: https://github.com/bdhirsh ghstack dependencies: #141166 --- test/test_fake_tensor.py | 54 ++++++++++++++++++++++++++++++++ torch/_subclasses/fake_tensor.py | 16 +++++++--- torch/_subclasses/meta_utils.py | 27 ++++++++++------ 3 files changed, 82 insertions(+), 15 deletions(-) diff --git a/test/test_fake_tensor.py b/test/test_fake_tensor.py index fee53c03601d..a894f16b757a 100644 --- a/test/test_fake_tensor.py +++ b/test/test_fake_tensor.py @@ -1951,6 +1951,60 @@ class FakeTensorDispatchCache(TestCase): extract_tensor_metadata(res4), ) + + @unittest.skipIf(not RUN_CUDA, "requires cuda") + def test_wrapper_tensor_subclass_different_device(self): + class DifferentDeviceTensor(torch.Tensor): + @staticmethod + def __new__(cls, a): + kwargs = {} + kwargs["strides"] = a.stride() + kwargs["storage_offset"] = a.storage_offset() + kwargs["device"] = torch.device("cpu") + kwargs["layout"] = a.layout + kwargs["requires_grad"] = a.requires_grad + kwargs["dtype"] = a.dtype + out = torch.Tensor._make_wrapper_subclass(cls, a.size(), **kwargs) + return out + + def __init__(self, a): + self.inner_tensor = a + + def __repr__(self): + return f"DifferentDeviceTensor({repr(self.inner_tensor)})" + + def __tensor_flatten__(self): + return ["inner_tensor"], None + + @staticmethod + def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride): + assert meta is None + return DifferentDeviceTensor(inner_tensors["inner_tensor"]) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + if kwargs is None: + kwargs = {} + args = pytree.tree_map_only(DifferentDeviceTensor, lambda x: x.inner_tensor, args) + kwargs = pytree.tree_map_only(DifferentDeviceTensor, lambda x: x.inner_tensor, kwargs) + # Returns unwrapped tensor + return func(*args, **kwargs) + + a = torch.ones(2, 2, 768, device="cuda") + wrapped_a = DifferentDeviceTensor(a) + + # Outer Tensor is on cpu, inner is on cuda + self.assertTrue(wrapped_a.is_cpu) + self.assertFalse(wrapped_a.inner_tensor.is_cpu) + + with FakeTensorMode() as fake_mode: + fake_wrapped_a = fake_mode.from_tensor(wrapped_a) + + self.assertTrue(fake_wrapped_a.is_cpu) + assert isinstance(fake_wrapped_a, DifferentDeviceTensor) + self.assertFalse(fake_wrapped_a.inner_tensor.is_cpu) + + def test_cache_tuple_outputs(self): """ Test to check that ops with tuple outputs work. diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py index 090a18ad9f55..1f6ad75cecdc 100644 --- a/torch/_subclasses/fake_tensor.py +++ b/torch/_subclasses/fake_tensor.py @@ -354,14 +354,20 @@ class FakeTensorConverter: maybe_memo = self._get_memo(t) if maybe_memo is not None: return maybe_memo - existing_device = t.device # not yet supported in metatensors if t.is_quantized: raise UnsupportedFakeTensorException("quantized nyi in meta tensors") if type(t) is torch.nn.Parameter: assert not make_constant - def mk_fake_tensor(make_meta_t: Callable[[], object]) -> FakeTensor: + constant = t if make_constant else None + + # This callback is used by both subclass and inner tensors. Require the + # caller to explicitly specify the device in case outer and inner tensors + # have different devices. + def mk_fake_tensor( + make_meta_t: Callable[[], object], device: torch.device + ) -> FakeTensor: # NB: don't use in_kernel_invocation_manager. to # ensure FakeTensor can internally do constant computation # as necessary. Invocation manager is "more correct" as @@ -373,16 +379,16 @@ class FakeTensorConverter: return FakeTensor( fake_mode, make_meta_t(), - existing_device, + device, # TODO: callback might be used in recursive contexts, in # which case using t is wrong! BUG! - constant=t if make_constant else None, + constant=constant, ) out = self.meta_converter( t, shape_env=shape_env, - callback=mk_fake_tensor, + callback=mk_fake_tensor, # type: ignore[arg-type] source=source, symbolic_context=symbolic_context, trace=trace, diff --git a/torch/_subclasses/meta_utils.py b/torch/_subclasses/meta_utils.py index 6815f3e5ef9a..cea2067db0b8 100644 --- a/torch/_subclasses/meta_utils.py +++ b/torch/_subclasses/meta_utils.py @@ -2,6 +2,7 @@ from __future__ import annotations import contextlib import dataclasses +import functools import typing import warnings import weakref @@ -709,7 +710,7 @@ class MetaConverter(Generic[_TensorT]): def meta_storage( self, s: MetaStorageDesc, - callback: Callable[[Callable[[], torch.Tensor]], _TensorT], + callback: Callable[[Callable], _TensorT], ) -> torch.UntypedStorage: # If we are fakeifying a tensor that has a secretly-zero-sized storage, # Need to make sure to resize the meta storage too. @@ -734,7 +735,9 @@ class MetaConverter(Generic[_TensorT]): return typing.cast(_TensorT, t) @classmethod - def _identity_callable(cls, t: Callable[[], torch.Tensor]) -> _TensorT: + def _identity_callable( + cls, t: Callable[[], torch.Tensor], device: torch.device + ) -> _TensorT: return cls._checked_cast_tensor_t(t()) @classmethod @@ -756,10 +759,11 @@ class MetaConverter(Generic[_TensorT]): self, t: MetaTensorDesc, shape_env: Optional[ShapeEnv], - callback: Callable[[Callable[[], torch.Tensor]], _TensorT], + callback: Callable[[Callable], _TensorT], source: Optional[Source], symbolic_context: Optional[SymbolicContext], ) -> _TensorT: + callback = functools.partial(callback, device=t.device) # type: ignore[call-arg] if source is None: from torch._dynamo.source import ConstantSource @@ -905,7 +909,7 @@ class MetaConverter(Generic[_TensorT]): symbolic_context: Optional[ torch.fx.experimental.symbolic_shapes.SymbolicContext ], - callback: Callable[[Callable[[], torch.Tensor]], _TensorT], + callback: Callable[[Callable], _TensorT], source: torch._guards.Source, ) -> _TensorT: # We are hitting plain meta_desc tensor so actually @@ -933,12 +937,15 @@ class MetaConverter(Generic[_TensorT]): ) current_source = AttrSource(source, attr) + inner_callback = functools.partial( + callback, device=meta_tensor_desc.device # type: ignore[call-arg] + ) new_empty_tensor = _empty_create_subclass( meta_tensor_desc, meta_tensor_desc.size, meta_tensor_desc.stride, current_context, - callback, + inner_callback, current_source, ) inner_tensors[attr] = new_empty_tensor @@ -975,7 +982,7 @@ class MetaConverter(Generic[_TensorT]): t: MetaTensorDesc, source: torch._guards.Source, shape_env: Optional[torch.fx.experimental.symbolic_shapes.ShapeEnv], - callback: Callable[[Callable[[], torch.Tensor]], _TensorT], + callback: Callable[[Callable], _TensorT], ) -> torch.fx.experimental.symbolic_shapes.SymbolicContext: from torch._dynamo.source import AttrSource from torch.fx.experimental.symbolic_shapes import ( @@ -1137,7 +1144,7 @@ class MetaConverter(Generic[_TensorT]): shape_env: Optional[ torch.fx.experimental.symbolic_shapes.ShapeEnv ] = shape_env, - callback: Callable[[Callable[[], torch.Tensor]], _TensorT] = callback, # type: ignore[assignment] + callback: Callable[[Callable], _TensorT] = callback, # type: ignore[assignment] ) -> torch.Tensor: # It's possible to close over an undefined tensor (e.g. NJT's lengths). if visited_t is None: @@ -1723,7 +1730,7 @@ class MetaConverter(Generic[_TensorT]): t: torch.Tensor, shape_env: Optional[ShapeEnv] = None, *, - callback: Optional[Callable[[Callable[[], torch.Tensor]], _TensorT]] = None, + callback: Optional[Callable[[Callable], _TensorT]] = None, source: Optional[Source] = None, symbolic_context: Optional[SymbolicContext] = None, # Controls whether or not we should dump the tensor metadata to structured logs @@ -1731,9 +1738,9 @@ class MetaConverter(Generic[_TensorT]): # we don't want to dump info again from AOTAutograd, it is redundant. trace: bool = True, ) -> _TensorT: - callback_: Callable[[Callable[[], torch.Tensor]], _TensorT] + callback_: Callable[[Callable], _TensorT] if callback is None: - callback_ = self._identity_callable + callback_ = self._identity_callable # type: ignore[assignment] else: callback_ = callback # TODO: zero tensors? We appear to have eliminated them by