vfdev
2022-06-02 17:24:00 +00:00
committed by PyTorch MergeBot
parent 79ddc32b6a
commit 642fc94501

View File

@ -643,8 +643,8 @@ implementation more permissive about what operations are allowed::
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
args = [a._t if hasattr(a, '_t') else a for a in args]
metadatas = tuple(a._metadata for a in args if hasattr(a, '_metadata'))
args = [a._t if hasattr(a, '_t') else a for a in args]
assert len(metadatas) > 0
ret = func(*args, **kwargs)
return MetadataTensor(ret, metadata=metadatas[0])