Add NNC AOT Compiler executable (#63994)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63994

Test Plan: Imported from OSS

Reviewed By: bertmaher

Differential Revision: D30582149

Pulled By: priyaramani

fbshipit-source-id: 3bbf085428824c3cb308e006c18bb0a57f50fef6
This commit is contained in:
Priya Ramani
2021-09-15 19:12:47 -07:00
committed by Facebook GitHub Bot
parent e0ecd09011
commit 206646d6ed
6 changed files with 315 additions and 0 deletions

View File

@ -108,3 +108,6 @@ caffe2_binary_target("tutorial_blob.cc")
caffe2_binary_target("dump_operator_names.cc")
caffe2_binary_target("optimize_for_mobile.cc")
caffe2_binary_target(aot_model_compiler "aot_model_compiler.cc")
target_link_libraries(aot_model_compiler aot_compiler)

View File

@ -0,0 +1,170 @@
#include <sstream>
#include <string>
#include <torch/csrc/jit/backends/backend.h>
#include <torch/csrc/jit/backends/backend_detail.h>
#include <torch/csrc/jit/backends/backend_preprocess.h>
#include <torch/csrc/jit/mobile/nnc/aot_compiler.h>
#include <torch/csrc/jit/passes/freeze_module.h>
#include <torch/csrc/jit/passes/frozen_graph_optimizations.h>
#include <torch/csrc/jit/serialization/export.h>
#include <torch/csrc/jit/serialization/import.h>
#include <torch/script.h>
C10_DEFINE_string(model, "", "The torch script model to optimize.");
C10_DEFINE_string(model_name, "", "The name of the model.");
C10_DEFINE_string(model_version, "", "The version of the model.");
C10_DEFINE_string(
input_dims,
"",
"For input float TensorCPUs, specify the dimension using comma "
"separated numbers. If multiple inputs needed, use semicolon "
"to separate the dimension of different tensors.");
C10_DEFINE_string(
output_llvm,
"",
"Name of the output llvm assembly to be saved.");
C10_DEFINE_string(output_model, "", "Name of the output model to be saved.");
namespace {
std::vector<std::string> split(
char separator,
const std::string& string,
bool ignore_empty = true) {
std::vector<std::string> pieces;
std::stringstream ss(string);
std::string item;
while (getline(ss, item, separator)) {
if (!ignore_empty || !item.empty()) {
pieces.push_back(std::move(item));
}
}
return pieces;
}
std::vector<std::vector<int64_t>> parseInputShapes() {
CAFFE_ENFORCE_GE(FLAGS_input_dims.size(), 0, "Input dims must be specified.");
std::vector<std::string> input_dims_list = split(';', FLAGS_input_dims);
std::vector<std::vector<int64_t>> inputs;
for (const auto& input_dims_item : input_dims_list) {
auto input_dims_str = split(',', input_dims_item);
std::vector<int64_t> input_dims;
input_dims.reserve(input_dims_str.size());
for (const auto& s : input_dims_str) {
input_dims.push_back(c10::stoi(s));
}
inputs.push_back(input_dims);
}
return inputs;
}
c10::Dict<c10::IValue, c10::IValue> createCompileSpec() {
c10::Dict<c10::IValue, c10::IValue> compile_spec(
c10::StringType::get(), c10::AnyType::get());
c10::Dict<c10::IValue, c10::IValue> method_spec(
c10::StringType::get(), c10::AnyType::get());
auto input_shapes = parseInputShapes();
TORCH_CHECK(
input_shapes.size() == 1,
"Wrong # of input shapes: ",
input_shapes.size());
method_spec.insert("sizes", input_shapes[0]); // TODO: support multiple inputs
compile_spec.insert("forward", method_spec);
return compile_spec;
}
std::vector<int64_t> getInputSizesForMethod(
const c10::Dict<c10::IValue, c10::IValue>& method_compile_spec,
const std::string& method_name) {
return method_compile_spec.at(method_name)
.toGenericDict()
.at("sizes")
.toIntVector();
}
std::string getNncKernelId(const std::string& method_name) {
// TODO: calculate the version_token.
const std::string version_token = "VERTOKEN";
return FLAGS_model_name + ":" + FLAGS_model_version + ":" + method_name +
":" + version_token;
}
void writeOutputLlvmAssembly(const std::string& asm_code) {
std::string output_llvm_file_name = FLAGS_output_llvm;
if (output_llvm_file_name.empty()) {
output_llvm_file_name =
FLAGS_model.substr(0, FLAGS_model.find('.')) + ".compiled.ll";
}
std::ofstream output(output_llvm_file_name);
output << asm_code;
}
c10::IValue preprocess(
const torch::jit::Module& mod,
const c10::Dict<c10::IValue, c10::IValue>& method_compile_spec,
const torch::jit::BackendDebugHandleGenerator& generate_debug_handles) {
const std::string& method_name = "forward";
auto method = mod.get_method(method_name);
auto graph = method.function().graph()->copy();
auto sizes = getInputSizesForMethod(method_compile_spec, method_name);
std::string llvm_asm_code;
auto func =
torch::jit::mobile::nnc::aotCompile(method_name, graph, sizes, &llvm_asm_code);
writeOutputLlvmAssembly(llvm_asm_code);
func->set_nnc_kernel_id(getNncKernelId(method_name));
torch::jit::mobile::nnc::CompilationUnit cu;
cu.register_function(std::move(func));
return cu.serialize();
}
static auto reg = torch::jit::backend_preprocess_register("nnc", preprocess);
} // namespace
int main(int argc, char** argv) {
c10::SetUsageMessage(
"Run NNC AOT compiler for pytorch model. Example usage:\n"
"build/bin/aot_model_compiler"
" --model=<model file>"
" --model_name=<model name>"
" --model_version=<model version>"
" --input_dims='1,3,224,224'"
" [--output_llvm=<llvm assembly output file path>]"
" [--output_model=<output model file path>]");
if (!c10::ParseCommandLineFlags(&argc, &argv)) {
std::cerr << "Failed to parse command line flags!" << std::endl;
std::cout << c10::UsageMessage() << std::endl;
return 1;
}
CAFFE_ENFORCE(!FLAGS_model.empty(), c10::UsageMessage());
std::string output_model_name = FLAGS_output_model;
if (output_model_name.empty()) {
output_model_name =
FLAGS_model.substr(0, FLAGS_model.find('.')) + ".compiled.pt";
}
auto m = torch::jit::load(FLAGS_model);
m.eval();
auto frozen_m = torch::jit::freeze_module(m.clone());
auto graph = frozen_m.get_method("forward").graph();
torch::jit::OptimizeFrozenGraph(graph, true);
auto compile_spec = createCompileSpec();
auto any_dict_ty =
c10::DictType::create(c10::StringType::get(), c10::AnyType::get());
auto compiled_module = torch::jit::detail::codegen_backend_module(
"nnc", frozen_m, compile_spec, any_dict_ty);
compiled_module._save_for_mobile(output_model_name);
std::cout << "The compiled model was saved to " << output_model_name
<< std::endl;
return 0;
}

View File

@ -183,6 +183,7 @@ core_sources_full_mobile_no_backend_interface = [
"torch/csrc/jit/ir/subgraph_matcher.cpp",
"torch/csrc/jit/jit_log.cpp",
"torch/csrc/jit/jit_opt_limit.cpp",
"torch/csrc/jit/mobile/nnc/aot_compiler.cpp",
"torch/csrc/jit/mobile/nnc/backend.cpp",
"torch/csrc/jit/mobile/nnc/context.cpp",
"torch/csrc/jit/mobile/nnc/registry.cpp",

View File

@ -423,3 +423,9 @@ if(NOT ${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
# Pybind11 requires explicit linking of the torch_python library
target_link_libraries(nnapi_backend torch torch_python)
endif()
if(BUILD_BINARY)
add_library(aot_compiler SHARED
${TORCH_SRC_DIR}/csrc/jit/mobile/nnc/aot_compiler.cpp
)
endif()

View File

@ -0,0 +1,112 @@
#include <torch/csrc/jit/mobile/nnc/aot_compiler.h>
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/passes/constant_propagation.h>
#include <torch/csrc/jit/passes/dead_code_elimination.h>
#include <torch/csrc/jit/passes/peephole.h>
#include <torch/csrc/jit/passes/remove_mutation.h>
#include <torch/csrc/jit/passes/shape_analysis.h>
#include <torch/csrc/jit/passes/symbolic_shape_analysis.h>
#include <torch/csrc/jit/tensorexpr/graph_opt.h>
#include <torch/csrc/jit/tensorexpr/ir.h>
#include <torch/csrc/jit/tensorexpr/kernel.h>
using namespace torch::jit;
using namespace torch::jit::tensorexpr;
namespace torch {
namespace jit {
namespace mobile {
namespace nnc {
std::vector<int64_t> getConstSizes(const BufPtr b) {
std::vector<int64_t> r;
for (const auto& dim : b->dims()) {
LongImmPtr imm_dim = to<LongImm>(dim);
// TODO: assert it's actually immediate
int64_t s = imm_dim->value();
r.push_back(s);
}
return r;
}
void getCompiledFunction(
std::shared_ptr<tensorexpr::TensorExprKernel> kernel,
Function* func) {
std::vector<at::Tensor> parameters;
auto const_descriptors = kernel->getConstantDescriptors();
for (const auto& cd : const_descriptors) {
auto sizes = getConstSizes(cd.buf);
at::Tensor const_tensor = at::from_blob(cd.ptr, sizes).clone();
parameters.push_back(const_tensor);
}
func->set_parameters(c10::impl::toList(c10::List<at::Tensor>(parameters)));
MemoryPlan plan;
plan.buffer_sizes_ = {}; // temp_sizes_;
// TODO: implement prealloc optimization and fill in temp_sizes
func->set_memory_plan(plan);
int64_t n_inputs = kernel->graph()->inputs().size();
int64_t n_outputs = kernel->graph()->outputs().size();
std::vector<OutputSpec> out_spec;
for (int64_t idx = n_inputs; idx < n_inputs + n_outputs; idx++) {
const auto& ba = kernel->getBufferArgs()[idx];
OutputSpec output;
output.sizes_ = getConstSizes(ba.buf());
// TODO: assert the output is a buffer and not a scalar
// TODO: use actual dtype
output.dtype_ = c10::ScalarType::Float;
out_spec.push_back(output);
}
func->set_output_specs(out_spec);
}
std::unique_ptr<Function> aotCompile(
const std::string& method_name,
std::shared_ptr<Graph>& g,
const std::vector<int64_t>& sizes,
std::string* compiled_assembly) {
auto g2 = g->copy();
GRAPH_DEBUG("Input sizes ", sizes);
RemoveTensorMutation(g);
EliminateDeadCode(g->block());
g = tensorexpr::removeUnusedSelfArgument(g);
GRAPH_DUMP("graph before shape propagation ", g);
std::vector<c10::optional<at::Tensor>> example_inputs = {at::rand(sizes)};
tensorexpr::annotateInputShapes(g, example_inputs);
PropagateShapesOnGraph(g);
PeepholeOptimize(g, false);
ConstantPropagation(g);
PropagateShapesOnGraph(g);
GRAPH_DUMP("graph after shape propagation ", g);
std::shared_ptr<tensorexpr::TensorExprKernel> kernel =
std::make_shared<tensorexpr::TensorExprKernel>(g);
*compiled_assembly = kernel->getCodeText();
g = g2;
auto func = std::make_unique<Function>();
func->set_name(method_name);
InputSpec input;
input.sizes_ = sizes;
input.dtype_ = c10::ScalarType::Float;
func->set_input_specs({input});
getCompiledFunction(kernel, func.get());
return func;
}
} // namespace nnc
} // namespace mobile
} // namespace jit
} // namespace torch

View File

@ -0,0 +1,23 @@
#pragma once
#include <torch/csrc/WindowsTorchApiMacro.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/mobile/nnc/context.h>
namespace torch {
namespace jit {
namespace mobile {
namespace nnc {
// Performs Ahead Of Time compilation of a given method in a model
// returning the compiled function and LLVM assembly code
TORCH_API std::unique_ptr<Function> aotCompile(
const std::string& method_name,
std::shared_ptr<Graph>& subgraph,
const std::vector<int64_t>& sizes,
std::string* compiled_assembly);
} // namespace nnc
} // namespace mobile
} // namespace jit
} // namespace torch