Skip tests that don't call gradcheck in slow gradcheck CI (#82117)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/82117
Approved by: https://github.com/kit1980, https://github.com/albanD
This commit is contained in:
soulitzer
2022-07-25 11:47:44 -04:00
committed by PyTorch MergeBot
parent ce92c1cfe9
commit 0fcdf936e7
4 changed files with 11 additions and 4 deletions

View File

@ -15,7 +15,7 @@ import torch
import torch.backends.cudnn
import torch.utils.cpp_extension
from torch.utils.cpp_extension import CUDA_HOME, ROCM_HOME
from torch.testing._internal.common_utils import gradcheck
from torch.testing._internal.common_utils import gradcheck, skipIfSlowGradcheckEnv
TEST_CUDA = torch.cuda.is_available() and CUDA_HOME is not None
@ -37,7 +37,8 @@ def remove_build_path():
if os.path.exists(default_build_root):
shutil.rmtree(default_build_root)
# There's only one test that runs gracheck, run slow mode manually
@skipIfSlowGradcheckEnv
class TestCppExtensionJIT(common.TestCase):
"""Tests just-in-time cpp extensions.
Don't confuse this with the PyTorch JIT (aka TorchScript).
@ -864,7 +865,8 @@ class TestCppExtensionJIT(common.TestCase):
a = torch.randn(5, 5, requires_grad=True)
b = torch.randn(5, 5, requires_grad=True)
gradcheck(torch.ops.my.add, [a, b], eps=1e-2)
for fast_mode in (True, False):
gradcheck(torch.ops.my.add, [a, b], eps=1e-2, fast_mode=fast_mode)
if __name__ == "__main__":