mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Bugfix] Match eager stride semantics for cloned tensors with preserve_format in compile (#163017)
Fixes #161010 by making `clone_meta` match the semantics of strides for eager mode. This is: * Case 1: Tensor is_non_overlapping_and_dense; in this case, stride should match input tensor stride * Case 2: Otherwise, stride should be contiguous computed from input tensor using `compute_elementwise_output_strides` Pull Request resolved: https://github.com/pytorch/pytorch/pull/163017 Approved by: https://github.com/williamwen42, https://github.com/xmfan Co-authored-by: morrison-turnansky <mturnans@redhat.com>
This commit is contained in:
committed by
PyTorch MergeBot
parent
bc7b17a36d
commit
979e10f7d6
@ -342,6 +342,16 @@ $1: f32[2] = torch._ops.prims.sin.default($0)""")
|
||||
x = torch.randn(4, dtype=torch.complex64, device='meta').conj()
|
||||
x + 1
|
||||
|
||||
def test_clone_meta_stride_preservation_dense(self):
|
||||
tensor = torch.randn(1, 5).t()
|
||||
meta_clone = prims._clone_meta(tensor, memory_format=torch.preserve_format)
|
||||
self.assertEqual(tensor.stride(), meta_clone.stride())
|
||||
|
||||
def test_clone_meta_stride_preservation_sparse(self):
|
||||
tensor = torch.arange(12).float().view(3, 4)[1:, ::2]
|
||||
meta_clone = prims._clone_meta(tensor, memory_format=torch.preserve_format)
|
||||
self.assertEqual(tensor.contiguous().stride(), meta_clone.stride())
|
||||
|
||||
def test_check_deprecation_warning(self):
|
||||
with self.assertWarnsRegex(FutureWarning, 'will be removed in the future'):
|
||||
torch._prims_common.check(True, lambda: 'message')
|
||||
|
Reference in New Issue
Block a user