Allow Fakified subclass to have different device for inner and outer tensor (#141839)

Previously if a wrapper tensor subclass is fakified, the inner tensors would end up having the same device as the outer tensor. This PR makes it so that inner and outer tensors can have different devices.

See OffloadTensor PR https://github.com/pytorch/pytorch/pull/141840/files#diff-3bc0cf540b694f4ec0a3749f78b047456657a53a5657e495ffb68e5970c5fdaaR1955 for an application. A simpler test has been added in this PR.

This is technically bc-breaking because now the callback passed to MetaConverter needs to accept an extra argument, but no one external should be using this anyway?
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141839
Approved by: https://github.com/bdhirsh
ghstack dependencies: #141166
This commit is contained in:
soulitzer
2024-12-02 11:46:44 -08:00
committed by PyTorch MergeBot
parent 9830e7b1e4
commit e41a0b33ec
3 changed files with 82 additions and 15 deletions

View File

@ -1951,6 +1951,60 @@ class FakeTensorDispatchCache(TestCase):
extract_tensor_metadata(res4),
)
@unittest.skipIf(not RUN_CUDA, "requires cuda")
def test_wrapper_tensor_subclass_different_device(self):
class DifferentDeviceTensor(torch.Tensor):
@staticmethod
def __new__(cls, a):
kwargs = {}
kwargs["strides"] = a.stride()
kwargs["storage_offset"] = a.storage_offset()
kwargs["device"] = torch.device("cpu")
kwargs["layout"] = a.layout
kwargs["requires_grad"] = a.requires_grad
kwargs["dtype"] = a.dtype
out = torch.Tensor._make_wrapper_subclass(cls, a.size(), **kwargs)
return out
def __init__(self, a):
self.inner_tensor = a
def __repr__(self):
return f"DifferentDeviceTensor({repr(self.inner_tensor)})"
def __tensor_flatten__(self):
return ["inner_tensor"], None
@staticmethod
def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride):
assert meta is None
return DifferentDeviceTensor(inner_tensors["inner_tensor"])
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
if kwargs is None:
kwargs = {}
args = pytree.tree_map_only(DifferentDeviceTensor, lambda x: x.inner_tensor, args)
kwargs = pytree.tree_map_only(DifferentDeviceTensor, lambda x: x.inner_tensor, kwargs)
# Returns unwrapped tensor
return func(*args, **kwargs)
a = torch.ones(2, 2, 768, device="cuda")
wrapped_a = DifferentDeviceTensor(a)
# Outer Tensor is on cpu, inner is on cuda
self.assertTrue(wrapped_a.is_cpu)
self.assertFalse(wrapped_a.inner_tensor.is_cpu)
with FakeTensorMode() as fake_mode:
fake_wrapped_a = fake_mode.from_tensor(wrapped_a)
self.assertTrue(fake_wrapped_a.is_cpu)
assert isinstance(fake_wrapped_a, DifferentDeviceTensor)
self.assertFalse(fake_wrapped_a.inner_tensor.is_cpu)
def test_cache_tuple_outputs(self):
"""
Test to check that ops with tuple outputs work.

View File

@ -354,14 +354,20 @@ class FakeTensorConverter:
maybe_memo = self._get_memo(t)
if maybe_memo is not None:
return maybe_memo
existing_device = t.device
# not yet supported in metatensors
if t.is_quantized:
raise UnsupportedFakeTensorException("quantized nyi in meta tensors")
if type(t) is torch.nn.Parameter:
assert not make_constant
def mk_fake_tensor(make_meta_t: Callable[[], object]) -> FakeTensor:
constant = t if make_constant else None
# This callback is used by both subclass and inner tensors. Require the
# caller to explicitly specify the device in case outer and inner tensors
# have different devices.
def mk_fake_tensor(
make_meta_t: Callable[[], object], device: torch.device
) -> FakeTensor:
# NB: don't use in_kernel_invocation_manager. to
# ensure FakeTensor can internally do constant computation
# as necessary. Invocation manager is "more correct" as
@ -373,16 +379,16 @@ class FakeTensorConverter:
return FakeTensor(
fake_mode,
make_meta_t(),
existing_device,
device,
# TODO: callback might be used in recursive contexts, in
# which case using t is wrong! BUG!
constant=t if make_constant else None,
constant=constant,
)
out = self.meta_converter(
t,
shape_env=shape_env,
callback=mk_fake_tensor,
callback=mk_fake_tensor, # type: ignore[arg-type]
source=source,
symbolic_context=symbolic_context,
trace=trace,

View File

@ -2,6 +2,7 @@ from __future__ import annotations
import contextlib
import dataclasses
import functools
import typing
import warnings
import weakref
@ -709,7 +710,7 @@ class MetaConverter(Generic[_TensorT]):
def meta_storage(
self,
s: MetaStorageDesc,
callback: Callable[[Callable[[], torch.Tensor]], _TensorT],
callback: Callable[[Callable], _TensorT],
) -> torch.UntypedStorage:
# If we are fakeifying a tensor that has a secretly-zero-sized storage,
# Need to make sure to resize the meta storage too.
@ -734,7 +735,9 @@ class MetaConverter(Generic[_TensorT]):
return typing.cast(_TensorT, t)
@classmethod
def _identity_callable(cls, t: Callable[[], torch.Tensor]) -> _TensorT:
def _identity_callable(
cls, t: Callable[[], torch.Tensor], device: torch.device
) -> _TensorT:
return cls._checked_cast_tensor_t(t())
@classmethod
@ -756,10 +759,11 @@ class MetaConverter(Generic[_TensorT]):
self,
t: MetaTensorDesc,
shape_env: Optional[ShapeEnv],
callback: Callable[[Callable[[], torch.Tensor]], _TensorT],
callback: Callable[[Callable], _TensorT],
source: Optional[Source],
symbolic_context: Optional[SymbolicContext],
) -> _TensorT:
callback = functools.partial(callback, device=t.device) # type: ignore[call-arg]
if source is None:
from torch._dynamo.source import ConstantSource
@ -905,7 +909,7 @@ class MetaConverter(Generic[_TensorT]):
symbolic_context: Optional[
torch.fx.experimental.symbolic_shapes.SymbolicContext
],
callback: Callable[[Callable[[], torch.Tensor]], _TensorT],
callback: Callable[[Callable], _TensorT],
source: torch._guards.Source,
) -> _TensorT:
# We are hitting plain meta_desc tensor so actually
@ -933,12 +937,15 @@ class MetaConverter(Generic[_TensorT]):
)
current_source = AttrSource(source, attr)
inner_callback = functools.partial(
callback, device=meta_tensor_desc.device # type: ignore[call-arg]
)
new_empty_tensor = _empty_create_subclass(
meta_tensor_desc,
meta_tensor_desc.size,
meta_tensor_desc.stride,
current_context,
callback,
inner_callback,
current_source,
)
inner_tensors[attr] = new_empty_tensor
@ -975,7 +982,7 @@ class MetaConverter(Generic[_TensorT]):
t: MetaTensorDesc,
source: torch._guards.Source,
shape_env: Optional[torch.fx.experimental.symbolic_shapes.ShapeEnv],
callback: Callable[[Callable[[], torch.Tensor]], _TensorT],
callback: Callable[[Callable], _TensorT],
) -> torch.fx.experimental.symbolic_shapes.SymbolicContext:
from torch._dynamo.source import AttrSource
from torch.fx.experimental.symbolic_shapes import (
@ -1137,7 +1144,7 @@ class MetaConverter(Generic[_TensorT]):
shape_env: Optional[
torch.fx.experimental.symbolic_shapes.ShapeEnv
] = shape_env,
callback: Callable[[Callable[[], torch.Tensor]], _TensorT] = callback, # type: ignore[assignment]
callback: Callable[[Callable], _TensorT] = callback, # type: ignore[assignment]
) -> torch.Tensor:
# It's possible to close over an undefined tensor (e.g. NJT's lengths).
if visited_t is None:
@ -1723,7 +1730,7 @@ class MetaConverter(Generic[_TensorT]):
t: torch.Tensor,
shape_env: Optional[ShapeEnv] = None,
*,
callback: Optional[Callable[[Callable[[], torch.Tensor]], _TensorT]] = None,
callback: Optional[Callable[[Callable], _TensorT]] = None,
source: Optional[Source] = None,
symbolic_context: Optional[SymbolicContext] = None,
# Controls whether or not we should dump the tensor metadata to structured logs
@ -1731,9 +1738,9 @@ class MetaConverter(Generic[_TensorT]):
# we don't want to dump info again from AOTAutograd, it is redundant.
trace: bool = True,
) -> _TensorT:
callback_: Callable[[Callable[[], torch.Tensor]], _TensorT]
callback_: Callable[[Callable], _TensorT]
if callback is None:
callback_ = self._identity_callable
callback_ = self._identity_callable # type: ignore[assignment]
else:
callback_ = callback
# TODO: zero tensors? We appear to have eliminated them by