mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Notable TODOs: 1. Need to implement AutogradHOP to get rid of subclasses before serializing 2. Need to implement mechanism to figure out what subclasses will be used in export when they are not expressed in the inputs Differential Revision: [D69640673](https://our.internmc.facebook.com/intern/diff/D69640673) Pull Request resolved: https://github.com/pytorch/pytorch/pull/147014 Approved by: https://github.com/bdhirsh
101 lines
3.5 KiB
Python
101 lines
3.5 KiB
Python
# mypy: ignore-errors
|
|
|
|
import torch
|
|
import torch.utils._pytree as pytree
|
|
from torch._export.wrappers import mark_subclass_constructor_exportable_experimental
|
|
from torch.utils._python_dispatch import return_and_correct_aliasing
|
|
|
|
|
|
# A simple tensor subclass that holds two tensors internally, and runs every op on both tensors.
|
|
class TwoTensor(torch.Tensor):
|
|
@staticmethod
|
|
def __new__(cls, a, b, outer_size=None, outer_stride=None):
|
|
if outer_size is None:
|
|
outer_size = a.size()
|
|
if outer_stride is None:
|
|
outer_stride = a.stride()
|
|
|
|
assert (
|
|
a.device == b.device
|
|
and a.layout == b.layout
|
|
and a.requires_grad == b.requires_grad
|
|
and a.dtype == b.dtype
|
|
)
|
|
# I guess it would be more accurate to represent the shape as torch.cat(a, b).shape
|
|
shape = outer_size
|
|
kwargs = {}
|
|
kwargs["strides"] = outer_stride
|
|
kwargs["storage_offset"] = a.storage_offset()
|
|
kwargs["device"] = a.device
|
|
kwargs["layout"] = a.layout
|
|
kwargs["requires_grad"] = a.requires_grad
|
|
kwargs["dtype"] = a.dtype
|
|
out = torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs)
|
|
|
|
assert a.shape == b.shape
|
|
assert a.stride() == b.stride()
|
|
assert a.storage_offset() == b.storage_offset()
|
|
return out
|
|
|
|
@torch._disable_dynamo
|
|
@mark_subclass_constructor_exportable_experimental
|
|
def __init__(self, a, b, outer_size=None, outer_stride=None):
|
|
self.a = a
|
|
self.b = b
|
|
|
|
def __repr__(self):
|
|
a_repr = repr(self.a)
|
|
b_repr = repr(self.b)
|
|
return f"TwoTensor({a_repr}, {b_repr})"
|
|
|
|
def __tensor_flatten__(self):
|
|
return ["a", "b"], None
|
|
|
|
@staticmethod
|
|
def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride):
|
|
assert meta is None
|
|
a, b = inner_tensors["a"], inner_tensors["b"]
|
|
if type(a) is torch.Tensor:
|
|
assert outer_size is not None
|
|
assert outer_stride is not None
|
|
return TwoTensor(a, b, outer_size, outer_stride)
|
|
|
|
@classmethod
|
|
def __torch_dispatch__(cls, func, types, args, kwargs):
|
|
if kwargs is None:
|
|
kwargs = {}
|
|
args_a = pytree.tree_map_only(TwoTensor, lambda x: x.a, args)
|
|
args_b = pytree.tree_map_only(TwoTensor, lambda x: x.b, args)
|
|
|
|
kwargs_a = pytree.tree_map_only(TwoTensor, lambda x: x.a, kwargs)
|
|
kwargs_b = pytree.tree_map_only(TwoTensor, lambda x: x.b, kwargs)
|
|
|
|
out_a = func(*args_a, **kwargs_a)
|
|
out_b = func(*args_b, **kwargs_b)
|
|
out_a_flat, spec = pytree.tree_flatten(out_a)
|
|
out_b_flat = pytree.tree_leaves(out_b)
|
|
# for aten ops that return non-tensors, just assume that
|
|
# our two inner tensors return the same value
|
|
out_flat = [
|
|
cls(o_a, o_b) if isinstance(o_a, torch.Tensor) else o_a
|
|
for o_a, o_b in zip(out_a_flat, out_b_flat)
|
|
]
|
|
out = pytree.tree_unflatten(out_flat, spec)
|
|
from torch._higher_order_ops.cond import cond_op
|
|
|
|
if func is cond_op:
|
|
return out
|
|
else:
|
|
return return_and_correct_aliasing(func, args, kwargs, out)
|
|
|
|
def get_elem_a(self):
|
|
return self.a
|
|
|
|
|
|
class TwoTensorMode(torch.utils._python_dispatch.TorchDispatchMode):
|
|
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
|
out = func(*args, **kwargs)
|
|
if torch._subclasses.fake_tensor._is_tensor_constructor(func):
|
|
out = TwoTensor(out, out.clone())
|
|
return out
|