mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "[nativert] triton runtime implementation (#161798)"
This reverts commit 3dde5d7f9bf80dd6623a712bc429e9e4302464b5. Reverted https://github.com/pytorch/pytorch/pull/161798 on behalf of https://github.com/jeanschmidt due to introducing linting failures ([comment](https://github.com/pytorch/pytorch/pull/161798#issuecomment-3255412085))
This commit is contained in:
@ -117,8 +117,6 @@ namespace at::cuda {
|
||||
_(nvrtcGetPTXSize) \
|
||||
_(nvrtcGetPTX) \
|
||||
_(cuModuleLoadData) \
|
||||
_(cuModuleLoad) \
|
||||
_(cuGetErrorString) \
|
||||
_(cuModuleGetFunction) \
|
||||
_(HIPOCCUPANCYMAXACTIVEBLOCKSPERMULTIPROCESSOR) \
|
||||
_(nvrtcGetErrorString) \
|
||||
|
@ -635,12 +635,6 @@ libtorch_nativert_sources = [
|
||||
"torch/nativert/graph/passes/pass_manager/GraphPasses.cpp",
|
||||
"torch/nativert/graph/passes/pass_manager/PassManager.cpp",
|
||||
"torch/nativert/kernels/KernelHandlerRegistry.cpp",
|
||||
"torch/nativert/kernels/TritonKernel.cpp",
|
||||
"torch/nativert/executor/triton/CpuTritonKernelManager.cpp",
|
||||
]
|
||||
|
||||
libtorch_nativert_cuda_sources = [
|
||||
"torch/nativert/executor/triton/CudaTritonKernelManager.cpp",
|
||||
]
|
||||
|
||||
torch_mobile_tracer_sources = [
|
||||
@ -776,7 +770,7 @@ libtorch_cuda_distributed_sources = libtorch_cuda_distributed_base_sources + lib
|
||||
|
||||
libtorch_cuda_sources = libtorch_cuda_core_sources + libtorch_cuda_distributed_sources + [
|
||||
"torch/csrc/cuda/nccl.cpp",
|
||||
] + libtorch_nativert_cuda_sources
|
||||
]
|
||||
|
||||
torch_cpp_srcs = [
|
||||
"torch/csrc/api/src/cuda.cpp", # this just forwards stuff, no real CUDA
|
||||
|
@ -40,16 +40,8 @@ set(NATIVERT_TEST_SRCS
|
||||
${TORCH_ROOT}/torch/nativert/graph/passes/pass_manager/GraphPasses.cpp
|
||||
${TORCH_ROOT}/torch/nativert/graph/passes/pass_manager/PassManager.cpp
|
||||
${TORCH_ROOT}/torch/nativert/kernels/KernelHandlerRegistry.cpp
|
||||
${TORCH_ROOT}/torch/nativert/kernels/TritonKernel.cpp
|
||||
${TORCH_ROOT}/torch/nativert/executor/triton/CpuTritonKernelManager.cpp
|
||||
${TORCH_ROOT}/torch/nativert/executor/DelegateExecutor.cpp
|
||||
)
|
||||
|
||||
if(USE_CUDA)
|
||||
list(APPEND NATIVERT_TEST_SRCS ${TORCH_ROOT}/torch/nativert/executor/triton/CudaTritonKernelManager.cpp)
|
||||
endif(MSVC)
|
||||
|
||||
|
||||
add_executable(test_nativert
|
||||
${TORCH_ROOT}/test/cpp/common/main.cpp
|
||||
${NATIVERT_TEST_SRCS}
|
||||
|
@ -1,14 +0,0 @@
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <torch/nativert/kernels/TritonKernel.h>
|
||||
|
||||
using namespace ::testing;
|
||||
using namespace torch::nativert;
|
||||
|
||||
TEST(TritonKernelManagerRegistrationTests, TestRegister) {
|
||||
#ifndef USE_CUDA
|
||||
EXPECT_TRUE(create_cuda_triton_kernel_manager == nullptr);
|
||||
#else
|
||||
EXPECT_FALSE(create_cuda_triton_kernel_manager == nullptr);
|
||||
#endif // USE_CUDA
|
||||
}
|
@ -11,7 +11,6 @@ enum class OpKernelKind : uint8_t {
|
||||
// static dispatch kernels that don't reuse
|
||||
// out TensorImpl
|
||||
kNativeStaticDispatchKernel,
|
||||
kTritonKernel,
|
||||
};
|
||||
|
||||
} // namespace torch::nativert
|
||||
|
@ -1,91 +0,0 @@
|
||||
#include <torch/nativert/executor/triton/CpuTritonKernelManager.h>
|
||||
|
||||
#include <c10/util/Logging.h>
|
||||
|
||||
#ifndef _WIN32
|
||||
#include <dlfcn.h>
|
||||
#endif // _WIN32
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
namespace {
|
||||
void* _dlopen(const char* filename) {
|
||||
#if defined(_WIN32)
|
||||
return nullptr;
|
||||
#else
|
||||
return dlopen(filename, RTLD_NOW | RTLD_LOCAL);
|
||||
#endif
|
||||
}
|
||||
|
||||
void* _dlsym(void* handle, const char* name) {
|
||||
#if defined(_WIN32)
|
||||
return nullptr;
|
||||
#else
|
||||
return dlsym(handle, name);
|
||||
#endif
|
||||
}
|
||||
|
||||
char* _dlerror() {
|
||||
#if defined(_WIN32)
|
||||
throw std::runtime_error("dlerror not supported on Windows");
|
||||
#else
|
||||
return dlerror();
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
CpuTritonKernelManager::CpuTritonKernelManager(
|
||||
std::string kernel_name,
|
||||
std::string kernel_bin_path,
|
||||
std::string kernel_launcher_bin_path)
|
||||
: TritonKernelManager(std::move(kernel_name), std::move(kernel_bin_path)),
|
||||
kernel_launcher_bin_path_(std::move(kernel_launcher_bin_path)) {}
|
||||
|
||||
void CpuTritonKernelManager::load() {
|
||||
if (C10_LIKELY(kernel_fn_ != nullptr)) {
|
||||
return;
|
||||
}
|
||||
|
||||
kernel_handle_.reset(_dlopen(kernel_bin_path_.c_str()));
|
||||
TORCH_CHECK(
|
||||
kernel_handle_ != nullptr,
|
||||
"could not dlopen ",
|
||||
kernel_bin_path_,
|
||||
": ",
|
||||
_dlerror());
|
||||
|
||||
launcher_handle_.reset(_dlopen(kernel_launcher_bin_path_.c_str()));
|
||||
TORCH_CHECK(
|
||||
launcher_handle_ != nullptr,
|
||||
"could not dlopen ",
|
||||
kernel_launcher_bin_path_,
|
||||
": ",
|
||||
_dlerror());
|
||||
|
||||
kernel_fn_ = _dlsym(kernel_handle_.get(), kernel_name_.c_str());
|
||||
TORCH_CHECK(
|
||||
kernel_fn_ != nullptr,
|
||||
"could not dlsym ",
|
||||
kernel_name_,
|
||||
": ",
|
||||
_dlerror());
|
||||
|
||||
launcher_fn_ =
|
||||
reinterpret_cast<launcher_ptr_t>(_dlsym(launcher_handle_.get(), "run"));
|
||||
TORCH_CHECK(launcher_fn_ != nullptr, "could not dlsym run: ", _dlerror());
|
||||
}
|
||||
|
||||
void CpuTritonKernelManager::launch(
|
||||
const LaunchParams& launch_params,
|
||||
void** args /* { ...inputs, output }*/) {
|
||||
load();
|
||||
launcher_fn_(
|
||||
launch_params.grid_dims.x,
|
||||
launch_params.grid_dims.y,
|
||||
launch_params.grid_dims.z,
|
||||
args,
|
||||
kernel_fn_);
|
||||
}
|
||||
|
||||
} // namespace torch::nativert
|
@ -1,51 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/nativert/executor/triton/TritonKernelManager.h>
|
||||
|
||||
#include <c10/core/Device.h>
|
||||
#include <c10/util/FbcodeMaps.h>
|
||||
|
||||
#ifndef _WIN32
|
||||
#include <dlfcn.h>
|
||||
#endif
|
||||
|
||||
typedef void* kernel_ptr_t;
|
||||
typedef void (
|
||||
*launcher_ptr_t)(uint32_t, uint32_t, uint32_t, void**, kernel_ptr_t);
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
struct DlcloseDeleter {
|
||||
void operator()(void* p) const {
|
||||
if (p) {
|
||||
#if defined(_WIN32)
|
||||
TORCH_CHECK(false, "Windows is not supported");
|
||||
#else
|
||||
dlclose(p);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
class CpuTritonKernelManager final : public TritonKernelManager {
|
||||
public:
|
||||
CpuTritonKernelManager(
|
||||
std::string kernel_name,
|
||||
std::string kernel_bin_path,
|
||||
std::string kernel_launcher_bin_path);
|
||||
~CpuTritonKernelManager() final {}
|
||||
void launch(const LaunchParams& launch_params, void** args) final;
|
||||
|
||||
private:
|
||||
void load();
|
||||
|
||||
kernel_ptr_t kernel_fn_{nullptr};
|
||||
launcher_ptr_t launcher_fn_{nullptr};
|
||||
|
||||
std::unique_ptr<void, DlcloseDeleter> kernel_handle_{nullptr};
|
||||
std::unique_ptr<void, DlcloseDeleter> launcher_handle_{nullptr};
|
||||
|
||||
std::string kernel_launcher_bin_path_;
|
||||
};
|
||||
|
||||
} // namespace torch::nativert
|
@ -1,155 +0,0 @@
|
||||
#include <torch/nativert/executor/triton/TritonKernelManager.h>
|
||||
|
||||
#include <ATen/cuda/Exceptions.h>
|
||||
#include <ATen/cuda/nvrtc_stub/ATenNVRTC.h>
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include <c10/util/FbcodeMaps.h>
|
||||
#include <c10/util/Logging.h>
|
||||
|
||||
namespace {
|
||||
const at::cuda::NVRTC& get_nvrtc() {
|
||||
return at::globalContext().getNVRTC();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
#define CU_LOG_ERROR(fn, result, ...) \
|
||||
{ \
|
||||
LOG(ERROR) << #fn << " returned error: " << result; \
|
||||
const char* errMsg = nullptr; \
|
||||
get_nvrtc().cuGetErrorString(result, &errMsg); \
|
||||
LOG(ERROR) << "cuGetErrorString: " << errMsg; \
|
||||
}
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
// cuda kernels require an extra level of indirection
|
||||
// for who knows what reason.
|
||||
class CudaKernelInputs final : public KernelInputs {
|
||||
public:
|
||||
CudaKernelInputs(size_t num_args, size_t num_attrs)
|
||||
: KernelInputs(num_args, num_attrs), arg_ptrs_(num_args) {};
|
||||
~CudaKernelInputs() final = default;
|
||||
|
||||
void add_arg(void* arg) override {
|
||||
TORCH_CHECK(arg_idx_ < num_args_, "Too many args");
|
||||
arg_ptrs_[arg_idx_] = arg;
|
||||
inputs_[arg_idx_] = reinterpret_cast<void*>(&arg_ptrs_[arg_idx_]);
|
||||
arg_idx_++;
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<void*> arg_ptrs_;
|
||||
};
|
||||
|
||||
class CudaTritonKernelManager final : public TritonKernelManager {
|
||||
public:
|
||||
CudaTritonKernelManager(std::string kernel_name, std::string kernel_bin_path);
|
||||
~CudaTritonKernelManager() final;
|
||||
|
||||
CudaTritonKernelManager(const CudaTritonKernelManager& other);
|
||||
CudaTritonKernelManager& operator=(const CudaTritonKernelManager& other);
|
||||
CudaTritonKernelManager(CudaTritonKernelManager&& other) noexcept;
|
||||
CudaTritonKernelManager& operator=(CudaTritonKernelManager&& other) noexcept;
|
||||
|
||||
void launch(const LaunchParams& launch_params, void** args) final;
|
||||
std::unique_ptr<KernelInputs> create_inputs(size_t num_args, size_t num_attrs)
|
||||
const final {
|
||||
return std::unique_ptr<KernelInputs>(
|
||||
new CudaKernelInputs(num_args, num_attrs));
|
||||
}
|
||||
|
||||
private:
|
||||
CUfunction load();
|
||||
c10::FastMap<c10::DeviceIndex, CUfunction> cache_;
|
||||
std::vector<CUmodule> loaded_modules_;
|
||||
};
|
||||
|
||||
CudaTritonKernelManager::CudaTritonKernelManager(
|
||||
std::string kernel_name,
|
||||
std::string kernel_bin_path)
|
||||
: TritonKernelManager(std::move(kernel_name), std::move(kernel_bin_path)) {
|
||||
TORCH_CHECK(
|
||||
at::globalContext().hasCUDA() || at::globalContext().hasHIP(),
|
||||
"cuda or hip required");
|
||||
};
|
||||
|
||||
CudaTritonKernelManager::~CudaTritonKernelManager() {
|
||||
const auto& nvrtc = get_nvrtc();
|
||||
for (auto& mod : loaded_modules_) {
|
||||
if (CUresult err = nvrtc.cuModuleUnload(mod); err != 0) {
|
||||
CU_LOG_ERROR(nvrtc.cuModuleUnload, err);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
CUfunction CudaTritonKernelManager::load() {
|
||||
const auto idx = c10::cuda::current_device();
|
||||
if (const auto res = cache_.find(idx); res != cache_.end()) {
|
||||
return res->second;
|
||||
}
|
||||
|
||||
const auto& nvrtc = get_nvrtc();
|
||||
|
||||
CUmodule mod_ptr = nullptr;
|
||||
|
||||
if (CUresult err = nvrtc.cuModuleLoad(&mod_ptr, kernel_bin_path_.c_str());
|
||||
err != 0) {
|
||||
CU_LOG_ERROR(nvrtc.cuModuleLoad, err);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
CUfunction func = nullptr;
|
||||
|
||||
if (CUresult err =
|
||||
nvrtc.cuModuleGetFunction(&func, mod_ptr, kernel_name_.c_str());
|
||||
err != 0) {
|
||||
CU_LOG_ERROR(nvrtc.cuModuleGetFunction, err);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
loaded_modules_.emplace_back(mod_ptr);
|
||||
return cache_.emplace(idx, func).first->second;
|
||||
}
|
||||
|
||||
void CudaTritonKernelManager::launch(
|
||||
const LaunchParams& launch_params,
|
||||
void** args /* { ...inputs, output }*/) {
|
||||
const constexpr int kThreadsPerWarp = 2 << 4;
|
||||
|
||||
auto kernel_fn = load();
|
||||
TORCH_CHECK(
|
||||
kernel_fn != nullptr, "failed to load triton kernel: ", kernel_name_);
|
||||
cudaStream_t stream = c10::cuda::getCurrentCUDAStream().stream();
|
||||
|
||||
AT_CUDA_DRIVER_CHECK(get_nvrtc().cuLaunchKernel(
|
||||
kernel_fn,
|
||||
launch_params.grid_dims.x,
|
||||
launch_params.grid_dims.y,
|
||||
launch_params.grid_dims.z,
|
||||
/* blockDimX = */ kThreadsPerWarp * launch_params.num_warps,
|
||||
/* blockDimY = */ 1,
|
||||
/* blockDimZ = */ 1,
|
||||
/* sharedMemBytes = */ launch_params.shared_memory_bytes,
|
||||
stream,
|
||||
args,
|
||||
nullptr));
|
||||
}
|
||||
|
||||
static std::unique_ptr<TritonKernelManager> _create_cuda_triton_kernel_manager(
|
||||
std::string kernel_name,
|
||||
std::string kernel_bin_path) {
|
||||
return std::unique_ptr<TritonKernelManager>(new CudaTritonKernelManager(
|
||||
std::move(kernel_name), std::move(kernel_bin_path)));
|
||||
}
|
||||
|
||||
} // namespace torch::nativert
|
||||
|
||||
namespace {
|
||||
static bool _initialized_cuda_triton_kernel_manager = []() {
|
||||
torch::nativert::create_cuda_triton_kernel_manager =
|
||||
&torch::nativert::_create_cuda_triton_kernel_manager;
|
||||
return true;
|
||||
}();
|
||||
} // namespace
|
@ -1,75 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
|
||||
#include <c10/util/Exception.h>
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
struct GridDims {
|
||||
public:
|
||||
GridDims(int x = 1, int y = 1, int z = 1) : x(x), y(y), z(z) {}
|
||||
int x;
|
||||
int y;
|
||||
int z;
|
||||
};
|
||||
|
||||
struct LaunchParams {
|
||||
int num_warps = 4;
|
||||
int shared_memory_bytes = 0;
|
||||
GridDims grid_dims;
|
||||
};
|
||||
|
||||
class KernelInputs {
|
||||
public:
|
||||
KernelInputs(size_t num_args, size_t num_attrs)
|
||||
: num_args_(num_args),
|
||||
inputs_(num_args + num_attrs),
|
||||
num_attrs_(num_attrs) {}
|
||||
virtual ~KernelInputs() = default;
|
||||
|
||||
virtual void add_arg(void* arg) {
|
||||
TORCH_CHECK(arg_idx_ < num_args_, "Too many args");
|
||||
inputs_[arg_idx_++] = arg;
|
||||
}
|
||||
|
||||
void add_attribute(void* attr) {
|
||||
TORCH_CHECK(attr_idx_ < num_attrs_, "Too many attributes");
|
||||
inputs_[num_args_ + attr_idx_++] = attr;
|
||||
}
|
||||
|
||||
void** as_void() {
|
||||
return inputs_.data();
|
||||
}
|
||||
|
||||
protected:
|
||||
size_t num_args_;
|
||||
size_t arg_idx_ = 0;
|
||||
std::vector<void*> inputs_;
|
||||
|
||||
private:
|
||||
size_t num_attrs_;
|
||||
size_t attr_idx_ = 0;
|
||||
};
|
||||
|
||||
class TritonKernelManager {
|
||||
public:
|
||||
TritonKernelManager(std::string kernel_name, std::string kernel_bin_path)
|
||||
: kernel_name_(std::move(kernel_name)),
|
||||
kernel_bin_path_(std::move(kernel_bin_path)) {}
|
||||
virtual ~TritonKernelManager() = default;
|
||||
virtual std::unique_ptr<KernelInputs> create_inputs(
|
||||
size_t num_args,
|
||||
size_t num_attrs) const {
|
||||
return std::make_unique<KernelInputs>(num_args, num_attrs);
|
||||
}
|
||||
virtual void launch(const LaunchParams& launch_params, void** args) = 0;
|
||||
|
||||
protected:
|
||||
std::string kernel_name_, kernel_bin_path_;
|
||||
};
|
||||
|
||||
inline std::unique_ptr<TritonKernelManager> (
|
||||
*create_cuda_triton_kernel_manager)(std::string, std::string) = nullptr;
|
||||
|
||||
} // namespace torch::nativert
|
@ -14,7 +14,6 @@
|
||||
#include <torch/nativert/kernels/HigherOrderKernel.h>
|
||||
#include <torch/nativert/kernels/KernelFactory.h>
|
||||
#include <torch/nativert/kernels/PrimKernelRegistry.h>
|
||||
#include <torch/nativert/kernels/TritonKernel.h>
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
@ -131,11 +130,6 @@ ExecutionKernels KernelFactory::initializeNodeKernels(
|
||||
} else if (c10::starts_with(
|
||||
node.target(), "torch.ops.higher_order.call_torchbind")) {
|
||||
nodeKernels.push_back(std::make_unique<CallTorchBindKernel>(&node));
|
||||
} else if (c10::starts_with(
|
||||
node.target(),
|
||||
"torch.ops.higher_order.triton_kernel_wrapper_functional")) {
|
||||
nodeKernels.push_back(
|
||||
std::make_unique<TritonKernel>(&node, pytorchStreamReader.get()));
|
||||
} else if (
|
||||
c10::starts_with(
|
||||
node.target(),
|
||||
|
@ -1,137 +0,0 @@
|
||||
#include <torch/nativert/kernels/TritonKernel.h>
|
||||
|
||||
#include <fmt/ostream.h>
|
||||
|
||||
#include <c10/util/Enumerate.h>
|
||||
#include <c10/util/Exception.h>
|
||||
|
||||
#include <ATen/Tensor.h>
|
||||
#include <ATen/core/op_registration/op_registration.h>
|
||||
|
||||
#include <torch/nativert/executor/DelegateExecutor.h>
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/Functions.h>
|
||||
#else
|
||||
#include <ATen/ops/empty.h>
|
||||
#endif
|
||||
|
||||
#include <torch/nativert/executor/triton/CpuTritonKernelManager.h>
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
TritonKernel::TritonKernel(
|
||||
const Node* node,
|
||||
caffe2::serialize::PyTorchStreamReader* reader)
|
||||
: OpKernel(node, OpKernelKind::kTritonKernel) {
|
||||
TORCH_CHECK(reader != nullptr, "reader is null");
|
||||
|
||||
std::string kernel_name{};
|
||||
bool found_grid = false;
|
||||
for (const auto& attr : node_->attributes()) {
|
||||
if (attr.name.empty()) {
|
||||
attr_ptrs_.emplace_back(std::visit(
|
||||
[](auto&& arg) -> void* {
|
||||
using T = std::decay_t<decltype(arg)>;
|
||||
if constexpr (std::is_same_v<T, None>) {
|
||||
return nullptr;
|
||||
}
|
||||
return static_cast<void*>(const_cast<T*>(&arg));
|
||||
},
|
||||
attr.value));
|
||||
} else if (attr.name == "name") {
|
||||
kernel_name = std::get<std::string>(attr.value);
|
||||
} else if (attr.name == "grid") {
|
||||
found_grid = true;
|
||||
auto grid = std::get<std::vector<int64_t>>(attr.value);
|
||||
TORCH_CHECK(grid.size() == 3, "grid must be a 3D vector");
|
||||
launch_params_.grid_dims = GridDims(
|
||||
static_cast<int>(grid[0]),
|
||||
static_cast<int>(grid[1]),
|
||||
static_cast<int>(grid[2]));
|
||||
} else if (attr.name == "num_warps") {
|
||||
if (const int num_warps = static_cast<int>(std::get<int64_t>(attr.value));
|
||||
num_warps > 0) {
|
||||
launch_params_.num_warps = num_warps;
|
||||
}
|
||||
} else if (attr.name == "shared_memory_bytes") {
|
||||
if (const int shared_memory_bytes =
|
||||
static_cast<int>(std::get<int64_t>(attr.value));
|
||||
shared_memory_bytes > 0) {
|
||||
launch_params_.shared_memory_bytes = shared_memory_bytes;
|
||||
}
|
||||
} else if (attr.name == "output_indices") {
|
||||
output_indices_ = std::get<std::vector<int64_t>>(attr.value);
|
||||
}
|
||||
}
|
||||
|
||||
TORCH_CHECK(!kernel_name.empty(), "kernel name not found");
|
||||
TORCH_CHECK(found_grid, "grid attribute not found");
|
||||
TORCH_CHECK(!output_indices_.empty(), "output_indices attribute not found");
|
||||
|
||||
auto kernel_prefix = std::string("data/triton") + "/" + kernel_name;
|
||||
|
||||
auto tmp_dir = extractToTemporaryFolder(*reader, kernel_prefix) + "/";
|
||||
|
||||
if (reader->hasRecord(kernel_prefix + "/" + kernel_name + ".cubin")) {
|
||||
TORCH_CHECK(
|
||||
create_cuda_triton_kernel_manager != nullptr,
|
||||
"couldn't find cuda loader -- is this a gpu build?");
|
||||
loader_ = create_cuda_triton_kernel_manager(
|
||||
kernel_name, tmp_dir + kernel_name + ".cubin");
|
||||
}
|
||||
|
||||
if (reader->hasRecord(kernel_prefix + "/" + kernel_name + ".hsaco")) {
|
||||
TORCH_CHECK(
|
||||
create_cuda_triton_kernel_manager != nullptr,
|
||||
"couldn't find cuda loader -- is this a gpu build?");
|
||||
loader_ = create_cuda_triton_kernel_manager(
|
||||
kernel_name, tmp_dir + kernel_name + ".hsaco");
|
||||
}
|
||||
|
||||
if (loader_ == nullptr) {
|
||||
loader_ = std::unique_ptr<TritonKernelManager>(new CpuTritonKernelManager(
|
||||
kernel_name,
|
||||
tmp_dir + kernel_name + ".so",
|
||||
tmp_dir + kernel_name + ".launcher.so"));
|
||||
}
|
||||
}
|
||||
|
||||
TritonKernel::~TritonKernel() = default;
|
||||
|
||||
void TritonKernel::computeInternal(ExecutionFrame& executionFrame) const {
|
||||
const auto num_inputs = node_->inputs().size();
|
||||
const auto num_attrs = attr_ptrs_.size();
|
||||
|
||||
auto* loader = const_cast<TritonKernelManager*>(loader_.get());
|
||||
|
||||
auto inputs = loader->create_inputs(num_inputs, num_attrs);
|
||||
|
||||
for (const auto i : c10::irange(num_inputs)) {
|
||||
inputs->add_arg(input(i, executionFrame).toTensor().data_ptr());
|
||||
}
|
||||
|
||||
for (const auto i : c10::irange(num_attrs)) {
|
||||
inputs->add_attribute(attr_ptrs_[i]);
|
||||
}
|
||||
|
||||
loader->launch(launch_params_, inputs->as_void());
|
||||
|
||||
auto& out = output(0, executionFrame);
|
||||
if (out.isNone()) {
|
||||
auto list = c10::List<at::Tensor>();
|
||||
for (const auto& i : output_indices_) {
|
||||
list.emplace_back(input(i, executionFrame).toTensor());
|
||||
}
|
||||
out = c10::IValue(std::move(list));
|
||||
return;
|
||||
}
|
||||
|
||||
// todo: check if this is redundant
|
||||
auto out_t = out.toTensorList();
|
||||
for (const auto& i : output_indices_) {
|
||||
out_t[i] = input(i, executionFrame).toTensor();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace torch::nativert
|
@ -1,31 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/core/Device.h>
|
||||
|
||||
#include <torch/nativert/executor/ExecutionFrame.h>
|
||||
#include <torch/nativert/executor/OpKernel.h>
|
||||
#include <torch/nativert/executor/triton/TritonKernelManager.h>
|
||||
#include <torch/nativert/graph/Graph.h>
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
class TritonKernel : public OpKernel {
|
||||
public:
|
||||
TritonKernel() = delete;
|
||||
TritonKernel(
|
||||
const Node* node,
|
||||
caffe2::serialize::PyTorchStreamReader* reader);
|
||||
~TritonKernel() override;
|
||||
|
||||
void computeInternal(ExecutionFrame& executionFrame) const override;
|
||||
|
||||
private:
|
||||
std::unique_ptr<TritonKernelManager> loader_;
|
||||
|
||||
// unnamed node attributes will be passed as arguments to the kernel
|
||||
std::vector<void*> attr_ptrs_;
|
||||
std::vector<int64_t> output_indices_;
|
||||
LaunchParams launch_params_;
|
||||
};
|
||||
|
||||
} // namespace torch::nativert
|
Reference in New Issue
Block a user