Files
pytorch/torch/testing/_internal/two_tensor.py
Tugsbayasgalan Manlaibaatar 6b1b95ad2a Support subclass constructor capturing in export (#147014)
Notable TODOs:
1. Need to implement AutogradHOP to get rid of subclasses before serializing
2. Need to implement mechanism to figure out what subclasses will be used in export when they are not expressed in the inputs

Differential Revision: [D69640673](https://our.internmc.facebook.com/intern/diff/D69640673)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147014
Approved by: https://github.com/bdhirsh
2025-03-16 18:19:19 +00:00

101 lines
3.5 KiB
Python

# mypy: ignore-errors
import torch
import torch.utils._pytree as pytree
from torch._export.wrappers import mark_subclass_constructor_exportable_experimental
from torch.utils._python_dispatch import return_and_correct_aliasing
# A simple tensor subclass that holds two tensors internally, and runs every op on both tensors.
class TwoTensor(torch.Tensor):
@staticmethod
def __new__(cls, a, b, outer_size=None, outer_stride=None):
if outer_size is None:
outer_size = a.size()
if outer_stride is None:
outer_stride = a.stride()
assert (
a.device == b.device
and a.layout == b.layout
and a.requires_grad == b.requires_grad
and a.dtype == b.dtype
)
# I guess it would be more accurate to represent the shape as torch.cat(a, b).shape
shape = outer_size
kwargs = {}
kwargs["strides"] = outer_stride
kwargs["storage_offset"] = a.storage_offset()
kwargs["device"] = a.device
kwargs["layout"] = a.layout
kwargs["requires_grad"] = a.requires_grad
kwargs["dtype"] = a.dtype
out = torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs)
assert a.shape == b.shape
assert a.stride() == b.stride()
assert a.storage_offset() == b.storage_offset()
return out
@torch._disable_dynamo
@mark_subclass_constructor_exportable_experimental
def __init__(self, a, b, outer_size=None, outer_stride=None):
self.a = a
self.b = b
def __repr__(self):
a_repr = repr(self.a)
b_repr = repr(self.b)
return f"TwoTensor({a_repr}, {b_repr})"
def __tensor_flatten__(self):
return ["a", "b"], None
@staticmethod
def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride):
assert meta is None
a, b = inner_tensors["a"], inner_tensors["b"]
if type(a) is torch.Tensor:
assert outer_size is not None
assert outer_stride is not None
return TwoTensor(a, b, outer_size, outer_stride)
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
if kwargs is None:
kwargs = {}
args_a = pytree.tree_map_only(TwoTensor, lambda x: x.a, args)
args_b = pytree.tree_map_only(TwoTensor, lambda x: x.b, args)
kwargs_a = pytree.tree_map_only(TwoTensor, lambda x: x.a, kwargs)
kwargs_b = pytree.tree_map_only(TwoTensor, lambda x: x.b, kwargs)
out_a = func(*args_a, **kwargs_a)
out_b = func(*args_b, **kwargs_b)
out_a_flat, spec = pytree.tree_flatten(out_a)
out_b_flat = pytree.tree_leaves(out_b)
# for aten ops that return non-tensors, just assume that
# our two inner tensors return the same value
out_flat = [
cls(o_a, o_b) if isinstance(o_a, torch.Tensor) else o_a
for o_a, o_b in zip(out_a_flat, out_b_flat)
]
out = pytree.tree_unflatten(out_flat, spec)
from torch._higher_order_ops.cond import cond_op
if func is cond_op:
return out
else:
return return_and_correct_aliasing(func, args, kwargs, out)
def get_elem_a(self):
return self.a
class TwoTensorMode(torch.utils._python_dispatch.TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
out = func(*args, **kwargs)
if torch._subclasses.fake_tensor._is_tensor_constructor(func):
out = TwoTensor(out, out.clone())
return out