add strides to slow path

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78610

Approved by: https://github.com/ezyang
This commit is contained in:
George Qi
2022-06-10 03:02:28 +00:00
committed by PyTorch MergeBot
parent 1d7627955b
commit a90f006fe5
8 changed files with 125 additions and 30 deletions

View File

@ -16,6 +16,7 @@ from torch.utils._python_dispatch import enable_torch_dispatch_mode, push_torch_
import logging
from functools import partial
class TestPythonRegistration(TestCase):
def test_override_aten_ops_with_multiple_libraries(self) -> None:
x = torch.tensor([1, 2])
@ -277,6 +278,7 @@ class TestPythonRegistration(TestCase):
test_helper("CONSERVATIVE")
class TestPythonDispatch(TestCase):
def test_basic(self) -> None:
with capture_logs() as logs:
@ -320,7 +322,6 @@ $0 = input('x')
$1 = input('y')
$2 = torch._ops.aten.abs.out($0, out=$1)''')
def test_kwarg_only(self) -> None:
with capture_logs() as logs:
x = LoggingTensor(torch.ones(1))
@ -1392,7 +1393,6 @@ $1 = torch._ops.aten.add.Tensor($0, $0)""")
return not_contiguous_data.is_contiguous()
return NotImplemented
err_msg = "no implementation found for 'torch.ops.aten.is_contiguous'"
e = ExampleTensor1(torch.randn(3, 3), use_wrapper_subclass)
with self.assertRaisesRegex(TypeError, err_msg):
@ -1481,7 +1481,6 @@ $1 = torch._ops.aten.add.Tensor($0, $0)""")
return data.dim()
return NotImplemented
err_msg = "no implementation found for 'torch.ops.aten.dim'"
e = DimNotImplementedTensor(torch.randn(3, 3), use_wrapper_subclass)
with self.assertRaisesRegex(TypeError, err_msg):
@ -1503,5 +1502,49 @@ $1 = torch._ops.aten.add.Tensor($0, $0)""")
# https://github.com/pytorch/pytorch/issues/79079
self.assertFalse(torch._C._dispatch_isTensorSubclassLike(torch.empty(0)))
def test_strides_slow_path(self):
for use_wrapper_subclass in [True, False]:
class StridesNotImplemented(torch.Tensor):
@staticmethod
def __new__(cls, data, wrapper):
return TestPythonDispatch.subclass_helper(cls, data, wrapper, dispatch_sizes_strides_policy="strides")
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
return NotImplemented
class StridesCustomReturn(torch.Tensor):
@staticmethod
def __new__(cls, data, wrapper):
return TestPythonDispatch.subclass_helper(cls, data, wrapper, dispatch_sizes_strides_policy="strides")
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
if func == torch.ops.aten.stride:
return (4, 2)
return NotImplemented
class StridesDefaultReturn(torch.Tensor):
@staticmethod
def __new__(cls, data, wrapper):
return TestPythonDispatch.subclass_helper(cls, data, wrapper, dispatch_sizes_strides_policy="strides")
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
if func == torch.ops.aten.stride:
return None
return NotImplemented
err_msg = "no implementation found for 'torch.ops.aten.stride'"
e = StridesNotImplemented(torch.randn(3, 3), use_wrapper_subclass)
with self.assertRaisesRegex(TypeError, err_msg):
e.stride()
e = StridesCustomReturn(torch.randn(3, 3), use_wrapper_subclass)
self.assertEqual(e.stride(), (4, 2))
e = StridesDefaultReturn(torch.randn(6, 2), use_wrapper_subclass)
self.assertEqual(e.stride(), (2, 1))
if __name__ == '__main__':
run_tests()