Refactor AotCompile to return a pair (#65707)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/65707

Refactoring aotCompile to return a pair of compiled function and the LLVM assembly instead of updating an incoming string with assembly code

Testing: Gives expected results when compiled and run
```
(pytorch)  ~/local/pytorch refactor_aot
└─ $ build/bin/aot_model_compiler --model mobilenetv3.pt --model_name=pytorch_dev_mobilenetv3 --model_version=v1 --input_dims="2,2,2"
The compiled model was saved to mobilenetv3.compiled.pt
```

Test Plan: Imported from OSS

Reviewed By: qihqi

Differential Revision: D31220452

Pulled By: priyaramani

fbshipit-source-id: f957c53ba83f876a2e7dbdd4b4571a760b3b6a9a
This commit is contained in:
Priya Ramani
2021-09-27 18:54:20 -07:00
committed by Facebook GitHub Bot
parent e9327ed2ce
commit 63bb7c6dba
3 changed files with 11 additions and 13 deletions

View File

@ -112,10 +112,10 @@ c10::IValue preprocess(
auto sizes = getInputSizesForMethod(method_compile_spec, method_name);
std::string llvm_asm_code;
auto func =
torch::jit::mobile::nnc::aotCompile(method_name, graph, sizes, &llvm_asm_code);
writeOutputLlvmAssembly(llvm_asm_code);
auto compiled = torch::jit::mobile::nnc::aotCompile(method_name, graph, sizes);
writeOutputLlvmAssembly(compiled.second);
auto func = std::move(compiled.first);
func->set_nnc_kernel_id(getNncKernelId(method_name));
torch::jit::mobile::nnc::CompilationUnit cu;

View File

@ -33,7 +33,7 @@ std::vector<int64_t> getConstSizes(const BufPtr b) {
return r;
}
void getCompiledFunction(
void compileFunction(
std::shared_ptr<tensorexpr::TensorExprKernel> kernel,
Function* func) {
std::vector<at::Tensor> parameters;
@ -66,11 +66,10 @@ void getCompiledFunction(
func->set_output_specs(out_spec);
}
std::unique_ptr<Function> aotCompile(
std::pair<std::unique_ptr<Function>, const std::string> aotCompile(
const std::string& method_name,
std::shared_ptr<Graph>& g,
const std::vector<int64_t>& sizes,
std::string* compiled_assembly) {
const std::vector<int64_t>& sizes) {
auto g2 = g->copy();
GRAPH_DEBUG("Input sizes ", sizes);
@ -90,7 +89,7 @@ std::unique_ptr<Function> aotCompile(
std::shared_ptr<tensorexpr::TensorExprKernel> kernel =
std::make_shared<tensorexpr::TensorExprKernel>(g);
*compiled_assembly = kernel->getCodeText();
const std::string compiled_assembly = kernel->getCodeText();
g = g2;
@ -102,8 +101,8 @@ std::unique_ptr<Function> aotCompile(
input.dtype_ = c10::ScalarType::Float;
func->set_input_specs({input});
getCompiledFunction(kernel, func.get());
return func;
compileFunction(kernel, func.get());
return std::make_pair(std::move(func), compiled_assembly);
}
} // namespace nnc

View File

@ -11,11 +11,10 @@ namespace nnc {
// Performs Ahead Of Time compilation of a given method in a model
// returning the compiled function and LLVM assembly code
TORCH_API std::unique_ptr<Function> aotCompile(
TORCH_API std::pair<std::unique_ptr<Function>, const std::string> aotCompile(
const std::string& method_name,
std::shared_ptr<Graph>& subgraph,
const std::vector<int64_t>& sizes,
std::string* compiled_assembly);
const std::vector<int64_t>& sizes);
} // namespace nnc
} // namespace mobile