mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
add slow path for is_contiguous
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77396 Approved by: https://github.com/ezyang, https://github.com/cpuhrsch
This commit is contained in:
committed by
PyTorch MergeBot
parent
ee080918df
commit
f6beda89c6
@ -1163,6 +1163,72 @@ $1 = torch._ops.aten.add.Tensor($0, $0)''')
|
||||
# - More steps....
|
||||
y.exp()
|
||||
|
||||
def test_is_contiguous_slow_path(self):
|
||||
data = torch.randn(3, 3)
|
||||
contiguous_data = data.clone()
|
||||
not_contiguous_data = torch.as_strided(data.clone(), (2, 2), (1, 2))
|
||||
|
||||
def subclass_helper(cls, data, use_wrapper_subclass):
|
||||
if use_wrapper_subclass:
|
||||
kwargs = {}
|
||||
kwargs["device"] = data.device
|
||||
kwargs["dtype"] = data.dtype
|
||||
kwargs["layout"] = data.layout
|
||||
kwargs["requires_grad"] = True
|
||||
kwargs['dispatch_strides'] = True
|
||||
return torch.Tensor._make_wrapper_subclass(cls, data.size(), **kwargs) # type: ignore[attr-defined]
|
||||
else:
|
||||
return torch.Tensor._make_subclass(cls, data, True, dispatch_strides=True)
|
||||
|
||||
for use_wrapper_subclass in [True, False]:
|
||||
class ExampleTensor1(torch.Tensor):
|
||||
@staticmethod
|
||||
def __new__(cls, data, wrapper):
|
||||
return subclass_helper(cls, data, wrapper)
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args, kwargs):
|
||||
return NotImplemented
|
||||
|
||||
class ExampleTensor2(torch.Tensor):
|
||||
@staticmethod
|
||||
def __new__(cls, data, wrapper):
|
||||
return subclass_helper(cls, data, wrapper)
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args, kwargs):
|
||||
if func.overloadpacket == torch.ops.aten.is_contiguous:
|
||||
return contiguous_data.is_contiguous()
|
||||
return NotImplemented
|
||||
|
||||
class ExampleTensor3(torch.Tensor):
|
||||
@staticmethod
|
||||
def __new__(cls, data, wrapper):
|
||||
return subclass_helper(cls, data, wrapper)
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args, kwargs):
|
||||
if func.overloadpacket == torch.ops.aten.is_contiguous:
|
||||
return not_contiguous_data.is_contiguous()
|
||||
return NotImplemented
|
||||
|
||||
|
||||
err_msg = "no implementation found for 'torch.ops.aten.is_contiguous'"
|
||||
e = ExampleTensor1(torch.randn(3, 3), use_wrapper_subclass)
|
||||
with self.assertRaisesRegex(TypeError, err_msg):
|
||||
e.is_contiguous()
|
||||
with self.assertRaisesRegex(TypeError, err_msg):
|
||||
e.contiguous()
|
||||
|
||||
e = ExampleTensor2(torch.randn(3, 3), use_wrapper_subclass)
|
||||
self.assertEqual(e.is_contiguous(), True)
|
||||
e.contiguous() # this will just return the original TensorImpl since is_contiguous = True
|
||||
|
||||
err_msg = "no implementation found for"
|
||||
e = ExampleTensor3(torch.randn(3, 3), use_wrapper_subclass)
|
||||
self.assertEqual(e.is_contiguous(), False)
|
||||
with self.assertRaisesRegex(TypeError, err_msg):
|
||||
e.contiguous()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
Reference in New Issue
Block a user