only compare attributes for meta tensors (#72508)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72508

Todo:

- [x] document this behavior
- [x] add tests

Test Plan: Imported from OSS

Reviewed By: zou3519

Differential Revision: D34262452

Pulled By: ezyang

fbshipit-source-id: bc5c9653d5c3ad5c6efccc9c8e0efc0d28e15104
(cherry picked from commit 233142c88e4cff02825c7e233aba9411a6df3e9f)
This commit is contained in:
Philip Meier
2022-02-16 18:25:35 -08:00
committed by PyTorch MergeBot
parent b5f2574f36
commit 1f74e082e2
9 changed files with 252 additions and 211 deletions

View File

@ -14,7 +14,7 @@ from torch.testing._internal.common_utils import (
torch_to_numpy_dtype_dict,
)
from torch.testing._internal.common_device_type import \
(instantiate_device_type_tests, onlyCPU, dtypes, onlyNativeDeviceTypes)
(instantiate_device_type_tests, onlyCPU, dtypes, onlyNativeDeviceTypes, skipMeta)
from torch.testing._internal.common_dtype import (
get_all_dtypes, get_all_int_dtypes, get_all_fp_dtypes, get_all_complex_dtypes
)
@ -729,6 +729,7 @@ class TestViewOps(TestCase):
s = t.contiguous()
self.assertTrue(s is t)
@skipMeta
def test_contiguous_nonview(self, device):
t = torch.ones(5, 5, device=device)
nv = t.t().contiguous()
@ -754,6 +755,7 @@ class TestViewOps(TestCase):
v[6] = 0
self.assertEqual(t[1, 1], v[6])
@skipMeta
def test_reshape_nonview(self, device):
t = torch.ones(5, 5, device=device)
nv = torch.reshape(t.t(), (25,))
@ -806,7 +808,8 @@ class TestViewOps(TestCase):
idx_nv = (0,) * nv.ndim
self.assertTrue(not nv._is_view())
nv[idx_nv] = 0
self.assertNotEqual(t[idx_t], nv[idx_nv])
if device != "meta":
self.assertNotEqual(t[idx_t], nv[idx_nv])
t = torch.ones(2, 3, 2, 3, device=device).transpose(2, 3)
nv = t.flatten(1, 3)
assert_is_nonview(t, nv)
@ -1027,7 +1030,9 @@ class TestOldViewOps(TestCase):
self.assertRaises(RuntimeError, lambda: x.reshape(-1, -1))
y = torch.randn(4, 4, 4, device=device)[:, 0, :]
self.assertNotEqual(y.data_ptr(), y.reshape(-1).data_ptr())
# .data_ptr() on meta tensors is always 0 so they are equal regardless of the reshape
if device != "meta":
self.assertNotEqual(y.data_ptr(), y.reshape(-1).data_ptr())
self.assertEqual(y.contiguous().view(-1), y.reshape(-1))
self.assertEqual(y.reshape(2, 2, 4).data_ptr(), y.data_ptr())