Allow Fakified subclass to have different device for inner and outer tensor (#141839)

Previously if a wrapper tensor subclass is fakified, the inner tensors would end up having the same device as the outer tensor. This PR makes it so that inner and outer tensors can have different devices.

See OffloadTensor PR https://github.com/pytorch/pytorch/pull/141840/files#diff-3bc0cf540b694f4ec0a3749f78b047456657a53a5657e495ffb68e5970c5fdaaR1955 for an application. A simpler test has been added in this PR.

This is technically bc-breaking because now the callback passed to MetaConverter needs to accept an extra argument, but no one external should be using this anyway?
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141839
Approved by: https://github.com/bdhirsh
ghstack dependencies: #141166
This commit is contained in:
soulitzer
2024-12-02 11:46:44 -08:00
committed by PyTorch MergeBot
parent 9830e7b1e4
commit e41a0b33ec
3 changed files with 82 additions and 15 deletions

View File

@ -1951,6 +1951,60 @@ class FakeTensorDispatchCache(TestCase):
extract_tensor_metadata(res4), extract_tensor_metadata(res4),
) )
@unittest.skipIf(not RUN_CUDA, "requires cuda")
def test_wrapper_tensor_subclass_different_device(self):
class DifferentDeviceTensor(torch.Tensor):
@staticmethod
def __new__(cls, a):
kwargs = {}
kwargs["strides"] = a.stride()
kwargs["storage_offset"] = a.storage_offset()
kwargs["device"] = torch.device("cpu")
kwargs["layout"] = a.layout
kwargs["requires_grad"] = a.requires_grad
kwargs["dtype"] = a.dtype
out = torch.Tensor._make_wrapper_subclass(cls, a.size(), **kwargs)
return out
def __init__(self, a):
self.inner_tensor = a
def __repr__(self):
return f"DifferentDeviceTensor({repr(self.inner_tensor)})"
def __tensor_flatten__(self):
return ["inner_tensor"], None
@staticmethod
def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride):
assert meta is None
return DifferentDeviceTensor(inner_tensors["inner_tensor"])
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
if kwargs is None:
kwargs = {}
args = pytree.tree_map_only(DifferentDeviceTensor, lambda x: x.inner_tensor, args)
kwargs = pytree.tree_map_only(DifferentDeviceTensor, lambda x: x.inner_tensor, kwargs)
# Returns unwrapped tensor
return func(*args, **kwargs)
a = torch.ones(2, 2, 768, device="cuda")
wrapped_a = DifferentDeviceTensor(a)
# Outer Tensor is on cpu, inner is on cuda
self.assertTrue(wrapped_a.is_cpu)
self.assertFalse(wrapped_a.inner_tensor.is_cpu)
with FakeTensorMode() as fake_mode:
fake_wrapped_a = fake_mode.from_tensor(wrapped_a)
self.assertTrue(fake_wrapped_a.is_cpu)
assert isinstance(fake_wrapped_a, DifferentDeviceTensor)
self.assertFalse(fake_wrapped_a.inner_tensor.is_cpu)
def test_cache_tuple_outputs(self): def test_cache_tuple_outputs(self):
""" """
Test to check that ops with tuple outputs work. Test to check that ops with tuple outputs work.

View File

