mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[nativert][triton] improve hardware registration (#162499)
Summary: att Test Plan: ci Rollback Plan: Differential Revision: D82031814 Pull Request resolved: https://github.com/pytorch/pytorch/pull/162499 Approved by: https://github.com/angelayi
This commit is contained in:
@ -40,21 +40,24 @@ set(NATIVERT_TEST_SRCS
|
||||
${TORCH_ROOT}/torch/nativert/graph/passes/pass_manager/GraphPasses.cpp
|
||||
${TORCH_ROOT}/torch/nativert/graph/passes/pass_manager/PassManager.cpp
|
||||
${TORCH_ROOT}/torch/nativert/kernels/KernelHandlerRegistry.cpp
|
||||
${TORCH_ROOT}/torch/nativert/kernels/TritonKernel.cpp
|
||||
${TORCH_ROOT}/torch/nativert/executor/triton/CpuTritonKernelManager.cpp
|
||||
${TORCH_ROOT}/torch/nativert/kernels/TritonKernel.cpp
|
||||
${TORCH_ROOT}/torch/nativert/executor/DelegateExecutor.cpp
|
||||
)
|
||||
|
||||
if(USE_CUDA)
|
||||
list(APPEND NATIVERT_TEST_SRCS ${TORCH_ROOT}/torch/nativert/executor/triton/CudaTritonKernelManager.cpp)
|
||||
endif(MSVC)
|
||||
|
||||
endif()
|
||||
|
||||
add_executable(test_nativert
|
||||
${TORCH_ROOT}/test/cpp/common/main.cpp
|
||||
${NATIVERT_TEST_SRCS}
|
||||
)
|
||||
|
||||
if(MSVC)
|
||||
target_compile_definitions(test_nativert PRIVATE NATIVERT_MSVC_TEST)
|
||||
endif()
|
||||
|
||||
# TODO temporary until we can delete the old gtest polyfills.
|
||||
target_compile_definitions(test_nativert PRIVATE USE_GTEST)
|
||||
|
||||
|
@ -6,9 +6,20 @@ using namespace ::testing;
|
||||
using namespace torch::nativert;
|
||||
|
||||
TEST(TritonKernelManagerRegistrationTests, TestRegister) {
|
||||
#ifndef USE_CUDA
|
||||
EXPECT_TRUE(create_cuda_triton_kernel_manager == nullptr);
|
||||
EXPECT_TRUE(TritonKernelManagerRegistry()->Has(at::kCPU));
|
||||
|
||||
#ifdef USE_CUDA
|
||||
#ifdef USE_ROCM
|
||||
EXPECT_TRUE(TritonKernelManagerRegistry()->Has(at::kHIP));
|
||||
EXPECT_FALSE(TritonKernelManagerRegistry()->Has(at::kCUDA));
|
||||
|
||||
#else
|
||||
EXPECT_FALSE(create_cuda_triton_kernel_manager == nullptr);
|
||||
EXPECT_TRUE(TritonKernelManagerRegistry()->Has(at::kCUDA));
|
||||
EXPECT_FALSE(TritonKernelManagerRegistry()->Has(at::kHIP));
|
||||
|
||||
#endif // USE_ROCM
|
||||
#else
|
||||
EXPECT_FALSE(TritonKernelManagerRegistry()->Has(at::kCUDA));
|
||||
EXPECT_FALSE(TritonKernelManagerRegistry()->Has(at::kHIP));
|
||||
#endif // USE_CUDA
|
||||
}
|
||||
|
Reference in New Issue
Block a user