mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Output of nonzero is transposed, fix fake tensor (#144695)
Needs this companion executorch PR: https://github.com/pytorch/executorch/pull/7657 Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/144695 Approved by: https://github.com/bobrenjc93, https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
323fb4dad0
commit
693d8c7e94
@ -1597,6 +1597,17 @@ class FakeTensorPropTest(TestCase):
|
||||
self.assertIsNot(u0, u1)
|
||||
self.assertTrue(statically_known_true(u0 == u1))
|
||||
|
||||
def test_nonzero_stride(self):
|
||||
shape_env = ShapeEnv()
|
||||
fake_mode = FakeTensorMode(shape_env=shape_env)
|
||||
with fake_mode:
|
||||
value = torch.ones(5)
|
||||
fake_r = value.nonzero()
|
||||
|
||||
r = torch.ones(5).nonzero()
|
||||
|
||||
self.assertEqual(fake_r.T.is_contiguous(), r.T.is_contiguous())
|
||||
|
||||
def test_torch_load_with_fake_mode(self):
|
||||
class TheModelClass(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
|
Reference in New Issue
Block a user