[nativert][triton] improve hardware registration (#162499)

Summary: att

Test Plan:
ci

Rollback Plan:

Differential Revision: D82031814

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162499
Approved by: https://github.com/angelayi
This commit is contained in:
dolpm
2025-09-10 04:52:57 +00:00
committed by PyTorch MergeBot
parent 96ef26f71a
commit 1c16c18a53
8 changed files with 142 additions and 88 deletions

View File

@ -552,6 +552,11 @@ if(USE_CUDA OR USE_ROCM)
append_filelist("libtorch_cuda_core_sources" Caffe2_GPU_HIP_JIT_FUSERS_SRCS)
endif()
if(USE_CUDA)
# eventually do rocm
append_filelist("libtorch_nativert_cuda_sources" Caffe2_GPU_SRCS)
endif()
if(USE_CUDA)
list(APPEND Caffe2_GPU_CU_SRCS ${Caffe2_GPU_HIP_JIT_FUSERS_SRCS})
add_library(caffe2_nvrtc SHARED ${ATen_NVRTC_STUB_SRCS})

View File

@ -40,21 +40,24 @@ set(NATIVERT_TEST_SRCS
${TORCH_ROOT}/torch/nativert/graph/passes/pass_manager/GraphPasses.cpp
${TORCH_ROOT}/torch/nativert/graph/passes/pass_manager/PassManager.cpp
${TORCH_ROOT}/torch/nativert/kernels/KernelHandlerRegistry.cpp
${TORCH_ROOT}/torch/nativert/kernels/TritonKernel.cpp
${TORCH_ROOT}/torch/nativert/executor/triton/CpuTritonKernelManager.cpp
${TORCH_ROOT}/torch/nativert/kernels/TritonKernel.cpp
${TORCH_ROOT}/torch/nativert/executor/DelegateExecutor.cpp
)
if(USE_CUDA)
list(APPEND NATIVERT_TEST_SRCS ${TORCH_ROOT}/torch/nativert/executor/triton/CudaTritonKernelManager.cpp)
endif(MSVC)
endif()
add_executable(test_nativert
${TORCH_ROOT}/test/cpp/common/main.cpp
${NATIVERT_TEST_SRCS}
)
if(MSVC)
target_compile_definitions(test_nativert PRIVATE NATIVERT_MSVC_TEST)
endif()
# TODO temporary until we can delete the old gtest polyfills.
target_compile_definitions(test_nativert PRIVATE USE_GTEST)

View File

@ -6,9 +6,20 @@ using namespace ::testing;
using namespace torch::nativert;
TEST(TritonKernelManagerRegistrationTests, TestRegister) {
#ifndef USE_CUDA
EXPECT_TRUE(create_cuda_triton_kernel_manager == nullptr);
EXPECT_TRUE(TritonKernelManagerRegistry()->Has(at::kCPU));
#ifdef USE_CUDA
#ifdef USE_ROCM
EXPECT_TRUE(TritonKernelManagerRegistry()->Has(at::kHIP));
EXPECT_FALSE(TritonKernelManagerRegistry()->Has(at::kCUDA));
#else
EXPECT_FALSE(create_cuda_triton_kernel_manager == nullptr);
EXPECT_TRUE(TritonKernelManagerRegistry()->Has(at::kCUDA));
EXPECT_FALSE(TritonKernelManagerRegistry()->Has(at::kHIP));
#endif // USE_ROCM
#else
EXPECT_FALSE(TritonKernelManagerRegistry()->Has(at::kCUDA));
EXPECT_FALSE(TritonKernelManagerRegistry()->Has(at::kHIP));
#endif // USE_CUDA
}

View File

@ -1,5 +1,6 @@
#include <torch/nativert/executor/triton/CpuTritonKernelManager.h>
#include <torch/nativert/executor/triton/TritonKernelManager.h>
#include <c10/util/FbcodeMaps.h>
#include <c10/util/Logging.h>
#ifndef _WIN32
@ -35,6 +36,43 @@ char* _dlerror() {
} // namespace
typedef void* kernel_ptr_t;
typedef void (
*launcher_ptr_t)(uint32_t, uint32_t, uint32_t, void**, kernel_ptr_t);
struct DlcloseDeleter {
void operator()(void* p) const {
if (p) {
#if defined(_WIN32)
TORCH_CHECK(false, "Windows is not supported");
#else
dlclose(p);
#endif
}
}
};
class CpuTritonKernelManager final : public TritonKernelManager {
public:
CpuTritonKernelManager(
std::string kernel_name,
std::string kernel_bin_path,
std::string kernel_launcher_bin_path);
~CpuTritonKernelManager() final = default;
void launch(const LaunchParams& launch_params, void** args) final;
private:
void load();
kernel_ptr_t kernel_fn_{nullptr};
launcher_ptr_t launcher_fn_{nullptr};
std::unique_ptr<void, DlcloseDeleter> kernel_handle_{nullptr};
std::unique_ptr<void, DlcloseDeleter> launcher_handle_{nullptr};
std::string kernel_launcher_bin_path_;
};
CpuTritonKernelManager::CpuTritonKernelManager(
std::string kernel_name,
std::string kernel_bin_path,
@ -88,4 +126,21 @@ void CpuTritonKernelManager::launch(
kernel_fn_);
}
namespace {
std::unique_ptr<TritonKernelManager> create_cpu_triton_kernel_manager(
std::string kernel_name,
std::string kernel_bin_path,
std::string kernel_launcher_bin_path) {
return std::make_unique<CpuTritonKernelManager>(
std::move(kernel_name),
std::move(kernel_bin_path),
std::move(kernel_launcher_bin_path));
}
} // namespace
C10_REGISTER_TYPED_CREATOR(
TritonKernelManagerRegistry,
at::kCPU,
create_cpu_triton_kernel_manager)
} // namespace torch::nativert

