mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[inductor] Replace torch.allclose with torch.testing.assert_close in test_fx_fusion (#130618)
Preventative fix of a test failure with oneDNN v3.5 upgrade where order of float32 arithmetic may change in torch.admm ( bias term can be at the start or end of the arithmetic ) resulting in slightly different output due to float32 precision loss. Replaced occurrences of torch.allclose with ~~torch._dynamo.testing.same~~ torch.testing.assert_close which is the recommended approach as per this issue https://github.com/pytorch/pytorch/issues/56544 ,the default tolerance is more relaxed than torch.allclose which satisfies the test with upcoming oneDNN change. This should fix aarch64 ci failures in #129932 Pull Request resolved: https://github.com/pytorch/pytorch/pull/130618 Approved by: https://github.com/jgong5, https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
4e610924d4
commit
4260f365ba
@ -67,7 +67,7 @@ class TestFxFusion(TestCase):
|
||||
]
|
||||
for f in [test_kwarg, test_arg, test_arg2, test_kwarg2, test_kwarg3]:
|
||||
traced = trace_func(f, inputs)
|
||||
self.assertTrue(torch.allclose(f(*inputs), traced(*inputs)))
|
||||
torch.testing.assert_close(f(*inputs), traced(*inputs))
|
||||
self.assertEqual(count_call_method(traced, "tanh"), 2)
|
||||
|
||||
def test_linear_permute_fusion(self):
|
||||
@ -98,7 +98,7 @@ class TestFxFusion(TestCase):
|
||||
self.assertEqual(num_linear, 0)
|
||||
self.assertEqual(num_linear_transpose, 1)
|
||||
|
||||
self.assertTrue(torch.allclose(module(input), traced(input)))
|
||||
torch.testing.assert_close(module(input), traced(input))
|
||||
|
||||
def test_permute_linear_fusion(self):
|
||||
class TestModule(torch.nn.Module):
|
||||
@ -127,7 +127,7 @@ class TestFxFusion(TestCase):
|
||||
self.assertEqual(num_linear, 0)
|
||||
self.assertEqual(num_transpose_linear, 1)
|
||||
|
||||
self.assertTrue(torch.allclose(module(input), traced(input)))
|
||||
torch.testing.assert_close(module(input), traced(input))
|
||||
|
||||
def test_permute_bmm_fusion(self):
|
||||
class TestModule(torch.nn.Module):
|
||||
@ -151,7 +151,7 @@ class TestFxFusion(TestCase):
|
||||
self.assertEqual(num_bmm, 0)
|
||||
self.assertEqual(num_transpose_matmul, 1)
|
||||
|
||||
self.assertTrue(torch.allclose(module(input), traced(input)))
|
||||
torch.testing.assert_close(module(input), traced(input))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Reference in New Issue
Block a user