Files
pytorch/torch/nativert/executor/triton/CpuTritonKernelManager.cpp
dolpm 3dde5d7f9b [nativert] triton runtime implementation (#161798)
Summary:
att
Test Plan:
ci
Rollback Plan:

Reviewed By: minjang

Differential Revision: D80828148

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161798
Approved by: https://github.com/minjang, https://github.com/SherlockNoMad
2025-09-04 19:00:15 +00:00

92 lines
2.1 KiB
C++

#include <torch/nativert/executor/triton/CpuTritonKernelManager.h>
#include <c10/util/Logging.h>
#ifndef _WIN32
#include <dlfcn.h>
#endif // _WIN32
namespace torch::nativert {
namespace {
void* _dlopen(const char* filename) {
#if defined(_WIN32)
return nullptr;
#else
return dlopen(filename, RTLD_NOW | RTLD_LOCAL);
#endif
}
void* _dlsym(void* handle, const char* name) {
#if defined(_WIN32)
return nullptr;
#else
return dlsym(handle, name);
#endif
}
char* _dlerror() {
#if defined(_WIN32)
throw std::runtime_error("dlerror not supported on Windows");
#else
return dlerror();
#endif
}
} // namespace
CpuTritonKernelManager::CpuTritonKernelManager(
std::string kernel_name,
std::string kernel_bin_path,
std::string kernel_launcher_bin_path)
: TritonKernelManager(std::move(kernel_name), std::move(kernel_bin_path)),
kernel_launcher_bin_path_(std::move(kernel_launcher_bin_path)) {}
void CpuTritonKernelManager::load() {
if (C10_LIKELY(kernel_fn_ != nullptr)) {
return;
}
kernel_handle_.reset(_dlopen(kernel_bin_path_.c_str()));
TORCH_CHECK(
kernel_handle_ != nullptr,
"could not dlopen ",
kernel_bin_path_,
": ",
_dlerror());
launcher_handle_.reset(_dlopen(kernel_launcher_bin_path_.c_str()));
TORCH_CHECK(
launcher_handle_ != nullptr,
"could not dlopen ",
kernel_launcher_bin_path_,
": ",
_dlerror());
kernel_fn_ = _dlsym(kernel_handle_.get(), kernel_name_.c_str());
TORCH_CHECK(
kernel_fn_ != nullptr,
"could not dlsym ",
kernel_name_,
": ",
_dlerror());
launcher_fn_ =
reinterpret_cast<launcher_ptr_t>(_dlsym(launcher_handle_.get(), "run"));
TORCH_CHECK(launcher_fn_ != nullptr, "could not dlsym run: ", _dlerror());
}
void CpuTritonKernelManager::launch(
const LaunchParams& launch_params,
void** args /* { ...inputs, output }*/) {
load();
launcher_fn_(
launch_params.grid_dims.x,
launch_params.grid_dims.y,
launch_params.grid_dims.z,
args,
kernel_fn_);
}
} // namespace torch::nativert