[aotinductor] Add AOTIModelRunner as a utility class (#110891)

Summary: Introduce a utility class AOTIModelRunner to take care of running an AOTInductor compiled model. It does things like dlopen a model, initialize the model container, setup inputs and outputs, and destroy the model container.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110891
Approved by: https://github.com/chenyang78
ghstack dependencies: #110652
This commit is contained in:
Bin Bao
2023-10-10 18:40:36 -07:00
committed by PyTorch MergeBot
parent b17c247eb1
commit 3058700f7f
8 changed files with 185 additions and 66 deletions

View File

@ -501,6 +501,7 @@ lazy_tensor_core_python_sources = [
]
inductor_core_resources = [
"torch/csrc/inductor/aoti_model_runner.cpp",
"torch/csrc/inductor/aoti_torch/shim_common.cpp",
"torch/csrc/inductor/aoti_torch/tensor_converter.cpp",
"torch/csrc/inductor/inductor_ops.cpp",
@ -686,6 +687,7 @@ libtorch_cuda_core_sources = [
"torch/csrc/CudaIPCTypes.cpp",
"torch/csrc/cuda/comm.cpp",
"torch/csrc/cuda/memory_snapshot.cpp",
"torch/csrc/inductor/aoti_model_runner_cuda.cpp",
"torch/csrc/inductor/aoti_torch/shim_cuda.cpp",
"torch/csrc/jit/codegen/fuser/cuda/fused_kernel.cpp",
"torch/csrc/profiler/stubs/cuda.cpp",

View File

@ -9,6 +9,7 @@ set(INDUCTOR_TEST_SRCS
add_executable(test_aot_inductor
${TORCH_ROOT}/test/cpp/common/main.cpp
${INDUCTOR_TEST_SRCS}
data.pt
)
# TODO temporary until we can delete the old gtest polyfills.
@ -16,22 +17,18 @@ target_compile_definitions(test_aot_inductor PRIVATE USE_GTEST)
# Define a custom command to generate the library
add_custom_command(
OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/libmodel.so
OUTPUT data.pt
COMMAND python ${AOT_INDUCTOR_TEST_ROOT}/test.py
DEPENDS ${AOT_INDUCTOR_TEST_ROOT}/test.py
)
add_custom_target(model_so ALL DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/libmodel.so)
add_dependencies(test_aot_inductor model_so)
target_link_libraries(test_aot_inductor PRIVATE
torch
gtest
${CMAKE_CURRENT_BINARY_DIR}/libmodel.so
)
if(USE_CUDA)
target_include_directories(test_aot_inductor PRIVATE ${ATen_CUDA_INCLUDE})
target_compile_definitions(test_aot_inductor PRIVATE USE_CUDA CMAKE_CURRENT_BINARY_DIR=${CMAKE_CURRENT_BINARY_DIR})
endif()

View File

@ -5,9 +5,8 @@
#include <string>
#include <vector>
#include <c10/cuda/CUDAStream.h>
#include <torch/csrc/inductor/aoti_runtime/interface.h>
#include <torch/csrc/inductor/aoti_torch/tensor_converter.h>
#include <torch/csrc/inductor/aoti_model_runner.h>
#include <torch/csrc/inductor/aoti_model_runner_cuda.h>
#include <torch/script.h>
#include <torch/torch.h>
@ -15,60 +14,24 @@
#define STRINGIZE(x) STR_VALUE(x)
namespace torch {
namespace aot_inductor {
namespace inductor {
TEST(AotInductorTest, BasicTest) {
torch::NoGradGuard no_grad;
std::string io_tensors_path =
(std::filesystem::path(
STRINGIZE(CMAKE_CURRENT_BINARY_DIR)) / "io_tensors.pt")
std::string data_path =
(std::filesystem::path(STRINGIZE(CMAKE_CURRENT_BINARY_DIR)) / "data.pt")
.string();
torch::jit::script::Module tensor_loader = torch::jit::load(io_tensors_path);
auto input_tensors = tensor_loader.attr("inputs").toTensorList().vec();
auto ref_output_tensors = tensor_loader.attr("outputs").toTensorList().vec();
torch::jit::script::Module data_loader = torch::jit::load(data_path);
const auto& model_so_path = data_loader.attr("model_so_path").toStringRef();
const auto& input_tensors = data_loader.attr("inputs").toTensorList().vec();
const auto& ref_output_tensors =
data_loader.attr("outputs").toTensorList().vec();
AOTInductorModelContainerHandle container_handle;
AOTI_RUNTIME_ERROR_CODE_CHECK(AOTInductorModelContainerCreate(
&container_handle,
1 /*num_models*/,
false /*is_cpu*/,
nullptr /*cubin_dir*/));
auto input_handles =
torch::aot_inductor::unsafe_alloc_new_handles_from_tensors(input_tensors);
// For outputs, we only allocate a vector to hold returned tensor handles,
// not allocating the actual output tensor storage here
size_t num_outputs;
AOTI_RUNTIME_ERROR_CODE_CHECK(
AOTInductorModelContainerGetNumOutputs(container_handle, &num_outputs));
std::vector<AtenTensorHandle> output_handles(num_outputs);
const auto& cuda_stream = at::cuda::getCurrentCUDAStream(0 /*device_index*/);
const auto stream_id = cuda_stream.stream();
AOTInductorStreamHandle stream_handle =
reinterpret_cast<AOTInductorStreamHandle>(stream_id);
AOTIProxyExecutorHandle proxy_executor_handle = nullptr;
AOTI_RUNTIME_ERROR_CODE_CHECK(AOTInductorModelContainerRun(
container_handle,
input_handles.data(),
input_tensors.size(),
output_handles.data(),
output_handles.size(),
stream_handle,
proxy_executor_handle));
auto output_tensors =
torch::aot_inductor::alloc_tensors_by_stealing_from_handles(
output_handles.data(), output_handles.size());
ASSERT_TRUE(torch::allclose(ref_output_tensors[0], output_tensors[0]));
AOTI_RUNTIME_ERROR_CODE_CHECK(
AOTInductorModelContainerDelete(container_handle));
AOTIModelRunnerCuda runner(model_so_path.c_str());
auto actual_output_tensors = runner.run(input_tensors);
ASSERT_TRUE(torch::allclose(ref_output_tensors[0], actual_output_tensors[0]));
}
} // namespace aot_inductor
} // namespace inductor
} // namespace torch

View File

@ -1,4 +1,3 @@
import shutil
import torch
from torch._export import aot_compile, dynamic_dim
@ -26,21 +25,19 @@ with torch.no_grad():
dynamic_dim(x, 0) <= 1024,
dynamic_dim(x, 0) == dynamic_dim(y, 0),
]
lib_path, module = aot_compile(model, (x, y), constraints=constraints)
shutil.copy(lib_path, "libmodel.so")
model_so_path, _ = aot_compile(model, (x, y), constraints=constraints)
# Use this to communicate tensors to the cpp code
class Serializer(torch.nn.Module):
def __init__(self, tensors):
def __init__(self, data):
super().__init__()
for key in tensors:
setattr(self, key, tensors[key])
for key in data:
setattr(self, key, data[key])
io_tensors = {
data = {
"model_so_path": model_so_path,
"inputs": [x, y],
"outputs": [ref_output],
}
torch.jit.script(Serializer(io_tensors)).save("io_tensors.pt")
torch.jit.script(Serializer(data)).save("data.pt")

View File

@ -0,0 +1,63 @@
#if !defined(C10_MOBILE) && !defined(ANDROID)
#include <ATen/DynamicLibrary.h>
#include <torch/csrc/inductor/aoti_model_runner.h>
#include <torch/csrc/inductor/aoti_torch/tensor_converter.h>
namespace torch::inductor {
AOTIModelRunner::AOTIModelRunner(
const char* model_path,
size_t num_models,
bool is_cpu,
const char* cubin_dir) {
model_so_ = std::make_unique<at::DynamicLibrary>(model_path);
TORCH_CHECK(model_so_, "Failed to load model: ", model_path);
create_func_ = reinterpret_cast<decltype(create_func_)>(
model_so_->sym("AOTInductorModelContainerCreate"));
delete_func_ = reinterpret_cast<decltype(delete_func_)>(
model_so_->sym("AOTInductorModelContainerDelete"));
get_num_outputs_func_ = reinterpret_cast<decltype(get_num_outputs_func_)>(
model_so_->sym("AOTInductorModelContainerGetNumOutputs"));
run_func_ = reinterpret_cast<decltype(run_func_)>(
model_so_->sym("AOTInductorModelContainerRun"));
AOTI_RUNTIME_ERROR_CODE_CHECK(
create_func_(&container_handle_, num_models, is_cpu, cubin_dir));
}
AOTIModelRunner::~AOTIModelRunner() {
AOTIRuntimeError result = delete_func_(container_handle_);
TORCH_CHECK(
result == AOTI_RUNTIME_SUCCESS, "AOTInductorModelContainerDelete failed");
}
std::vector<at::Tensor> AOTIModelRunner::run(
std::vector<at::Tensor> inputs,
AOTInductorStreamHandle cuda_stream_handle,
AOTIProxyExecutorHandle proxy_executor_handle) {
auto input_handles =
torch::aot_inductor::unsafe_alloc_new_handles_from_tensors(inputs);
// For outputs, we only allocate a vector to hold returned tensor handles,
// not allocating the actual output tensor storage here
size_t num_outputs = 0;
AOTI_RUNTIME_ERROR_CODE_CHECK(
get_num_outputs_func_(container_handle_, &num_outputs));
std::vector<AtenTensorHandle> output_handles(num_outputs);
AOTI_RUNTIME_ERROR_CODE_CHECK(run_func_(
container_handle_,
input_handles.data(),
input_handles.size(),
output_handles.data(),
output_handles.size(),
cuda_stream_handle,
proxy_executor_handle));
return torch::aot_inductor::alloc_tensors_by_stealing_from_handles(
output_handles.data(), output_handles.size());
}
} // namespace torch::inductor
#endif

View File

@ -0,0 +1,58 @@
#if !defined(C10_MOBILE) && !defined(ANDROID)
#pragma once
#include <ATen/Tensor.h>
#include <torch/csrc/inductor/aoti_runtime/interface.h>
// Forward declare DynamicLibrary
namespace at {
struct DynamicLibrary;
}
namespace torch::inductor {
class TORCH_API AOTIModelRunner {
public:
AOTIModelRunner() = delete;
AOTIModelRunner(const AOTIModelRunner& other) = delete;
AOTIModelRunner(AOTIModelRunner&& other) = delete;
AOTIModelRunner& operator=(const AOTIModelRunner& other) = delete;
AOTIModelRunner& operator=(AOTIModelRunner&& other) = delete;
protected:
std::vector<at::Tensor> run(
std::vector<at::Tensor> inputs,
AOTInductorStreamHandle cuda_stream_handle,
AOTIProxyExecutorHandle proxy_executor_handle);
AOTIModelRunner(
const char* model_path,
size_t num_models,
bool is_cpu,
const char* cubin_dir);
~AOTIModelRunner();
std::unique_ptr<at::DynamicLibrary> model_so_;
decltype(&AOTInductorModelContainerCreate) create_func_{nullptr};
decltype(&AOTInductorModelContainerDelete) delete_func_{nullptr};
decltype(&AOTInductorModelContainerGetNumOutputs) get_num_outputs_func_{
nullptr};
decltype(&AOTInductorModelContainerRun) run_func_{nullptr};
AOTInductorModelContainerHandle container_handle_ = nullptr;
};
class TORCH_API AOTIModelRunnerCpu : public AOTIModelRunner {
public:
AOTIModelRunnerCpu(const char* model_path, size_t num_models = 1)
: AOTIModelRunner(model_path, num_models, true, nullptr) {}
std::vector<at::Tensor> run(
std::vector<at::Tensor> inputs,
AOTIProxyExecutorHandle proxy_executor_handle = nullptr) {
return AOTIModelRunner::run(inputs, nullptr, proxy_executor_handle);
}
};
} // namespace torch::inductor
#endif

View File

@ -0,0 +1,18 @@
#include <c10/cuda/CUDAStream.h>
#include <torch/csrc/inductor/aoti_model_runner_cuda.h>
namespace torch::inductor {
std::vector<at::Tensor> AOTIModelRunnerCuda::run(
std::vector<at::Tensor> inputs,
AOTInductorStreamHandle cuda_stream_handle,
AOTIProxyExecutorHandle proxy_executor_handle) {
if (cuda_stream_handle == nullptr) {
cudaStream_t stream_id = c10::cuda::getCurrentCUDAStream().stream();
cuda_stream_handle = reinterpret_cast<AOTInductorStreamHandle>(stream_id);
}
return AOTIModelRunner::run(
inputs, cuda_stream_handle, proxy_executor_handle);
}
} // namespace torch::inductor

View File

@ -0,0 +1,21 @@
#pragma once
#include <torch/csrc/inductor/aoti_model_runner.h>
namespace torch::inductor {
class TORCH_API AOTIModelRunnerCuda : public AOTIModelRunner {
public:
AOTIModelRunnerCuda(
const char* model_path,
size_t num_models = 1,
const char* cubin_dir = nullptr)
: AOTIModelRunner(model_path, num_models, false, cubin_dir) {}
std::vector<at::Tensor> run(
std::vector<at::Tensor> inputs,
AOTInductorStreamHandle cuda_stream_handle = nullptr,
AOTIProxyExecutorHandle proxy_executor_handle = nullptr);
};
} // namespace torch::inductor