re-land triton runtime implementation" (#162217)

Summary: original pr - https://github.com/pytorch/pytorch/pull/161798

Test Plan:
ci

Rollback Plan:

Differential Revision: D81724234

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162217
Approved by: https://github.com/SherlockNoMad
This commit is contained in:
dolpm
2025-09-06 00:52:29 +00:00
committed by PyTorch MergeBot
parent 1463714833
commit 4f72d932fe
13 changed files with 579 additions and 1 deletions

View File

@ -117,6 +117,8 @@ namespace at::cuda {
_(nvrtcGetPTXSize) \
_(nvrtcGetPTX) \
_(cuModuleLoadData) \
_(cuModuleLoad) \
_(cuGetErrorString) \
_(cuModuleGetFunction) \
_(HIPOCCUPANCYMAXACTIVEBLOCKSPERMULTIPROCESSOR) \
_(nvrtcGetErrorString) \

View File

@ -636,6 +636,12 @@ libtorch_nativert_sources = [
"torch/nativert/graph/passes/pass_manager/GraphPasses.cpp",
"torch/nativert/graph/passes/pass_manager/PassManager.cpp",
"torch/nativert/kernels/KernelHandlerRegistry.cpp",
"torch/nativert/kernels/TritonKernel.cpp",
"torch/nativert/executor/triton/CpuTritonKernelManager.cpp",
]
libtorch_nativert_cuda_sources = [
"torch/nativert/executor/triton/CudaTritonKernelManager.cpp",
]
torch_mobile_tracer_sources = [
@ -771,7 +777,7 @@ libtorch_cuda_distributed_sources = libtorch_cuda_distributed_base_sources + lib
libtorch_cuda_sources = libtorch_cuda_core_sources + libtorch_cuda_distributed_sources + [
"torch/csrc/cuda/nccl.cpp",
]
] + libtorch_nativert_cuda_sources
torch_cpp_srcs = [
"torch/csrc/api/src/cuda.cpp", # this just forwards stuff, no real CUDA

View File

@ -40,8 +40,16 @@ set(NATIVERT_TEST_SRCS
${TORCH_ROOT}/torch/nativert/graph/passes/pass_manager/GraphPasses.cpp
${TORCH_ROOT}/torch/nativert/graph/passes/pass_manager/PassManager.cpp
${TORCH_ROOT}/torch/nativert/kernels/KernelHandlerRegistry.cpp
${TORCH_ROOT}/torch/nativert/kernels/TritonKernel.cpp
${TORCH_ROOT}/torch/nativert/executor/triton/CpuTritonKernelManager.cpp
${TORCH_ROOT}/torch/nativert/executor/DelegateExecutor.cpp
)
if(USE_CUDA)
list(APPEND NATIVERT_TEST_SRCS ${TORCH_ROOT}/torch/nativert/executor/triton/CudaTritonKernelManager.cpp)
endif(MSVC)
add_executable(test_nativert
${TORCH_ROOT}/test/cpp/common/main.cpp
${NATIVERT_TEST_SRCS}

View File

@ -0,0 +1,14 @@
#include <gtest/gtest.h>
#include <torch/nativert/kernels/TritonKernel.h>
using namespace ::testing;
using namespace torch::nativert;
TEST(TritonKernelManagerRegistrationTests, TestRegister) {
#ifndef USE_CUDA
EXPECT_TRUE(create_cuda_triton_kernel_manager == nullptr);
#else
EXPECT_FALSE(create_cuda_triton_kernel_manager == nullptr);
#endif // USE_CUDA
}

View File

@ -28,6 +28,7 @@ char* _mkdtemp(char* outputDir) {
std::string extractToTemporaryFolder(
caffe2::serialize::PyTorchStreamReader& packageReader,
const std::string& targetPath) {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
char outputDir[] = "/tmp/delegate_model_XXXXXX";
char* tempdir = _mkdtemp(outputDir);
TORCH_CHECK(

View File

@ -11,6 +11,7 @@ enum class OpKernelKind : uint8_t {
// static dispatch kernels that don't reuse
// out TensorImpl
kNativeStaticDispatchKernel,
kTritonKernel,
};
} // namespace torch::nativert

View File

@ -0,0 +1,91 @@
#include <torch/nativert/executor/triton/CpuTritonKernelManager.h>
#include <c10/util/Logging.h>
#ifndef _WIN32
#include <dlfcn.h>
#endif // _WIN32
namespace torch::nativert {
namespace {
void* _dlopen(const char* filename) {
#if defined(_WIN32)
return nullptr;
#else
return dlopen(filename, RTLD_NOW | RTLD_LOCAL);
#endif
}
void* _dlsym(void* handle, const char* name) {
#if defined(_WIN32)
return nullptr;
#else
return dlsym(handle, name);
#endif
}
char* _dlerror() {
#if defined(_WIN32)
throw std::runtime_error("dlerror not supported on Windows");
#else
return dlerror();
#endif
}
} // namespace
CpuTritonKernelManager::CpuTritonKernelManager(
std::string kernel_name,
std::string kernel_bin_path,
std::string kernel_launcher_bin_path)
: TritonKernelManager(std::move(kernel_name), std::move(kernel_bin_path)),
kernel_launcher_bin_path_(std::move(kernel_launcher_bin_path)) {}
void CpuTritonKernelManager::load() {
if (C10_LIKELY(kernel_fn_ != nullptr)) {
return;
}
kernel_handle_.reset(_dlopen(kernel_bin_path_.c_str()));
TORCH_CHECK(
kernel_handle_ != nullptr,
"could not dlopen ",
kernel_bin_path_,
": ",
_dlerror());
launcher_handle_.reset(_dlopen(kernel_launcher_bin_path_.c_str()));
TORCH_CHECK(
launcher_handle_ != nullptr,
"could not dlopen ",
kernel_launcher_bin_path_,
": ",
_dlerror());
kernel_fn_ = _dlsym(kernel_handle_.get(), kernel_name_.c_str());
TORCH_CHECK(
kernel_fn_ != nullptr,
"could not dlsym ",
kernel_name_,
": ",
_dlerror());
launcher_fn_ =
reinterpret_cast<launcher_ptr_t>(_dlsym(launcher_handle_.get(), "run"));
TORCH_CHECK(launcher_fn_ != nullptr, "could not dlsym run: ", _dlerror());
}
void CpuTritonKernelManager::launch(
const LaunchParams& launch_params,
void** args /* { ...inputs, output }*/) {
load();
launcher_fn_(
launch_params.grid_dims.x,
launch_params.grid_dims.y,
launch_params.grid_dims.z,
args,
kernel_fn_);
}
} // namespace torch::nativert

View File

@ -0,0 +1,51 @@
#pragma once
#include <torch/nativert/executor/triton/TritonKernelManager.h>
#include <c10/core/Device.h>
#include <c10/util/FbcodeMaps.h>
#ifndef _WIN32
#include <dlfcn.h>
#endif
typedef void* kernel_ptr_t;
typedef void (
*launcher_ptr_t)(uint32_t, uint32_t, uint32_t, void**, kernel_ptr_t);
namespace torch::nativert {
struct DlcloseDeleter {
void operator()(void* p) const {
if (p) {
#if defined(_WIN32)
TORCH_CHECK(false, "Windows is not supported");
#else
dlclose(p);
#endif
}
}
};
class CpuTritonKernelManager final : public TritonKernelManager {
public:
CpuTritonKernelManager(
std::string kernel_name,
std::string kernel_bin_path,
std::string kernel_launcher_bin_path);
~CpuTritonKernelManager() final = default;
void launch(const LaunchParams& launch_params, void** args) final;
private:
void load();
kernel_ptr_t kernel_fn_{nullptr};
launcher_ptr_t launcher_fn_{nullptr};
std::unique_ptr<void, DlcloseDeleter> kernel_handle_{nullptr};
std::unique_ptr<void, DlcloseDeleter> launcher_handle_{nullptr};
std::string kernel_launcher_bin_path_;
};
} // namespace torch::nativert

View File

@ -0,0 +1,155 @@
#include <torch/nativert/executor/triton/TritonKernelManager.h>
#include <ATen/cuda/Exceptions.h>
#include <ATen/cuda/nvrtc_stub/ATenNVRTC.h>
#include <c10/cuda/CUDAStream.h>
#include <cuda_runtime.h>
#include <c10/util/FbcodeMaps.h>
#include <c10/util/Logging.h>
namespace {
const at::cuda::NVRTC& get_nvrtc() {
return at::globalContext().getNVRTC();
}
} // namespace
#define CU_LOG_ERROR(fn, result, ...) \
{ \
LOG(ERROR) << #fn << " returned error: " << result; \
const char* errMsg = nullptr; \
get_nvrtc().cuGetErrorString(result, &errMsg); \
LOG(ERROR) << "cuGetErrorString: " << errMsg; \
}
namespace torch::nativert {
// cuda kernels require an extra level of indirection
// for who knows what reason.
class CudaKernelInputs final : public KernelInputs {
public:
CudaKernelInputs(size_t num_args, size_t num_attrs)
: KernelInputs(num_args, num_attrs), arg_ptrs_(num_args) {};
~CudaKernelInputs() final = default;
void add_arg(void* arg) override {
TORCH_CHECK(arg_idx_ < num_args_, "Too many args");
arg_ptrs_[arg_idx_] = arg;
inputs_[arg_idx_] = reinterpret_cast<void*>(&arg_ptrs_[arg_idx_]);
arg_idx_++;
}
private:
std::vector<void*> arg_ptrs_;
};
class CudaTritonKernelManager final : public TritonKernelManager {
public:
CudaTritonKernelManager(std::string kernel_name, std::string kernel_bin_path);
~CudaTritonKernelManager() final;
CudaTritonKernelManager(const CudaTritonKernelManager& other);
CudaTritonKernelManager& operator=(const CudaTritonKernelManager& other);
CudaTritonKernelManager(CudaTritonKernelManager&& other) noexcept;
CudaTritonKernelManager& operator=(CudaTritonKernelManager&& other) noexcept;
void launch(const LaunchParams& launch_params, void** args) final;
std::unique_ptr<KernelInputs> create_inputs(size_t num_args, size_t num_attrs)
const final {
return std::unique_ptr<KernelInputs>(
new CudaKernelInputs(num_args, num_attrs));
}
private:
CUfunction load();
c10::FastMap<c10::DeviceIndex, CUfunction> cache_;
std::vector<CUmodule> loaded_modules_;
};
CudaTritonKernelManager::CudaTritonKernelManager(
std::string kernel_name,
std::string kernel_bin_path)
: TritonKernelManager(std::move(kernel_name), std::move(kernel_bin_path)) {
TORCH_CHECK(
at::globalContext().hasCUDA() || at::globalContext().hasHIP(),
"cuda or hip required");
};
CudaTritonKernelManager::~CudaTritonKernelManager() {
const auto& nvrtc = get_nvrtc();
for (auto& mod : loaded_modules_) {
if (CUresult err = nvrtc.cuModuleUnload(mod); err != 0) {
CU_LOG_ERROR(nvrtc.cuModuleUnload, err);
}
}
}
CUfunction CudaTritonKernelManager::load() {
const auto idx = c10::cuda::current_device();
if (const auto res = cache_.find(idx); res != cache_.end()) {
return res->second;
}
const auto& nvrtc = get_nvrtc();
CUmodule mod_ptr = nullptr;
if (CUresult err = nvrtc.cuModuleLoad(&mod_ptr, kernel_bin_path_.c_str());
err != 0) {
CU_LOG_ERROR(nvrtc.cuModuleLoad, err);
return nullptr;
}
CUfunction func = nullptr;
if (CUresult err =
nvrtc.cuModuleGetFunction(&func, mod_ptr, kernel_name_.c_str());
err != 0) {
CU_LOG_ERROR(nvrtc.cuModuleGetFunction, err);
return nullptr;
}
loaded_modules_.emplace_back(mod_ptr);
return cache_.emplace(idx, func).first->second;
}
void CudaTritonKernelManager::launch(
const LaunchParams& launch_params,
void** args /* { ...inputs, output }*/) {
const constexpr int kThreadsPerWarp = 2 << 4;
auto kernel_fn = load();
TORCH_CHECK(
kernel_fn != nullptr, "failed to load triton kernel: ", kernel_name_);
cudaStream_t stream = c10::cuda::getCurrentCUDAStream().stream();
AT_CUDA_DRIVER_CHECK(get_nvrtc().cuLaunchKernel(
kernel_fn,
launch_params.grid_dims.x,
launch_params.grid_dims.y,
launch_params.grid_dims.z,
/* blockDimX = */ kThreadsPerWarp * launch_params.num_warps,
/* blockDimY = */ 1,
/* blockDimZ = */ 1,
/* sharedMemBytes = */ launch_params.shared_memory_bytes,
stream,
args,
nullptr));
}
static std::unique_ptr<TritonKernelManager> _create_cuda_triton_kernel_manager(
std::string kernel_name,
std::string kernel_bin_path) {
return std::make_unique<CudaTritonKernelManager>(
std::move(kernel_name), std::move(kernel_bin_path));
}
} // namespace torch::nativert
namespace {
static bool _initialized_cuda_triton_kernel_manager = []() {
torch::nativert::create_cuda_triton_kernel_manager =
&torch::nativert::_create_cuda_triton_kernel_manager;
return true;
}();
} // namespace

View File

@ -0,0 +1,75 @@
#pragma once
#include <string>
#include <c10/util/Exception.h>
namespace torch::nativert {
struct GridDims {
public:
GridDims(int x = 1, int y = 1, int z = 1) : x(x), y(y), z(z) {}
int x;
int y;
int z;
};
struct LaunchParams {
int num_warps = 4;
int shared_memory_bytes = 0;
GridDims grid_dims;
};
class KernelInputs {
public:
KernelInputs(size_t num_args, size_t num_attrs)
: num_args_(num_args),
inputs_(num_args + num_attrs),
num_attrs_(num_attrs) {}
virtual ~KernelInputs() = default;
virtual void add_arg(void* arg) {
TORCH_CHECK(arg_idx_ < num_args_, "Too many args");
inputs_[arg_idx_++] = arg;
}
void add_attribute(void* attr) {
TORCH_CHECK(attr_idx_ < num_attrs_, "Too many attributes");
inputs_[num_args_ + attr_idx_++] = attr;
}
void** as_void() {
return inputs_.data();
}
protected:
size_t num_args_;
size_t arg_idx_ = 0;
std::vector<void*> inputs_;
private:
size_t num_attrs_;
size_t attr_idx_ = 0;
};
class TritonKernelManager {
public:
TritonKernelManager(std::string kernel_name, std::string kernel_bin_path)
: kernel_name_(std::move(kernel_name)),
kernel_bin_path_(std::move(kernel_bin_path)) {}
virtual ~TritonKernelManager() = default;
virtual std::unique_ptr<KernelInputs> create_inputs(
size_t num_args,
size_t num_attrs) const {
return std::make_unique<KernelInputs>(num_args, num_attrs);
}
virtual void launch(const LaunchParams& launch_params, void** args) = 0;
protected:
std::string kernel_name_, kernel_bin_path_;
};
inline std::unique_ptr<TritonKernelManager> (
*create_cuda_triton_kernel_manager)(std::string, std::string) = nullptr;
} // namespace torch::nativert

View File

@ -14,6 +14,7 @@
#include <torch/nativert/kernels/HigherOrderKernel.h>
#include <torch/nativert/kernels/KernelFactory.h>
#include <torch/nativert/kernels/PrimKernelRegistry.h>
#include <torch/nativert/kernels/TritonKernel.h>
namespace torch::nativert {
@ -130,6 +131,11 @@ ExecutionKernels KernelFactory::initializeNodeKernels(
} else if (c10::starts_with(
node.target(), "torch.ops.higher_order.call_torchbind")) {
nodeKernels.push_back(std::make_unique<CallTorchBindKernel>(&node));
} else if (c10::starts_with(
node.target(),
"torch.ops.higher_order.triton_kernel_wrapper_functional")) {
nodeKernels.push_back(
std::make_unique<TritonKernel>(&node, pytorchStreamReader.get()));
} else if (
c10::starts_with(
node.target(),

View File

@ -0,0 +1,137 @@
#include <torch/nativert/kernels/TritonKernel.h>
#include <fmt/ostream.h>
#include <c10/util/Enumerate.h>
#include <c10/util/Exception.h>
#include <ATen/Tensor.h>
#include <ATen/core/op_registration/op_registration.h>
#include <torch/nativert/executor/DelegateExecutor.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#else
#include <ATen/ops/empty.h>
#endif
#include <torch/nativert/executor/triton/CpuTritonKernelManager.h>
namespace torch::nativert {
TritonKernel::TritonKernel(
const Node* node,
caffe2::serialize::PyTorchStreamReader* reader)
: OpKernel(node, OpKernelKind::kTritonKernel) {
TORCH_CHECK(reader != nullptr, "reader is null");
std::string kernel_name{};
bool found_grid = false;
for (const auto& attr : node_->attributes()) {
if (attr.name.empty()) {
attr_ptrs_.emplace_back(std::visit(
[](auto&& arg) -> void* {
using T = std::decay_t<decltype(arg)>;
if constexpr (std::is_same_v<T, None>) {
return nullptr;
}
return static_cast<void*>(const_cast<T*>(&arg));
},
attr.value));
} else if (attr.name == "name") {
kernel_name = std::get<std::string>(attr.value);
} else if (attr.name == "grid") {
found_grid = true;
auto grid = std::get<std::vector<int64_t>>(attr.value);
TORCH_CHECK(grid.size() == 3, "grid must be a 3D vector");
launch_params_.grid_dims = GridDims(
static_cast<int>(grid[0]),
static_cast<int>(grid[1]),
static_cast<int>(grid[2]));
} else if (attr.name == "num_warps") {
if (const int num_warps = static_cast<int>(std::get<int64_t>(attr.value));
num_warps > 0) {
launch_params_.num_warps = num_warps;
}
} else if (attr.name == "shared_memory_bytes") {
if (const int shared_memory_bytes =
static_cast<int>(std::get<int64_t>(attr.value));
shared_memory_bytes > 0) {
launch_params_.shared_memory_bytes = shared_memory_bytes;
}
} else if (attr.name == "output_indices") {
output_indices_ = std::get<std::vector<int64_t>>(attr.value);
}
}
TORCH_CHECK(!kernel_name.empty(), "kernel name not found");
TORCH_CHECK(found_grid, "grid attribute not found");
TORCH_CHECK(!output_indices_.empty(), "output_indices attribute not found");
auto kernel_prefix = std::string("data/triton") + "/" + kernel_name;
auto tmp_dir = extractToTemporaryFolder(*reader, kernel_prefix) + "/";
if (reader->hasRecord(kernel_prefix + "/" + kernel_name + ".cubin")) {
TORCH_CHECK(
create_cuda_triton_kernel_manager != nullptr,
"couldn't find cuda loader -- is this a gpu build?");
loader_ = create_cuda_triton_kernel_manager(
kernel_name, tmp_dir + kernel_name + ".cubin");
}
if (reader->hasRecord(kernel_prefix + "/" + kernel_name + ".hsaco")) {
TORCH_CHECK(
create_cuda_triton_kernel_manager != nullptr,
"couldn't find cuda loader -- is this a gpu build?");
loader_ = create_cuda_triton_kernel_manager(
kernel_name, tmp_dir + kernel_name + ".hsaco");
}
if (loader_ == nullptr) {
loader_ = std::unique_ptr<TritonKernelManager>(new CpuTritonKernelManager(
kernel_name,
tmp_dir + kernel_name + ".so",
tmp_dir + kernel_name + ".launcher.so"));
}
}
TritonKernel::~TritonKernel() = default;
void TritonKernel::computeInternal(ExecutionFrame& executionFrame) const {
const auto num_inputs = node_->inputs().size();
const auto num_attrs = attr_ptrs_.size();
auto* loader = const_cast<TritonKernelManager*>(loader_.get());
auto inputs = loader->create_inputs(num_inputs, num_attrs);
for (const auto i : c10::irange(num_inputs)) {
inputs->add_arg(input(i, executionFrame).toTensor().data_ptr());
}
for (const auto i : c10::irange(num_attrs)) {
inputs->add_attribute(attr_ptrs_[i]);
}
loader->launch(launch_params_, inputs->as_void());
auto& out = output(0, executionFrame);
if (out.isNone()) {
auto list = c10::List<at::Tensor>();
for (const auto& i : output_indices_) {
list.emplace_back(input(i, executionFrame).toTensor());
}
out = c10::IValue(std::move(list));
return;
}
// todo: check if this is redundant
auto out_t = out.toTensorList();
for (const auto& i : output_indices_) {
out_t[i] = input(i, executionFrame).toTensor();
}
}
} // namespace torch::nativert

View File

@ -0,0 +1,31 @@
#pragma once
#include <c10/core/Device.h>
#include <torch/nativert/executor/ExecutionFrame.h>
#include <torch/nativert/executor/OpKernel.h>
#include <torch/nativert/executor/triton/TritonKernelManager.h>
#include <torch/nativert/graph/Graph.h>
namespace torch::nativert {
class TritonKernel : public OpKernel {
public:
TritonKernel() = delete;
TritonKernel(
const Node* node,
caffe2::serialize::PyTorchStreamReader* reader);
~TritonKernel() override;
void computeInternal(ExecutionFrame& executionFrame) const override;
private:
std::unique_ptr<TritonKernelManager> loader_;
// unnamed node attributes will be passed as arguments to the kernel
std::vector<void*> attr_ptrs_;
std::vector<int64_t> output_indices_;
LaunchParams launch_params_;
};
} // namespace torch::nativert