mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Previous discussion: https://github.com/pytorch/pytorch/pull/143671#issuecomment-2560112499 and https://github.com/pytorch/pytorch/pull/143671 Differential Revision: [D67693609](https://our.internmc.facebook.com/intern/diff/D67693609) Pull Request resolved: https://github.com/pytorch/pytorch/pull/143946 Approved by: https://github.com/bdhirsh
159 lines
5.1 KiB
Python
159 lines
5.1 KiB
Python
# mypy: ignore-errors
|
|
|
|
|
|
from collections import namedtuple
|
|
|
|
import torch
|
|
import torch.utils._pytree as pytree
|
|
from torch.utils._python_dispatch import return_and_correct_aliasing
|
|
|
|
|
|
FancyNamedTuple = namedtuple("FancyNamedTuple", ["foo", "bar"])
|
|
|
|
|
|
# A simple tensor subclass that holds a tensor with custom metadata and custom method
|
|
class ConstantExtraMetadataTensor(torch.Tensor):
|
|
@staticmethod
|
|
def __new__(cls, elem):
|
|
shape = elem.shape
|
|
kwargs = {}
|
|
kwargs["strides"] = elem.stride()
|
|
kwargs["storage_offset"] = elem.storage_offset()
|
|
kwargs["device"] = elem.device
|
|
kwargs["layout"] = elem.layout
|
|
kwargs["requires_grad"] = elem.requires_grad
|
|
kwargs["dtype"] = elem.dtype
|
|
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs)
|
|
|
|
def __init__(self, elem):
|
|
self.elem = elem
|
|
self.constant_attribute = 4
|
|
|
|
def __repr__(self):
|
|
inner_repr = repr(self.elem)
|
|
return f"CustomTensor({inner_repr})"
|
|
|
|
def get_complicated_metadata(self):
|
|
return FancyNamedTuple(self.constant_attribute, self.constant_attribute)
|
|
|
|
def __tensor_flatten__(self):
|
|
return ["elem"], self.constant_attribute
|
|
|
|
def add_constant(self, a):
|
|
self.constant_attribute += a
|
|
|
|
@staticmethod
|
|
def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride):
|
|
assert meta is not None
|
|
elem = inner_tensors["elem"]
|
|
out = ConstantExtraMetadataTensor(elem)
|
|
out.constant_attribute = meta
|
|
return out
|
|
|
|
@classmethod
|
|
def __torch_dispatch__(cls, func, types, args, kwargs):
|
|
if kwargs is None:
|
|
kwargs = {}
|
|
args_inner = pytree.tree_map_only(
|
|
ConstantExtraMetadataTensor, lambda x: x.elem, args
|
|
)
|
|
|
|
kwargs_inner = pytree.tree_map_only(
|
|
ConstantExtraMetadataTensor, lambda x: x.elem, kwargs
|
|
)
|
|
|
|
out_inner = func(*args_inner, **kwargs_inner)
|
|
out_inner_flat, spec = pytree.tree_flatten(out_inner)
|
|
# for aten ops that return non-tensors, just assume that
|
|
# our cust inner tensors return the same value
|
|
out_flat = [
|
|
ConstantExtraMetadataTensor(o_inner)
|
|
if isinstance(o_inner, torch.Tensor)
|
|
else o_inner
|
|
for o_inner in out_inner_flat
|
|
]
|
|
out = pytree.tree_unflatten(out_flat, spec)
|
|
return return_and_correct_aliasing(func, args, kwargs, out)
|
|
|
|
|
|
# A simple tensor subclass that always returns plain tensor during __torch_dispatch__
|
|
# It is similar to TwoTensor and is used to simulate torchao quantized tensors
|
|
class CustomTensorPlainOut(torch.Tensor):
|
|
@staticmethod
|
|
def __new__(cls, elem1, elem2):
|
|
shape = elem1.shape
|
|
kwargs = {}
|
|
kwargs["strides"] = elem1.stride()
|
|
kwargs["storage_offset"] = elem1.storage_offset()
|
|
kwargs["device"] = elem1.device
|
|
kwargs["layout"] = elem1.layout
|
|
kwargs["requires_grad"] = elem1.requires_grad
|
|
kwargs["dtype"] = elem1.dtype
|
|
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs)
|
|
|
|
def __init__(self, elem1, elem2):
|
|
self.elem1 = elem1
|
|
self.elem2 = elem2
|
|
|
|
def get_elem(self):
|
|
return self.elem1
|
|
|
|
def __repr__(self):
|
|
inner_repr_1 = repr(self.elem1)
|
|
inner_repr_2 = repr(self.elem2)
|
|
return f"CustomTensorPlainOut({inner_repr_1}, {inner_repr_2})"
|
|
|
|
def __tensor_flatten__(self):
|
|
return ["elem1", "elem2"], None
|
|
|
|
@staticmethod
|
|
def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride):
|
|
elem1 = inner_tensors["elem1"]
|
|
elem2 = inner_tensors["elem2"]
|
|
out = CustomTensorPlainOut(elem1, elem2)
|
|
return out
|
|
|
|
@classmethod
|
|
def __torch_dispatch__(cls, func, types, args, kwargs):
|
|
# Don't use this tensor with view ops
|
|
if kwargs is None:
|
|
kwargs = {}
|
|
args_inner_1 = pytree.tree_map_only(
|
|
CustomTensorPlainOut, lambda x: x.elem1, args
|
|
)
|
|
|
|
kwargs_inner_1 = pytree.tree_map_only(
|
|
CustomTensorPlainOut, lambda x: x.elem1, kwargs
|
|
)
|
|
|
|
args_inner_2 = pytree.tree_map_only(
|
|
CustomTensorPlainOut, lambda x: x.elem2, args
|
|
)
|
|
|
|
kwargs_inner_2 = pytree.tree_map_only(
|
|
CustomTensorPlainOut, lambda x: x.elem2, kwargs
|
|
)
|
|
|
|
out_inner_1 = func(*args_inner_1, **kwargs_inner_1)
|
|
out_inner_2 = func(*args_inner_2, **kwargs_inner_2)
|
|
|
|
out_inner_flat_1, spec = pytree.tree_flatten(out_inner_1)
|
|
out_inner_flat_2, spec = pytree.tree_flatten(out_inner_2)
|
|
|
|
if func.is_view:
|
|
new_out = pytree.tree_unflatten(
|
|
(
|
|
CustomTensorPlainOut(tensor1, tensor2)
|
|
for tensor1, tensor2 in zip(out_inner_flat_1, out_inner_flat_2)
|
|
),
|
|
spec,
|
|
)
|
|
return return_and_correct_aliasing(func, args, kwargs, new_out)
|
|
|
|
out_new = (
|
|
out_inner_flat_1[ix] + out_inner_flat_2[ix]
|
|
for ix in range(len(out_inner_flat_1))
|
|
)
|
|
|
|
return pytree.tree_unflatten(out_new, spec)
|