mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: In AOTInductor generated CPU model code, there can be direct references to some aten/c10 utility functions and data structures, e.g. at::vec and c10::Half. These are performance critical and thus it doesn't make sense to create C shim for them. Instead, we make sure they are implemented in a header-only way, and use this set of tests to guard future changes. There are more header files to be updated, but we will do it in other followup PRs. Pull Request resolved: https://github.com/pytorch/pytorch/pull/123848 Approved by: https://github.com/jansel ghstack dependencies: #123847
49 lines
1.5 KiB
C++
49 lines
1.5 KiB
C++
#include <stdexcept>
|
|
|
|
#include <torch/csrc/inductor/aoti_runner/model_container_runner_cpu.h>
|
|
#ifdef USE_CUDA
|
|
#include <torch/csrc/inductor/aoti_runner/model_container_runner_cuda.h>
|
|
#endif
|
|
|
|
#include "aoti_custom_class.h"
|
|
|
|
namespace torch::aot_inductor {
|
|
|
|
static auto registerMyAOTIClass =
|
|
torch::class_<MyAOTIClass>("aoti", "MyAOTIClass")
|
|
.def(torch::init<std::string, std::string>())
|
|
.def("forward", &MyAOTIClass::forward)
|
|
.def_pickle(
|
|
[](const c10::intrusive_ptr<MyAOTIClass>& self)
|
|
-> std::vector<std::string> {
|
|
std::vector<std::string> v;
|
|
v.push_back(self->lib_path());
|
|
v.push_back(self->device());
|
|
return v;
|
|
},
|
|
[](std::vector<std::string> params) {
|
|
return c10::make_intrusive<MyAOTIClass>(params[0], params[1]);
|
|
});
|
|
|
|
MyAOTIClass::MyAOTIClass(
|
|
const std::string& model_path,
|
|
const std::string& device)
|
|
: lib_path_(model_path), device_(device) {
|
|
if (device_ == "cuda") {
|
|
runner_ = std::make_unique<torch::inductor::AOTIModelContainerRunnerCuda>(
|
|
model_path.c_str());
|
|
} else if (device_ == "cpu") {
|
|
runner_ = std::make_unique<torch::inductor::AOTIModelContainerRunnerCpu>(
|
|
model_path.c_str());
|
|
} else {
|
|
throw std::runtime_error("invalid device: " + device);
|
|
}
|
|
}
|
|
|
|
std::vector<torch::Tensor> MyAOTIClass::forward(
|
|
std::vector<torch::Tensor> inputs) {
|
|
return runner_->run(inputs);
|
|
}
|
|
|
|
} // namespace torch::aot_inductor
|