mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Refactor AotCompile to return a pair (#65707)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/65707 Refactoring aotCompile to return a pair of compiled function and the LLVM assembly instead of updating an incoming string with assembly code Testing: Gives expected results when compiled and run ``` (pytorch) ~/local/pytorch refactor_aot └─ $ build/bin/aot_model_compiler --model mobilenetv3.pt --model_name=pytorch_dev_mobilenetv3 --model_version=v1 --input_dims="2,2,2" The compiled model was saved to mobilenetv3.compiled.pt ``` Test Plan: Imported from OSS Reviewed By: qihqi Differential Revision: D31220452 Pulled By: priyaramani fbshipit-source-id: f957c53ba83f876a2e7dbdd4b4571a760b3b6a9a
This commit is contained in:
committed by
Facebook GitHub Bot
parent
e9327ed2ce
commit
63bb7c6dba
@ -112,10 +112,10 @@ c10::IValue preprocess(
|
||||
auto sizes = getInputSizesForMethod(method_compile_spec, method_name);
|
||||
|
||||
std::string llvm_asm_code;
|
||||
auto func =
|
||||
torch::jit::mobile::nnc::aotCompile(method_name, graph, sizes, &llvm_asm_code);
|
||||
writeOutputLlvmAssembly(llvm_asm_code);
|
||||
auto compiled = torch::jit::mobile::nnc::aotCompile(method_name, graph, sizes);
|
||||
writeOutputLlvmAssembly(compiled.second);
|
||||
|
||||
auto func = std::move(compiled.first);
|
||||
func->set_nnc_kernel_id(getNncKernelId(method_name));
|
||||
|
||||
torch::jit::mobile::nnc::CompilationUnit cu;
|
||||
|
@ -33,7 +33,7 @@ std::vector<int64_t> getConstSizes(const BufPtr b) {
|
||||
return r;
|
||||
}
|
||||
|
||||
void getCompiledFunction(
|
||||
void compileFunction(
|
||||
std::shared_ptr<tensorexpr::TensorExprKernel> kernel,
|
||||
Function* func) {
|
||||
std::vector<at::Tensor> parameters;
|
||||
@ -66,11 +66,10 @@ void getCompiledFunction(
|
||||
func->set_output_specs(out_spec);
|
||||
}
|
||||
|
||||
std::unique_ptr<Function> aotCompile(
|
||||
std::pair<std::unique_ptr<Function>, const std::string> aotCompile(
|
||||
const std::string& method_name,
|
||||
std::shared_ptr<Graph>& g,
|
||||
const std::vector<int64_t>& sizes,
|
||||
std::string* compiled_assembly) {
|
||||
const std::vector<int64_t>& sizes) {
|
||||
auto g2 = g->copy();
|
||||
GRAPH_DEBUG("Input sizes ", sizes);
|
||||
|
||||
@ -90,7 +89,7 @@ std::unique_ptr<Function> aotCompile(
|
||||
|
||||
std::shared_ptr<tensorexpr::TensorExprKernel> kernel =
|
||||
std::make_shared<tensorexpr::TensorExprKernel>(g);
|
||||
*compiled_assembly = kernel->getCodeText();
|
||||
const std::string compiled_assembly = kernel->getCodeText();
|
||||
|
||||
g = g2;
|
||||
|
||||
@ -102,8 +101,8 @@ std::unique_ptr<Function> aotCompile(
|
||||
input.dtype_ = c10::ScalarType::Float;
|
||||
func->set_input_specs({input});
|
||||
|
||||
getCompiledFunction(kernel, func.get());
|
||||
return func;
|
||||
compileFunction(kernel, func.get());
|
||||
return std::make_pair(std::move(func), compiled_assembly);
|
||||
}
|
||||
|
||||
} // namespace nnc
|
||||
|
@ -11,11 +11,10 @@ namespace nnc {
|
||||
|
||||
// Performs Ahead Of Time compilation of a given method in a model
|
||||
// returning the compiled function and LLVM assembly code
|
||||
TORCH_API std::unique_ptr<Function> aotCompile(
|
||||
TORCH_API std::pair<std::unique_ptr<Function>, const std::string> aotCompile(
|
||||
const std::string& method_name,
|
||||
std::shared_ptr<Graph>& subgraph,
|
||||
const std::vector<int64_t>& sizes,
|
||||
std::string* compiled_assembly);
|
||||
const std::vector<int64_t>& sizes);
|
||||
|
||||
} // namespace nnc
|
||||
} // namespace mobile
|
||||
|
Reference in New Issue
Block a user