mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 21:49:24 +08:00
[aotinductor] Add AOTIModelRunner as a utility class (#110891)
Summary: Introduce a utility class AOTIModelRunner to take care of running an AOTInductor compiled model. It does things like dlopen a model, initialize the model container, setup inputs and outputs, and destroy the model container. Pull Request resolved: https://github.com/pytorch/pytorch/pull/110891 Approved by: https://github.com/chenyang78 ghstack dependencies: #110652
This commit is contained in:
committed by
PyTorch MergeBot
parent
b17c247eb1
commit
3058700f7f
@ -501,6 +501,7 @@ lazy_tensor_core_python_sources = [
|
||||
]
|
||||
|
||||
inductor_core_resources = [
|
||||
"torch/csrc/inductor/aoti_model_runner.cpp",
|
||||
"torch/csrc/inductor/aoti_torch/shim_common.cpp",
|
||||
"torch/csrc/inductor/aoti_torch/tensor_converter.cpp",
|
||||
"torch/csrc/inductor/inductor_ops.cpp",
|
||||
@ -686,6 +687,7 @@ libtorch_cuda_core_sources = [
|
||||
"torch/csrc/CudaIPCTypes.cpp",
|
||||
"torch/csrc/cuda/comm.cpp",
|
||||
"torch/csrc/cuda/memory_snapshot.cpp",
|
||||
"torch/csrc/inductor/aoti_model_runner_cuda.cpp",
|
||||
"torch/csrc/inductor/aoti_torch/shim_cuda.cpp",
|
||||
"torch/csrc/jit/codegen/fuser/cuda/fused_kernel.cpp",
|
||||
"torch/csrc/profiler/stubs/cuda.cpp",
|
||||
|
@ -9,6 +9,7 @@ set(INDUCTOR_TEST_SRCS
|
||||
add_executable(test_aot_inductor
|
||||
${TORCH_ROOT}/test/cpp/common/main.cpp
|
||||
${INDUCTOR_TEST_SRCS}
|
||||
data.pt
|
||||
)
|
||||
|
||||
# TODO temporary until we can delete the old gtest polyfills.
|
||||
@ -16,22 +17,18 @@ target_compile_definitions(test_aot_inductor PRIVATE USE_GTEST)
|
||||
|
||||
# Define a custom command to generate the library
|
||||
add_custom_command(
|
||||
OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/libmodel.so
|
||||
OUTPUT data.pt
|
||||
COMMAND python ${AOT_INDUCTOR_TEST_ROOT}/test.py
|
||||
DEPENDS ${AOT_INDUCTOR_TEST_ROOT}/test.py
|
||||
)
|
||||
add_custom_target(model_so ALL DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/libmodel.so)
|
||||
add_dependencies(test_aot_inductor model_so)
|
||||
|
||||
target_link_libraries(test_aot_inductor PRIVATE
|
||||
torch
|
||||
gtest
|
||||
${CMAKE_CURRENT_BINARY_DIR}/libmodel.so
|
||||
)
|
||||
|
||||
if(USE_CUDA)
|
||||
target_include_directories(test_aot_inductor PRIVATE ${ATen_CUDA_INCLUDE})
|
||||
|
||||
target_compile_definitions(test_aot_inductor PRIVATE USE_CUDA CMAKE_CURRENT_BINARY_DIR=${CMAKE_CURRENT_BINARY_DIR})
|
||||
endif()
|
||||
|
||||
|
@ -5,9 +5,8 @@
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <torch/csrc/inductor/aoti_runtime/interface.h>
|
||||
#include <torch/csrc/inductor/aoti_torch/tensor_converter.h>
|
||||
#include <torch/csrc/inductor/aoti_model_runner.h>
|
||||
#include <torch/csrc/inductor/aoti_model_runner_cuda.h>
|
||||
#include <torch/script.h>
|
||||
#include <torch/torch.h>
|
||||
|
||||
@ -15,60 +14,24 @@
|
||||
#define STRINGIZE(x) STR_VALUE(x)
|
||||
|
||||
namespace torch {
|
||||
namespace aot_inductor {
|
||||
namespace inductor {
|
||||
|
||||
TEST(AotInductorTest, BasicTest) {
|
||||
torch::NoGradGuard no_grad;
|
||||
|
||||
std::string io_tensors_path =
|
||||
(std::filesystem::path(
|
||||
STRINGIZE(CMAKE_CURRENT_BINARY_DIR)) / "io_tensors.pt")
|
||||
std::string data_path =
|
||||
(std::filesystem::path(STRINGIZE(CMAKE_CURRENT_BINARY_DIR)) / "data.pt")
|
||||
.string();
|
||||
torch::jit::script::Module tensor_loader = torch::jit::load(io_tensors_path);
|
||||
auto input_tensors = tensor_loader.attr("inputs").toTensorList().vec();
|
||||
auto ref_output_tensors = tensor_loader.attr("outputs").toTensorList().vec();
|
||||
torch::jit::script::Module data_loader = torch::jit::load(data_path);
|
||||
const auto& model_so_path = data_loader.attr("model_so_path").toStringRef();
|
||||
const auto& input_tensors = data_loader.attr("inputs").toTensorList().vec();
|
||||
const auto& ref_output_tensors =
|
||||
data_loader.attr("outputs").toTensorList().vec();
|
||||
|
||||
AOTInductorModelContainerHandle container_handle;
|
||||
AOTI_RUNTIME_ERROR_CODE_CHECK(AOTInductorModelContainerCreate(
|
||||
&container_handle,
|
||||
1 /*num_models*/,
|
||||
false /*is_cpu*/,
|
||||
nullptr /*cubin_dir*/));
|
||||
|
||||
auto input_handles =
|
||||
torch::aot_inductor::unsafe_alloc_new_handles_from_tensors(input_tensors);
|
||||
|
||||
// For outputs, we only allocate a vector to hold returned tensor handles,
|
||||
// not allocating the actual output tensor storage here
|
||||
size_t num_outputs;
|
||||
AOTI_RUNTIME_ERROR_CODE_CHECK(
|
||||
AOTInductorModelContainerGetNumOutputs(container_handle, &num_outputs));
|
||||
std::vector<AtenTensorHandle> output_handles(num_outputs);
|
||||
|
||||
const auto& cuda_stream = at::cuda::getCurrentCUDAStream(0 /*device_index*/);
|
||||
const auto stream_id = cuda_stream.stream();
|
||||
AOTInductorStreamHandle stream_handle =
|
||||
reinterpret_cast<AOTInductorStreamHandle>(stream_id);
|
||||
|
||||
AOTIProxyExecutorHandle proxy_executor_handle = nullptr;
|
||||
|
||||
AOTI_RUNTIME_ERROR_CODE_CHECK(AOTInductorModelContainerRun(
|
||||
container_handle,
|
||||
input_handles.data(),
|
||||
input_tensors.size(),
|
||||
output_handles.data(),
|
||||
output_handles.size(),
|
||||
stream_handle,
|
||||
proxy_executor_handle));
|
||||
|
||||
auto output_tensors =
|
||||
torch::aot_inductor::alloc_tensors_by_stealing_from_handles(
|
||||
output_handles.data(), output_handles.size());
|
||||
|
||||
ASSERT_TRUE(torch::allclose(ref_output_tensors[0], output_tensors[0]));
|
||||
AOTI_RUNTIME_ERROR_CODE_CHECK(
|
||||
AOTInductorModelContainerDelete(container_handle));
|
||||
AOTIModelRunnerCuda runner(model_so_path.c_str());
|
||||
auto actual_output_tensors = runner.run(input_tensors);
|
||||
ASSERT_TRUE(torch::allclose(ref_output_tensors[0], actual_output_tensors[0]));
|
||||
}
|
||||
|
||||
} // namespace aot_inductor
|
||||
} // namespace inductor
|
||||
} // namespace torch
|
||||
|
@ -1,4 +1,3 @@
|
||||
import shutil
|
||||
|
||||
import torch
|
||||
from torch._export import aot_compile, dynamic_dim
|
||||
@ -26,21 +25,19 @@ with torch.no_grad():
|
||||
dynamic_dim(x, 0) <= 1024,
|
||||
dynamic_dim(x, 0) == dynamic_dim(y, 0),
|
||||
]
|
||||
lib_path, module = aot_compile(model, (x, y), constraints=constraints)
|
||||
|
||||
shutil.copy(lib_path, "libmodel.so")
|
||||
|
||||
model_so_path, _ = aot_compile(model, (x, y), constraints=constraints)
|
||||
|
||||
# Use this to communicate tensors to the cpp code
|
||||
class Serializer(torch.nn.Module):
|
||||
def __init__(self, tensors):
|
||||
def __init__(self, data):
|
||||
super().__init__()
|
||||
for key in tensors:
|
||||
setattr(self, key, tensors[key])
|
||||
for key in data:
|
||||
setattr(self, key, data[key])
|
||||
|
||||
io_tensors = {
|
||||
data = {
|
||||
"model_so_path": model_so_path,
|
||||
"inputs": [x, y],
|
||||
"outputs": [ref_output],
|
||||
}
|
||||
|
||||
torch.jit.script(Serializer(io_tensors)).save("io_tensors.pt")
|
||||
torch.jit.script(Serializer(data)).save("data.pt")
|
||||
|
63
torch/csrc/inductor/aoti_model_runner.cpp
Normal file
63
torch/csrc/inductor/aoti_model_runner.cpp
Normal file
@ -0,0 +1,63 @@
|
||||
#if !defined(C10_MOBILE) && !defined(ANDROID)
|
||||
#include <ATen/DynamicLibrary.h>
|
||||
|
||||
#include <torch/csrc/inductor/aoti_model_runner.h>
|
||||
#include <torch/csrc/inductor/aoti_torch/tensor_converter.h>
|
||||
|
||||
namespace torch::inductor {
|
||||
|
||||
AOTIModelRunner::AOTIModelRunner(
|
||||
const char* model_path,
|
||||
size_t num_models,
|
||||
bool is_cpu,
|
||||
const char* cubin_dir) {
|
||||
model_so_ = std::make_unique<at::DynamicLibrary>(model_path);
|
||||
TORCH_CHECK(model_so_, "Failed to load model: ", model_path);
|
||||
create_func_ = reinterpret_cast<decltype(create_func_)>(
|
||||
model_so_->sym("AOTInductorModelContainerCreate"));
|
||||
delete_func_ = reinterpret_cast<decltype(delete_func_)>(
|
||||
model_so_->sym("AOTInductorModelContainerDelete"));
|
||||
get_num_outputs_func_ = reinterpret_cast<decltype(get_num_outputs_func_)>(
|
||||
model_so_->sym("AOTInductorModelContainerGetNumOutputs"));
|
||||
run_func_ = reinterpret_cast<decltype(run_func_)>(
|
||||
model_so_->sym("AOTInductorModelContainerRun"));
|
||||
|
||||
AOTI_RUNTIME_ERROR_CODE_CHECK(
|
||||
create_func_(&container_handle_, num_models, is_cpu, cubin_dir));
|
||||
}
|
||||
|
||||
AOTIModelRunner::~AOTIModelRunner() {
|
||||
AOTIRuntimeError result = delete_func_(container_handle_);
|
||||
TORCH_CHECK(
|
||||
result == AOTI_RUNTIME_SUCCESS, "AOTInductorModelContainerDelete failed");
|
||||
}
|
||||
|
||||
std::vector<at::Tensor> AOTIModelRunner::run(
|
||||
std::vector<at::Tensor> inputs,
|
||||
AOTInductorStreamHandle cuda_stream_handle,
|
||||
AOTIProxyExecutorHandle proxy_executor_handle) {
|
||||
auto input_handles =
|
||||
torch::aot_inductor::unsafe_alloc_new_handles_from_tensors(inputs);
|
||||
|
||||
// For outputs, we only allocate a vector to hold returned tensor handles,
|
||||
// not allocating the actual output tensor storage here
|
||||
size_t num_outputs = 0;
|
||||
AOTI_RUNTIME_ERROR_CODE_CHECK(
|
||||
get_num_outputs_func_(container_handle_, &num_outputs));
|
||||
std::vector<AtenTensorHandle> output_handles(num_outputs);
|
||||
|
||||
AOTI_RUNTIME_ERROR_CODE_CHECK(run_func_(
|
||||
container_handle_,
|
||||
input_handles.data(),
|
||||
input_handles.size(),
|
||||
output_handles.data(),
|
||||
output_handles.size(),
|
||||
cuda_stream_handle,
|
||||
proxy_executor_handle));
|
||||
|
||||
return torch::aot_inductor::alloc_tensors_by_stealing_from_handles(
|
||||
output_handles.data(), output_handles.size());
|
||||
}
|
||||
|
||||
} // namespace torch::inductor
|
||||
#endif
|
58
torch/csrc/inductor/aoti_model_runner.h
Normal file
58
torch/csrc/inductor/aoti_model_runner.h
Normal file
@ -0,0 +1,58 @@
|
||||
#if !defined(C10_MOBILE) && !defined(ANDROID)
|
||||
#pragma once
|
||||
|
||||
#include <ATen/Tensor.h>
|
||||
#include <torch/csrc/inductor/aoti_runtime/interface.h>
|
||||
|
||||
// Forward declare DynamicLibrary
|
||||
namespace at {
|
||||
struct DynamicLibrary;
|
||||
}
|
||||
|
||||
namespace torch::inductor {
|
||||
|
||||
class TORCH_API AOTIModelRunner {
|
||||
public:
|
||||
AOTIModelRunner() = delete;
|
||||
AOTIModelRunner(const AOTIModelRunner& other) = delete;
|
||||
AOTIModelRunner(AOTIModelRunner&& other) = delete;
|
||||
AOTIModelRunner& operator=(const AOTIModelRunner& other) = delete;
|
||||
AOTIModelRunner& operator=(AOTIModelRunner&& other) = delete;
|
||||
|
||||
protected:
|
||||
std::vector<at::Tensor> run(
|
||||
std::vector<at::Tensor> inputs,
|
||||
AOTInductorStreamHandle cuda_stream_handle,
|
||||
AOTIProxyExecutorHandle proxy_executor_handle);
|
||||
|
||||
AOTIModelRunner(
|
||||
const char* model_path,
|
||||
size_t num_models,
|
||||
bool is_cpu,
|
||||
const char* cubin_dir);
|
||||
|
||||
~AOTIModelRunner();
|
||||
|
||||
std::unique_ptr<at::DynamicLibrary> model_so_;
|
||||
decltype(&AOTInductorModelContainerCreate) create_func_{nullptr};
|
||||
decltype(&AOTInductorModelContainerDelete) delete_func_{nullptr};
|
||||
decltype(&AOTInductorModelContainerGetNumOutputs) get_num_outputs_func_{
|
||||
nullptr};
|
||||
decltype(&AOTInductorModelContainerRun) run_func_{nullptr};
|
||||
AOTInductorModelContainerHandle container_handle_ = nullptr;
|
||||
};
|
||||
|
||||
class TORCH_API AOTIModelRunnerCpu : public AOTIModelRunner {
|
||||
public:
|
||||
AOTIModelRunnerCpu(const char* model_path, size_t num_models = 1)
|
||||
: AOTIModelRunner(model_path, num_models, true, nullptr) {}
|
||||
|
||||
std::vector<at::Tensor> run(
|
||||
std::vector<at::Tensor> inputs,
|
||||
AOTIProxyExecutorHandle proxy_executor_handle = nullptr) {
|
||||
return AOTIModelRunner::run(inputs, nullptr, proxy_executor_handle);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace torch::inductor
|
||||
#endif
|
18
torch/csrc/inductor/aoti_model_runner_cuda.cpp
Normal file
18
torch/csrc/inductor/aoti_model_runner_cuda.cpp
Normal file
@ -0,0 +1,18 @@
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <torch/csrc/inductor/aoti_model_runner_cuda.h>
|
||||
|
||||
namespace torch::inductor {
|
||||
|
||||
std::vector<at::Tensor> AOTIModelRunnerCuda::run(
|
||||
std::vector<at::Tensor> inputs,
|
||||
AOTInductorStreamHandle cuda_stream_handle,
|
||||
AOTIProxyExecutorHandle proxy_executor_handle) {
|
||||
if (cuda_stream_handle == nullptr) {
|
||||
cudaStream_t stream_id = c10::cuda::getCurrentCUDAStream().stream();
|
||||
cuda_stream_handle = reinterpret_cast<AOTInductorStreamHandle>(stream_id);
|
||||
}
|
||||
return AOTIModelRunner::run(
|
||||
inputs, cuda_stream_handle, proxy_executor_handle);
|
||||
}
|
||||
|
||||
} // namespace torch::inductor
|
21
torch/csrc/inductor/aoti_model_runner_cuda.h
Normal file
21
torch/csrc/inductor/aoti_model_runner_cuda.h
Normal file
@ -0,0 +1,21 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/inductor/aoti_model_runner.h>
|
||||
|
||||
namespace torch::inductor {
|
||||
|
||||
class TORCH_API AOTIModelRunnerCuda : public AOTIModelRunner {
|
||||
public:
|
||||
AOTIModelRunnerCuda(
|
||||
const char* model_path,
|
||||
size_t num_models = 1,
|
||||
const char* cubin_dir = nullptr)
|
||||
: AOTIModelRunner(model_path, num_models, false, cubin_dir) {}
|
||||
|
||||
std::vector<at::Tensor> run(
|
||||
std::vector<at::Tensor> inputs,
|
||||
AOTInductorStreamHandle cuda_stream_handle = nullptr,
|
||||
AOTIProxyExecutorHandle proxy_executor_handle = nullptr);
|
||||
};
|
||||
|
||||
} // namespace torch::inductor
|
Reference in New Issue
Block a user