add slow path for is_contiguous

Pull Request resolved: https://github.com/pytorch/pytorch/pull/77396

Approved by: https://github.com/ezyang, https://github.com/cpuhrsch
This commit is contained in:
George Qi
2022-05-17 18:35:12 +00:00
committed by PyTorch MergeBot
parent ee080918df
commit f6beda89c6
6 changed files with 166 additions and 35 deletions

View File

@ -1163,6 +1163,72 @@ $1 = torch._ops.aten.add.Tensor($0, $0)''')
# - More steps....
y.exp()
def test_is_contiguous_slow_path(self):
data = torch.randn(3, 3)
contiguous_data = data.clone()
not_contiguous_data = torch.as_strided(data.clone(), (2, 2), (1, 2))
def subclass_helper(cls, data, use_wrapper_subclass):
if use_wrapper_subclass:
kwargs = {}
kwargs["device"] = data.device
kwargs["dtype"] = data.dtype
kwargs["layout"] = data.layout
kwargs["requires_grad"] = True
kwargs['dispatch_strides'] = True
return torch.Tensor._make_wrapper_subclass(cls, data.size(), **kwargs) # type: ignore[attr-defined]
else:
return torch.Tensor._make_subclass(cls, data, True, dispatch_strides=True)
for use_wrapper_subclass in [True, False]:
class ExampleTensor1(torch.Tensor):
@staticmethod
def __new__(cls, data, wrapper):
return subclass_helper(cls, data, wrapper)
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
return NotImplemented
class ExampleTensor2(torch.Tensor):
@staticmethod
def __new__(cls, data, wrapper):
return subclass_helper(cls, data, wrapper)
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
if func.overloadpacket == torch.ops.aten.is_contiguous:
return contiguous_data.is_contiguous()
return NotImplemented
class ExampleTensor3(torch.Tensor):
@staticmethod
def __new__(cls, data, wrapper):
return subclass_helper(cls, data, wrapper)
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
if func.overloadpacket == torch.ops.aten.is_contiguous:
return not_contiguous_data.is_contiguous()
return NotImplemented
err_msg = "no implementation found for 'torch.ops.aten.is_contiguous'"
e = ExampleTensor1(torch.randn(3, 3), use_wrapper_subclass)
with self.assertRaisesRegex(TypeError, err_msg):
e.is_contiguous()
with self.assertRaisesRegex(TypeError, err_msg):
e.contiguous()
e = ExampleTensor2(torch.randn(3, 3), use_wrapper_subclass)
self.assertEqual(e.is_contiguous(), True)
e.contiguous() # this will just return the original TensorImpl since is_contiguous = True
err_msg = "no implementation found for"
e = ExampleTensor3(torch.randn(3, 3), use_wrapper_subclass)
self.assertEqual(e.is_contiguous(), False)
with self.assertRaisesRegex(TypeError, err_msg):
e.contiguous()
if __name__ == '__main__':