[fix] Contiguity of torch.ravel!

Hi!
The PR aims to fix #70657. The objective was to ensure that `torch.ravel()` returns contiguous outputs for non-contiguous inputs. It also adds the test verifying the contiguity of the `torch.ravel`, which was missing.
I am looking forward to your viewpoints. Thanks :)

Thank you so much, @kshitij12345, for helping me clear up the concepts! :)

cc: @mruberry @kshitij12345
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71771
Approved by: https://github.com/mruberry
This commit is contained in:
Khushi Agrawal
2022-03-28 16:41:37 +00:00
committed by PyTorch MergeBot
parent 5f94eea495
commit f1af4dbed0
3 changed files with 14 additions and 4 deletions

View File

@ -2221,7 +2221,7 @@ Tensor flatten(const Tensor& self, DimnameList dims, Dimname out_dim) {
}
Tensor ravel(const Tensor& self) {
return self.reshape(-1);
return self.contiguous().view(-1);
}
static inline void handle_unflatten_exception(const std::runtime_error &e,

View File

@ -917,29 +917,38 @@ class TestOldViewOps(TestCase):
flat = src.ravel()
self.assertEqual(flat.shape, torch.Size([size]))
self.assertEqual(src.view(-1), flat)
self.assertEqual(flat._base, src)
self.assertIs(flat._base, src)
self.assertTrue(flat.is_contiguous())
# Non-continuous Tensor -> Copy
if nc:
nc_src = src.t()
nc_flat = nc_src.ravel()
self.assertEqual(nc_flat.shape, torch.Size([size]))
self.assertEqual(nc_src.reshape(-1), nc_flat)
self.assertTrue(nc_flat._base != nc_src)
self.assertEqual(nc_src.contiguous().view(-1), nc_flat)
self.assertIsNot(nc_flat._base, src)
self.assertTrue(nc_flat.is_contiguous())
# Test that flatten returns 1-dim tensor when given a 0-dim tensor
zero_dim_tensor = torch.tensor(123, device=device)
flat0 = zero_dim_tensor.ravel()
one_dim_tensor = torch.tensor([123], device=device)
flat1 = zero_dim_tensor.ravel()
nc_ones_tensor = torch.ones(10, device=device)[::2]
flat2 = nc_ones_tensor.ravel()
self.assertEqual(zero_dim_tensor.shape, torch.Size([]))
self.assertEqual(flat0.shape, torch.Size([1]))
self.assertEqual(one_dim_tensor.shape, torch.Size([1]))
self.assertEqual(flat1.shape, torch.Size([1]))
self.assertEqual(nc_ones_tensor.shape, torch.Size([5]))
self.assertEqual(flat2.shape, torch.Size([5]))
self.assertEqual(flat0, one_dim_tensor)
self.assertEqual(flat0, flat1)
self.assertEqual(flat0.shape, flat1.shape)
self.assertTrue(flat0.is_contiguous())
self.assertTrue(flat1.is_contiguous())
self.assertTrue(flat2.is_contiguous())
# Test both float tensor and quantized tensor
tensors = [torch.randn(5, 5, 5, 5, device=device),

View File

@ -13286,6 +13286,7 @@ op_db: List[OpInfo] = [
# polygamma functions have multiple singularities at x <= 0
reference_numerics_filter=NumericsFilter(condition=lambda x: x < 0.1, safe_val=1)),
OpInfo('ravel',
ref=np.ravel,
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
supports_out=False,
supports_forward_ad=True,