From 3dde5d7f9bf80dd6623a712bc429e9e4302464b5 Mon Sep 17 00:00:00 2001 From: dolpm <34420038+dolpm@users.noreply.github.com> Date: Thu, 4 Sep 2025 19:00:11 +0000 Subject: [PATCH] [nativert] triton runtime implementation (#161798) Summary: att Test Plan: ci Rollback Plan: Reviewed By: minjang Differential Revision: D80828148 Pull Request resolved: https://github.com/pytorch/pytorch/pull/161798 Approved by: https://github.com/minjang, https://github.com/SherlockNoMad --- aten/src/ATen/cuda/nvrtc_stub/ATenNVRTC.h | 2 + build_variables.bzl | 8 +- test/cpp/nativert/CMakeLists.txt | 8 + ...est_triton_kernel_manager_registration.cpp | 14 ++ torch/nativert/executor/OpKernelKind.h | 1 + .../triton/CpuTritonKernelManager.cpp | 91 ++++++++++ .../executor/triton/CpuTritonKernelManager.h | 51 ++++++ .../triton/CudaTritonKernelManager.cpp | 155 ++++++++++++++++++ .../executor/triton/TritonKernelManager.h | 75 +++++++++ torch/nativert/kernels/KernelFactory.cpp | 6 + torch/nativert/kernels/TritonKernel.cpp | 137 ++++++++++++++++ torch/nativert/kernels/TritonKernel.h | 31 ++++ 12 files changed, 578 insertions(+), 1 deletion(-) create mode 100644 test/cpp/nativert/test_triton_kernel_manager_registration.cpp create mode 100644 torch/nativert/executor/triton/CpuTritonKernelManager.cpp create mode 100644 torch/nativert/executor/triton/CpuTritonKernelManager.h create mode 100644 torch/nativert/executor/triton/CudaTritonKernelManager.cpp create mode 100644 torch/nativert/executor/triton/TritonKernelManager.h create mode 100644 torch/nativert/kernels/TritonKernel.cpp create mode 100644 torch/nativert/kernels/TritonKernel.h diff --git a/aten/src/ATen/cuda/nvrtc_stub/ATenNVRTC.h b/aten/src/ATen/cuda/nvrtc_stub/ATenNVRTC.h index d89875865b88..aca83386ad42 100644 --- a/aten/src/ATen/cuda/nvrtc_stub/ATenNVRTC.h +++ b/aten/src/ATen/cuda/nvrtc_stub/ATenNVRTC.h @@ -117,6 +117,8 @@ namespace at::cuda { _(nvrtcGetPTXSize) \ _(nvrtcGetPTX) \ _(cuModuleLoadData) \ + _(cuModuleLoad) \ + _(cuGetErrorString) \ _(cuModuleGetFunction) \ _(HIPOCCUPANCYMAXACTIVEBLOCKSPERMULTIPROCESSOR) \ _(nvrtcGetErrorString) \ diff --git a/build_variables.bzl b/build_variables.bzl index fd53c9e8aa12..990385da2362 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -635,6 +635,12 @@ libtorch_nativert_sources = [ "torch/nativert/graph/passes/pass_manager/GraphPasses.cpp", "torch/nativert/graph/passes/pass_manager/PassManager.cpp", "torch/nativert/kernels/KernelHandlerRegistry.cpp", + "torch/nativert/kernels/TritonKernel.cpp", + "torch/nativert/executor/triton/CpuTritonKernelManager.cpp", +] + +libtorch_nativert_cuda_sources = [ + "torch/nativert/executor/triton/CudaTritonKernelManager.cpp", ] torch_mobile_tracer_sources = [ @@ -770,7 +776,7 @@ libtorch_cuda_distributed_sources = libtorch_cuda_distributed_base_sources + lib libtorch_cuda_sources = libtorch_cuda_core_sources + libtorch_cuda_distributed_sources + [ "torch/csrc/cuda/nccl.cpp", -] +] + libtorch_nativert_cuda_sources torch_cpp_srcs = [ "torch/csrc/api/src/cuda.cpp", # this just forwards stuff, no real CUDA diff --git a/test/cpp/nativert/CMakeLists.txt b/test/cpp/nativert/CMakeLists.txt index 1b7024f75488..1b4752ed9089 100644 --- a/test/cpp/nativert/CMakeLists.txt +++ b/test/cpp/nativert/CMakeLists.txt @@ -40,8 +40,16 @@ set(NATIVERT_TEST_SRCS ${TORCH_ROOT}/torch/nativert/graph/passes/pass_manager/GraphPasses.cpp ${TORCH_ROOT}/torch/nativert/graph/passes/pass_manager/PassManager.cpp ${TORCH_ROOT}/torch/nativert/kernels/KernelHandlerRegistry.cpp + ${TORCH_ROOT}/torch/nativert/kernels/TritonKernel.cpp + ${TORCH_ROOT}/torch/nativert/executor/triton/CpuTritonKernelManager.cpp + ${TORCH_ROOT}/torch/nativert/executor/DelegateExecutor.cpp ) +if(USE_CUDA) + list(APPEND NATIVERT_TEST_SRCS ${TORCH_ROOT}/torch/nativert/executor/triton/CudaTritonKernelManager.cpp) +endif(MSVC) + + add_executable(test_nativert ${TORCH_ROOT}/test/cpp/common/main.cpp ${NATIVERT_TEST_SRCS} diff --git a/test/cpp/nativert/test_triton_kernel_manager_registration.cpp b/test/cpp/nativert/test_triton_kernel_manager_registration.cpp new file mode 100644 index 000000000000..ca864158e312 --- /dev/null +++ b/test/cpp/nativert/test_triton_kernel_manager_registration.cpp @@ -0,0 +1,14 @@ +#include + +#include + +using namespace ::testing; +using namespace torch::nativert; + +TEST(TritonKernelManagerRegistrationTests, TestRegister) { +#ifndef USE_CUDA + EXPECT_TRUE(create_cuda_triton_kernel_manager == nullptr); +#else + EXPECT_FALSE(create_cuda_triton_kernel_manager == nullptr); +#endif // USE_CUDA +} diff --git a/torch/nativert/executor/OpKernelKind.h b/torch/nativert/executor/OpKernelKind.h index 045664cfdee1..5a8ba38316f6 100644 --- a/torch/nativert/executor/OpKernelKind.h +++ b/torch/nativert/executor/OpKernelKind.h @@ -11,6 +11,7 @@ enum class OpKernelKind : uint8_t { // static dispatch kernels that don't reuse // out TensorImpl kNativeStaticDispatchKernel, + kTritonKernel, }; } // namespace torch::nativert diff --git a/torch/nativert/executor/triton/CpuTritonKernelManager.cpp b/torch/nativert/executor/triton/CpuTritonKernelManager.cpp new file mode 100644 index 000000000000..1f8d394ecf39 --- /dev/null +++ b/torch/nativert/executor/triton/CpuTritonKernelManager.cpp @@ -0,0 +1,91 @@ +#include + +#include + +#ifndef _WIN32 +#include +#endif // _WIN32 + +namespace torch::nativert { + +namespace { +void* _dlopen(const char* filename) { +#if defined(_WIN32) + return nullptr; +#else + return dlopen(filename, RTLD_NOW | RTLD_LOCAL); +#endif +} + +void* _dlsym(void* handle, const char* name) { +#if defined(_WIN32) + return nullptr; +#else + return dlsym(handle, name); +#endif +} + +char* _dlerror() { +#if defined(_WIN32) + throw std::runtime_error("dlerror not supported on Windows"); +#else + return dlerror(); +#endif +} + +} // namespace + +CpuTritonKernelManager::CpuTritonKernelManager( + std::string kernel_name, + std::string kernel_bin_path, + std::string kernel_launcher_bin_path) + : TritonKernelManager(std::move(kernel_name), std::move(kernel_bin_path)), + kernel_launcher_bin_path_(std::move(kernel_launcher_bin_path)) {} + +void CpuTritonKernelManager::load() { + if (C10_LIKELY(kernel_fn_ != nullptr)) { + return; + } + + kernel_handle_.reset(_dlopen(kernel_bin_path_.c_str())); + TORCH_CHECK( + kernel_handle_ != nullptr, + "could not dlopen ", + kernel_bin_path_, + ": ", + _dlerror()); + + launcher_handle_.reset(_dlopen(kernel_launcher_bin_path_.c_str())); + TORCH_CHECK( + launcher_handle_ != nullptr, + "could not dlopen ", + kernel_launcher_bin_path_, + ": ", + _dlerror()); + + kernel_fn_ = _dlsym(kernel_handle_.get(), kernel_name_.c_str()); + TORCH_CHECK( + kernel_fn_ != nullptr, + "could not dlsym ", + kernel_name_, + ": ", + _dlerror()); + + launcher_fn_ = + reinterpret_cast(_dlsym(launcher_handle_.get(), "run")); + TORCH_CHECK(launcher_fn_ != nullptr, "could not dlsym run: ", _dlerror()); +} + +void CpuTritonKernelManager::launch( + const LaunchParams& launch_params, + void** args /* { ...inputs, output }*/) { + load(); + launcher_fn_( + launch_params.grid_dims.x, + launch_params.grid_dims.y, + launch_params.grid_dims.z, + args, + kernel_fn_); +} + +} // namespace torch::nativert diff --git a/torch/nativert/executor/triton/CpuTritonKernelManager.h b/torch/nativert/executor/triton/CpuTritonKernelManager.h new file mode 100644 index 000000000000..6eff0a6fd0d0 --- /dev/null +++ b/torch/nativert/executor/triton/CpuTritonKernelManager.h @@ -0,0 +1,51 @@ +#pragma once + +#include + +#include +#include + +#ifndef _WIN32 +#include +#endif + +typedef void* kernel_ptr_t; +typedef void ( + *launcher_ptr_t)(uint32_t, uint32_t, uint32_t, void**, kernel_ptr_t); + +namespace torch::nativert { + +struct DlcloseDeleter { + void operator()(void* p) const { + if (p) { +#if defined(_WIN32) + TORCH_CHECK(false, "Windows is not supported"); +#else + dlclose(p); +#endif + } + } +}; + +class CpuTritonKernelManager final : public TritonKernelManager { + public: + CpuTritonKernelManager( + std::string kernel_name, + std::string kernel_bin_path, + std::string kernel_launcher_bin_path); + ~CpuTritonKernelManager() final {} + void launch(const LaunchParams& launch_params, void** args) final; + + private: + void load(); + + kernel_ptr_t kernel_fn_{nullptr}; + launcher_ptr_t launcher_fn_{nullptr}; + + std::unique_ptr kernel_handle_{nullptr}; + std::unique_ptr launcher_handle_{nullptr}; + + std::string kernel_launcher_bin_path_; +}; + +} // namespace torch::nativert diff --git a/torch/nativert/executor/triton/CudaTritonKernelManager.cpp b/torch/nativert/executor/triton/CudaTritonKernelManager.cpp new file mode 100644 index 000000000000..9bacb5a82269 --- /dev/null +++ b/torch/nativert/executor/triton/CudaTritonKernelManager.cpp @@ -0,0 +1,155 @@ +#include + +#include +#include +#include +#include + +#include +#include + +namespace { +const at::cuda::NVRTC& get_nvrtc() { + return at::globalContext().getNVRTC(); +} +} // namespace + +#define CU_LOG_ERROR(fn, result, ...) \ + { \ + LOG(ERROR) << #fn << " returned error: " << result; \ + const char* errMsg = nullptr; \ + get_nvrtc().cuGetErrorString(result, &errMsg); \ + LOG(ERROR) << "cuGetErrorString: " << errMsg; \ + } + +namespace torch::nativert { + +// cuda kernels require an extra level of indirection +// for who knows what reason. +class CudaKernelInputs final : public KernelInputs { + public: + CudaKernelInputs(size_t num_args, size_t num_attrs) + : KernelInputs(num_args, num_attrs), arg_ptrs_(num_args) {}; + ~CudaKernelInputs() final = default; + + void add_arg(void* arg) override { + TORCH_CHECK(arg_idx_ < num_args_, "Too many args"); + arg_ptrs_[arg_idx_] = arg; + inputs_[arg_idx_] = reinterpret_cast(&arg_ptrs_[arg_idx_]); + arg_idx_++; + } + + private: + std::vector arg_ptrs_; +}; + +class CudaTritonKernelManager final : public TritonKernelManager { + public: + CudaTritonKernelManager(std::string kernel_name, std::string kernel_bin_path); + ~CudaTritonKernelManager() final; + + CudaTritonKernelManager(const CudaTritonKernelManager& other); + CudaTritonKernelManager& operator=(const CudaTritonKernelManager& other); + CudaTritonKernelManager(CudaTritonKernelManager&& other) noexcept; + CudaTritonKernelManager& operator=(CudaTritonKernelManager&& other) noexcept; + + void launch(const LaunchParams& launch_params, void** args) final; + std::unique_ptr create_inputs(size_t num_args, size_t num_attrs) + const final { + return std::unique_ptr( + new CudaKernelInputs(num_args, num_attrs)); + } + + private: + CUfunction load(); + c10::FastMap cache_; + std::vector loaded_modules_; +}; + +CudaTritonKernelManager::CudaTritonKernelManager( + std::string kernel_name, + std::string kernel_bin_path) + : TritonKernelManager(std::move(kernel_name), std::move(kernel_bin_path)) { + TORCH_CHECK( + at::globalContext().hasCUDA() || at::globalContext().hasHIP(), + "cuda or hip required"); +}; + +CudaTritonKernelManager::~CudaTritonKernelManager() { + const auto& nvrtc = get_nvrtc(); + for (auto& mod : loaded_modules_) { + if (CUresult err = nvrtc.cuModuleUnload(mod); err != 0) { + CU_LOG_ERROR(nvrtc.cuModuleUnload, err); + } + } +} + +CUfunction CudaTritonKernelManager::load() { + const auto idx = c10::cuda::current_device(); + if (const auto res = cache_.find(idx); res != cache_.end()) { + return res->second; + } + + const auto& nvrtc = get_nvrtc(); + + CUmodule mod_ptr = nullptr; + + if (CUresult err = nvrtc.cuModuleLoad(&mod_ptr, kernel_bin_path_.c_str()); + err != 0) { + CU_LOG_ERROR(nvrtc.cuModuleLoad, err); + return nullptr; + } + + CUfunction func = nullptr; + + if (CUresult err = + nvrtc.cuModuleGetFunction(&func, mod_ptr, kernel_name_.c_str()); + err != 0) { + CU_LOG_ERROR(nvrtc.cuModuleGetFunction, err); + return nullptr; + } + + loaded_modules_.emplace_back(mod_ptr); + return cache_.emplace(idx, func).first->second; +} + +void CudaTritonKernelManager::launch( + const LaunchParams& launch_params, + void** args /* { ...inputs, output }*/) { + const constexpr int kThreadsPerWarp = 2 << 4; + + auto kernel_fn = load(); + TORCH_CHECK( + kernel_fn != nullptr, "failed to load triton kernel: ", kernel_name_); + cudaStream_t stream = c10::cuda::getCurrentCUDAStream().stream(); + + AT_CUDA_DRIVER_CHECK(get_nvrtc().cuLaunchKernel( + kernel_fn, + launch_params.grid_dims.x, + launch_params.grid_dims.y, + launch_params.grid_dims.z, + /* blockDimX = */ kThreadsPerWarp * launch_params.num_warps, + /* blockDimY = */ 1, + /* blockDimZ = */ 1, + /* sharedMemBytes = */ launch_params.shared_memory_bytes, + stream, + args, + nullptr)); +} + +static std::unique_ptr _create_cuda_triton_kernel_manager( + std::string kernel_name, + std::string kernel_bin_path) { + return std::unique_ptr(new CudaTritonKernelManager( + std::move(kernel_name), std::move(kernel_bin_path))); +} + +} // namespace torch::nativert + +namespace { +static bool _initialized_cuda_triton_kernel_manager = []() { + torch::nativert::create_cuda_triton_kernel_manager = + &torch::nativert::_create_cuda_triton_kernel_manager; + return true; +}(); +} // namespace diff --git a/torch/nativert/executor/triton/TritonKernelManager.h b/torch/nativert/executor/triton/TritonKernelManager.h new file mode 100644 index 000000000000..ffa8e2573bc0 --- /dev/null +++ b/torch/nativert/executor/triton/TritonKernelManager.h @@ -0,0 +1,75 @@ +#pragma once + +#include + +#include + +namespace torch::nativert { + +struct GridDims { + public: + GridDims(int x = 1, int y = 1, int z = 1) : x(x), y(y), z(z) {} + int x; + int y; + int z; +}; + +struct LaunchParams { + int num_warps = 4; + int shared_memory_bytes = 0; + GridDims grid_dims; +}; + +class KernelInputs { + public: + KernelInputs(size_t num_args, size_t num_attrs) + : num_args_(num_args), + inputs_(num_args + num_attrs), + num_attrs_(num_attrs) {} + virtual ~KernelInputs() = default; + + virtual void add_arg(void* arg) { + TORCH_CHECK(arg_idx_ < num_args_, "Too many args"); + inputs_[arg_idx_++] = arg; + } + + void add_attribute(void* attr) { + TORCH_CHECK(attr_idx_ < num_attrs_, "Too many attributes"); + inputs_[num_args_ + attr_idx_++] = attr; + } + + void** as_void() { + return inputs_.data(); + } + + protected: + size_t num_args_; + size_t arg_idx_ = 0; + std::vector inputs_; + + private: + size_t num_attrs_; + size_t attr_idx_ = 0; +}; + +class TritonKernelManager { + public: + TritonKernelManager(std::string kernel_name, std::string kernel_bin_path) + : kernel_name_(std::move(kernel_name)), + kernel_bin_path_(std::move(kernel_bin_path)) {} + virtual ~TritonKernelManager() = default; + virtual std::unique_ptr create_inputs( + size_t num_args, + size_t num_attrs) const { + return std::make_unique(num_args, num_attrs); + } + virtual void launch(const LaunchParams& launch_params, void** args) = 0; + + protected: + std::string kernel_name_, kernel_bin_path_; +}; + +inline std::unique_ptr ( + *create_cuda_triton_kernel_manager)(std::string, std::string) = nullptr; + +} // namespace torch::nativert diff --git a/torch/nativert/kernels/KernelFactory.cpp b/torch/nativert/kernels/KernelFactory.cpp index 9e31a93a58c8..3fc4f2bcdc53 100644 --- a/torch/nativert/kernels/KernelFactory.cpp +++ b/torch/nativert/kernels/KernelFactory.cpp @@ -14,6 +14,7 @@ #include #include #include +#include namespace torch::nativert { @@ -130,6 +131,11 @@ ExecutionKernels KernelFactory::initializeNodeKernels( } else if (c10::starts_with( node.target(), "torch.ops.higher_order.call_torchbind")) { nodeKernels.push_back(std::make_unique(&node)); + } else if (c10::starts_with( + node.target(), + "torch.ops.higher_order.triton_kernel_wrapper_functional")) { + nodeKernels.push_back( + std::make_unique(&node, pytorchStreamReader.get())); } else if ( c10::starts_with( node.target(), diff --git a/torch/nativert/kernels/TritonKernel.cpp b/torch/nativert/kernels/TritonKernel.cpp new file mode 100644 index 000000000000..84fbf09a37f4 --- /dev/null +++ b/torch/nativert/kernels/TritonKernel.cpp @@ -0,0 +1,137 @@ +#include + +#include + +#include +#include + +#include +#include + +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#endif + +#include + +namespace torch::nativert { + +TritonKernel::TritonKernel( + const Node* node, + caffe2::serialize::PyTorchStreamReader* reader) + : OpKernel(node, OpKernelKind::kTritonKernel) { + TORCH_CHECK(reader != nullptr, "reader is null"); + + std::string kernel_name{}; + bool found_grid = false; + for (const auto& attr : node_->attributes()) { + if (attr.name.empty()) { + attr_ptrs_.emplace_back(std::visit( + [](auto&& arg) -> void* { + using T = std::decay_t; + if constexpr (std::is_same_v) { + return nullptr; + } + return static_cast(const_cast(&arg)); + }, + attr.value)); + } else if (attr.name == "name") { + kernel_name = std::get(attr.value); + } else if (attr.name == "grid") { + found_grid = true; + auto grid = std::get>(attr.value); + TORCH_CHECK(grid.size() == 3, "grid must be a 3D vector"); + launch_params_.grid_dims = GridDims( + static_cast(grid[0]), + static_cast(grid[1]), + static_cast(grid[2])); + } else if (attr.name == "num_warps") { + if (const int num_warps = static_cast(std::get(attr.value)); + num_warps > 0) { + launch_params_.num_warps = num_warps; + } + } else if (attr.name == "shared_memory_bytes") { + if (const int shared_memory_bytes = + static_cast(std::get(attr.value)); + shared_memory_bytes > 0) { + launch_params_.shared_memory_bytes = shared_memory_bytes; + } + } else if (attr.name == "output_indices") { + output_indices_ = std::get>(attr.value); + } + } + + TORCH_CHECK(!kernel_name.empty(), "kernel name not found"); + TORCH_CHECK(found_grid, "grid attribute not found"); + TORCH_CHECK(!output_indices_.empty(), "output_indices attribute not found"); + + auto kernel_prefix = std::string("data/triton") + "/" + kernel_name; + + auto tmp_dir = extractToTemporaryFolder(*reader, kernel_prefix) + "/"; + + if (reader->hasRecord(kernel_prefix + "/" + kernel_name + ".cubin")) { + TORCH_CHECK( + create_cuda_triton_kernel_manager != nullptr, + "couldn't find cuda loader -- is this a gpu build?"); + loader_ = create_cuda_triton_kernel_manager( + kernel_name, tmp_dir + kernel_name + ".cubin"); + } + + if (reader->hasRecord(kernel_prefix + "/" + kernel_name + ".hsaco")) { + TORCH_CHECK( + create_cuda_triton_kernel_manager != nullptr, + "couldn't find cuda loader -- is this a gpu build?"); + loader_ = create_cuda_triton_kernel_manager( + kernel_name, tmp_dir + kernel_name + ".hsaco"); + } + + if (loader_ == nullptr) { + loader_ = std::unique_ptr(new CpuTritonKernelManager( + kernel_name, + tmp_dir + kernel_name + ".so", + tmp_dir + kernel_name + ".launcher.so")); + } +} + +TritonKernel::~TritonKernel() = default; + +void TritonKernel::computeInternal(ExecutionFrame& executionFrame) const { + const auto num_inputs = node_->inputs().size(); + const auto num_attrs = attr_ptrs_.size(); + + auto* loader = const_cast(loader_.get()); + + auto inputs = loader->create_inputs(num_inputs, num_attrs); + + for (const auto i : c10::irange(num_inputs)) { + inputs->add_arg(input(i, executionFrame).toTensor().data_ptr()); + } + + for (const auto i : c10::irange(num_attrs)) { + inputs->add_attribute(attr_ptrs_[i]); + } + + loader->launch(launch_params_, inputs->as_void()); + + auto& out = output(0, executionFrame); + if (out.isNone()) { + auto list = c10::List(); + for (const auto& i : output_indices_) { + list.emplace_back(input(i, executionFrame).toTensor()); + } + out = c10::IValue(std::move(list)); + return; + } + + // todo: check if this is redundant + auto out_t = out.toTensorList(); + for (const auto& i : output_indices_) { + out_t[i] = input(i, executionFrame).toTensor(); + } +} + +} // namespace torch::nativert diff --git a/torch/nativert/kernels/TritonKernel.h b/torch/nativert/kernels/TritonKernel.h new file mode 100644 index 000000000000..4f9f0e47b00c --- /dev/null +++ b/torch/nativert/kernels/TritonKernel.h @@ -0,0 +1,31 @@ +#pragma once + +#include + +#include +#include +#include +#include + +namespace torch::nativert { + +class TritonKernel : public OpKernel { + public: + TritonKernel() = delete; + TritonKernel( + const Node* node, + caffe2::serialize::PyTorchStreamReader* reader); + ~TritonKernel() override; + + void computeInternal(ExecutionFrame& executionFrame) const override; + + private: + std::unique_ptr loader_; + + // unnamed node attributes will be passed as arguments to the kernel + std::vector attr_ptrs_; + std::vector output_indices_; + LaunchParams launch_params_; +}; + +} // namespace torch::nativert