mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[TensorExpr] Move AOT compilation logic from aot_compiler.cpp to NNC's to_backend (#70375)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/70375 Differential Revision: D33303645 D33303645 Test Plan: Imported from OSS Reviewed By: VitalyFedyunin, priyaramani Pulled By: ZolotukhinM fbshipit-source-id: 01ab9fab9bb0d63f89b06a146d3c5fb6ed7fe52d (cherry picked from commit aac8e0ed900d1b760606b0b50eb064e6b00f8b7a)
This commit is contained in:
committed by
PyTorch MergeBot
parent
64668e61b8
commit
a60e2ae037
@ -7,14 +7,7 @@
|
||||
#include <torch/csrc/jit/backends/backend_detail.h>
|
||||
#include <torch/csrc/jit/backends/backend_preprocess.h>
|
||||
#include <torch/csrc/jit/mobile/nnc/aot_compiler.h>
|
||||
#include <torch/csrc/jit/passes/constant_propagation.h>
|
||||
#include <torch/csrc/jit/passes/dead_code_elimination.h>
|
||||
#include <torch/csrc/jit/passes/freeze_module.h>
|
||||
#include <torch/csrc/jit/passes/frozen_graph_optimizations.h>
|
||||
#include <torch/csrc/jit/passes/peephole.h>
|
||||
#include <torch/csrc/jit/passes/remove_mutation.h>
|
||||
#include <torch/csrc/jit/passes/shape_analysis.h>
|
||||
#include <torch/csrc/jit/passes/symbolic_shape_analysis.h>
|
||||
#include <torch/csrc/jit/serialization/export.h>
|
||||
#include <torch/csrc/jit/serialization/import.h>
|
||||
#include <torch/csrc/jit/tensorexpr/graph_opt.h>
|
||||
@ -61,125 +54,20 @@ std::vector<std::string> split(
|
||||
return pieces;
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> parseInputShapes() {
|
||||
CAFFE_ENFORCE_GE(FLAGS_input_dims.size(), 0, "Input dims must be specified.");
|
||||
std::vector<std::string> input_dims_list = split(';', FLAGS_input_dims);
|
||||
std::vector<std::vector<int64_t>> inputs;
|
||||
for (const auto& input_dims_item : input_dims_list) {
|
||||
auto input_dims_str = split(',', input_dims_item);
|
||||
std::vector<int64_t> input_dims;
|
||||
input_dims.reserve(input_dims_str.size());
|
||||
for (const auto& s : input_dims_str) {
|
||||
input_dims.push_back(c10::stoi(s));
|
||||
}
|
||||
inputs.push_back(input_dims);
|
||||
}
|
||||
return inputs;
|
||||
}
|
||||
|
||||
std::vector<at::ScalarType> parseInputTypes() {
|
||||
std::vector<std::string> inputTypes = split(';', FLAGS_input_types);
|
||||
std::vector<at::ScalarType> scalarTypes;
|
||||
for (const auto& inputType : inputTypes) {
|
||||
at::ScalarType scalarType;
|
||||
if (inputType == "float") {
|
||||
scalarType = at::ScalarType::Float;
|
||||
} else if (inputType == "uint8") {
|
||||
scalarType = at::ScalarType::Byte;
|
||||
} else if (inputType == "int64") {
|
||||
scalarType = at::ScalarType::Long;
|
||||
} else {
|
||||
CAFFE_THROW("Unsupported input type: ", inputType);
|
||||
}
|
||||
scalarTypes.push_back(scalarType);
|
||||
}
|
||||
return scalarTypes;
|
||||
}
|
||||
|
||||
c10::Dict<c10::IValue, c10::IValue> createCompileSpec() {
|
||||
c10::Dict<c10::IValue, c10::IValue> compile_spec(
|
||||
c10::StringType::get(), c10::AnyType::get());
|
||||
c10::Dict<c10::IValue, c10::IValue> method_spec(
|
||||
c10::StringType::get(), c10::AnyType::get());
|
||||
auto inputShapes = parseInputShapes();
|
||||
auto inputTypes = parseInputTypes();
|
||||
method_spec.insert("sizes", inputShapes);
|
||||
method_spec.insert("types", inputTypes);
|
||||
method_spec.insert("sizes", FLAGS_input_dims);
|
||||
method_spec.insert("types", FLAGS_input_types);
|
||||
method_spec.insert("asmfile", FLAGS_output_llvm);
|
||||
method_spec.insert("model_name", FLAGS_model_name);
|
||||
method_spec.insert("model_version", FLAGS_model_version);
|
||||
compile_spec.insert(FLAGS_method_name, method_spec);
|
||||
return compile_spec;
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> getInputSizes(
|
||||
const c10::Dict<c10::IValue, c10::IValue>& compile_spec) {
|
||||
auto input_shapes = compile_spec.at(FLAGS_method_name).toGenericDict().at("sizes").toList();
|
||||
std::vector<std::vector<int64_t>> inputSizes;
|
||||
for (const auto& input_shape : input_shapes) {
|
||||
auto sizes = ((c10::IValue) input_shape).toIntVector();
|
||||
inputSizes.emplace_back(sizes);
|
||||
}
|
||||
return inputSizes;
|
||||
}
|
||||
|
||||
std::vector<at::ScalarType> getInputTypes(
|
||||
const c10::Dict<c10::IValue, c10::IValue>& compile_spec) {
|
||||
auto inputTypesList = compile_spec.at(FLAGS_method_name).toGenericDict().at("types").toList();
|
||||
std::vector<at::ScalarType> inputTypes;
|
||||
for (const auto& inputType : inputTypesList) {
|
||||
auto type = ((c10::IValue) inputType).toScalarType();
|
||||
inputTypes.emplace_back(type);
|
||||
}
|
||||
return inputTypes;
|
||||
}
|
||||
|
||||
std::string getNncKernelId() {
|
||||
// TODO: calculate the version_token.
|
||||
const std::string version_token = "VERTOKEN";
|
||||
return FLAGS_model_name + ":" + FLAGS_model_version + ":" + FLAGS_method_name +
|
||||
":" + version_token;
|
||||
}
|
||||
|
||||
std::string getNncKernelFuncName(const std::string& method_name) {
|
||||
return "nnc_" + FLAGS_model_name + "_" + FLAGS_model_version + "_" + method_name;
|
||||
}
|
||||
|
||||
void writeOutputLlvmAssembly(const std::string& asm_code) {
|
||||
std::string output_llvm_file_name = FLAGS_output_llvm;
|
||||
if (output_llvm_file_name.empty()) {
|
||||
output_llvm_file_name =
|
||||
FLAGS_model.substr(0, FLAGS_model.find('.')) + ".compiled.ll";
|
||||
}
|
||||
|
||||
std::ofstream output(output_llvm_file_name);
|
||||
output << asm_code;
|
||||
std::cout << "The compiled llvm assembly code was saved to " << output_llvm_file_name
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
c10::IValue preprocess(
|
||||
const torch::jit::Module& mod,
|
||||
const c10::Dict<c10::IValue, c10::IValue>& compile_spec,
|
||||
const torch::jit::BackendDebugHandleGenerator& generate_debug_handles) {
|
||||
|
||||
auto method = mod.get_method(FLAGS_method_name);
|
||||
auto graph = toGraphFunction(method.function()).graph()->copy();
|
||||
auto sizes = getInputSizes(compile_spec);
|
||||
auto types = getInputTypes(compile_spec);
|
||||
auto kernel_func_name = getNncKernelFuncName(FLAGS_method_name);
|
||||
|
||||
auto compiled = torch::jit::mobile::nnc::aotCompile(
|
||||
FLAGS_method_name, graph, sizes, types, kernel_func_name);
|
||||
writeOutputLlvmAssembly(compiled.second);
|
||||
|
||||
auto func = std::move(compiled.first);
|
||||
func->set_nnc_kernel_id(getNncKernelId());
|
||||
|
||||
torch::jit::mobile::nnc::CompilationUnit cu;
|
||||
cu.register_function(std::move(func));
|
||||
return cu.serialize();
|
||||
}
|
||||
|
||||
static auto reg = torch::jit::backend_preprocess_register("nnc", preprocess);
|
||||
|
||||
} // namespace
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
@ -205,7 +93,9 @@ int main(int argc, char** argv) {
|
||||
CAFFE_ENFORCE(!FLAGS_model_name.empty(), c10::UsageMessage());
|
||||
CAFFE_ENFORCE(!FLAGS_model_version.empty(), c10::UsageMessage());
|
||||
CAFFE_ENFORCE(!FLAGS_input_dims.empty(), c10::UsageMessage());
|
||||
CAFFE_ENFORCE(split(';', FLAGS_input_dims).size() == split(';', FLAGS_input_types).size(),
|
||||
CAFFE_ENFORCE(
|
||||
split(';', FLAGS_input_dims).size() ==
|
||||
split(';', FLAGS_input_types).size(),
|
||||
"Number of input_dims and input_types should be the same");
|
||||
|
||||
std::string output_model_name = FLAGS_output_model;
|
||||
@ -217,27 +107,6 @@ int main(int argc, char** argv) {
|
||||
auto m = torch::jit::load(FLAGS_model);
|
||||
m.eval();
|
||||
auto frozen_m = torch::jit::freeze_module(m.clone());
|
||||
auto graph = frozen_m.get_method(FLAGS_method_name).graph();
|
||||
auto inputShapes = parseInputShapes();
|
||||
auto inputTypes = parseInputTypes();
|
||||
std::vector<c10::optional<at::Tensor>> example_inputs;
|
||||
example_inputs.reserve(inputShapes.size());
|
||||
for (int i = 0; i < inputShapes.size(); ++i) {
|
||||
example_inputs.emplace_back(at::rand(inputShapes[i]).to(at::dtype(inputTypes[i])));
|
||||
}
|
||||
|
||||
torch::jit::RemoveTensorMutation(graph);
|
||||
torch::jit::EliminateDeadCode(graph->block());
|
||||
graph = torch::jit::tensorexpr::removeUnusedSelfArgument(graph);
|
||||
|
||||
torch::jit::tensorexpr::annotateInputShapes(graph, example_inputs);
|
||||
torch::jit::OptimizeFrozenGraph(graph, true);
|
||||
torch::jit::PropagateShapesOnGraph(graph);
|
||||
torch::jit::PeepholeOptimize(graph, false);
|
||||
torch::jit::ConstantPropagation(graph);
|
||||
torch::jit::PropagateShapesOnGraph(graph);
|
||||
torch::jit::PeepholeOptimize(graph, false);
|
||||
torch::jit::ConstantPropagation(graph);
|
||||
|
||||
auto compile_spec = createCompileSpec();
|
||||
auto any_dict_ty =
|
||||
|
@ -15,7 +15,7 @@ test_aot_model_compiler() {
|
||||
python "$CURRENT_DIR"/aot_test_model.py
|
||||
mv "$MODEL" "$TMP_DIR"/
|
||||
pushd "$TMP_DIR"
|
||||
"$TORCH_BIN_DIR"/aot_model_compiler_test --model "$MODEL" --model_name=aot_test_model --model_version=v1 --input_dims="2,2,2"
|
||||
"$TORCH_BIN_DIR"/aot_model_compiler_test --model "$MODEL" --model_name=aot_test_model --output_llvm=$COMPILED_CODE --model_version=v1 --input_dims="2,2,2"
|
||||
if [ ! -f "$COMPILED_MODEL" ] || [ ! -f "$COMPILED_CODE" ]; then
|
||||
echo "AOT model compiler failed to generate $COMPILED_MODEL and $COMPILED_CODE"
|
||||
exit 1
|
||||
|
@ -1,3 +1,4 @@
|
||||
#include <ATen/Functions.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include <torch/csrc/jit/backends/backend.h>
|
||||
#include <torch/csrc/jit/backends/backend_detail.h>
|
||||
@ -7,9 +8,9 @@
|
||||
#include <torch/csrc/jit/mobile/module.h>
|
||||
#include <torch/csrc/jit/mobile/nnc/context.h>
|
||||
#include <torch/csrc/jit/mobile/nnc/registry.h>
|
||||
#include <torch/csrc/jit/passes/freeze_module.h>
|
||||
#include <torch/custom_class.h>
|
||||
#include <torch/script.h>
|
||||
#include <ATen/Functions.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
@ -20,20 +21,18 @@ namespace {
|
||||
|
||||
c10::Dict<c10::IValue, c10::IValue> create_compile_spec(
|
||||
const std::string& method_name,
|
||||
const std::string& nnc_kernel_id,
|
||||
const std::vector<std::vector<int64_t>>& input_shapes,
|
||||
const std::vector<std::vector<int64_t>>& output_shapes,
|
||||
const c10::impl::GenericList& parameters,
|
||||
const std::vector<int64_t>& buffer_sizes) {
|
||||
const std::string& model_name,
|
||||
const std::string& input_shapes,
|
||||
const std::string& input_types) {
|
||||
c10::Dict<c10::IValue, c10::IValue> method_spec(
|
||||
c10::StringType::get(), c10::AnyType::get());
|
||||
method_spec.insert("nnc_kernel_id", nnc_kernel_id);
|
||||
method_spec.insert("input_sizes", input_shapes);
|
||||
method_spec.insert("output_sizes", output_shapes);
|
||||
|
||||
// For testing purpose we don't call the real NNC so pass in these directly.
|
||||
method_spec.insert("parameters", parameters);
|
||||
method_spec.insert("buffer_sizes", buffer_sizes);
|
||||
method_spec.insert("sizes", input_shapes);
|
||||
method_spec.insert("types", input_types);
|
||||
method_spec.insert("model_name", model_name);
|
||||
method_spec.insert("model_version", "v1");
|
||||
method_spec.insert("asmfile", "fake_nnc_model.s");
|
||||
method_spec.insert("arch", "x86-64");
|
||||
|
||||
c10::Dict<c10::IValue, c10::IValue> compile_spec(
|
||||
c10::StringType::get(), c10::AnyType::get());
|
||||
@ -41,85 +40,6 @@ c10::Dict<c10::IValue, c10::IValue> create_compile_spec(
|
||||
return compile_spec;
|
||||
}
|
||||
|
||||
std::vector<mobile::nnc::InputSpec> get_input_specs(
|
||||
const c10::Dict<c10::IValue, c10::IValue>& method_compile_spec) {
|
||||
auto input_shapes = method_compile_spec.at("input_sizes").toList();
|
||||
|
||||
std::vector<mobile::nnc::InputSpec> specs;
|
||||
for (const auto& input_shape : input_shapes) {
|
||||
mobile::nnc::InputSpec spec;
|
||||
spec.sizes_ = ((c10::IValue) input_shape).toIntVector();
|
||||
spec.dtype_ = c10::ScalarType::Float;
|
||||
specs.emplace_back(std::move(spec));
|
||||
}
|
||||
return specs;
|
||||
}
|
||||
|
||||
std::vector<mobile::nnc::OutputSpec> get_output_specs(
|
||||
const c10::Dict<c10::IValue, c10::IValue>& method_compile_spec) {
|
||||
auto output_shapes = method_compile_spec.at("output_sizes").toList();
|
||||
|
||||
std::vector<mobile::nnc::OutputSpec> specs;
|
||||
for (const auto& output_shape : output_shapes) {
|
||||
mobile::nnc::OutputSpec spec;
|
||||
spec.sizes_ = ((c10::IValue) output_shape).toIntVector();
|
||||
spec.dtype_ = c10::ScalarType::Float;
|
||||
specs.emplace_back(std::move(spec));
|
||||
}
|
||||
return specs;
|
||||
}
|
||||
|
||||
// A fake NNC preprocess method, which only produces the compiled model but
|
||||
// does not produce the assembly with the NNC compiler.
|
||||
c10::IValue preprocess(
|
||||
const torch::jit::Module& /* mod */,
|
||||
const c10::Dict<c10::IValue, c10::IValue>& method_compile_spec,
|
||||
const torch::jit::BackendDebugHandleGenerator&) {
|
||||
torch::jit::mobile::nnc::CompilationUnit cu;
|
||||
for (const auto& entry : method_compile_spec) {
|
||||
const std::string& method_name = entry.key().toStringRef();
|
||||
auto compile_spec = entry.value().toGenericDict();
|
||||
|
||||
auto func = std::make_unique<mobile::nnc::Function>();
|
||||
func->set_name(method_name);
|
||||
func->set_nnc_kernel_id(compile_spec.at("nnc_kernel_id").toStringRef());
|
||||
func->set_input_specs(get_input_specs(compile_spec));
|
||||
func->set_output_specs(get_output_specs(compile_spec));
|
||||
|
||||
func->set_parameters(compile_spec.at("parameters").toList());
|
||||
|
||||
mobile::nnc::MemoryPlan plan;
|
||||
plan.buffer_sizes_ = compile_spec.at("buffer_sizes").toIntVector();
|
||||
func->set_memory_plan(plan);
|
||||
|
||||
cu.register_function(std::move(func));
|
||||
}
|
||||
return cu.serialize();
|
||||
}
|
||||
|
||||
static auto reg = torch::jit::backend_preprocess_register("nnc", preprocess);
|
||||
|
||||
struct FakeTensor : torch::CustomClassHolder {
|
||||
explicit FakeTensor(std::vector<int64_t> data) : data_(std::move(data)) {}
|
||||
int64_t get() {
|
||||
return data_[0];
|
||||
}
|
||||
std::vector<int64_t> data_;
|
||||
};
|
||||
|
||||
TORCH_LIBRARY(_TorchScriptTesting, m) {
|
||||
m.class_<FakeTensor>("_MobileNNCFakeTensor")
|
||||
.def(torch::init<std::vector<int64_t>>())
|
||||
.def("get", &FakeTensor::get)
|
||||
.def_pickle(
|
||||
[](c10::intrusive_ptr<FakeTensor> self) { // __getstate__
|
||||
return self->data_;
|
||||
},
|
||||
[](std::vector<int64_t> state) { // __setstate__
|
||||
return c10::make_intrusive<FakeTensor>(std::move(state));
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
extern "C" {
|
||||
@ -135,19 +55,11 @@ int add_kernel(void** args) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
int fake_tensor_add_kernel(void** args) {
|
||||
// out = input + param.get()
|
||||
at::Tensor input = at::from_blob(args[0], {4, 4}, at::kFloat);
|
||||
at::Tensor out = at::from_blob(args[1], {4, 4}, at::kFloat);
|
||||
FakeTensor* param = reinterpret_cast<FakeTensor*>(args[2]);
|
||||
out.copy_(at::add(input, param->get()));
|
||||
return 0;
|
||||
}
|
||||
|
||||
} // extern "C"
|
||||
|
||||
REGISTER_NNC_KERNEL("_add_kernel", add_kernel)
|
||||
REGISTER_NNC_KERNEL("_fake_tensor_add_kernel", fake_tensor_add_kernel)
|
||||
REGISTER_NNC_KERNEL(
|
||||
"_add_kernel_nnc_fake_model:v1:forward:VERTOKEN",
|
||||
add_kernel)
|
||||
|
||||
TEST(NNCBackendTest, AOTCompileThenExecute) {
|
||||
torch::jit::Module m("m");
|
||||
@ -165,16 +77,12 @@ TEST(NNCBackendTest, AOTCompileThenExecute) {
|
||||
|
||||
// Compile the model with NNC.
|
||||
auto compile_spec = create_compile_spec(
|
||||
"forward",
|
||||
"_add_kernel",
|
||||
{{4, 4}},
|
||||
{{4, 4}},
|
||||
c10::impl::toList(c10::List<at::Tensor>({param})),
|
||||
{});
|
||||
"forward", "_add_kernel_nnc_fake_model", "4,4", "float");
|
||||
auto any_dict_ty =
|
||||
c10::DictType::create(c10::StringType::get(), c10::AnyType::get());
|
||||
auto frozen_m = torch::jit::freeze_module(m.clone());
|
||||
auto compiled_module = torch::jit::detail::codegen_backend_module(
|
||||
"nnc", m, compile_spec, any_dict_ty);
|
||||
"nnc", frozen_m, compile_spec, any_dict_ty);
|
||||
|
||||
// Save the compiled model.
|
||||
std::stringstream ss;
|
||||
@ -185,49 +93,7 @@ TEST(NNCBackendTest, AOTCompileThenExecute) {
|
||||
auto result = loaded_module.forward(inputs);
|
||||
EXPECT_TRUE(result.toTensor().equal(3.0 * torch::ones({4, 4})));
|
||||
EXPECT_TRUE(result.toTensor().equal(reference.toTensor()));
|
||||
}
|
||||
|
||||
TEST(NNCBackendTest, FakeTensor) {
|
||||
script::Module m("m");
|
||||
auto param_cls = getCustomClass(
|
||||
"__torch__.torch.classes._TorchScriptTesting._MobileNNCFakeTensor");
|
||||
auto param_value = c10::make_intrusive<FakeTensor>(std::vector<int64_t>({3}));
|
||||
m.register_attribute("param", param_cls, param_value, false);
|
||||
m.define(
|
||||
R"(
|
||||
def forward(self, input):
|
||||
return input + self.param.get()
|
||||
)");
|
||||
|
||||
// Run the TorchScript module to get reference result.
|
||||
std::vector<IValue> inputs;
|
||||
inputs.emplace_back(2.0 * torch::ones({4, 4}));
|
||||
auto reference = m.forward(inputs);
|
||||
|
||||
// Compile the model with NNC.
|
||||
auto params = c10::impl::GenericList(c10::AnyType::get());
|
||||
params.emplace_back(param_value);
|
||||
auto compile_spec = create_compile_spec(
|
||||
"forward",
|
||||
"_fake_tensor_add_kernel",
|
||||
{{4, 4}},
|
||||
{{4, 4}},
|
||||
params,
|
||||
{});
|
||||
auto any_dict_ty =
|
||||
c10::DictType::create(c10::StringType::get(), c10::AnyType::get());
|
||||
auto compiled_module = torch::jit::detail::codegen_backend_module(
|
||||
"nnc", m, compile_spec, any_dict_ty);
|
||||
|
||||
// Save the compiled model.
|
||||
std::stringstream ss;
|
||||
compiled_module._save_for_mobile(ss);
|
||||
|
||||
// Load and run the saved model.
|
||||
auto loaded_module = _load_for_mobile(ss);
|
||||
auto result = loaded_module.forward(inputs);
|
||||
EXPECT_TRUE(result.toTensor().equal(5.0 * torch::ones({4, 4})));
|
||||
EXPECT_TRUE(result.toTensor().equal(reference.toTensor()));
|
||||
EXPECT_EQ(remove("fake_nnc_model.s"), 0);
|
||||
}
|
||||
|
||||
} // namespace nnc
|
||||
|
@ -2,10 +2,14 @@
|
||||
|
||||
#include <ATen/Functions.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#include <torch/csrc/jit/backends/backend.h>
|
||||
#include <torch/csrc/jit/backends/backend_detail.h>
|
||||
#include <torch/csrc/jit/backends/backend_preprocess.h>
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
#include <torch/csrc/jit/jit_log.h>
|
||||
#include <torch/csrc/jit/passes/constant_propagation.h>
|
||||
#include <torch/csrc/jit/passes/dead_code_elimination.h>
|
||||
#include <torch/csrc/jit/passes/frozen_graph_optimizations.h>
|
||||
#include <torch/csrc/jit/passes/lower_tuples.h>
|
||||
#include <torch/csrc/jit/passes/peephole.h>
|
||||
#include <torch/csrc/jit/passes/remove_mutation.h>
|
||||
@ -16,6 +20,7 @@
|
||||
#include <torch/csrc/jit/tensorexpr/ir.h>
|
||||
#include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
|
||||
#include <torch/csrc/jit/tensorexpr/kernel.h>
|
||||
#include <fstream>
|
||||
|
||||
using namespace torch::jit;
|
||||
using namespace torch::jit::tensorexpr;
|
||||
@ -178,30 +183,6 @@ std::pair<std::unique_ptr<Function>, const std::string> aotCompile(
|
||||
GRAPH_DEBUG("Method name ", method_name);
|
||||
GRAPH_DEBUG("Kernel func name ", kernel_func_name);
|
||||
|
||||
CAFFE_ENFORCE(
|
||||
sizes.size() == types.size(),
|
||||
"Number of input sizes and input types should be the same");
|
||||
|
||||
std::vector<at::IValue> example_values;
|
||||
std::vector<c10::optional<at::Tensor>> example_inputs;
|
||||
for (int i = 0; i < sizes.size(); ++i) {
|
||||
auto example_input = at::rand(sizes[i]).to(at::dtype(types[i]));
|
||||
example_values.emplace_back(example_input);
|
||||
example_inputs.emplace_back(example_input);
|
||||
}
|
||||
|
||||
GRAPH_DUMP("graph before compiler passes ", g);
|
||||
tensorexpr::removeUnusedSelfArgument(g);
|
||||
g = TraceGraph(g, example_values);
|
||||
// TODO: Remove annotateInputShapes pass when TraceGraph can also capture
|
||||
// input shapes
|
||||
tensorexpr::annotateInputShapes(g, example_inputs);
|
||||
RemoveListMutation(g);
|
||||
RemoveTensorMutation(g);
|
||||
EliminateDeadCode(g);
|
||||
LowerAllTuples(g);
|
||||
GRAPH_DUMP("graph after compiler passes ", g);
|
||||
|
||||
std::shared_ptr<tensorexpr::TensorExprKernel> kernel =
|
||||
std::make_shared<tensorexpr::TensorExprKernel>(
|
||||
TensorExprKernel(g, kernel_func_name));
|
||||
@ -212,6 +193,174 @@ std::pair<std::unique_ptr<Function>, const std::string> aotCompile(
|
||||
return std::make_pair(std::move(func), compiled_assembly);
|
||||
}
|
||||
|
||||
void writeOutputLlvmAssembly(
|
||||
const std::string& asm_code,
|
||||
const std::string& output_llvm_file_name) {
|
||||
std::ofstream output(output_llvm_file_name);
|
||||
output << asm_code;
|
||||
GRAPH_DEBUG(
|
||||
"The compiled llvm assembly code was saved to ", output_llvm_file_name);
|
||||
}
|
||||
|
||||
std::vector<std::string> split(
|
||||
char separator,
|
||||
const std::string& string,
|
||||
bool ignore_empty = true) {
|
||||
std::vector<std::string> pieces;
|
||||
std::stringstream ss(string);
|
||||
std::string item;
|
||||
while (getline(ss, item, separator)) {
|
||||
if (!ignore_empty || !item.empty()) {
|
||||
pieces.push_back(std::move(item));
|
||||
}
|
||||
}
|
||||
return pieces;
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> parseInputShapes(
|
||||
const std::string& input_dims_s) {
|
||||
std::vector<std::string> input_dims_list = split(';', input_dims_s);
|
||||
std::vector<std::vector<int64_t>> inputs;
|
||||
for (const auto& input_dims_item : input_dims_list) {
|
||||
auto input_dims_str = split(',', input_dims_item);
|
||||
std::vector<int64_t> input_dims;
|
||||
input_dims.reserve(input_dims_str.size());
|
||||
for (const auto& s : input_dims_str) {
|
||||
input_dims.push_back(c10::stoi(s));
|
||||
}
|
||||
inputs.push_back(input_dims);
|
||||
}
|
||||
return inputs;
|
||||
}
|
||||
|
||||
std::vector<at::ScalarType> parseInputTypes(
|
||||
const std::string& input_types_str) {
|
||||
std::vector<std::string> inputTypes = split(';', input_types_str);
|
||||
std::vector<at::ScalarType> scalarTypes;
|
||||
for (const auto& inputType : inputTypes) {
|
||||
at::ScalarType scalarType;
|
||||
if (inputType == "float") {
|
||||
scalarType = at::ScalarType::Float;
|
||||
} else if (inputType == "uint8") {
|
||||
scalarType = at::ScalarType::Byte;
|
||||
} else if (inputType == "int64") {
|
||||
scalarType = at::ScalarType::Long;
|
||||
} else {
|
||||
CAFFE_THROW("Unsupported input type: ", inputType);
|
||||
}
|
||||
scalarTypes.push_back(scalarType);
|
||||
}
|
||||
return scalarTypes;
|
||||
}
|
||||
|
||||
std::string getNncKernelId(
|
||||
const std::string& model_name,
|
||||
const std::string& model_version,
|
||||
const std::string& method_name) {
|
||||
// TODO: calculate the version_token.
|
||||
const std::string version_token = "VERTOKEN";
|
||||
return model_name + ":" + model_version + ":" + method_name + ":" +
|
||||
version_token;
|
||||
}
|
||||
|
||||
std::string getNncKernelFuncName(
|
||||
const std::string& model_name,
|
||||
const std::string& model_version,
|
||||
const std::string& method_name) {
|
||||
return "nnc_" + model_name + "_" + model_version + "_" + method_name;
|
||||
}
|
||||
|
||||
std::shared_ptr<Graph> preprocessGraphPasses(
|
||||
std::shared_ptr<Graph>& graph,
|
||||
const std::vector<c10::optional<at::Tensor>>& example_inputs) {
|
||||
GRAPH_DEBUG("Before preprocessing graph passes: ", *graph);
|
||||
torch::jit::RemoveTensorMutation(graph);
|
||||
torch::jit::EliminateDeadCode(graph->block());
|
||||
graph = torch::jit::tensorexpr::removeUnusedSelfArgument(graph);
|
||||
|
||||
torch::jit::tensorexpr::annotateInputShapes(graph, example_inputs);
|
||||
torch::jit::OptimizeFrozenGraph(graph, true);
|
||||
torch::jit::PropagateShapesOnGraph(graph);
|
||||
torch::jit::PeepholeOptimize(graph, false);
|
||||
torch::jit::ConstantPropagation(graph);
|
||||
torch::jit::PropagateShapesOnGraph(graph);
|
||||
torch::jit::PeepholeOptimize(graph, false);
|
||||
torch::jit::ConstantPropagation(graph);
|
||||
|
||||
tensorexpr::removeUnusedSelfArgument(graph);
|
||||
|
||||
std::vector<at::IValue> example_values;
|
||||
example_values.reserve(example_inputs.size());
|
||||
for (auto example_input : example_inputs) {
|
||||
example_values.emplace_back(*example_input);
|
||||
}
|
||||
graph = TraceGraph(graph, example_values);
|
||||
// TODO: Remove annotateInputShapes pass when TraceGraph can also capture
|
||||
// input shapes
|
||||
tensorexpr::annotateInputShapes(graph, example_inputs);
|
||||
|
||||
RemoveListMutation(graph);
|
||||
RemoveTensorMutation(graph);
|
||||
EliminateDeadCode(graph);
|
||||
LowerAllTuples(graph);
|
||||
GRAPH_DEBUG("After preprocessing graph passes: ", *graph);
|
||||
return graph;
|
||||
}
|
||||
|
||||
std::vector<c10::optional<at::Tensor>> generateExampleInputs(
|
||||
const std::vector<std::vector<int64_t>>& inputShapes,
|
||||
const std::vector<at::ScalarType>& inputTypes) {
|
||||
std::vector<c10::optional<at::Tensor>> example_inputs;
|
||||
example_inputs.reserve(inputShapes.size());
|
||||
for (int i = 0; i < inputShapes.size(); ++i) {
|
||||
example_inputs.emplace_back(
|
||||
at::rand(inputShapes[i]).to(at::dtype(inputTypes[i])));
|
||||
}
|
||||
return example_inputs;
|
||||
}
|
||||
|
||||
c10::IValue preprocess(
|
||||
const torch::jit::Module& mod,
|
||||
const c10::Dict<c10::IValue, c10::IValue>& compile_spec,
|
||||
const torch::jit::BackendDebugHandleGenerator& generate_debug_handles) {
|
||||
torch::jit::mobile::nnc::CompilationUnit cu;
|
||||
for (const auto& kv : compile_spec) {
|
||||
GRAPH_DEBUG("Key: ", kv.key());
|
||||
GRAPH_DEBUG("Value: ", kv.value());
|
||||
std::string method_name = *(kv.key().toString());
|
||||
GRAPH_DEBUG("Method name: ", method_name);
|
||||
auto method_spec = kv.value().toGenericDict();
|
||||
std::string model_name = *method_spec.at("model_name").toString();
|
||||
std::string model_version = *method_spec.at("model_version").toString();
|
||||
std::string asmfile_name = *method_spec.at("asmfile").toString();
|
||||
GRAPH_DEBUG("Model name: ", model_name);
|
||||
GRAPH_DEBUG("Model version: ", model_version);
|
||||
GRAPH_DEBUG("Asm file name: ", asmfile_name);
|
||||
|
||||
auto method = mod.get_method(method_name);
|
||||
auto graph = toGraphFunction(method.function()).graph()->copy();
|
||||
|
||||
auto sizes = parseInputShapes(*method_spec.at("sizes").toString());
|
||||
auto types = parseInputTypes(*method_spec.at("types").toString());
|
||||
|
||||
auto example_inputs = generateExampleInputs(sizes, types);
|
||||
graph = preprocessGraphPasses(graph, example_inputs);
|
||||
|
||||
auto kernel_func_name =
|
||||
getNncKernelFuncName(model_name, model_version, method_name);
|
||||
auto compiled = torch::jit::mobile::nnc::aotCompile(
|
||||
method_name, graph, sizes, types, kernel_func_name);
|
||||
writeOutputLlvmAssembly(compiled.second, asmfile_name);
|
||||
auto func = std::move(compiled.first);
|
||||
func->set_nnc_kernel_id(
|
||||
getNncKernelId(model_name, model_version, method_name));
|
||||
cu.register_function(std::move(func));
|
||||
}
|
||||
return cu.serialize();
|
||||
}
|
||||
|
||||
static auto reg = torch::jit::backend_preprocess_register("nnc", preprocess);
|
||||
|
||||
} // namespace nnc
|
||||
} // namespace mobile
|
||||
} // namespace jit
|
||||
|
Reference in New Issue
Block a user