Revert "add slow path for is_contiguous"

This reverts commit f6beda89c6acbb92ff7f82699b9ea4c5c7428a19.

Reverted https://github.com/pytorch/pytorch/pull/77396 on behalf of https://github.com/malfet
This commit is contained in:
PyTorch MergeBot
2022-05-19 17:07:54 +00:00
parent 2d2b9f9980
commit 00a187c373
6 changed files with 35 additions and 166 deletions

View File

@ -1163,72 +1163,6 @@ $1 = torch._ops.aten.add.Tensor($0, $0)''')
# - More steps....
y.exp()
def test_is_contiguous_slow_path(self):
data = torch.randn(3, 3)
contiguous_data = data.clone()
not_contiguous_data = torch.as_strided(data.clone(), (2, 2), (1, 2))
def subclass_helper(cls, data, use_wrapper_subclass):
if use_wrapper_subclass:
kwargs = {}
kwargs["device"] = data.device
kwargs["dtype"] = data.dtype
kwargs["layout"] = data.layout
kwargs["requires_grad"] = True
kwargs['dispatch_strides'] = True
return torch.Tensor._make_wrapper_subclass(cls, data.size(), **kwargs) # type: ignore[attr-defined]
else:
return torch.Tensor._make_subclass(cls, data, True, dispatch_strides=True)
for use_wrapper_subclass in [True, False]:
class ExampleTensor1(torch.Tensor):
@staticmethod
def __new__(cls, data, wrapper):
return subclass_helper(cls, data, wrapper)
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
return NotImplemented
class ExampleTensor2(torch.Tensor):
@staticmethod
def __new__(cls, data, wrapper):
return subclass_helper(cls, data, wrapper)
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
if func.overloadpacket == torch.ops.aten.is_contiguous:
return contiguous_data.is_contiguous()
return NotImplemented
class ExampleTensor3(torch.Tensor):
@staticmethod
def __new__(cls, data, wrapper):
return subclass_helper(cls, data, wrapper)
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
if func.overloadpacket == torch.ops.aten.is_contiguous:
return not_contiguous_data.is_contiguous()
return NotImplemented
err_msg = "no implementation found for 'torch.ops.aten.is_contiguous'"
e = ExampleTensor1(torch.randn(3, 3), use_wrapper_subclass)
with self.assertRaisesRegex(TypeError, err_msg):
e.is_contiguous()
with self.assertRaisesRegex(TypeError, err_msg):
e.contiguous()
e = ExampleTensor2(torch.randn(3, 3), use_wrapper_subclass)
self.assertEqual(e.is_contiguous(), True)
e.contiguous() # this will just return the original TensorImpl since is_contiguous = True
err_msg = "no implementation found for"
e = ExampleTensor3(torch.randn(3, 3), use_wrapper_subclass)
self.assertEqual(e.is_contiguous(), False)
with self.assertRaisesRegex(TypeError, err_msg):
e.contiguous()
if __name__ == '__main__':