Files
pytorch/test/cpp/aoti_inference/aoti_custom_class.cpp
Bin Bao 4946638f06 [AOTI] Add ABI-compatiblity tests (#123848)
Summary: In AOTInductor generated CPU model code, there can be direct references to some aten/c10 utility functions and data structures, e.g. at::vec and c10::Half. These are performance critical and thus it doesn't make sense to create C shim for them. Instead, we make sure they are implemented in a header-only way, and use this set of tests to guard future changes.

There are more header files to be updated, but we will do it in other followup PRs.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123848
Approved by: https://github.com/jansel
ghstack dependencies: #123847
2024-04-19 00:51:24 +00:00

49 lines
1.5 KiB
C++

#include <stdexcept>
#include <torch/csrc/inductor/aoti_runner/model_container_runner_cpu.h>
#ifdef USE_CUDA
#include <torch/csrc/inductor/aoti_runner/model_container_runner_cuda.h>
#endif
#include "aoti_custom_class.h"
namespace torch::aot_inductor {
static auto registerMyAOTIClass =
torch::class_<MyAOTIClass>("aoti", "MyAOTIClass")
.def(torch::init<std::string, std::string>())
.def("forward", &MyAOTIClass::forward)
.def_pickle(
[](const c10::intrusive_ptr<MyAOTIClass>& self)
-> std::vector<std::string> {
std::vector<std::string> v;
v.push_back(self->lib_path());
v.push_back(self->device());
return v;
},
[](std::vector<std::string> params) {
return c10::make_intrusive<MyAOTIClass>(params[0], params[1]);
});
MyAOTIClass::MyAOTIClass(
const std::string& model_path,
const std::string& device)
: lib_path_(model_path), device_(device) {
if (device_ == "cuda") {
runner_ = std::make_unique<torch::inductor::AOTIModelContainerRunnerCuda>(
model_path.c_str());
} else if (device_ == "cpu") {
runner_ = std::make_unique<torch::inductor::AOTIModelContainerRunnerCpu>(
model_path.c_str());
} else {
throw std::runtime_error("invalid device: " + device);
}
}
std::vector<torch::Tensor> MyAOTIClass::forward(
std::vector<torch::Tensor> inputs) {
return runner_->run(inputs);
}
} // namespace torch::aot_inductor