mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Skip tests that don't call gradcheck in slow gradcheck CI (#82117)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82117 Approved by: https://github.com/kit1980, https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
ce92c1cfe9
commit
0fcdf936e7
@ -15,7 +15,7 @@ import torch
|
||||
import torch.backends.cudnn
|
||||
import torch.utils.cpp_extension
|
||||
from torch.utils.cpp_extension import CUDA_HOME, ROCM_HOME
|
||||
from torch.testing._internal.common_utils import gradcheck
|
||||
from torch.testing._internal.common_utils import gradcheck, skipIfSlowGradcheckEnv
|
||||
|
||||
|
||||
TEST_CUDA = torch.cuda.is_available() and CUDA_HOME is not None
|
||||
@ -37,7 +37,8 @@ def remove_build_path():
|
||||
if os.path.exists(default_build_root):
|
||||
shutil.rmtree(default_build_root)
|
||||
|
||||
|
||||
# There's only one test that runs gracheck, run slow mode manually
|
||||
@skipIfSlowGradcheckEnv
|
||||
class TestCppExtensionJIT(common.TestCase):
|
||||
"""Tests just-in-time cpp extensions.
|
||||
Don't confuse this with the PyTorch JIT (aka TorchScript).
|
||||
@ -864,7 +865,8 @@ class TestCppExtensionJIT(common.TestCase):
|
||||
a = torch.randn(5, 5, requires_grad=True)
|
||||
b = torch.randn(5, 5, requires_grad=True)
|
||||
|
||||
gradcheck(torch.ops.my.add, [a, b], eps=1e-2)
|
||||
for fast_mode in (True, False):
|
||||
gradcheck(torch.ops.my.add, [a, b], eps=1e-2, fast_mode=fast_mode)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Reference in New Issue
Block a user