diff --git a/test/jit/test_models.py b/test/jit/test_models.py index c6364f10197d..4dd099dbaad5 100644 --- a/test/jit/test_models.py +++ b/test/jit/test_models.py @@ -7,6 +7,7 @@ import unittest import torch import torch.nn as nn import torch.nn.functional as F +from torch.testing._internal.common_cuda import tf32_on_and_off from torch.testing._internal.common_utils import ( enable_profiling_mode_for_profiling_tests, GRAPH_EXECUTOR, @@ -482,6 +483,7 @@ class TestModels(JitTestCase): self._test_super_resolution(self, device="cpu") @unittest.skipIf(not RUN_CUDA, "no CUDA") + @tf32_on_and_off(0.02) def test_super_resolution_cuda(self): # XXX: export_import on CUDA modules doesn't work (#11480) self._test_super_resolution(self, device="cuda", check_export_import=False)