mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: att Test Plan: ci Rollback Plan: Reviewed By: minjang Differential Revision: D80828148 Pull Request resolved: https://github.com/pytorch/pytorch/pull/161798 Approved by: https://github.com/minjang, https://github.com/SherlockNoMad
92 lines
2.1 KiB
C++
92 lines
2.1 KiB
C++
#include <torch/nativert/executor/triton/CpuTritonKernelManager.h>
|
|
|
|
#include <c10/util/Logging.h>
|
|
|
|
#ifndef _WIN32
|
|
#include <dlfcn.h>
|
|
#endif // _WIN32
|
|
|
|
namespace torch::nativert {
|
|
|
|
namespace {
|
|
void* _dlopen(const char* filename) {
|
|
#if defined(_WIN32)
|
|
return nullptr;
|
|
#else
|
|
return dlopen(filename, RTLD_NOW | RTLD_LOCAL);
|
|
#endif
|
|
}
|
|
|
|
void* _dlsym(void* handle, const char* name) {
|
|
#if defined(_WIN32)
|
|
return nullptr;
|
|
#else
|
|
return dlsym(handle, name);
|
|
#endif
|
|
}
|
|
|
|
char* _dlerror() {
|
|
#if defined(_WIN32)
|
|
throw std::runtime_error("dlerror not supported on Windows");
|
|
#else
|
|
return dlerror();
|
|
#endif
|
|
}
|
|
|
|
} // namespace
|
|
|
|
CpuTritonKernelManager::CpuTritonKernelManager(
|
|
std::string kernel_name,
|
|
std::string kernel_bin_path,
|
|
std::string kernel_launcher_bin_path)
|
|
: TritonKernelManager(std::move(kernel_name), std::move(kernel_bin_path)),
|
|
kernel_launcher_bin_path_(std::move(kernel_launcher_bin_path)) {}
|
|
|
|
void CpuTritonKernelManager::load() {
|
|
if (C10_LIKELY(kernel_fn_ != nullptr)) {
|
|
return;
|
|
}
|
|
|
|
kernel_handle_.reset(_dlopen(kernel_bin_path_.c_str()));
|
|
TORCH_CHECK(
|
|
kernel_handle_ != nullptr,
|
|
"could not dlopen ",
|
|
kernel_bin_path_,
|
|
": ",
|
|
_dlerror());
|
|
|
|
launcher_handle_.reset(_dlopen(kernel_launcher_bin_path_.c_str()));
|
|
TORCH_CHECK(
|
|
launcher_handle_ != nullptr,
|
|
"could not dlopen ",
|
|
kernel_launcher_bin_path_,
|
|
": ",
|
|
_dlerror());
|
|
|
|
kernel_fn_ = _dlsym(kernel_handle_.get(), kernel_name_.c_str());
|
|
TORCH_CHECK(
|
|
kernel_fn_ != nullptr,
|
|
"could not dlsym ",
|
|
kernel_name_,
|
|
": ",
|
|
_dlerror());
|
|
|
|
launcher_fn_ =
|
|
reinterpret_cast<launcher_ptr_t>(_dlsym(launcher_handle_.get(), "run"));
|
|
TORCH_CHECK(launcher_fn_ != nullptr, "could not dlsym run: ", _dlerror());
|
|
}
|
|
|
|
void CpuTritonKernelManager::launch(
|
|
const LaunchParams& launch_params,
|
|
void** args /* { ...inputs, output }*/) {
|
|
load();
|
|
launcher_fn_(
|
|
launch_params.grid_dims.x,
|
|
launch_params.grid_dims.y,
|
|
launch_params.grid_dims.z,
|
|
args,
|
|
kernel_fn_);
|
|
}
|
|
|
|
} // namespace torch::nativert
|