mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Update TwoTensor impl. to accept outer_size/outer_stride
(#133337)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133337 Approved by: https://github.com/bdhirsh
This commit is contained in:
committed by
PyTorch MergeBot
parent
f4f0f2995d
commit
6baccb430b
@ -8,7 +8,12 @@ from torch.utils._python_dispatch import return_and_correct_aliasing
|
||||
# A simple tensor subclass that holds two tensors internally, and runs every op on both tensors.
|
||||
class TwoTensor(torch.Tensor):
|
||||
@staticmethod
|
||||
def __new__(cls, a, b):
|
||||
def __new__(cls, a, b, outer_size=None, outer_stride=None):
|
||||
if outer_size is None:
|
||||
outer_size = a.size()
|
||||
if outer_stride is None:
|
||||
outer_stride = a.stride()
|
||||
|
||||
assert (
|
||||
a.device == b.device
|
||||
and a.layout == b.layout
|
||||
@ -16,9 +21,9 @@ class TwoTensor(torch.Tensor):
|
||||
and a.dtype == b.dtype
|
||||
)
|
||||
# I guess it would be more accurate to represent the shape as torch.cat(a, b).shape
|
||||
shape = a.shape
|
||||
shape = outer_size
|
||||
kwargs = {}
|
||||
kwargs["strides"] = a.stride()
|
||||
kwargs["strides"] = outer_stride
|
||||
kwargs["storage_offset"] = a.storage_offset()
|
||||
kwargs["device"] = a.device
|
||||
kwargs["layout"] = a.layout
|
||||
@ -31,7 +36,7 @@ class TwoTensor(torch.Tensor):
|
||||
assert a.storage_offset() == b.storage_offset()
|
||||
return out
|
||||
|
||||
def __init__(self, a, b):
|
||||
def __init__(self, a, b, outer_size=None, outer_stride=None):
|
||||
self.a = a
|
||||
self.b = b
|
||||
|
||||
@ -47,7 +52,7 @@ class TwoTensor(torch.Tensor):
|
||||
def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride):
|
||||
assert meta is None
|
||||
a, b = inner_tensors["a"], inner_tensors["b"]
|
||||
return TwoTensor(a, b)
|
||||
return TwoTensor(a, b, outer_size, outer_stride)
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args, kwargs):
|
||||
|
Reference in New Issue
Block a user