Update TwoTensor impl. to accept outer_size/outer_stride (#133337)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133337
Approved by: https://github.com/bdhirsh
This commit is contained in:
Guilherme Leobas
2024-10-28 15:26:12 +00:00
committed by PyTorch MergeBot
parent f4f0f2995d
commit 6baccb430b
2 changed files with 10 additions and 5 deletions

View File

@ -8,7 +8,12 @@ from torch.utils._python_dispatch import return_and_correct_aliasing
# A simple tensor subclass that holds two tensors internally, and runs every op on both tensors.
class TwoTensor(torch.Tensor):
@staticmethod
def __new__(cls, a, b):
def __new__(cls, a, b, outer_size=None, outer_stride=None):
if outer_size is None:
outer_size = a.size()
if outer_stride is None:
outer_stride = a.stride()
assert (
a.device == b.device
and a.layout == b.layout
@ -16,9 +21,9 @@ class TwoTensor(torch.Tensor):
and a.dtype == b.dtype
)
# I guess it would be more accurate to represent the shape as torch.cat(a, b).shape
shape = a.shape
shape = outer_size
kwargs = {}
kwargs["strides"] = a.stride()
kwargs["strides"] = outer_stride
kwargs["storage_offset"] = a.storage_offset()
kwargs["device"] = a.device
kwargs["layout"] = a.layout
@ -31,7 +36,7 @@ class TwoTensor(torch.Tensor):
assert a.storage_offset() == b.storage_offset()
return out
def __init__(self, a, b):
def __init__(self, a, b, outer_size=None, outer_stride=None):
self.a = a
self.b = b
@ -47,7 +52,7 @@ class TwoTensor(torch.Tensor):
def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride):
assert meta is None
a, b = inner_tensors["a"], inner_tensors["b"]
return TwoTensor(a, b)
return TwoTensor(a, b, outer_size, outer_stride)
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):