mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
add strides to slow path
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78610 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
1d7627955b
commit
a90f006fe5
@ -16,6 +16,7 @@ from torch.utils._python_dispatch import enable_torch_dispatch_mode, push_torch_
|
||||
import logging
|
||||
from functools import partial
|
||||
|
||||
|
||||
class TestPythonRegistration(TestCase):
|
||||
def test_override_aten_ops_with_multiple_libraries(self) -> None:
|
||||
x = torch.tensor([1, 2])
|
||||
@ -277,6 +278,7 @@ class TestPythonRegistration(TestCase):
|
||||
|
||||
test_helper("CONSERVATIVE")
|
||||
|
||||
|
||||
class TestPythonDispatch(TestCase):
|
||||
def test_basic(self) -> None:
|
||||
with capture_logs() as logs:
|
||||
@ -320,7 +322,6 @@ $0 = input('x')
|
||||
$1 = input('y')
|
||||
$2 = torch._ops.aten.abs.out($0, out=$1)''')
|
||||
|
||||
|
||||
def test_kwarg_only(self) -> None:
|
||||
with capture_logs() as logs:
|
||||
x = LoggingTensor(torch.ones(1))
|
||||
@ -1392,7 +1393,6 @@ $1 = torch._ops.aten.add.Tensor($0, $0)""")
|
||||
return not_contiguous_data.is_contiguous()
|
||||
return NotImplemented
|
||||
|
||||
|
||||
err_msg = "no implementation found for 'torch.ops.aten.is_contiguous'"
|
||||
e = ExampleTensor1(torch.randn(3, 3), use_wrapper_subclass)
|
||||
with self.assertRaisesRegex(TypeError, err_msg):
|
||||
@ -1481,7 +1481,6 @@ $1 = torch._ops.aten.add.Tensor($0, $0)""")
|
||||
return data.dim()
|
||||
return NotImplemented
|
||||
|
||||
|
||||
err_msg = "no implementation found for 'torch.ops.aten.dim'"
|
||||
e = DimNotImplementedTensor(torch.randn(3, 3), use_wrapper_subclass)
|
||||
with self.assertRaisesRegex(TypeError, err_msg):
|
||||
@ -1503,5 +1502,49 @@ $1 = torch._ops.aten.add.Tensor($0, $0)""")
|
||||
# https://github.com/pytorch/pytorch/issues/79079
|
||||
self.assertFalse(torch._C._dispatch_isTensorSubclassLike(torch.empty(0)))
|
||||
|
||||
def test_strides_slow_path(self):
|
||||
for use_wrapper_subclass in [True, False]:
|
||||
class StridesNotImplemented(torch.Tensor):
|
||||
@staticmethod
|
||||
def __new__(cls, data, wrapper):
|
||||
return TestPythonDispatch.subclass_helper(cls, data, wrapper, dispatch_sizes_strides_policy="strides")
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args, kwargs):
|
||||
return NotImplemented
|
||||
|
||||
class StridesCustomReturn(torch.Tensor):
|
||||
@staticmethod
|
||||
def __new__(cls, data, wrapper):
|
||||
return TestPythonDispatch.subclass_helper(cls, data, wrapper, dispatch_sizes_strides_policy="strides")
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args, kwargs):
|
||||
if func == torch.ops.aten.stride:
|
||||
return (4, 2)
|
||||
return NotImplemented
|
||||
|
||||
class StridesDefaultReturn(torch.Tensor):
|
||||
@staticmethod
|
||||
def __new__(cls, data, wrapper):
|
||||
return TestPythonDispatch.subclass_helper(cls, data, wrapper, dispatch_sizes_strides_policy="strides")
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args, kwargs):
|
||||
if func == torch.ops.aten.stride:
|
||||
return None
|
||||
return NotImplemented
|
||||
|
||||
err_msg = "no implementation found for 'torch.ops.aten.stride'"
|
||||
e = StridesNotImplemented(torch.randn(3, 3), use_wrapper_subclass)
|
||||
with self.assertRaisesRegex(TypeError, err_msg):
|
||||
e.stride()
|
||||
|
||||
e = StridesCustomReturn(torch.randn(3, 3), use_wrapper_subclass)
|
||||
self.assertEqual(e.stride(), (4, 2))
|
||||
|
||||
e = StridesDefaultReturn(torch.randn(6, 2), use_wrapper_subclass)
|
||||
self.assertEqual(e.stride(), (2, 1))
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
Reference in New Issue
Block a user