mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
load inline user overridable gencode (#156850)
Fixes https://github.com/pytorch/pytorch/issues/156815 As far as testing goes * I tried to use cuobjdump but that was kinda goofybccd9393a5
the problem was that the name of the cubin will have a single gencode always * Another idea was to read stderr and check that the right amount of gencodes is there0beadc01b3
this helped a lot to convince me locally that this test works, the test passed on my dev gpu but was failing in CI and I suspect it's because of a bad interaction with subprocesses * Last approach was to have a simpler unit test to check which flags get added by default, this is not as comprehensive as the previous ideas but it works and is fast so will opt for this since I'm convinced testing is working per my own experiments and customers Pull Request resolved: https://github.com/pytorch/pytorch/pull/156850 Approved by: https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
bbf1a6feac
commit
18b01afa9e
@ -21,6 +21,7 @@ import torch.utils.cpp_extension
|
||||
from torch.testing._internal.common_cuda import TEST_CUDA, TEST_CUDNN
|
||||
from torch.testing._internal.common_utils import gradcheck, TEST_XPU
|
||||
from torch.utils.cpp_extension import (
|
||||
_get_cuda_arch_flags,
|
||||
_TORCH_PATH,
|
||||
check_compiler_is_gcc,
|
||||
CUDA_HOME,
|
||||
@ -347,6 +348,35 @@ class TestCppExtensionJIT(common.TestCase):
|
||||
# to avoid errors from here leaking into other tests
|
||||
pass
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "CUDA not found")
|
||||
def test_cuda_arch_flags_non_default_gencode(self):
|
||||
user_arch_flags = ["-gencode=arch=compute_86,code=sm_86"]
|
||||
result = _get_cuda_arch_flags(user_arch_flags)
|
||||
|
||||
self.assertEqual(
|
||||
len(result),
|
||||
0,
|
||||
f"User arch flags should prevent default generation. "
|
||||
f"Expected: [], Got: {result}",
|
||||
)
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "CUDA not found")
|
||||
def test_cuda_arch_flags_default_gencode(self):
|
||||
default_flags = _get_cuda_arch_flags()
|
||||
self.assertGreater(
|
||||
len(default_flags), 0, "No args should generate default flags"
|
||||
)
|
||||
|
||||
non_arch_flags = _get_cuda_arch_flags(["-O2", "--use-fast-math"])
|
||||
self.assertGreater(
|
||||
len(non_arch_flags), 0, "Non-arch flags should still generate defaults"
|
||||
)
|
||||
|
||||
empty_flags = _get_cuda_arch_flags([])
|
||||
self.assertGreater(
|
||||
len(empty_flags), 0, "Empty list should generate default flags"
|
||||
)
|
||||
|
||||
@unittest.skipIf(not TEST_CUDNN, "CuDNN not found")
|
||||
@unittest.skipIf(TEST_ROCM, "Not supported on ROCm")
|
||||
def test_jit_cudnn_extension(self):
|
||||
|
Reference in New Issue
Block a user