View File

@ -1,51 +0,0 @@
#pragma once
#include <torch/nativert/executor/triton/TritonKernelManager.h>
#include <c10/core/Device.h>
#include <c10/util/FbcodeMaps.h>
#ifndef _WIN32
#include <dlfcn.h>
#endif
typedef void* kernel_ptr_t;
typedef void (
*launcher_ptr_t)(uint32_t, uint32_t, uint32_t, void**, kernel_ptr_t);
namespace torch::nativert {
struct DlcloseDeleter {
void operator()(void* p) const {
if (p) {
#if defined(_WIN32)
TORCH_CHECK(false, "Windows is not supported");
#else
dlclose(p);
#endif
}
}
};
class CpuTritonKernelManager final : public TritonKernelManager {
public:
CpuTritonKernelManager(
std::string kernel_name,
std::string kernel_bin_path,
std::string kernel_launcher_bin_path);
~CpuTritonKernelManager() final = default;
void launch(const LaunchParams& launch_params, void** args) final;
private:
void load();
kernel_ptr_t kernel_fn_{nullptr};
launcher_ptr_t launcher_fn_{nullptr};
std::unique_ptr<void, DlcloseDeleter> kernel_handle_{nullptr};
std::unique_ptr<void, DlcloseDeleter> launcher_handle_{nullptr};
std::string kernel_launcher_bin_path_;
};
} // namespace torch::nativert

View File

