mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert D31514095: Use kernel_func_name from aotCompiler
Test Plan: revert-hammer
Differential Revision:
D31514095 (7b55dc8340
)
Original commit changeset: b70c8e2c7336
fbshipit-source-id: ad4d828f33506e612b51c276149fa0e12b0565d5
This commit is contained in:
committed by
Facebook GitHub Bot
parent
313939c9c6
commit
b6fa998892
@ -90,10 +90,6 @@ std::string getNncKernelId() {
|
|||||||
":" + version_token;
|
":" + version_token;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string getNncKernelFuncName(const std::string& method_name) {
|
|
||||||
return "nnc_" + FLAGS_model_name + "_" + FLAGS_model_version + "_" + method_name;
|
|
||||||
}
|
|
||||||
|
|
||||||
void writeOutputLlvmAssembly(const std::string& asm_code) {
|
void writeOutputLlvmAssembly(const std::string& asm_code) {
|
||||||
std::string output_llvm_file_name = FLAGS_output_llvm;
|
std::string output_llvm_file_name = FLAGS_output_llvm;
|
||||||
if (output_llvm_file_name.empty()) {
|
if (output_llvm_file_name.empty()) {
|
||||||
@ -112,13 +108,18 @@ c10::IValue preprocess(
|
|||||||
const c10::Dict<c10::IValue, c10::IValue>& method_compile_spec,
|
const c10::Dict<c10::IValue, c10::IValue>& method_compile_spec,
|
||||||
const torch::jit::BackendDebugHandleGenerator& generate_debug_handles) {
|
const torch::jit::BackendDebugHandleGenerator& generate_debug_handles) {
|
||||||
|
|
||||||
|
std::string output_llvm_file_name = FLAGS_output_llvm;
|
||||||
|
if (output_llvm_file_name.empty()) {
|
||||||
|
output_llvm_file_name =
|
||||||
|
FLAGS_model.substr(0, FLAGS_model.find('.')) + ".compiled.ll";
|
||||||
|
}
|
||||||
|
|
||||||
auto method = mod.get_method(FLAGS_method_name);
|
auto method = mod.get_method(FLAGS_method_name);
|
||||||
auto graph = method.function().graph()->copy();
|
auto graph = method.function().graph()->copy();
|
||||||
auto sizes = getInputSizes(method_compile_spec);
|
auto sizes = getInputSizes(method_compile_spec);
|
||||||
auto kernel_func_name = getNncKernelFuncName(FLAGS_method_name);
|
|
||||||
|
|
||||||
auto compiled = torch::jit::mobile::nnc::aotCompile(
|
std::string llvm_asm_code;
|
||||||
FLAGS_method_name, graph, sizes, kernel_func_name);
|
auto compiled = torch::jit::mobile::nnc::aotCompile(FLAGS_method_name, graph, sizes);
|
||||||
writeOutputLlvmAssembly(compiled.second);
|
writeOutputLlvmAssembly(compiled.second);
|
||||||
|
|
||||||
auto func = std::move(compiled.first);
|
auto func = std::move(compiled.first);
|
||||||
@ -140,8 +141,8 @@ int main(int argc, char** argv) {
|
|||||||
" --model=<model file>"
|
" --model=<model file>"
|
||||||
" --model_name=<model name>"
|
" --model_name=<model name>"
|
||||||
" --model_version=<model version>"
|
" --model_version=<model version>"
|
||||||
" --input_dims=<input dimensions like '1,3,224,224;2,2'>"
|
" --input_dims='1,3,224,224'"
|
||||||
" [--method_name=<method name>]"
|
" [--method_name=<mehhod name>]"
|
||||||
" [--output_llvm=<llvm assembly output file path>]"
|
" [--output_llvm=<llvm assembly output file path>]"
|
||||||
" [--output_model=<output model file path>]");
|
" [--output_model=<output model file path>]");
|
||||||
|
|
||||||
@ -152,9 +153,6 @@ int main(int argc, char** argv) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
CAFFE_ENFORCE(!FLAGS_model.empty(), c10::UsageMessage());
|
CAFFE_ENFORCE(!FLAGS_model.empty(), c10::UsageMessage());
|
||||||
CAFFE_ENFORCE(!FLAGS_model_name.empty(), c10::UsageMessage());
|
|
||||||
CAFFE_ENFORCE(!FLAGS_model_version.empty(), c10::UsageMessage());
|
|
||||||
CAFFE_ENFORCE(!FLAGS_input_dims.empty(), c10::UsageMessage());
|
|
||||||
|
|
||||||
std::string output_model_name = FLAGS_output_model;
|
std::string output_model_name = FLAGS_output_model;
|
||||||
if (output_model_name.empty()) {
|
if (output_model_name.empty()) {
|
||||||
|
@ -87,8 +87,7 @@ std::unique_ptr<Function> compileMethod(
|
|||||||
std::pair<std::unique_ptr<Function>, const std::string> aotCompile(
|
std::pair<std::unique_ptr<Function>, const std::string> aotCompile(
|
||||||
const std::string& method_name,
|
const std::string& method_name,
|
||||||
std::shared_ptr<Graph>& g,
|
std::shared_ptr<Graph>& g,
|
||||||
const std::vector<std::vector<int64_t>>& sizes,
|
const std::vector<std::vector<int64_t>>& sizes) {
|
||||||
const std::string& kernel_func_name) {
|
|
||||||
GRAPH_DEBUG("Input sizes ", sizes);
|
GRAPH_DEBUG("Input sizes ", sizes);
|
||||||
GRAPH_DEBUG("Method name ", method_name);
|
GRAPH_DEBUG("Method name ", method_name);
|
||||||
|
|
||||||
@ -112,9 +111,7 @@ std::pair<std::unique_ptr<Function>, const std::string> aotCompile(
|
|||||||
GRAPH_DUMP("graph after shape propagation ", g);
|
GRAPH_DUMP("graph after shape propagation ", g);
|
||||||
|
|
||||||
std::shared_ptr<tensorexpr::TensorExprKernel> kernel =
|
std::shared_ptr<tensorexpr::TensorExprKernel> kernel =
|
||||||
std::make_shared<tensorexpr::TensorExprKernel>(
|
std::make_shared<tensorexpr::TensorExprKernel>(g);
|
||||||
TensorExprKernel(g, {}, false, kernel_func_name));
|
|
||||||
|
|
||||||
const std::string compiled_assembly = kernel->getCodeText();
|
const std::string compiled_assembly = kernel->getCodeText();
|
||||||
|
|
||||||
auto func = compileMethod(kernel, method_name, sizes);
|
auto func = compileMethod(kernel, method_name, sizes);
|
||||||
|
@ -14,8 +14,7 @@ namespace nnc {
|
|||||||
TORCH_API std::pair<std::unique_ptr<Function>, const std::string> aotCompile(
|
TORCH_API std::pair<std::unique_ptr<Function>, const std::string> aotCompile(
|
||||||
const std::string& method_name,
|
const std::string& method_name,
|
||||||
std::shared_ptr<Graph>& subgraph,
|
std::shared_ptr<Graph>& subgraph,
|
||||||
const std::vector<std::vector<int64_t>>& sizes,
|
const std::vector<std::vector<int64_t>>& sizes);
|
||||||
const std::string& kernel_func_name = "func");
|
|
||||||
|
|
||||||
} // namespace nnc
|
} // namespace nnc
|
||||||
} // namespace mobile
|
} // namespace mobile
|
||||||
|
@ -1172,19 +1172,17 @@ void TensorExprKernel::compile() {
|
|||||||
stmt,
|
stmt,
|
||||||
bufferArgs_,
|
bufferArgs_,
|
||||||
device_,
|
device_,
|
||||||
kernel_func_name_);
|
SubgraphUtils::generateNameForGraph(graph_));
|
||||||
}
|
}
|
||||||
|
|
||||||
TensorExprKernel::TensorExprKernel(
|
TensorExprKernel::TensorExprKernel(
|
||||||
const std::shared_ptr<Graph>& subgraph,
|
const std::shared_ptr<Graph>& subgraph,
|
||||||
std::unordered_map<c10::Symbol, NNCLoweringFunction> custom_lowerings,
|
std::unordered_map<c10::Symbol, NNCLoweringFunction> custom_lowerings,
|
||||||
bool pre_alloc /*= false*/,
|
bool pre_alloc /*= false*/)
|
||||||
const std::string& kernel_func_name)
|
|
||||||
: graph_(subgraph),
|
: graph_(subgraph),
|
||||||
code_(subgraph, ""),
|
code_(subgraph, ""),
|
||||||
custom_lowerings_(std::move(custom_lowerings)),
|
custom_lowerings_(std::move(custom_lowerings)),
|
||||||
pre_alloc_(pre_alloc),
|
pre_alloc_(pre_alloc) {
|
||||||
kernel_func_name_(kernel_func_name) {
|
|
||||||
allow_fallback_ = fallbackAllowed();
|
allow_fallback_ = fallbackAllowed();
|
||||||
if (!allow_fallback_) {
|
if (!allow_fallback_) {
|
||||||
compile();
|
compile();
|
||||||
|
@ -93,8 +93,7 @@ class TORCH_API TensorExprKernel {
|
|||||||
const std::shared_ptr<Graph>& subgraph,
|
const std::shared_ptr<Graph>& subgraph,
|
||||||
std::unordered_map<c10::Symbol, NNCLoweringFunction> custom_lowerings =
|
std::unordered_map<c10::Symbol, NNCLoweringFunction> custom_lowerings =
|
||||||
{},
|
{},
|
||||||
bool pre_alloc = false,
|
bool pre_alloc = false);
|
||||||
const std::string& kernel_func_name = "func");
|
|
||||||
|
|
||||||
void run(Stack& stack);
|
void run(Stack& stack);
|
||||||
void runFast(
|
void runFast(
|
||||||
@ -236,7 +235,6 @@ class TORCH_API TensorExprKernel {
|
|||||||
|
|
||||||
std::unordered_map<c10::Symbol, NNCLoweringFunction> custom_lowerings_;
|
std::unordered_map<c10::Symbol, NNCLoweringFunction> custom_lowerings_;
|
||||||
bool pre_alloc_{false};
|
bool pre_alloc_{false};
|
||||||
const std::string& kernel_func_name_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
TORCH_API int& getTECudaPointwiseLoopLevels();
|
TORCH_API int& getTECudaPointwiseLoopLevels();
|
||||||
|
Reference in New Issue
Block a user