@ -354,14 +354,20 @@ class FakeTensorConverter:
maybe_memo = self._get_memo(t) maybe_memo = self._get_memo(t)
if maybe_memo is not None: if maybe_memo is not None:
return maybe_memo return maybe_memo
existing_device = t.device
# not yet supported in metatensors # not yet supported in metatensors
if t.is_quantized: if t.is_quantized:
raise UnsupportedFakeTensorException("quantized nyi in meta tensors") raise UnsupportedFakeTensorException("quantized nyi in meta tensors")
if type(t) is torch.nn.Parameter: if type(t) is torch.nn.Parameter:
assert not make_constant assert not make_constant
def mk_fake_tensor(make_meta_t: Callable[[], object]) -> FakeTensor: constant = t if make_constant else None
# This callback is used by both subclass and inner tensors. Require the
# caller to explicitly specify the device in case outer and inner tensors
# have different devices.
def mk_fake_tensor(
make_meta_t: Callable[[], object], device: torch.device
) -> FakeTensor:
# NB: don't use in_kernel_invocation_manager. to # NB: don't use in_kernel_invocation_manager. to
# ensure FakeTensor can internally do constant computation # ensure FakeTensor can internally do constant computation
# as necessary. Invocation manager is "more correct" as # as necessary. Invocation manager is "more correct" as
@ -373,16 +379,16 @@ class FakeTensorConverter:
return FakeTensor( return FakeTensor(
fake_mode, fake_mode,
make_meta_t(), make_meta_t(),
existing_device, device,
# TODO: callback might be used in recursive contexts, in # TODO: callback might be used in recursive contexts, in
# which case using t is wrong! BUG! # which case using t is wrong! BUG!
constant=t if make_constant else None, constant=constant,
) )
out = self.meta_converter( out = self.meta_converter(
t, t,
shape_env=shape_env, shape_env=shape_env,
callback=mk_fake_tensor, callback=mk_fake_tensor, # type: ignore[arg-type]
source=source, source=source,
symbolic_context=symbolic_context, symbolic_context=symbolic_context,
trace=trace, trace=trace,

View File

@ -2,6 +2,7 @@ from __future__ import annotations
import contextlib import contextlib
import dataclasses import dataclasses
import functools
import typing import typing
import warnings import warnings
import weakref import weakref
@ -709,7 +710,7 @@ class MetaConverter(Generic[_TensorT]):
def meta_storage( def meta_storage(
self, self,
s: MetaStorageDesc, s: MetaStorageDesc,
callback: Callable[[Callable[[], torch.Tensor]], _TensorT], callback: Callable[[Callable], _TensorT],
) -> torch.UntypedStorage: ) -> torch.UntypedStorage:
# If we are fakeifying a tensor that has a secretly-zero-sized storage, # If we are fakeifying a tensor that has a secretly-zero-sized storage,
# Need to make sure to resize the meta storage too. # Need to make sure to resize the meta storage too.
@ -734,7 +735,9 @@ class MetaConverter(Generic[_TensorT]):
return typing.cast(_TensorT, t) return typing.cast(_TensorT, t)
@classmethod @classmethod
def _identity_callable(cls, t: Callable[[], torch.Tensor]) -> _TensorT: def _identity_callable(
cls, t: Callable[[], torch.Tensor], device: torch.device
) -> _TensorT:
return cls._checked_cast_tensor_t(t()) return cls._checked_cast_tensor_t(t())
@classmethod @classmethod
@ -756,10 +759,11 @@ class MetaConverter(Generic[_TensorT]):
self, self,
t: MetaTensorDesc, t: MetaTensorDesc,
shape_env: Optional[ShapeEnv], shape_env: Optional[ShapeEnv],
callback: Callable[[Callable[[], torch.Tensor]], _TensorT], callback: Callable[[Callable], _TensorT],
source: Optional[Source], source: Optional[Source],
symbolic_context: Optional[SymbolicContext], symbolic_context: Optional[SymbolicContext],
) -> _TensorT: ) -> _TensorT:
callback = functools.partial(callback, device=t.device) # type: ignore[call-arg]
if source is None: if source is None:
from torch._dynamo.source import ConstantSource from torch._dynamo.source import ConstantSource
@ -905,7 +909,7 @@ class MetaConverter(Generic[_TensorT]):
symbolic_context: Optional[ symbolic_context: Optional[
torch.fx.experimental.symbolic_shapes.SymbolicContext torch.fx.experimental.symbolic_shapes.SymbolicContext
], ],
callback: Callable[[Callable[[], torch.Tensor]], _TensorT], callback: Callable[[Callable], _TensorT],
source: torch._guards.Source, source: torch._guards.Source,
) -> _TensorT: ) -> _TensorT:
# We are hitting plain meta_desc tensor so actually # We are hitting plain meta_desc tensor so actually
@ -933,12 +937,15 @@ class MetaConverter(Generic[_TensorT]):
) )
current_source = AttrSource(source, attr) current_source = AttrSource(source, attr)
inner_callback = functools.partial(
callback, device=meta_tensor_desc.device # type: ignore[call-arg]
)
new_empty_tensor = _empty_create_subclass( new_empty_tensor = _empty_create_subclass(
meta_tensor_desc, meta_tensor_desc,
meta_tensor_desc.size, meta_tensor_desc.size,
meta_tensor_desc.stride, meta_tensor_desc.stride,
current_context, current_context,
callback, inner_callback,
current_source, current_source,
) )
inner_tensors[attr] = new_empty_tensor inner_tensors[attr] = new_empty_tensor
@ -975,7 +982,7 @@ class MetaConverter(Generic[_TensorT]):
t: MetaTensorDesc, t: MetaTensorDesc,
source: torch._guards.Source, source: torch._guards.Source,
shape_env: Optional[torch.fx.experimental.symbolic_shapes.ShapeEnv], shape_env: Optional[torch.fx.experimental.symbolic_shapes.ShapeEnv],
callback: Callable[[Callable[[], torch.Tensor]], _TensorT], callback: Callable[[Callable], _TensorT],
) -> torch.fx.experimental.symbolic_shapes.SymbolicContext: ) -> torch.fx.experimental.symbolic_shapes.SymbolicContext:
from torch._dynamo.source import AttrSource from torch._dynamo.source import AttrSource
from torch.fx.experimental.symbolic_shapes import ( from torch.fx.experimental.symbolic_shapes import (
@ -1137,7 +1144,7 @@ class MetaConverter(Generic[_TensorT]):
shape_env: Optional[ shape_env: Optional[
torch.fx.experimental.symbolic_shapes.ShapeEnv torch.fx.experimental.symbolic_shapes.ShapeEnv
] = shape_env, ] = shape_env,
callback: Callable[[Callable[[], torch.Tensor]], _TensorT] = callback, # type: ignore[assignment] callback: Callable[[Callable], _TensorT] = callback, # type: ignore[assignment]
) -> torch.Tensor: ) -> torch.Tensor:
# It's possible to close over an undefined tensor (e.g. NJT's lengths). # It's possible to close over an undefined tensor (e.g. NJT's lengths).
if visited_t is None: if visited_t is None:
@ -1723,7 +1730,7 @@ class MetaConverter(Generic[_TensorT]):
t: torch.Tensor, t: torch.Tensor,
shape_env: Optional[ShapeEnv] = None, shape_env: Optional[ShapeEnv] = None,
*, *,
callback: Optional[Callable[[Callable[[], torch.Tensor]], _TensorT]] = None, callback: Optional[Callable[[Callable], _TensorT]] = None,
source: Optional[Source] = None, source: Optional[Source] = None,
symbolic_context: Optional[SymbolicContext] = None, symbolic_context: Optional[SymbolicContext] = None,
# Controls whether or not we should dump the tensor metadata to structured logs # Controls whether or not we should dump the tensor metadata to structured logs
@ -1731,9 +1738,9 @@ class MetaConverter(Generic[_TensorT]):
# we don't want to dump info again from AOTAutograd, it is redundant. # we don't want to dump info again from AOTAutograd, it is redundant.
trace: bool = True, trace: bool = True,
) -> _TensorT: ) -> _TensorT:
callback_: Callable[[Callable[[], torch.Tensor]], _TensorT] callback_: Callable[[Callable], _TensorT]
if callback is None: if callback is None:
callback_ = self._identity_callable callback_ = self._identity_callable # type: ignore[assignment]
else: else:
callback_ = callback callback_ = callback
# TODO: zero tensors? We appear to have eliminated them by # TODO: zero tensors? We appear to have eliminated them by