mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Allow Fakified subclass to have different device for inner and outer tensor (#141839)
Previously if a wrapper tensor subclass is fakified, the inner tensors would end up having the same device as the outer tensor. This PR makes it so that inner and outer tensors can have different devices. See OffloadTensor PR https://github.com/pytorch/pytorch/pull/141840/files#diff-3bc0cf540b694f4ec0a3749f78b047456657a53a5657e495ffb68e5970c5fdaaR1955 for an application. A simpler test has been added in this PR. This is technically bc-breaking because now the callback passed to MetaConverter needs to accept an extra argument, but no one external should be using this anyway? Pull Request resolved: https://github.com/pytorch/pytorch/pull/141839 Approved by: https://github.com/bdhirsh ghstack dependencies: #141166
This commit is contained in:
committed by
PyTorch MergeBot
parent
9830e7b1e4
commit
e41a0b33ec
@ -1951,6 +1951,60 @@ class FakeTensorDispatchCache(TestCase):
|
||||
extract_tensor_metadata(res4),
|
||||
)
|
||||
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "requires cuda")
|
||||
def test_wrapper_tensor_subclass_different_device(self):
|
||||
class DifferentDeviceTensor(torch.Tensor):
|
||||
@staticmethod
|
||||
def __new__(cls, a):
|
||||
kwargs = {}
|
||||
kwargs["strides"] = a.stride()
|
||||
kwargs["storage_offset"] = a.storage_offset()
|
||||
kwargs["device"] = torch.device("cpu")
|
||||
kwargs["layout"] = a.layout
|
||||
kwargs["requires_grad"] = a.requires_grad
|
||||
kwargs["dtype"] = a.dtype
|
||||
out = torch.Tensor._make_wrapper_subclass(cls, a.size(), **kwargs)
|
||||
return out
|
||||
|
||||
def __init__(self, a):
|
||||
self.inner_tensor = a
|
||||
|
||||
def __repr__(self):
|
||||
return f"DifferentDeviceTensor({repr(self.inner_tensor)})"
|
||||
|
||||
def __tensor_flatten__(self):
|
||||
return ["inner_tensor"], None
|
||||
|
||||
@staticmethod
|
||||
def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride):
|
||||
assert meta is None
|
||||
return DifferentDeviceTensor(inner_tensors["inner_tensor"])
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args, kwargs):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
args = pytree.tree_map_only(DifferentDeviceTensor, lambda x: x.inner_tensor, args)
|
||||
kwargs = pytree.tree_map_only(DifferentDeviceTensor, lambda x: x.inner_tensor, kwargs)
|
||||
# Returns unwrapped tensor
|
||||
return func(*args, **kwargs)
|
||||
|
||||
a = torch.ones(2, 2, 768, device="cuda")
|
||||
wrapped_a = DifferentDeviceTensor(a)
|
||||
|
||||
# Outer Tensor is on cpu, inner is on cuda
|
||||
self.assertTrue(wrapped_a.is_cpu)
|
||||
self.assertFalse(wrapped_a.inner_tensor.is_cpu)
|
||||
|
||||
with FakeTensorMode() as fake_mode:
|
||||
fake_wrapped_a = fake_mode.from_tensor(wrapped_a)
|
||||
|
||||
self.assertTrue(fake_wrapped_a.is_cpu)
|
||||
assert isinstance(fake_wrapped_a, DifferentDeviceTensor)
|
||||
self.assertFalse(fake_wrapped_a.inner_tensor.is_cpu)
|
||||
|
||||
|
||||
def test_cache_tuple_outputs(self):
|
||||
"""
|
||||
Test to check that ops with tuple outputs work.
|
||||
|
@ -354,14 +354,20 @@ class FakeTensorConverter:
|
||||
maybe_memo = self._get_memo(t)
|
||||
if maybe_memo is not None:
|
||||
return maybe_memo
|
||||
existing_device = t.device
|
||||
# not yet supported in metatensors
|
||||
if t.is_quantized:
|
||||
raise UnsupportedFakeTensorException("quantized nyi in meta tensors")
|
||||
if type(t) is torch.nn.Parameter:
|
||||
assert not make_constant
|
||||
|
||||
def mk_fake_tensor(make_meta_t: Callable[[], object]) -> FakeTensor:
|
||||
constant = t if make_constant else None
|
||||
|
||||
# This callback is used by both subclass and inner tensors. Require the
|
||||
# caller to explicitly specify the device in case outer and inner tensors
|
||||
# have different devices.
|
||||
def mk_fake_tensor(
|
||||
make_meta_t: Callable[[], object], device: torch.device
|
||||
) -> FakeTensor:
|
||||
# NB: don't use in_kernel_invocation_manager. to
|
||||
# ensure FakeTensor can internally do constant computation
|
||||
# as necessary. Invocation manager is "more correct" as
|
||||
@ -373,16 +379,16 @@ class FakeTensorConverter:
|
||||
return FakeTensor(
|
||||
fake_mode,
|
||||
make_meta_t(),
|
||||
existing_device,
|
||||
device,
|
||||
# TODO: callback might be used in recursive contexts, in
|
||||
# which case using t is wrong! BUG!
|
||||
constant=t if make_constant else None,
|
||||
constant=constant,
|
||||
)
|
||||
|
||||
out = self.meta_converter(
|
||||
t,
|
||||
shape_env=shape_env,
|
||||
callback=mk_fake_tensor,
|
||||
callback=mk_fake_tensor, # type: ignore[arg-type]
|
||||
source=source,
|
||||
symbolic_context=symbolic_context,
|
||||
trace=trace,
|
||||
|
@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import dataclasses
|
||||
import functools
|
||||
import typing
|
||||
import warnings
|
||||
import weakref
|
||||
@ -709,7 +710,7 @@ class MetaConverter(Generic[_TensorT]):
|
||||
def meta_storage(
|
||||
self,
|
||||
s: MetaStorageDesc,
|
||||
callback: Callable[[Callable[[], torch.Tensor]], _TensorT],
|
||||
callback: Callable[[Callable], _TensorT],
|
||||
) -> torch.UntypedStorage:
|
||||
# If we are fakeifying a tensor that has a secretly-zero-sized storage,
|
||||
# Need to make sure to resize the meta storage too.
|
||||
@ -734,7 +735,9 @@ class MetaConverter(Generic[_TensorT]):
|
||||
return typing.cast(_TensorT, t)
|
||||
|
||||
@classmethod
|
||||
def _identity_callable(cls, t: Callable[[], torch.Tensor]) -> _TensorT:
|
||||
def _identity_callable(
|
||||
cls, t: Callable[[], torch.Tensor], device: torch.device
|
||||
) -> _TensorT:
|
||||
return cls._checked_cast_tensor_t(t())
|
||||
|
||||
@classmethod
|
||||
@ -756,10 +759,11 @@ class MetaConverter(Generic[_TensorT]):
|
||||
self,
|
||||
t: MetaTensorDesc,
|
||||
shape_env: Optional[ShapeEnv],
|
||||
callback: Callable[[Callable[[], torch.Tensor]], _TensorT],
|
||||
callback: Callable[[Callable], _TensorT],
|
||||
source: Optional[Source],
|
||||
symbolic_context: Optional[SymbolicContext],
|
||||
) -> _TensorT:
|
||||
callback = functools.partial(callback, device=t.device) # type: ignore[call-arg]
|
||||
if source is None:
|
||||
from torch._dynamo.source import ConstantSource
|
||||
|
||||
@ -905,7 +909,7 @@ class MetaConverter(Generic[_TensorT]):
|
||||
symbolic_context: Optional[
|
||||
torch.fx.experimental.symbolic_shapes.SymbolicContext
|
||||
],
|
||||
callback: Callable[[Callable[[], torch.Tensor]], _TensorT],
|
||||
callback: Callable[[Callable], _TensorT],
|
||||
source: torch._guards.Source,
|
||||
) -> _TensorT:
|
||||
# We are hitting plain meta_desc tensor so actually
|
||||
@ -933,12 +937,15 @@ class MetaConverter(Generic[_TensorT]):
|
||||
)
|
||||
|
||||
current_source = AttrSource(source, attr)
|
||||
inner_callback = functools.partial(
|
||||
callback, device=meta_tensor_desc.device # type: ignore[call-arg]
|
||||
)
|
||||
new_empty_tensor = _empty_create_subclass(
|
||||
meta_tensor_desc,
|
||||
meta_tensor_desc.size,
|
||||
meta_tensor_desc.stride,
|
||||
current_context,
|
||||
callback,
|
||||
inner_callback,
|
||||
current_source,
|
||||
)
|
||||
inner_tensors[attr] = new_empty_tensor
|
||||
@ -975,7 +982,7 @@ class MetaConverter(Generic[_TensorT]):
|
||||
t: MetaTensorDesc,
|
||||
source: torch._guards.Source,
|
||||
shape_env: Optional[torch.fx.experimental.symbolic_shapes.ShapeEnv],
|
||||
callback: Callable[[Callable[[], torch.Tensor]], _TensorT],
|
||||
callback: Callable[[Callable], _TensorT],
|
||||
) -> torch.fx.experimental.symbolic_shapes.SymbolicContext:
|
||||
from torch._dynamo.source import AttrSource
|
||||
from torch.fx.experimental.symbolic_shapes import (
|
||||
@ -1137,7 +1144,7 @@ class MetaConverter(Generic[_TensorT]):
|
||||
shape_env: Optional[
|
||||
torch.fx.experimental.symbolic_shapes.ShapeEnv
|
||||
] = shape_env,
|
||||
callback: Callable[[Callable[[], torch.Tensor]], _TensorT] = callback, # type: ignore[assignment]
|
||||
callback: Callable[[Callable], _TensorT] = callback, # type: ignore[assignment]
|
||||
) -> torch.Tensor:
|
||||
# It's possible to close over an undefined tensor (e.g. NJT's lengths).
|
||||
if visited_t is None:
|
||||
@ -1723,7 +1730,7 @@ class MetaConverter(Generic[_TensorT]):
|
||||
t: torch.Tensor,
|
||||
shape_env: Optional[ShapeEnv] = None,
|
||||
*,
|
||||
callback: Optional[Callable[[Callable[[], torch.Tensor]], _TensorT]] = None,
|
||||
callback: Optional[Callable[[Callable], _TensorT]] = None,
|
||||
source: Optional[Source] = None,
|
||||
symbolic_context: Optional[SymbolicContext] = None,
|
||||
# Controls whether or not we should dump the tensor metadata to structured logs
|
||||
@ -1731,9 +1738,9 @@ class MetaConverter(Generic[_TensorT]):
|
||||
# we don't want to dump info again from AOTAutograd, it is redundant.
|
||||
trace: bool = True,
|
||||
) -> _TensorT:
|
||||
callback_: Callable[[Callable[[], torch.Tensor]], _TensorT]
|
||||
callback_: Callable[[Callable], _TensorT]
|
||||
if callback is None:
|
||||
callback_ = self._identity_callable
|
||||
callback_ = self._identity_callable # type: ignore[assignment]
|
||||
else:
|
||||
callback_ = callback
|
||||
# TODO: zero tensors? We appear to have eliminated them by
|
||||
|
Reference in New Issue
Block a user