Delete Python reference implementation from torchdim, as it is untested (#160115)

Signed-off-by: Edward Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160115
Approved by: https://github.com/albanD
This commit is contained in:
Edward Yang
2025-08-10 00:17:46 -04:00
committed by PyTorch MergeBot
parent af10f1f86c
commit c9671dc865
6 changed files with 15 additions and 909 deletions

View File

@ -24,10 +24,6 @@ from . import op_properties
# use dict to avoid writing C++ bindings for set
pointwise = dict.fromkeys(op_properties.pointwise, True)
use_c = True
if not use_c:
from . import reference
class _Tensor:
# fast path around slow wrapping/unwrapping logic for simply queries used
@ -40,12 +36,8 @@ class _Tensor:
def dim(self):
return self.ndim
if use_c:
__torch_function__ = classmethod(_C.__torch_function__)
expand = _C._instancemethod(_C.expand)
else:
__torch_function__ = reference.__torch_function__
expand = reference.expand
__torch_function__ = classmethod(_C.__torch_function__)
expand = _C._instancemethod(_C.expand)
index = _C._instancemethod(_C.index)
@ -64,8 +56,6 @@ class Dim(_C.Dim, _Tensor):
class Tensor(_Tensor, _C.Tensor):
if not use_c:
from_batched = staticmethod(_C.Tensor_from_batched)
from_positional = staticmethod(_C.Tensor_from_positional)
sum = _C._instancemethod(_C.Tensor_sum)
@ -75,21 +65,17 @@ def cat(tensors, dim, new_dim):
return stack(tensors, n, dim).index([n, dim], new_dim)
if use_c:
_wrap = _C._wrap
_wrap = _C._wrap
def _def(name, *args, **kwargs):
orig = getattr(torch.Tensor, name)
setattr(_Tensor, name, _C._instancemethod(_wrap(orig, *args, **kwargs)))
t__getitem__ = _C._instancemethod(_C.__getitem__)
stack = _C.stack
split = _C._instancemethod(_C.split)
else:
_wrap, _def = reference._wrap, reference._def
t__getitem__ = reference.t__getitem__
stack = reference.stack
split = reference.split
def _def(name, *args, **kwargs):
orig = getattr(torch.Tensor, name)
setattr(_Tensor, name, _C._instancemethod(_wrap(orig, *args, **kwargs)))
t__getitem__ = _C._instancemethod(_C.__getitem__)
stack = _C.stack
split = _C._instancemethod(_C.split)
# note: there is no python reference
t__setitem__ = _C._instancemethod(_C.__setitem__)
@ -105,13 +91,10 @@ torch.Tensor.split = split
_Tensor.split = split
torch.Tensor.expand = _C._instancemethod(_C.expand)
torch.Tensor.index = _C._instancemethod(_C.index)
wrap_type(use_c, _Tensor, torch.Tensor, _Tensor.__torch_function__)
wrap_type(_Tensor, torch.Tensor, _Tensor.__torch_function__)
del _Tensor.ndim
if use_c:
_Tensor.order = _C._instancemethod(_C.order)
else:
_Tensor.order = reference.positional
_Tensor.order = _C._instancemethod(_C.order)
_def("mean")
_def("sum")