mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
only compare attributes for meta tensors (#72508)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/72508 Todo: - [x] document this behavior - [x] add tests Test Plan: Imported from OSS Reviewed By: zou3519 Differential Revision: D34262452 Pulled By: ezyang fbshipit-source-id: bc5c9653d5c3ad5c6efccc9c8e0efc0d28e15104 (cherry picked from commit 233142c88e4cff02825c7e233aba9411a6df3e9f)
This commit is contained in:
committed by
PyTorch MergeBot
parent
b5f2574f36
commit
1f74e082e2
@ -14,7 +14,7 @@ from torch.testing._internal.common_utils import (
|
||||
torch_to_numpy_dtype_dict,
|
||||
)
|
||||
from torch.testing._internal.common_device_type import \
|
||||
(instantiate_device_type_tests, onlyCPU, dtypes, onlyNativeDeviceTypes)
|
||||
(instantiate_device_type_tests, onlyCPU, dtypes, onlyNativeDeviceTypes, skipMeta)
|
||||
from torch.testing._internal.common_dtype import (
|
||||
get_all_dtypes, get_all_int_dtypes, get_all_fp_dtypes, get_all_complex_dtypes
|
||||
)
|
||||
@ -729,6 +729,7 @@ class TestViewOps(TestCase):
|
||||
s = t.contiguous()
|
||||
self.assertTrue(s is t)
|
||||
|
||||
@skipMeta
|
||||
def test_contiguous_nonview(self, device):
|
||||
t = torch.ones(5, 5, device=device)
|
||||
nv = t.t().contiguous()
|
||||
@ -754,6 +755,7 @@ class TestViewOps(TestCase):
|
||||
v[6] = 0
|
||||
self.assertEqual(t[1, 1], v[6])
|
||||
|
||||
@skipMeta
|
||||
def test_reshape_nonview(self, device):
|
||||
t = torch.ones(5, 5, device=device)
|
||||
nv = torch.reshape(t.t(), (25,))
|
||||
@ -806,7 +808,8 @@ class TestViewOps(TestCase):
|
||||
idx_nv = (0,) * nv.ndim
|
||||
self.assertTrue(not nv._is_view())
|
||||
nv[idx_nv] = 0
|
||||
self.assertNotEqual(t[idx_t], nv[idx_nv])
|
||||
if device != "meta":
|
||||
self.assertNotEqual(t[idx_t], nv[idx_nv])
|
||||
t = torch.ones(2, 3, 2, 3, device=device).transpose(2, 3)
|
||||
nv = t.flatten(1, 3)
|
||||
assert_is_nonview(t, nv)
|
||||
@ -1027,7 +1030,9 @@ class TestOldViewOps(TestCase):
|
||||
self.assertRaises(RuntimeError, lambda: x.reshape(-1, -1))
|
||||
|
||||
y = torch.randn(4, 4, 4, device=device)[:, 0, :]
|
||||
self.assertNotEqual(y.data_ptr(), y.reshape(-1).data_ptr())
|
||||
# .data_ptr() on meta tensors is always 0 so they are equal regardless of the reshape
|
||||
if device != "meta":
|
||||
self.assertNotEqual(y.data_ptr(), y.reshape(-1).data_ptr())
|
||||
self.assertEqual(y.contiguous().view(-1), y.reshape(-1))
|
||||
self.assertEqual(y.reshape(2, 2, 4).data_ptr(), y.data_ptr())
|
||||
|
||||
|
Reference in New Issue
Block a user