diff --git a/binaries/aot_model_compiler.cc b/binaries/aot_model_compiler.cc index 2ff895c235be..b9d1d24c08ea 100644 --- a/binaries/aot_model_compiler.cc +++ b/binaries/aot_model_compiler.cc @@ -7,14 +7,7 @@ #include #include #include -#include -#include #include -#include -#include -#include -#include -#include #include #include #include @@ -61,125 +54,20 @@ std::vector split( return pieces; } -std::vector> parseInputShapes() { - CAFFE_ENFORCE_GE(FLAGS_input_dims.size(), 0, "Input dims must be specified."); - std::vector input_dims_list = split(';', FLAGS_input_dims); - std::vector> inputs; - for (const auto& input_dims_item : input_dims_list) { - auto input_dims_str = split(',', input_dims_item); - std::vector input_dims; - input_dims.reserve(input_dims_str.size()); - for (const auto& s : input_dims_str) { - input_dims.push_back(c10::stoi(s)); - } - inputs.push_back(input_dims); - } - return inputs; -} - -std::vector parseInputTypes() { - std::vector inputTypes = split(';', FLAGS_input_types); - std::vector scalarTypes; - for (const auto& inputType : inputTypes) { - at::ScalarType scalarType; - if (inputType == "float") { - scalarType = at::ScalarType::Float; - } else if (inputType == "uint8") { - scalarType = at::ScalarType::Byte; - } else if (inputType == "int64") { - scalarType = at::ScalarType::Long; - } else { - CAFFE_THROW("Unsupported input type: ", inputType); - } - scalarTypes.push_back(scalarType); - } - return scalarTypes; -} - c10::Dict createCompileSpec() { c10::Dict compile_spec( c10::StringType::get(), c10::AnyType::get()); c10::Dict method_spec( c10::StringType::get(), c10::AnyType::get()); - auto inputShapes = parseInputShapes(); - auto inputTypes = parseInputTypes(); - method_spec.insert("sizes", inputShapes); - method_spec.insert("types", inputTypes); + method_spec.insert("sizes", FLAGS_input_dims); + method_spec.insert("types", FLAGS_input_types); + method_spec.insert("asmfile", FLAGS_output_llvm); + method_spec.insert("model_name", FLAGS_model_name); + method_spec.insert("model_version", FLAGS_model_version); compile_spec.insert(FLAGS_method_name, method_spec); return compile_spec; } -std::vector> getInputSizes( - const c10::Dict& compile_spec) { - auto input_shapes = compile_spec.at(FLAGS_method_name).toGenericDict().at("sizes").toList(); - std::vector> inputSizes; - for (const auto& input_shape : input_shapes) { - auto sizes = ((c10::IValue) input_shape).toIntVector(); - inputSizes.emplace_back(sizes); - } - return inputSizes; -} - -std::vector getInputTypes( - const c10::Dict& compile_spec) { - auto inputTypesList = compile_spec.at(FLAGS_method_name).toGenericDict().at("types").toList(); - std::vector inputTypes; - for (const auto& inputType : inputTypesList) { - auto type = ((c10::IValue) inputType).toScalarType(); - inputTypes.emplace_back(type); - } - return inputTypes; -} - -std::string getNncKernelId() { - // TODO: calculate the version_token. - const std::string version_token = "VERTOKEN"; - return FLAGS_model_name + ":" + FLAGS_model_version + ":" + FLAGS_method_name + - ":" + version_token; -} - -std::string getNncKernelFuncName(const std::string& method_name) { - return "nnc_" + FLAGS_model_name + "_" + FLAGS_model_version + "_" + method_name; -} - -void writeOutputLlvmAssembly(const std::string& asm_code) { - std::string output_llvm_file_name = FLAGS_output_llvm; - if (output_llvm_file_name.empty()) { - output_llvm_file_name = - FLAGS_model.substr(0, FLAGS_model.find('.')) + ".compiled.ll"; - } - - std::ofstream output(output_llvm_file_name); - output << asm_code; - std::cout << "The compiled llvm assembly code was saved to " << output_llvm_file_name - << std::endl; -} - -c10::IValue preprocess( - const torch::jit::Module& mod, - const c10::Dict& compile_spec, - const torch::jit::BackendDebugHandleGenerator& generate_debug_handles) { - - auto method = mod.get_method(FLAGS_method_name); - auto graph = toGraphFunction(method.function()).graph()->copy(); - auto sizes = getInputSizes(compile_spec); - auto types = getInputTypes(compile_spec); - auto kernel_func_name = getNncKernelFuncName(FLAGS_method_name); - - auto compiled = torch::jit::mobile::nnc::aotCompile( - FLAGS_method_name, graph, sizes, types, kernel_func_name); - writeOutputLlvmAssembly(compiled.second); - - auto func = std::move(compiled.first); - func->set_nnc_kernel_id(getNncKernelId()); - - torch::jit::mobile::nnc::CompilationUnit cu; - cu.register_function(std::move(func)); - return cu.serialize(); -} - -static auto reg = torch::jit::backend_preprocess_register("nnc", preprocess); - } // namespace int main(int argc, char** argv) { @@ -205,7 +93,9 @@ int main(int argc, char** argv) { CAFFE_ENFORCE(!FLAGS_model_name.empty(), c10::UsageMessage()); CAFFE_ENFORCE(!FLAGS_model_version.empty(), c10::UsageMessage()); CAFFE_ENFORCE(!FLAGS_input_dims.empty(), c10::UsageMessage()); - CAFFE_ENFORCE(split(';', FLAGS_input_dims).size() == split(';', FLAGS_input_types).size(), + CAFFE_ENFORCE( + split(';', FLAGS_input_dims).size() == + split(';', FLAGS_input_types).size(), "Number of input_dims and input_types should be the same"); std::string output_model_name = FLAGS_output_model; @@ -217,27 +107,6 @@ int main(int argc, char** argv) { auto m = torch::jit::load(FLAGS_model); m.eval(); auto frozen_m = torch::jit::freeze_module(m.clone()); - auto graph = frozen_m.get_method(FLAGS_method_name).graph(); - auto inputShapes = parseInputShapes(); - auto inputTypes = parseInputTypes(); - std::vector> example_inputs; - example_inputs.reserve(inputShapes.size()); - for (int i = 0; i < inputShapes.size(); ++i) { - example_inputs.emplace_back(at::rand(inputShapes[i]).to(at::dtype(inputTypes[i]))); - } - - torch::jit::RemoveTensorMutation(graph); - torch::jit::EliminateDeadCode(graph->block()); - graph = torch::jit::tensorexpr::removeUnusedSelfArgument(graph); - - torch::jit::tensorexpr::annotateInputShapes(graph, example_inputs); - torch::jit::OptimizeFrozenGraph(graph, true); - torch::jit::PropagateShapesOnGraph(graph); - torch::jit::PeepholeOptimize(graph, false); - torch::jit::ConstantPropagation(graph); - torch::jit::PropagateShapesOnGraph(graph); - torch::jit::PeepholeOptimize(graph, false); - torch::jit::ConstantPropagation(graph); auto compile_spec = createCompileSpec(); auto any_dict_ty = diff --git a/test/mobile/nnc/test_aot_compile.sh b/test/mobile/nnc/test_aot_compile.sh index 6ff0ef2e6812..f4387a83c441 100755 --- a/test/mobile/nnc/test_aot_compile.sh +++ b/test/mobile/nnc/test_aot_compile.sh @@ -15,7 +15,7 @@ test_aot_model_compiler() { python "$CURRENT_DIR"/aot_test_model.py mv "$MODEL" "$TMP_DIR"/ pushd "$TMP_DIR" - "$TORCH_BIN_DIR"/aot_model_compiler_test --model "$MODEL" --model_name=aot_test_model --model_version=v1 --input_dims="2,2,2" + "$TORCH_BIN_DIR"/aot_model_compiler_test --model "$MODEL" --model_name=aot_test_model --output_llvm=$COMPILED_CODE --model_version=v1 --input_dims="2,2,2" if [ ! -f "$COMPILED_MODEL" ] || [ ! -f "$COMPILED_CODE" ]; then echo "AOT model compiler failed to generate $COMPILED_MODEL and $COMPILED_CODE" exit 1 diff --git a/test/mobile/nnc/test_nnc_backend.cpp b/test/mobile/nnc/test_nnc_backend.cpp index 0e59aaa5547d..f7adcb62459f 100644 --- a/test/mobile/nnc/test_nnc_backend.cpp +++ b/test/mobile/nnc/test_nnc_backend.cpp @@ -1,3 +1,4 @@ +#include #include #include #include @@ -7,9 +8,9 @@ #include #include #include +#include #include #include -#include namespace torch { namespace jit { @@ -20,20 +21,18 @@ namespace { c10::Dict create_compile_spec( const std::string& method_name, - const std::string& nnc_kernel_id, - const std::vector>& input_shapes, - const std::vector>& output_shapes, - const c10::impl::GenericList& parameters, - const std::vector& buffer_sizes) { + const std::string& model_name, + const std::string& input_shapes, + const std::string& input_types) { c10::Dict method_spec( c10::StringType::get(), c10::AnyType::get()); - method_spec.insert("nnc_kernel_id", nnc_kernel_id); - method_spec.insert("input_sizes", input_shapes); - method_spec.insert("output_sizes", output_shapes); - // For testing purpose we don't call the real NNC so pass in these directly. - method_spec.insert("parameters", parameters); - method_spec.insert("buffer_sizes", buffer_sizes); + method_spec.insert("sizes", input_shapes); + method_spec.insert("types", input_types); + method_spec.insert("model_name", model_name); + method_spec.insert("model_version", "v1"); + method_spec.insert("asmfile", "fake_nnc_model.s"); + method_spec.insert("arch", "x86-64"); c10::Dict compile_spec( c10::StringType::get(), c10::AnyType::get()); @@ -41,85 +40,6 @@ c10::Dict create_compile_spec( return compile_spec; } -std::vector get_input_specs( - const c10::Dict& method_compile_spec) { - auto input_shapes = method_compile_spec.at("input_sizes").toList(); - - std::vector specs; - for (const auto& input_shape : input_shapes) { - mobile::nnc::InputSpec spec; - spec.sizes_ = ((c10::IValue) input_shape).toIntVector(); - spec.dtype_ = c10::ScalarType::Float; - specs.emplace_back(std::move(spec)); - } - return specs; -} - -std::vector get_output_specs( - const c10::Dict& method_compile_spec) { - auto output_shapes = method_compile_spec.at("output_sizes").toList(); - - std::vector specs; - for (const auto& output_shape : output_shapes) { - mobile::nnc::OutputSpec spec; - spec.sizes_ = ((c10::IValue) output_shape).toIntVector(); - spec.dtype_ = c10::ScalarType::Float; - specs.emplace_back(std::move(spec)); - } - return specs; -} - -// A fake NNC preprocess method, which only produces the compiled model but -// does not produce the assembly with the NNC compiler. -c10::IValue preprocess( - const torch::jit::Module& /* mod */, - const c10::Dict& method_compile_spec, - const torch::jit::BackendDebugHandleGenerator&) { - torch::jit::mobile::nnc::CompilationUnit cu; - for (const auto& entry : method_compile_spec) { - const std::string& method_name = entry.key().toStringRef(); - auto compile_spec = entry.value().toGenericDict(); - - auto func = std::make_unique(); - func->set_name(method_name); - func->set_nnc_kernel_id(compile_spec.at("nnc_kernel_id").toStringRef()); - func->set_input_specs(get_input_specs(compile_spec)); - func->set_output_specs(get_output_specs(compile_spec)); - - func->set_parameters(compile_spec.at("parameters").toList()); - - mobile::nnc::MemoryPlan plan; - plan.buffer_sizes_ = compile_spec.at("buffer_sizes").toIntVector(); - func->set_memory_plan(plan); - - cu.register_function(std::move(func)); - } - return cu.serialize(); -} - -static auto reg = torch::jit::backend_preprocess_register("nnc", preprocess); - -struct FakeTensor : torch::CustomClassHolder { - explicit FakeTensor(std::vector data) : data_(std::move(data)) {} - int64_t get() { - return data_[0]; - } - std::vector data_; -}; - -TORCH_LIBRARY(_TorchScriptTesting, m) { - m.class_("_MobileNNCFakeTensor") - .def(torch::init>()) - .def("get", &FakeTensor::get) - .def_pickle( - [](c10::intrusive_ptr self) { // __getstate__ - return self->data_; - }, - [](std::vector state) { // __setstate__ - return c10::make_intrusive(std::move(state)); - }); -} - } // namespace extern "C" { @@ -135,19 +55,11 @@ int add_kernel(void** args) { return 0; } -int fake_tensor_add_kernel(void** args) { - // out = input + param.get() - at::Tensor input = at::from_blob(args[0], {4, 4}, at::kFloat); - at::Tensor out = at::from_blob(args[1], {4, 4}, at::kFloat); - FakeTensor* param = reinterpret_cast(args[2]); - out.copy_(at::add(input, param->get())); - return 0; -} - } // extern "C" -REGISTER_NNC_KERNEL("_add_kernel", add_kernel) -REGISTER_NNC_KERNEL("_fake_tensor_add_kernel", fake_tensor_add_kernel) +REGISTER_NNC_KERNEL( + "_add_kernel_nnc_fake_model:v1:forward:VERTOKEN", + add_kernel) TEST(NNCBackendTest, AOTCompileThenExecute) { torch::jit::Module m("m"); @@ -165,16 +77,12 @@ TEST(NNCBackendTest, AOTCompileThenExecute) { // Compile the model with NNC. auto compile_spec = create_compile_spec( - "forward", - "_add_kernel", - {{4, 4}}, - {{4, 4}}, - c10::impl::toList(c10::List({param})), - {}); + "forward", "_add_kernel_nnc_fake_model", "4,4", "float"); auto any_dict_ty = c10::DictType::create(c10::StringType::get(), c10::AnyType::get()); + auto frozen_m = torch::jit::freeze_module(m.clone()); auto compiled_module = torch::jit::detail::codegen_backend_module( - "nnc", m, compile_spec, any_dict_ty); + "nnc", frozen_m, compile_spec, any_dict_ty); // Save the compiled model. std::stringstream ss; @@ -185,49 +93,7 @@ TEST(NNCBackendTest, AOTCompileThenExecute) { auto result = loaded_module.forward(inputs); EXPECT_TRUE(result.toTensor().equal(3.0 * torch::ones({4, 4}))); EXPECT_TRUE(result.toTensor().equal(reference.toTensor())); -} - -TEST(NNCBackendTest, FakeTensor) { - script::Module m("m"); - auto param_cls = getCustomClass( - "__torch__.torch.classes._TorchScriptTesting._MobileNNCFakeTensor"); - auto param_value = c10::make_intrusive(std::vector({3})); - m.register_attribute("param", param_cls, param_value, false); - m.define( - R"( - def forward(self, input): - return input + self.param.get() - )"); - - // Run the TorchScript module to get reference result. - std::vector inputs; - inputs.emplace_back(2.0 * torch::ones({4, 4})); - auto reference = m.forward(inputs); - - // Compile the model with NNC. - auto params = c10::impl::GenericList(c10::AnyType::get()); - params.emplace_back(param_value); - auto compile_spec = create_compile_spec( - "forward", - "_fake_tensor_add_kernel", - {{4, 4}}, - {{4, 4}}, - params, - {}); - auto any_dict_ty = - c10::DictType::create(c10::StringType::get(), c10::AnyType::get()); - auto compiled_module = torch::jit::detail::codegen_backend_module( - "nnc", m, compile_spec, any_dict_ty); - - // Save the compiled model. - std::stringstream ss; - compiled_module._save_for_mobile(ss); - - // Load and run the saved model. - auto loaded_module = _load_for_mobile(ss); - auto result = loaded_module.forward(inputs); - EXPECT_TRUE(result.toTensor().equal(5.0 * torch::ones({4, 4}))); - EXPECT_TRUE(result.toTensor().equal(reference.toTensor())); + EXPECT_EQ(remove("fake_nnc_model.s"), 0); } } // namespace nnc diff --git a/torch/csrc/jit/mobile/nnc/aot_compiler.cpp b/torch/csrc/jit/mobile/nnc/aot_compiler.cpp index b792ae9d2227..1d140390a52c 100644 --- a/torch/csrc/jit/mobile/nnc/aot_compiler.cpp +++ b/torch/csrc/jit/mobile/nnc/aot_compiler.cpp @@ -2,10 +2,14 @@ #include #include +#include +#include +#include #include #include #include #include +#include #include #include #include @@ -16,6 +20,7 @@ #include #include #include +#include using namespace torch::jit; using namespace torch::jit::tensorexpr; @@ -178,30 +183,6 @@ std::pair, const std::string> aotCompile( GRAPH_DEBUG("Method name ", method_name); GRAPH_DEBUG("Kernel func name ", kernel_func_name); - CAFFE_ENFORCE( - sizes.size() == types.size(), - "Number of input sizes and input types should be the same"); - - std::vector example_values; - std::vector> example_inputs; - for (int i = 0; i < sizes.size(); ++i) { - auto example_input = at::rand(sizes[i]).to(at::dtype(types[i])); - example_values.emplace_back(example_input); - example_inputs.emplace_back(example_input); - } - - GRAPH_DUMP("graph before compiler passes ", g); - tensorexpr::removeUnusedSelfArgument(g); - g = TraceGraph(g, example_values); - // TODO: Remove annotateInputShapes pass when TraceGraph can also capture - // input shapes - tensorexpr::annotateInputShapes(g, example_inputs); - RemoveListMutation(g); - RemoveTensorMutation(g); - EliminateDeadCode(g); - LowerAllTuples(g); - GRAPH_DUMP("graph after compiler passes ", g); - std::shared_ptr kernel = std::make_shared( TensorExprKernel(g, kernel_func_name)); @@ -212,6 +193,174 @@ std::pair, const std::string> aotCompile( return std::make_pair(std::move(func), compiled_assembly); } +void writeOutputLlvmAssembly( + const std::string& asm_code, + const std::string& output_llvm_file_name) { + std::ofstream output(output_llvm_file_name); + output << asm_code; + GRAPH_DEBUG( + "The compiled llvm assembly code was saved to ", output_llvm_file_name); +} + +std::vector split( + char separator, + const std::string& string, + bool ignore_empty = true) { + std::vector pieces; + std::stringstream ss(string); + std::string item; + while (getline(ss, item, separator)) { + if (!ignore_empty || !item.empty()) { + pieces.push_back(std::move(item)); + } + } + return pieces; +} + +std::vector> parseInputShapes( + const std::string& input_dims_s) { + std::vector input_dims_list = split(';', input_dims_s); + std::vector> inputs; + for (const auto& input_dims_item : input_dims_list) { + auto input_dims_str = split(',', input_dims_item); + std::vector input_dims; + input_dims.reserve(input_dims_str.size()); + for (const auto& s : input_dims_str) { + input_dims.push_back(c10::stoi(s)); + } + inputs.push_back(input_dims); + } + return inputs; +} + +std::vector parseInputTypes( + const std::string& input_types_str) { + std::vector inputTypes = split(';', input_types_str); + std::vector scalarTypes; + for (const auto& inputType : inputTypes) { + at::ScalarType scalarType; + if (inputType == "float") { + scalarType = at::ScalarType::Float; + } else if (inputType == "uint8") { + scalarType = at::ScalarType::Byte; + } else if (inputType == "int64") { + scalarType = at::ScalarType::Long; + } else { + CAFFE_THROW("Unsupported input type: ", inputType); + } + scalarTypes.push_back(scalarType); + } + return scalarTypes; +} + +std::string getNncKernelId( + const std::string& model_name, + const std::string& model_version, + const std::string& method_name) { + // TODO: calculate the version_token. + const std::string version_token = "VERTOKEN"; + return model_name + ":" + model_version + ":" + method_name + ":" + + version_token; +} + +std::string getNncKernelFuncName( + const std::string& model_name, + const std::string& model_version, + const std::string& method_name) { + return "nnc_" + model_name + "_" + model_version + "_" + method_name; +} + +std::shared_ptr preprocessGraphPasses( + std::shared_ptr& graph, + const std::vector>& example_inputs) { + GRAPH_DEBUG("Before preprocessing graph passes: ", *graph); + torch::jit::RemoveTensorMutation(graph); + torch::jit::EliminateDeadCode(graph->block()); + graph = torch::jit::tensorexpr::removeUnusedSelfArgument(graph); + + torch::jit::tensorexpr::annotateInputShapes(graph, example_inputs); + torch::jit::OptimizeFrozenGraph(graph, true); + torch::jit::PropagateShapesOnGraph(graph); + torch::jit::PeepholeOptimize(graph, false); + torch::jit::ConstantPropagation(graph); + torch::jit::PropagateShapesOnGraph(graph); + torch::jit::PeepholeOptimize(graph, false); + torch::jit::ConstantPropagation(graph); + + tensorexpr::removeUnusedSelfArgument(graph); + + std::vector example_values; + example_values.reserve(example_inputs.size()); + for (auto example_input : example_inputs) { + example_values.emplace_back(*example_input); + } + graph = TraceGraph(graph, example_values); + // TODO: Remove annotateInputShapes pass when TraceGraph can also capture + // input shapes + tensorexpr::annotateInputShapes(graph, example_inputs); + + RemoveListMutation(graph); + RemoveTensorMutation(graph); + EliminateDeadCode(graph); + LowerAllTuples(graph); + GRAPH_DEBUG("After preprocessing graph passes: ", *graph); + return graph; +} + +std::vector> generateExampleInputs( + const std::vector>& inputShapes, + const std::vector& inputTypes) { + std::vector> example_inputs; + example_inputs.reserve(inputShapes.size()); + for (int i = 0; i < inputShapes.size(); ++i) { + example_inputs.emplace_back( + at::rand(inputShapes[i]).to(at::dtype(inputTypes[i]))); + } + return example_inputs; +} + +c10::IValue preprocess( + const torch::jit::Module& mod, + const c10::Dict& compile_spec, + const torch::jit::BackendDebugHandleGenerator& generate_debug_handles) { + torch::jit::mobile::nnc::CompilationUnit cu; + for (const auto& kv : compile_spec) { + GRAPH_DEBUG("Key: ", kv.key()); + GRAPH_DEBUG("Value: ", kv.value()); + std::string method_name = *(kv.key().toString()); + GRAPH_DEBUG("Method name: ", method_name); + auto method_spec = kv.value().toGenericDict(); + std::string model_name = *method_spec.at("model_name").toString(); + std::string model_version = *method_spec.at("model_version").toString(); + std::string asmfile_name = *method_spec.at("asmfile").toString(); + GRAPH_DEBUG("Model name: ", model_name); + GRAPH_DEBUG("Model version: ", model_version); + GRAPH_DEBUG("Asm file name: ", asmfile_name); + + auto method = mod.get_method(method_name); + auto graph = toGraphFunction(method.function()).graph()->copy(); + + auto sizes = parseInputShapes(*method_spec.at("sizes").toString()); + auto types = parseInputTypes(*method_spec.at("types").toString()); + + auto example_inputs = generateExampleInputs(sizes, types); + graph = preprocessGraphPasses(graph, example_inputs); + + auto kernel_func_name = + getNncKernelFuncName(model_name, model_version, method_name); + auto compiled = torch::jit::mobile::nnc::aotCompile( + method_name, graph, sizes, types, kernel_func_name); + writeOutputLlvmAssembly(compiled.second, asmfile_name); + auto func = std::move(compiled.first); + func->set_nnc_kernel_id( + getNncKernelId(model_name, model_version, method_name)); + cu.register_function(std::move(func)); + } + return cu.serialize(); +} + +static auto reg = torch::jit::backend_preprocess_register("nnc", preprocess); + } // namespace nnc } // namespace mobile } // namespace jit