mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
[cuDNN][TF32] Account for TF32 in test_super_resolution_cuda
(#161662)
cuDNN seems to be dispatching to TF32 kernels on B200 Pull Request resolved: https://github.com/pytorch/pytorch/pull/161662 Approved by: https://github.com/Skylion007
This commit is contained in:
committed by
PyTorch MergeBot
parent
196232bb93
commit
2e77a08b95
@ -7,6 +7,7 @@ import unittest
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from torch.testing._internal.common_cuda import tf32_on_and_off
|
||||||
from torch.testing._internal.common_utils import (
|
from torch.testing._internal.common_utils import (
|
||||||
enable_profiling_mode_for_profiling_tests,
|
enable_profiling_mode_for_profiling_tests,
|
||||||
GRAPH_EXECUTOR,
|
GRAPH_EXECUTOR,
|
||||||
@ -482,6 +483,7 @@ class TestModels(JitTestCase):
|
|||||||
self._test_super_resolution(self, device="cpu")
|
self._test_super_resolution(self, device="cpu")
|
||||||
|
|
||||||
@unittest.skipIf(not RUN_CUDA, "no CUDA")
|
@unittest.skipIf(not RUN_CUDA, "no CUDA")
|
||||||
|
@tf32_on_and_off(0.02)
|
||||||
def test_super_resolution_cuda(self):
|
def test_super_resolution_cuda(self):
|
||||||
# XXX: export_import on CUDA modules doesn't work (#11480)
|
# XXX: export_import on CUDA modules doesn't work (#11480)
|
||||||
self._test_super_resolution(self, device="cuda", check_export_import=False)
|
self._test_super_resolution(self, device="cuda", check_export_import=False)
|
||||||
|
Reference in New Issue
Block a user