mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[nativert][triton] improve hardware registration (#162499)
Summary: att Test Plan: ci Rollback Plan: Differential Revision: D82031814 Pull Request resolved: https://github.com/pytorch/pytorch/pull/162499 Approved by: https://github.com/angelayi
This commit is contained in:
@ -552,6 +552,11 @@ if(USE_CUDA OR USE_ROCM)
|
||||
append_filelist("libtorch_cuda_core_sources" Caffe2_GPU_HIP_JIT_FUSERS_SRCS)
|
||||
endif()
|
||||
|
||||
if(USE_CUDA)
|
||||
# eventually do rocm
|
||||
append_filelist("libtorch_nativert_cuda_sources" Caffe2_GPU_SRCS)
|
||||
endif()
|
||||
|
||||
if(USE_CUDA)
|
||||
list(APPEND Caffe2_GPU_CU_SRCS ${Caffe2_GPU_HIP_JIT_FUSERS_SRCS})
|
||||
add_library(caffe2_nvrtc SHARED ${ATen_NVRTC_STUB_SRCS})
|
||||
|
@ -40,21 +40,24 @@ set(NATIVERT_TEST_SRCS
|
||||
${TORCH_ROOT}/torch/nativert/graph/passes/pass_manager/GraphPasses.cpp
|
||||
${TORCH_ROOT}/torch/nativert/graph/passes/pass_manager/PassManager.cpp
|
||||
${TORCH_ROOT}/torch/nativert/kernels/KernelHandlerRegistry.cpp
|
||||
${TORCH_ROOT}/torch/nativert/kernels/TritonKernel.cpp
|
||||
${TORCH_ROOT}/torch/nativert/executor/triton/CpuTritonKernelManager.cpp
|
||||
${TORCH_ROOT}/torch/nativert/kernels/TritonKernel.cpp
|
||||
${TORCH_ROOT}/torch/nativert/executor/DelegateExecutor.cpp
|
||||
)
|
||||
|
||||
if(USE_CUDA)
|
||||
list(APPEND NATIVERT_TEST_SRCS ${TORCH_ROOT}/torch/nativert/executor/triton/CudaTritonKernelManager.cpp)
|
||||
endif(MSVC)
|
||||
|
||||
endif()
|
||||
|
||||
add_executable(test_nativert
|
||||
${TORCH_ROOT}/test/cpp/common/main.cpp
|
||||
${NATIVERT_TEST_SRCS}
|
||||
)
|
||||
|
||||
if(MSVC)
|
||||
target_compile_definitions(test_nativert PRIVATE NATIVERT_MSVC_TEST)
|
||||
endif()
|
||||
|
||||
# TODO temporary until we can delete the old gtest polyfills.
|
||||
target_compile_definitions(test_nativert PRIVATE USE_GTEST)
|
||||
|
||||
|
@ -6,9 +6,20 @@ using namespace ::testing;
|
||||
using namespace torch::nativert;
|
||||
|
||||
TEST(TritonKernelManagerRegistrationTests, TestRegister) {
|
||||
#ifndef USE_CUDA
|
||||
EXPECT_TRUE(create_cuda_triton_kernel_manager == nullptr);
|
||||
EXPECT_TRUE(TritonKernelManagerRegistry()->Has(at::kCPU));
|
||||
|
||||
#ifdef USE_CUDA
|
||||
#ifdef USE_ROCM
|
||||
EXPECT_TRUE(TritonKernelManagerRegistry()->Has(at::kHIP));
|
||||
EXPECT_FALSE(TritonKernelManagerRegistry()->Has(at::kCUDA));
|
||||
|
||||
#else
|
||||
EXPECT_FALSE(create_cuda_triton_kernel_manager == nullptr);
|
||||
EXPECT_TRUE(TritonKernelManagerRegistry()->Has(at::kCUDA));
|
||||
EXPECT_FALSE(TritonKernelManagerRegistry()->Has(at::kHIP));
|
||||
|
||||
#endif // USE_ROCM
|
||||
#else
|
||||
EXPECT_FALSE(TritonKernelManagerRegistry()->Has(at::kCUDA));
|
||||
EXPECT_FALSE(TritonKernelManagerRegistry()->Has(at::kHIP));
|
||||
#endif // USE_CUDA
|
||||
}
|
||||
|
@ -1,5 +1,6 @@
|
||||
#include <torch/nativert/executor/triton/CpuTritonKernelManager.h>
|
||||
#include <torch/nativert/executor/triton/TritonKernelManager.h>
|
||||
|
||||
#include <c10/util/FbcodeMaps.h>
|
||||
#include <c10/util/Logging.h>
|
||||
|
||||
#ifndef _WIN32
|
||||
@ -35,6 +36,43 @@ char* _dlerror() {
|
||||
|
||||
} // namespace
|
||||
|
||||
typedef void* kernel_ptr_t;
|
||||
typedef void (
|
||||
*launcher_ptr_t)(uint32_t, uint32_t, uint32_t, void**, kernel_ptr_t);
|
||||
|
||||
struct DlcloseDeleter {
|
||||
void operator()(void* p) const {
|
||||
if (p) {
|
||||
#if defined(_WIN32)
|
||||
TORCH_CHECK(false, "Windows is not supported");
|
||||
#else
|
||||
dlclose(p);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
class CpuTritonKernelManager final : public TritonKernelManager {
|
||||
public:
|
||||
CpuTritonKernelManager(
|
||||
std::string kernel_name,
|
||||
std::string kernel_bin_path,
|
||||
std::string kernel_launcher_bin_path);
|
||||
~CpuTritonKernelManager() final = default;
|
||||
void launch(const LaunchParams& launch_params, void** args) final;
|
||||
|
||||
private:
|
||||
void load();
|
||||
|
||||
kernel_ptr_t kernel_fn_{nullptr};
|
||||
launcher_ptr_t launcher_fn_{nullptr};
|
||||
|
||||
std::unique_ptr<void, DlcloseDeleter> kernel_handle_{nullptr};
|
||||
std::unique_ptr<void, DlcloseDeleter> launcher_handle_{nullptr};
|
||||
|
||||
std::string kernel_launcher_bin_path_;
|
||||
};
|
||||
|
||||
CpuTritonKernelManager::CpuTritonKernelManager(
|
||||
std::string kernel_name,
|
||||
std::string kernel_bin_path,
|
||||
@ -88,4 +126,21 @@ void CpuTritonKernelManager::launch(
|
||||
kernel_fn_);
|
||||
}
|
||||
|
||||
namespace {
|
||||
std::unique_ptr<TritonKernelManager> create_cpu_triton_kernel_manager(
|
||||
std::string kernel_name,
|
||||
std::string kernel_bin_path,
|
||||
std::string kernel_launcher_bin_path) {
|
||||
return std::make_unique<CpuTritonKernelManager>(
|
||||
std::move(kernel_name),
|
||||
std::move(kernel_bin_path),
|
||||
std::move(kernel_launcher_bin_path));
|
||||
}
|
||||
} // namespace
|
||||
|
||||
C10_REGISTER_TYPED_CREATOR(
|
||||
TritonKernelManagerRegistry,
|
||||
at::kCPU,
|
||||
create_cpu_triton_kernel_manager)
|
||||
|
||||
} // namespace torch::nativert
|
||||
|
@ -1,51 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/nativert/executor/triton/TritonKernelManager.h>
|
||||
|
||||
#include <c10/core/Device.h>
|
||||
#include <c10/util/FbcodeMaps.h>
|
||||
|
||||
#ifndef _WIN32
|
||||
#include <dlfcn.h>
|
||||
#endif
|
||||
|
||||
typedef void* kernel_ptr_t;
|
||||
typedef void (
|
||||
*launcher_ptr_t)(uint32_t, uint32_t, uint32_t, void**, kernel_ptr_t);
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
struct DlcloseDeleter {
|
||||
void operator()(void* p) const {
|
||||
if (p) {
|
||||
#if defined(_WIN32)
|
||||
TORCH_CHECK(false, "Windows is not supported");
|
||||
#else
|
||||
dlclose(p);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
class CpuTritonKernelManager final : public TritonKernelManager {
|
||||
public:
|
||||
CpuTritonKernelManager(
|
||||
std::string kernel_name,
|
||||
std::string kernel_bin_path,
|
||||
std::string kernel_launcher_bin_path);
|
||||
~CpuTritonKernelManager() final = default;
|
||||
void launch(const LaunchParams& launch_params, void** args) final;
|
||||
|
||||
private:
|
||||
void load();
|
||||
|
||||
kernel_ptr_t kernel_fn_{nullptr};
|
||||
launcher_ptr_t launcher_fn_{nullptr};
|
||||
|
||||
std::unique_ptr<void, DlcloseDeleter> kernel_handle_{nullptr};
|
||||
std::unique_ptr<void, DlcloseDeleter> launcher_handle_{nullptr};
|
||||
|
||||
std::string kernel_launcher_bin_path_;
|
||||
};
|
||||
|
||||
} // namespace torch::nativert
|
@ -29,7 +29,7 @@ namespace torch::nativert {
|
||||
class CudaKernelInputs final : public KernelInputs {
|
||||
public:
|
||||
CudaKernelInputs(size_t num_args, size_t num_attrs)
|
||||
: KernelInputs(num_args, num_attrs), arg_ptrs_(num_args) {};
|
||||
: KernelInputs(num_args, num_attrs), arg_ptrs_(num_args) {}
|
||||
~CudaKernelInputs() final = default;
|
||||
|
||||
void add_arg(void* arg) override {
|
||||
@ -73,7 +73,7 @@ CudaTritonKernelManager::CudaTritonKernelManager(
|
||||
TORCH_CHECK(
|
||||
at::globalContext().hasCUDA() || at::globalContext().hasHIP(),
|
||||
"cuda or hip required");
|
||||
};
|
||||
}
|
||||
|
||||
CudaTritonKernelManager::~CudaTritonKernelManager() {
|
||||
const auto& nvrtc = get_nvrtc();
|
||||
@ -137,19 +137,31 @@ void CudaTritonKernelManager::launch(
|
||||
nullptr));
|
||||
}
|
||||
|
||||
static std::unique_ptr<TritonKernelManager> _create_cuda_triton_kernel_manager(
|
||||
namespace {
|
||||
std::unique_ptr<TritonKernelManager> create_cuda_triton_kernel_manager(
|
||||
std::string kernel_name,
|
||||
std::string kernel_bin_path) {
|
||||
std::string kernel_bin_path,
|
||||
// NOLINTNEXTLINE(performance-unnecessary-value-param)
|
||||
[[maybe_unused]] std::string kernel_launcher_bin_path) {
|
||||
return std::make_unique<CudaTritonKernelManager>(
|
||||
std::move(kernel_name), std::move(kernel_bin_path));
|
||||
}
|
||||
} // namespace
|
||||
|
||||
#ifdef USE_ROCM
|
||||
|
||||
C10_REGISTER_TYPED_CREATOR(
|
||||
TritonKernelManagerRegistry,
|
||||
at::kHIP,
|
||||
create_cuda_triton_kernel_manager)
|
||||
|
||||
#else
|
||||
|
||||
C10_REGISTER_TYPED_CREATOR(
|
||||
TritonKernelManagerRegistry,
|
||||
at::kCUDA,
|
||||
create_cuda_triton_kernel_manager)
|
||||
|
||||
#endif // USE_ROCM
|
||||
|
||||
} // namespace torch::nativert
|
||||
|
||||
namespace {
|
||||
static bool _initialized_cuda_triton_kernel_manager = []() {
|
||||
torch::nativert::create_cuda_triton_kernel_manager =
|
||||
&torch::nativert::_create_cuda_triton_kernel_manager;
|
||||
return true;
|
||||
}();
|
||||
} // namespace
|
||||
|
@ -2,7 +2,9 @@
|
||||
|
||||
#include <string>
|
||||
|
||||
#include <c10/core/DeviceType.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/Registry.h>
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
@ -69,7 +71,13 @@ class TritonKernelManager {
|
||||
std::string kernel_name_, kernel_bin_path_;
|
||||
};
|
||||
|
||||
inline std::unique_ptr<TritonKernelManager> (
|
||||
*create_cuda_triton_kernel_manager)(std::string, std::string) = nullptr;
|
||||
C10_DECLARE_TYPED_REGISTRY(
|
||||
TritonKernelManagerRegistry,
|
||||
c10::DeviceType,
|
||||
TritonKernelManager,
|
||||
std::unique_ptr,
|
||||
std::string /* kernel_name */,
|
||||
std::string /* kernel_bin_path */,
|
||||
std::string /* kernel_launcher_bin_path */);
|
||||
|
||||
} // namespace torch::nativert
|
||||
|
@ -16,10 +16,20 @@
|
||||
#include <ATen/ops/empty.h>
|
||||
#endif
|
||||
|
||||
#include <torch/nativert/executor/triton/CpuTritonKernelManager.h>
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
// in this case, we want to use the symbol from torch_cpu.dll
|
||||
#ifndef NATIVERT_MSVC_TEST
|
||||
C10_DEFINE_TYPED_REGISTRY(
|
||||
TritonKernelManagerRegistry,
|
||||
c10::DeviceType,
|
||||
TritonKernelManager,
|
||||
std::unique_ptr,
|
||||
std::string /* kernel_name */,
|
||||
std::string /* kernel_bin_path */,
|
||||
std::string /* kernel_launcher_bin_path */)
|
||||
#endif
|
||||
|
||||
TritonKernel::TritonKernel(
|
||||
const Node* node,
|
||||
caffe2::serialize::PyTorchStreamReader* reader)
|
||||
@ -74,27 +84,28 @@ TritonKernel::TritonKernel(
|
||||
auto tmp_dir = extractToTemporaryFolder(*reader, kernel_prefix) + "/";
|
||||
|
||||
if (reader->hasRecord(kernel_prefix + "/" + kernel_name + ".cubin")) {
|
||||
loader_ = TritonKernelManagerRegistry()->Create(
|
||||
at::kCUDA, kernel_name, tmp_dir + kernel_name + ".cubin", "");
|
||||
TORCH_CHECK(
|
||||
create_cuda_triton_kernel_manager != nullptr,
|
||||
loader_ != nullptr,
|
||||
"couldn't find cuda loader -- is this a gpu build?");
|
||||
loader_ = create_cuda_triton_kernel_manager(
|
||||
kernel_name, tmp_dir + kernel_name + ".cubin");
|
||||
}
|
||||
|
||||
if (reader->hasRecord(kernel_prefix + "/" + kernel_name + ".hsaco")) {
|
||||
} else if (reader->hasRecord(kernel_prefix + "/" + kernel_name + ".hsaco")) {
|
||||
loader_ = TritonKernelManagerRegistry()->Create(
|
||||
at::kHIP, kernel_name, tmp_dir + kernel_name + ".hsaco", "");
|
||||
TORCH_CHECK(
|
||||
create_cuda_triton_kernel_manager != nullptr,
|
||||
loader_ != nullptr,
|
||||
"couldn't find cuda loader -- is this a gpu build?");
|
||||
loader_ = create_cuda_triton_kernel_manager(
|
||||
kernel_name, tmp_dir + kernel_name + ".hsaco");
|
||||
}
|
||||
|
||||
if (loader_ == nullptr) {
|
||||
loader_ = std::unique_ptr<TritonKernelManager>(new CpuTritonKernelManager(
|
||||
} else {
|
||||
loader_ = TritonKernelManagerRegistry()->Create(
|
||||
at::kCPU,
|
||||
kernel_name,
|
||||
tmp_dir + kernel_name + ".so",
|
||||
tmp_dir + kernel_name + ".launcher.so"));
|
||||
tmp_dir + kernel_name + ".launcher.so");
|
||||
}
|
||||
|
||||
TORCH_CHECK(
|
||||
loader_ != nullptr,
|
||||
"couldn't find triton kernel loader -- are you trying to run gpu kernels on a cpu build?");
|
||||
}
|
||||
|
||||
TritonKernel::~TritonKernel() = default;
|
||||
|
Reference in New Issue
Block a user