Output of nonzero is transposed, fix fake tensor (#144695)

Needs this companion executorch PR: https://github.com/pytorch/executorch/pull/7657

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144695
Approved by: https://github.com/bobrenjc93, https://github.com/albanD
This commit is contained in:
Edward Z. Yang
2025-01-21 07:35:42 -08:00
committed by PyTorch MergeBot
parent 323fb4dad0
commit 693d8c7e94
5 changed files with 47 additions and 22 deletions

View File

@ -1597,6 +1597,17 @@ class FakeTensorPropTest(TestCase):
self.assertIsNot(u0, u1)
self.assertTrue(statically_known_true(u0 == u1))
def test_nonzero_stride(self):
shape_env = ShapeEnv()
fake_mode = FakeTensorMode(shape_env=shape_env)
with fake_mode:
value = torch.ones(5)
fake_r = value.nonzero()
r = torch.ones(5).nonzero()
self.assertEqual(fake_r.T.is_contiguous(), r.T.is_contiguous())
def test_torch_load_with_fake_mode(self):
class TheModelClass(torch.nn.Module):
def __init__(self) -> None: