From 1c16c18a534d320d101ebb10c88bdf57cf84b3b1 Mon Sep 17 00:00:00 2001 From: dolpm <34420038+dolpm@users.noreply.github.com> Date: Wed, 10 Sep 2025 04:52:57 +0000 Subject: [PATCH] [nativert][triton] improve hardware registration (#162499) Summary: att Test Plan: ci Rollback Plan: Differential Revision: D82031814 Pull Request resolved: https://github.com/pytorch/pytorch/pull/162499 Approved by: https://github.com/angelayi --- caffe2/CMakeLists.txt | 5 ++ test/cpp/nativert/CMakeLists.txt | 9 ++- ...est_triton_kernel_manager_registration.cpp | 17 +++++- .../triton/CpuTritonKernelManager.cpp | 57 ++++++++++++++++++- .../executor/triton/CpuTritonKernelManager.h | 51 ----------------- .../triton/CudaTritonKernelManager.cpp | 36 ++++++++---- .../executor/triton/TritonKernelManager.h | 12 +++- torch/nativert/kernels/TritonKernel.cpp | 43 ++++++++------ 8 files changed, 142 insertions(+), 88 deletions(-) delete mode 100644 torch/nativert/executor/triton/CpuTritonKernelManager.h diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 4623fec08fe3..99d4b2cd5aa9 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -552,6 +552,11 @@ if(USE_CUDA OR USE_ROCM) append_filelist("libtorch_cuda_core_sources" Caffe2_GPU_HIP_JIT_FUSERS_SRCS) endif() +if(USE_CUDA) + # eventually do rocm + append_filelist("libtorch_nativert_cuda_sources" Caffe2_GPU_SRCS) +endif() + if(USE_CUDA) list(APPEND Caffe2_GPU_CU_SRCS ${Caffe2_GPU_HIP_JIT_FUSERS_SRCS}) add_library(caffe2_nvrtc SHARED ${ATen_NVRTC_STUB_SRCS}) diff --git a/test/cpp/nativert/CMakeLists.txt b/test/cpp/nativert/CMakeLists.txt index 1b4752ed9089..91605c0933d2 100644 --- a/test/cpp/nativert/CMakeLists.txt +++ b/test/cpp/nativert/CMakeLists.txt @@ -40,21 +40,24 @@ set(NATIVERT_TEST_SRCS ${TORCH_ROOT}/torch/nativert/graph/passes/pass_manager/GraphPasses.cpp ${TORCH_ROOT}/torch/nativert/graph/passes/pass_manager/PassManager.cpp ${TORCH_ROOT}/torch/nativert/kernels/KernelHandlerRegistry.cpp - ${TORCH_ROOT}/torch/nativert/kernels/TritonKernel.cpp ${TORCH_ROOT}/torch/nativert/executor/triton/CpuTritonKernelManager.cpp + ${TORCH_ROOT}/torch/nativert/kernels/TritonKernel.cpp ${TORCH_ROOT}/torch/nativert/executor/DelegateExecutor.cpp ) if(USE_CUDA) list(APPEND NATIVERT_TEST_SRCS ${TORCH_ROOT}/torch/nativert/executor/triton/CudaTritonKernelManager.cpp) -endif(MSVC) - +endif() add_executable(test_nativert ${TORCH_ROOT}/test/cpp/common/main.cpp ${NATIVERT_TEST_SRCS} ) +if(MSVC) + target_compile_definitions(test_nativert PRIVATE NATIVERT_MSVC_TEST) +endif() + # TODO temporary until we can delete the old gtest polyfills. target_compile_definitions(test_nativert PRIVATE USE_GTEST) diff --git a/test/cpp/nativert/test_triton_kernel_manager_registration.cpp b/test/cpp/nativert/test_triton_kernel_manager_registration.cpp index ca864158e312..8cedb84abf21 100644 --- a/test/cpp/nativert/test_triton_kernel_manager_registration.cpp +++ b/test/cpp/nativert/test_triton_kernel_manager_registration.cpp @@ -6,9 +6,20 @@ using namespace ::testing; using namespace torch::nativert; TEST(TritonKernelManagerRegistrationTests, TestRegister) { -#ifndef USE_CUDA - EXPECT_TRUE(create_cuda_triton_kernel_manager == nullptr); + EXPECT_TRUE(TritonKernelManagerRegistry()->Has(at::kCPU)); + +#ifdef USE_CUDA +#ifdef USE_ROCM + EXPECT_TRUE(TritonKernelManagerRegistry()->Has(at::kHIP)); + EXPECT_FALSE(TritonKernelManagerRegistry()->Has(at::kCUDA)); + #else - EXPECT_FALSE(create_cuda_triton_kernel_manager == nullptr); + EXPECT_TRUE(TritonKernelManagerRegistry()->Has(at::kCUDA)); + EXPECT_FALSE(TritonKernelManagerRegistry()->Has(at::kHIP)); + +#endif // USE_ROCM +#else + EXPECT_FALSE(TritonKernelManagerRegistry()->Has(at::kCUDA)); + EXPECT_FALSE(TritonKernelManagerRegistry()->Has(at::kHIP)); #endif // USE_CUDA } diff --git a/torch/nativert/executor/triton/CpuTritonKernelManager.cpp b/torch/nativert/executor/triton/CpuTritonKernelManager.cpp index 1f8d394ecf39..c212539e4930 100644 --- a/torch/nativert/executor/triton/CpuTritonKernelManager.cpp +++ b/torch/nativert/executor/triton/CpuTritonKernelManager.cpp @@ -1,5 +1,6 @@ -#include +#include +#include #include #ifndef _WIN32 @@ -35,6 +36,43 @@ char* _dlerror() { } // namespace +typedef void* kernel_ptr_t; +typedef void ( + *launcher_ptr_t)(uint32_t, uint32_t, uint32_t, void**, kernel_ptr_t); + +struct DlcloseDeleter { + void operator()(void* p) const { + if (p) { +#if defined(_WIN32) + TORCH_CHECK(false, "Windows is not supported"); +#else + dlclose(p); +#endif + } + } +}; + +class CpuTritonKernelManager final : public TritonKernelManager { + public: + CpuTritonKernelManager( + std::string kernel_name, + std::string kernel_bin_path, + std::string kernel_launcher_bin_path); + ~CpuTritonKernelManager() final = default; + void launch(const LaunchParams& launch_params, void** args) final; + + private: + void load(); + + kernel_ptr_t kernel_fn_{nullptr}; + launcher_ptr_t launcher_fn_{nullptr}; + + std::unique_ptr kernel_handle_{nullptr}; + std::unique_ptr launcher_handle_{nullptr}; + + std::string kernel_launcher_bin_path_; +}; + CpuTritonKernelManager::CpuTritonKernelManager( std::string kernel_name, std::string kernel_bin_path, @@ -88,4 +126,21 @@ void CpuTritonKernelManager::launch( kernel_fn_); } +namespace { +std::unique_ptr create_cpu_triton_kernel_manager( + std::string kernel_name, + std::string kernel_bin_path, + std::string kernel_launcher_bin_path) { + return std::make_unique( + std::move(kernel_name), + std::move(kernel_bin_path), + std::move(kernel_launcher_bin_path)); +} +} // namespace + +C10_REGISTER_TYPED_CREATOR( + TritonKernelManagerRegistry, + at::kCPU, + create_cpu_triton_kernel_manager) + } // namespace torch::nativert diff --git a/torch/nativert/executor/triton/CpuTritonKernelManager.h b/torch/nativert/executor/triton/CpuTritonKernelManager.h deleted file mode 100644 index 45b3327c878e..000000000000 --- a/torch/nativert/executor/triton/CpuTritonKernelManager.h +++ /dev/null @@ -1,51 +0,0 @@ -#pragma once - -#include - -#include -#include - -#ifndef _WIN32 -#include -#endif - -typedef void* kernel_ptr_t; -typedef void ( - *launcher_ptr_t)(uint32_t, uint32_t, uint32_t, void**, kernel_ptr_t); - -namespace torch::nativert { - -struct DlcloseDeleter { - void operator()(void* p) const { - if (p) { -#if defined(_WIN32) - TORCH_CHECK(false, "Windows is not supported"); -#else - dlclose(p); -#endif - } - } -}; - -class CpuTritonKernelManager final : public TritonKernelManager { - public: - CpuTritonKernelManager( - std::string kernel_name, - std::string kernel_bin_path, - std::string kernel_launcher_bin_path); - ~CpuTritonKernelManager() final = default; - void launch(const LaunchParams& launch_params, void** args) final; - - private: - void load(); - - kernel_ptr_t kernel_fn_{nullptr}; - launcher_ptr_t launcher_fn_{nullptr}; - - std::unique_ptr kernel_handle_{nullptr}; - std::unique_ptr launcher_handle_{nullptr}; - - std::string kernel_launcher_bin_path_; -}; - -} // namespace torch::nativert diff --git a/torch/nativert/executor/triton/CudaTritonKernelManager.cpp b/torch/nativert/executor/triton/CudaTritonKernelManager.cpp index 47f72ce0c5e3..d18efcc178f4 100644 --- a/torch/nativert/executor/triton/CudaTritonKernelManager.cpp +++ b/torch/nativert/executor/triton/CudaTritonKernelManager.cpp @@ -29,7 +29,7 @@ namespace torch::nativert { class CudaKernelInputs final : public KernelInputs { public: CudaKernelInputs(size_t num_args, size_t num_attrs) - : KernelInputs(num_args, num_attrs), arg_ptrs_(num_args) {}; + : KernelInputs(num_args, num_attrs), arg_ptrs_(num_args) {} ~CudaKernelInputs() final = default; void add_arg(void* arg) override { @@ -73,7 +73,7 @@ CudaTritonKernelManager::CudaTritonKernelManager( TORCH_CHECK( at::globalContext().hasCUDA() || at::globalContext().hasHIP(), "cuda or hip required"); -}; +} CudaTritonKernelManager::~CudaTritonKernelManager() { const auto& nvrtc = get_nvrtc(); @@ -137,19 +137,31 @@ void CudaTritonKernelManager::launch( nullptr)); } -static std::unique_ptr _create_cuda_triton_kernel_manager( +namespace { +std::unique_ptr create_cuda_triton_kernel_manager( std::string kernel_name, - std::string kernel_bin_path) { + std::string kernel_bin_path, + // NOLINTNEXTLINE(performance-unnecessary-value-param) + [[maybe_unused]] std::string kernel_launcher_bin_path) { return std::make_unique( std::move(kernel_name), std::move(kernel_bin_path)); } +} // namespace + +#ifdef USE_ROCM + +C10_REGISTER_TYPED_CREATOR( + TritonKernelManagerRegistry, + at::kHIP, + create_cuda_triton_kernel_manager) + +#else + +C10_REGISTER_TYPED_CREATOR( + TritonKernelManagerRegistry, + at::kCUDA, + create_cuda_triton_kernel_manager) + +#endif // USE_ROCM } // namespace torch::nativert - -namespace { -static bool _initialized_cuda_triton_kernel_manager = []() { - torch::nativert::create_cuda_triton_kernel_manager = - &torch::nativert::_create_cuda_triton_kernel_manager; - return true; -}(); -} // namespace diff --git a/torch/nativert/executor/triton/TritonKernelManager.h b/torch/nativert/executor/triton/TritonKernelManager.h index ffa8e2573bc0..976fb3921f0a 100644 --- a/torch/nativert/executor/triton/TritonKernelManager.h +++ b/torch/nativert/executor/triton/TritonKernelManager.h @@ -2,7 +2,9 @@ #include +#include #include +#include namespace torch::nativert { @@ -69,7 +71,13 @@ class TritonKernelManager { std::string kernel_name_, kernel_bin_path_; }; -inline std::unique_ptr ( - *create_cuda_triton_kernel_manager)(std::string, std::string) = nullptr; +C10_DECLARE_TYPED_REGISTRY( + TritonKernelManagerRegistry, + c10::DeviceType, + TritonKernelManager, + std::unique_ptr, + std::string /* kernel_name */, + std::string /* kernel_bin_path */, + std::string /* kernel_launcher_bin_path */); } // namespace torch::nativert diff --git a/torch/nativert/kernels/TritonKernel.cpp b/torch/nativert/kernels/TritonKernel.cpp index 84fbf09a37f4..3843036aead9 100644 --- a/torch/nativert/kernels/TritonKernel.cpp +++ b/torch/nativert/kernels/TritonKernel.cpp @@ -16,10 +16,20 @@ #include #endif -#include - namespace torch::nativert { +// in this case, we want to use the symbol from torch_cpu.dll +#ifndef NATIVERT_MSVC_TEST +C10_DEFINE_TYPED_REGISTRY( + TritonKernelManagerRegistry, + c10::DeviceType, + TritonKernelManager, + std::unique_ptr, + std::string /* kernel_name */, + std::string /* kernel_bin_path */, + std::string /* kernel_launcher_bin_path */) +#endif + TritonKernel::TritonKernel( const Node* node, caffe2::serialize::PyTorchStreamReader* reader) @@ -74,27 +84,28 @@ TritonKernel::TritonKernel( auto tmp_dir = extractToTemporaryFolder(*reader, kernel_prefix) + "/"; if (reader->hasRecord(kernel_prefix + "/" + kernel_name + ".cubin")) { + loader_ = TritonKernelManagerRegistry()->Create( + at::kCUDA, kernel_name, tmp_dir + kernel_name + ".cubin", ""); TORCH_CHECK( - create_cuda_triton_kernel_manager != nullptr, + loader_ != nullptr, "couldn't find cuda loader -- is this a gpu build?"); - loader_ = create_cuda_triton_kernel_manager( - kernel_name, tmp_dir + kernel_name + ".cubin"); - } - - if (reader->hasRecord(kernel_prefix + "/" + kernel_name + ".hsaco")) { + } else if (reader->hasRecord(kernel_prefix + "/" + kernel_name + ".hsaco")) { + loader_ = TritonKernelManagerRegistry()->Create( + at::kHIP, kernel_name, tmp_dir + kernel_name + ".hsaco", ""); TORCH_CHECK( - create_cuda_triton_kernel_manager != nullptr, + loader_ != nullptr, "couldn't find cuda loader -- is this a gpu build?"); - loader_ = create_cuda_triton_kernel_manager( - kernel_name, tmp_dir + kernel_name + ".hsaco"); - } - - if (loader_ == nullptr) { - loader_ = std::unique_ptr(new CpuTritonKernelManager( + } else { + loader_ = TritonKernelManagerRegistry()->Create( + at::kCPU, kernel_name, tmp_dir + kernel_name + ".so", - tmp_dir + kernel_name + ".launcher.so")); + tmp_dir + kernel_name + ".launcher.so"); } + + TORCH_CHECK( + loader_ != nullptr, + "couldn't find triton kernel loader -- are you trying to run gpu kernels on a cpu build?"); } TritonKernel::~TritonKernel() = default;