mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[fix] Contiguity of torch.ravel
!
Hi! The PR aims to fix #70657. The objective was to ensure that `torch.ravel()` returns contiguous outputs for non-contiguous inputs. It also adds the test verifying the contiguity of the `torch.ravel`, which was missing. I am looking forward to your viewpoints. Thanks :) Thank you so much, @kshitij12345, for helping me clear up the concepts! :) cc: @mruberry @kshitij12345 Pull Request resolved: https://github.com/pytorch/pytorch/pull/71771 Approved by: https://github.com/mruberry
This commit is contained in:
committed by
PyTorch MergeBot
parent
5f94eea495
commit
f1af4dbed0
@ -2221,7 +2221,7 @@ Tensor flatten(const Tensor& self, DimnameList dims, Dimname out_dim) {
|
||||
}
|
||||
|
||||
Tensor ravel(const Tensor& self) {
|
||||
return self.reshape(-1);
|
||||
return self.contiguous().view(-1);
|
||||
}
|
||||
|
||||
static inline void handle_unflatten_exception(const std::runtime_error &e,
|
||||
|
@ -917,29 +917,38 @@ class TestOldViewOps(TestCase):
|
||||
flat = src.ravel()
|
||||
self.assertEqual(flat.shape, torch.Size([size]))
|
||||
self.assertEqual(src.view(-1), flat)
|
||||
self.assertEqual(flat._base, src)
|
||||
self.assertIs(flat._base, src)
|
||||
self.assertTrue(flat.is_contiguous())
|
||||
|
||||
# Non-continuous Tensor -> Copy
|
||||
if nc:
|
||||
nc_src = src.t()
|
||||
nc_flat = nc_src.ravel()
|
||||
self.assertEqual(nc_flat.shape, torch.Size([size]))
|
||||
self.assertEqual(nc_src.reshape(-1), nc_flat)
|
||||
self.assertTrue(nc_flat._base != nc_src)
|
||||
self.assertEqual(nc_src.contiguous().view(-1), nc_flat)
|
||||
self.assertIsNot(nc_flat._base, src)
|
||||
self.assertTrue(nc_flat.is_contiguous())
|
||||
|
||||
# Test that flatten returns 1-dim tensor when given a 0-dim tensor
|
||||
zero_dim_tensor = torch.tensor(123, device=device)
|
||||
flat0 = zero_dim_tensor.ravel()
|
||||
one_dim_tensor = torch.tensor([123], device=device)
|
||||
flat1 = zero_dim_tensor.ravel()
|
||||
nc_ones_tensor = torch.ones(10, device=device)[::2]
|
||||
flat2 = nc_ones_tensor.ravel()
|
||||
|
||||
self.assertEqual(zero_dim_tensor.shape, torch.Size([]))
|
||||
self.assertEqual(flat0.shape, torch.Size([1]))
|
||||
self.assertEqual(one_dim_tensor.shape, torch.Size([1]))
|
||||
self.assertEqual(flat1.shape, torch.Size([1]))
|
||||
self.assertEqual(nc_ones_tensor.shape, torch.Size([5]))
|
||||
self.assertEqual(flat2.shape, torch.Size([5]))
|
||||
self.assertEqual(flat0, one_dim_tensor)
|
||||
self.assertEqual(flat0, flat1)
|
||||
self.assertEqual(flat0.shape, flat1.shape)
|
||||
self.assertTrue(flat0.is_contiguous())
|
||||
self.assertTrue(flat1.is_contiguous())
|
||||
self.assertTrue(flat2.is_contiguous())
|
||||
|
||||
# Test both float tensor and quantized tensor
|
||||
tensors = [torch.randn(5, 5, 5, 5, device=device),
|
||||
|
@ -13286,6 +13286,7 @@ op_db: List[OpInfo] = [
|
||||
# polygamma functions have multiple singularities at x <= 0
|
||||
reference_numerics_filter=NumericsFilter(condition=lambda x: x < 0.1, safe_val=1)),
|
||||
OpInfo('ravel',
|
||||
ref=np.ravel,
|
||||
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
|
||||
supports_out=False,
|
||||
supports_forward_ad=True,
|
||||
|
Reference in New Issue
Block a user