[nativert][triton] improve hardware registration (#162499)

Summary: att

Test Plan:
ci

Rollback Plan:

Differential Revision: D82031814

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162499
Approved by: https://github.com/angelayi
This commit is contained in:
dolpm
2025-09-10 04:52:57 +00:00
committed by PyTorch MergeBot
parent 96ef26f71a
commit 1c16c18a53
8 changed files with 142 additions and 88 deletions

View File

@ -40,21 +40,24 @@ set(NATIVERT_TEST_SRCS
${TORCH_ROOT}/torch/nativert/graph/passes/pass_manager/GraphPasses.cpp
${TORCH_ROOT}/torch/nativert/graph/passes/pass_manager/PassManager.cpp
${TORCH_ROOT}/torch/nativert/kernels/KernelHandlerRegistry.cpp
${TORCH_ROOT}/torch/nativert/kernels/TritonKernel.cpp
${TORCH_ROOT}/torch/nativert/executor/triton/CpuTritonKernelManager.cpp
${TORCH_ROOT}/torch/nativert/kernels/TritonKernel.cpp
${TORCH_ROOT}/torch/nativert/executor/DelegateExecutor.cpp
)
if(USE_CUDA)
list(APPEND NATIVERT_TEST_SRCS ${TORCH_ROOT}/torch/nativert/executor/triton/CudaTritonKernelManager.cpp)
endif(MSVC)
endif()
add_executable(test_nativert
${TORCH_ROOT}/test/cpp/common/main.cpp
${NATIVERT_TEST_SRCS}
)
if(MSVC)
target_compile_definitions(test_nativert PRIVATE NATIVERT_MSVC_TEST)
endif()
# TODO temporary until we can delete the old gtest polyfills.
target_compile_definitions(test_nativert PRIVATE USE_GTEST)

View File

@ -6,9 +6,20 @@ using namespace ::testing;
using namespace torch::nativert;
TEST(TritonKernelManagerRegistrationTests, TestRegister) {
#ifndef USE_CUDA
EXPECT_TRUE(create_cuda_triton_kernel_manager == nullptr);
EXPECT_TRUE(TritonKernelManagerRegistry()->Has(at::kCPU));
#ifdef USE_CUDA
#ifdef USE_ROCM
EXPECT_TRUE(TritonKernelManagerRegistry()->Has(at::kHIP));
EXPECT_FALSE(TritonKernelManagerRegistry()->Has(at::kCUDA));
#else
EXPECT_FALSE(create_cuda_triton_kernel_manager == nullptr);
EXPECT_TRUE(TritonKernelManagerRegistry()->Has(at::kCUDA));
EXPECT_FALSE(TritonKernelManagerRegistry()->Has(at::kHIP));
#endif // USE_ROCM
#else
EXPECT_FALSE(TritonKernelManagerRegistry()->Has(at::kCUDA));
EXPECT_FALSE(TritonKernelManagerRegistry()->Has(at::kHIP));
#endif // USE_CUDA
}