@ -29,7 +29,7 @@ namespace torch::nativert {
class CudaKernelInputs final : public KernelInputs {
public:
CudaKernelInputs(size_t num_args, size_t num_attrs)
: KernelInputs(num_args, num_attrs), arg_ptrs_(num_args) {};
: KernelInputs(num_args, num_attrs), arg_ptrs_(num_args) {}
~CudaKernelInputs() final = default;
void add_arg(void* arg) override {
@ -73,7 +73,7 @@ CudaTritonKernelManager::CudaTritonKernelManager(
TORCH_CHECK(
at::globalContext().hasCUDA() || at::globalContext().hasHIP(),
"cuda or hip required");
};
}
CudaTritonKernelManager::~CudaTritonKernelManager() {
const auto& nvrtc = get_nvrtc();
@ -137,19 +137,31 @@ void CudaTritonKernelManager::launch(
nullptr));
}
static std::unique_ptr<TritonKernelManager> _create_cuda_triton_kernel_manager(
namespace {
std::unique_ptr<TritonKernelManager> create_cuda_triton_kernel_manager(
std::string kernel_name,
std::string kernel_bin_path) {
std::string kernel_bin_path,
// NOLINTNEXTLINE(performance-unnecessary-value-param)
[[maybe_unused]] std::string kernel_launcher_bin_path) {
return std::make_unique<CudaTritonKernelManager>(
std::move(kernel_name), std::move(kernel_bin_path));
}
} // namespace
#ifdef USE_ROCM
C10_REGISTER_TYPED_CREATOR(
TritonKernelManagerRegistry,
at::kHIP,
create_cuda_triton_kernel_manager)
#else
C10_REGISTER_TYPED_CREATOR(
TritonKernelManagerRegistry,
at::kCUDA,
create_cuda_triton_kernel_manager)
#endif // USE_ROCM
} // namespace torch::nativert
namespace {
static bool _initialized_cuda_triton_kernel_manager = []() {
torch::nativert::create_cuda_triton_kernel_manager =
&torch::nativert::_create_cuda_triton_kernel_manager;
return true;
}();
} // namespace

View File

@ -2,7 +2,9 @@
#include <string>
#include <c10/core/DeviceType.h>
#include <c10/util/Exception.h>
#include <c10/util/Registry.h>
namespace torch::nativert {
@ -69,7 +71,13 @@ class TritonKernelManager {
std::string kernel_name_, kernel_bin_path_;
};
inline std::unique_ptr<TritonKernelManager> (
*create_cuda_triton_kernel_manager)(std::string, std::string) = nullptr;
C10_DECLARE_TYPED_REGISTRY(
TritonKernelManagerRegistry,
c10::DeviceType,
TritonKernelManager,
std::unique_ptr,
std::string /* kernel_name */,
std::string /* kernel_bin_path */,
std::string /* kernel_launcher_bin_path */);
} // namespace torch::nativert

View File

@ -16,10 +16,20 @@
#include <ATen/ops/empty.h>
#endif
#include <torch/nativert/executor/triton/CpuTritonKernelManager.h>
namespace torch::nativert {
// in this case, we want to use the symbol from torch_cpu.dll
#ifndef NATIVERT_MSVC_TEST
C10_DEFINE_TYPED_REGISTRY(
TritonKernelManagerRegistry,
c10::DeviceType,
TritonKernelManager,
std::unique_ptr,
std::string /* kernel_name */,
std::string /* kernel_bin_path */,
std::string /* kernel_launcher_bin_path */)
#endif
TritonKernel::TritonKernel(
const Node* node,
caffe2::serialize::PyTorchStreamReader* reader)
@ -74,27 +84,28 @@ TritonKernel::TritonKernel(
auto tmp_dir = extractToTemporaryFolder(*reader, kernel_prefix) + "/";
if (reader->hasRecord(kernel_prefix + "/" + kernel_name + ".cubin")) {
loader_ = TritonKernelManagerRegistry()->Create(
at::kCUDA, kernel_name, tmp_dir + kernel_name + ".cubin", "");
TORCH_CHECK(
create_cuda_triton_kernel_manager != nullptr,
loader_ != nullptr,
"couldn't find cuda loader -- is this a gpu build?");
loader_ = create_cuda_triton_kernel_manager(
kernel_name, tmp_dir + kernel_name + ".cubin");
}
if (reader->hasRecord(kernel_prefix + "/" + kernel_name + ".hsaco")) {
} else if (reader->hasRecord(kernel_prefix + "/" + kernel_name + ".hsaco")) {
loader_ = TritonKernelManagerRegistry()->Create(
at::kHIP, kernel_name, tmp_dir + kernel_name + ".hsaco", "");
TORCH_CHECK(
create_cuda_triton_kernel_manager != nullptr,
loader_ != nullptr,
"couldn't find cuda loader -- is this a gpu build?");
loader_ = create_cuda_triton_kernel_manager(
kernel_name, tmp_dir + kernel_name + ".hsaco");
}
if (loader_ == nullptr) {
loader_ = std::unique_ptr<TritonKernelManager>(new CpuTritonKernelManager(
} else {
loader_ = TritonKernelManagerRegistry()->Create(
at::kCPU,
kernel_name,
tmp_dir + kernel_name + ".so",
tmp_dir + kernel_name + ".launcher.so"));
tmp_dir + kernel_name + ".launcher.so");
}
TORCH_CHECK(
loader_ != nullptr,
"couldn't find triton kernel loader -- are you trying to run gpu kernels on a cpu build?");
}
TritonKernel::~TritonKernel() = default;