[inductor] Replace torch.allclose with torch.testing.assert_close in test_fx_fusion (#130618)

Preventative fix of a test failure with oneDNN v3.5 upgrade where order of float32 arithmetic may change in torch.admm ( bias term can be at the start or end of the arithmetic ) resulting in slightly different output due to float32 precision loss.

Replaced occurrences of torch.allclose with ~~torch._dynamo.testing.same~~  torch.testing.assert_close which is the recommended approach as per this issue https://github.com/pytorch/pytorch/issues/56544 ,the default tolerance is more relaxed than torch.allclose which satisfies the test with upcoming oneDNN change.

This should fix aarch64 ci failures in #129932

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130618
Approved by: https://github.com/jgong5, https://github.com/malfet
This commit is contained in:
Robert Hardwick
2024-08-06 03:58:43 +00:00
committed by PyTorch MergeBot
parent 4e610924d4
commit 4260f365ba

View File

@ -67,7 +67,7 @@ class TestFxFusion(TestCase):
]
for f in [test_kwarg, test_arg, test_arg2, test_kwarg2, test_kwarg3]:
traced = trace_func(f, inputs)
self.assertTrue(torch.allclose(f(*inputs), traced(*inputs)))
torch.testing.assert_close(f(*inputs), traced(*inputs))
self.assertEqual(count_call_method(traced, "tanh"), 2)
def test_linear_permute_fusion(self):
@ -98,7 +98,7 @@ class TestFxFusion(TestCase):
self.assertEqual(num_linear, 0)
self.assertEqual(num_linear_transpose, 1)
self.assertTrue(torch.allclose(module(input), traced(input)))
torch.testing.assert_close(module(input), traced(input))
def test_permute_linear_fusion(self):
class TestModule(torch.nn.Module):
@ -127,7 +127,7 @@ class TestFxFusion(TestCase):
self.assertEqual(num_linear, 0)
self.assertEqual(num_transpose_linear, 1)
self.assertTrue(torch.allclose(module(input), traced(input)))
torch.testing.assert_close(module(input), traced(input))
def test_permute_bmm_fusion(self):
class TestModule(torch.nn.Module):
@ -151,7 +151,7 @@ class TestFxFusion(TestCase):
self.assertEqual(num_bmm, 0)
self.assertEqual(num_transpose_matmul, 1)
self.assertTrue(torch.allclose(module(input), traced(input)))
torch.testing.assert_close(module(input), traced(input))
if __name__ == "__main__":