load inline user overridable gencode (#156850)

Fixes https://github.com/pytorch/pytorch/issues/156815

As far as testing goes
* I tried to use cuobjdump but that was kinda goofy bccd9393a5 the problem was that the name of the cubin will have a single gencode always
* Another idea was to read stderr and check that the right amount of gencodes is there 0beadc01b3 this helped a lot to convince me locally that this test works, the test passed on my dev gpu but was failing in CI and I suspect it's because of a bad interaction with subprocesses
* Last approach was to have a simpler unit test to check which flags get added by default, this is not as comprehensive as the previous ideas but it works and is fast so will opt for this since I'm convinced testing is working per my own experiments and customers

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156850
Approved by: https://github.com/malfet
This commit is contained in:
Mark Saroufim
2025-06-26 10:15:03 +00:00
committed by PyTorch MergeBot
parent bbf1a6feac
commit 18b01afa9e
2 changed files with 31 additions and 1 deletions

View File

@ -21,6 +21,7 @@ import torch.utils.cpp_extension
from torch.testing._internal.common_cuda import TEST_CUDA, TEST_CUDNN
from torch.testing._internal.common_utils import gradcheck, TEST_XPU
from torch.utils.cpp_extension import (
_get_cuda_arch_flags,
_TORCH_PATH,
check_compiler_is_gcc,
CUDA_HOME,
@ -347,6 +348,35 @@ class TestCppExtensionJIT(common.TestCase):
# to avoid errors from here leaking into other tests
pass
@unittest.skipIf(not TEST_CUDA, "CUDA not found")
def test_cuda_arch_flags_non_default_gencode(self):
user_arch_flags = ["-gencode=arch=compute_86,code=sm_86"]
result = _get_cuda_arch_flags(user_arch_flags)
self.assertEqual(
len(result),
0,
f"User arch flags should prevent default generation. "
f"Expected: [], Got: {result}",
)
@unittest.skipIf(not TEST_CUDA, "CUDA not found")
def test_cuda_arch_flags_default_gencode(self):
default_flags = _get_cuda_arch_flags()
self.assertGreater(
len(default_flags), 0, "No args should generate default flags"
)
non_arch_flags = _get_cuda_arch_flags(["-O2", "--use-fast-math"])
self.assertGreater(
len(non_arch_flags), 0, "Non-arch flags should still generate defaults"
)
empty_flags = _get_cuda_arch_flags([])
self.assertGreater(
len(empty_flags), 0, "Empty list should generate default flags"
)
@unittest.skipIf(not TEST_CUDNN, "CuDNN not found")
@unittest.skipIf(TEST_ROCM, "Not supported on ROCm")
def test_jit_cudnn_extension(self):