[PyTorch Edge][tracing-based] Unify tracer between internal and external (#64152)

Summary:
As title, introduce the file `TracerRunner` shared by internal/external tracer and the main function is
```
TracerResult trace_run(const std::string& input_module_path);
```
which basically takes the path to model file and generate the trace result. The main difference between external tracer and internal tracer is
1. the dependency on `<yaml-cpp/yaml.h>`.
2. the output yaml file from internal tracer includes `model_version` and `model_asset`. These are only needed for internal.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/64152

ghstack-source-id: 140692467

Test Plan:
```
./build/bin/model_tracer --model_input_path "/Users/chenlai/Documents/pytorch/tracing/deeplabv3_scripted_with_bundled_input.ptl" --build_yaml_path  "/Users/chenlai/Documents/pytorch/tracing/tmp.yaml"
```
```
./fbcode/caffe2/fb/model_tracer/run_model_with_bundled_inputs.sh ~/local/notebooks/prod_models/deeplabv3_scripted_with_bundled_input.ptl
```
have the same operator output

selected_operators.yaml (P460296279)
selected_mobile_ops.h (P460296258)

Reviewed By: dhruvbird

Differential Revision: D30632224

fbshipit-source-id: eb0321dbc0f1fcf6d2e05384695eebb59ac04f8c
This commit is contained in:
Chen Lai
2021-10-15 02:17:57 -07:00
committed by Facebook GitHub Bot
parent 1e47181c47
commit 76efbccc3b
5 changed files with 261 additions and 196 deletions

View File

@ -712,6 +712,10 @@ if(BUILD_LITE_INTERPRETER)
string(APPEND CMAKE_CXX_FLAGS " -DBUILD_LITE_INTERPRETER")
endif()
if(TRACING_BASED)
string(APPEND CMAKE_CXX_FLAGS " -DTRACING_BASED")
endif()
if(USE_PYTORCH_METAL)
string(APPEND CMAKE_CXX_FLAGS " -DUSE_PYTORCH_METAL")
endif()

View File

@ -448,6 +448,7 @@ libtorch_core_jit_sources = sorted(jit_sources_full)
torch_mobile_tracer_sources = [
"torch/csrc/jit/mobile/model_tracer/tracer.cpp",
"torch/csrc/jit/mobile/model_tracer/TensorUtils.cpp",
"torch/csrc/jit/mobile/model_tracer/TracerRunner.cpp",
"torch/csrc/jit/mobile/model_tracer/MobileModelRunner.cpp",
"torch/csrc/jit/mobile/model_tracer/OperatorCallTracer.cpp",
"torch/csrc/jit/mobile/model_tracer/KernelDTypeTracer.cpp",

View File

@ -0,0 +1,208 @@
#include <ATen/Functions.h>
#include <ATen/core/dispatch/ObservedOperators.h>
#include <c10/core/ScalarType.h>
#include <c10/util/Exception.h>
#include <torch/csrc/autograd/grad_mode.h>
#include <torch/csrc/jit/mobile/model_tracer/KernelDTypeTracer.h>
#include <torch/csrc/jit/mobile/model_tracer/MobileModelRunner.h>
#include <torch/csrc/jit/mobile/model_tracer/OperatorCallTracer.h>
#include <torch/csrc/jit/mobile/model_tracer/TensorUtils.h>
#include <torch/csrc/jit/mobile/model_tracer/TracerRunner.h>
#include <torch/csrc/jit/mobile/parse_operators.h>
#include <torch/script.h>
namespace torch {
namespace jit {
namespace mobile {
// Fetched from caffe2/aten/src/ATen/native/metal/MetalAten.mm
// Diffusion Link: https://fburl.com/diffusion/atwwmax2
const std::vector<std::string> gpu_metal_operators = {
"aten::conv2d",
"aten::add.Tensor",
"aten::add_.Tensor",
"aten::addmm",
"aten::empty.memory_format",
"aten::empty_strided",
"aten::log_softmax.int",
"aten::max_pool2d",
"aten::mul.Tensor",
"aten::relu",
"aten::relu_",
"aten::sigmoid",
"aten::sub.Tensor",
"aten::upsample_nearest2d.vec",
"aten::view",
"aten::adaptive_avg_pool2d",
"aten::hardtanh_",
"aten::reshape",
"aten::flatten.using_ints",
};
/**
* These are a collection of some common ATen methods that are usually
* called outside of the Model's forward() run, and they need to be
* traced to ensure that the used operators are included in the build.
* If/When this list becomes too long, we can consider making it a
* per-model list.
*/
void call_setup_methods() {
at::zeros({2, 2});
at::ones({2, 2});
at::Tensor t1 = at::empty({7, 7});
at::Tensor t2 = t1.fill_(3);
at::narrow(t2, 1, 0, 1);
at::eq(t1, t2);
const volatile bool nz = at::zeros({1}).is_nonzero();
(void)nz;
// Create a byte tensor and copy it
auto zb = at::zeros({10}, at::kByte);
auto zf = at::zeros({10}, at::kFloat);
zb.copy_(zf);
t2.div(1);
// Typically, failures show up in CopyKernel.cpp, so enumerating
// common dtypes that may show up.
const auto all_dtypes_for_copy = {
at::kByte,
at::kFloat,
at::kInt,
at::kChar,
at::kDouble,
at::kShort,
at::kLong};
for (const auto dtype : all_dtypes_for_copy) {
auto tensor1 = at::empty({10}, dtype);
tensor1.copy_(at::zeros({10}, at::kFloat));
}
torch::zeros({0, 0}, torch::ScalarType::Float);
std::vector<float> storage(20, 1.0);
std::vector<int64_t> sizes({2, 10});
torch::from_blob(storage.data(), at::IntArrayRef(sizes), at::kFloat);
}
/**
* Call methods on the Tensor object that we expect to be called
* in production on this Tensor.
*/
void consume_tensor(const at::Tensor& t) {
const at::Tensor& c = t;
c.copy_(t.cpu());
}
void run_model(
const std::string& input_module_path,
std::set<std::string>& root_ops,
std::set<std::string>& enabled_backends,
KernelDTypeTracer::kernel_tags_type& called_kernel_tags) {
// Load the module on CPU with the flag to skip the operator exists check.
// This is needed so that we can load any TorchBind objects (custom classes)
// that this model refers to so that any operators being called from those
// TorchBind objects can be traced by the model tracer.
//
torch::jit::mobile::MobileModelRunner module_runner(input_module_path, 0);
root_ops = module_runner.get_root_operators();
std::cout << "Got " << root_ops.size() << " Root Operators." << std::endl;
if (torch::jit::mobile::MobileModelRunner::set_has_metal_gpu_operators(
root_ops)) {
std::cout << "Inferred Metal GPU Model." << std::endl;
root_ops.insert(gpu_metal_operators.begin(), gpu_metal_operators.end());
called_kernel_tags["__unused__"] = {"Float"};
enabled_backends.insert("Metal GPU");
// When we encounter a GPU model, we should call .cpu().copy_() on the
// tensors in the bundled inputs, since this is what will happen when
// such a model is executed on an iOS device (to copy the Tensor to Metal
// memory via a call to .metal()).
module_runner.for_each_tensor_in_bundled_inputs(consume_tensor);
} else {
std::cout << "Inferred CPU Model." << std::endl;
enabled_backends.insert("CPU");
torch::jit::mobile::MobileModelRunner mobile_module_runner(
input_module_path);
// When we encounter a CPU model, we should call .cpu().copy_() on the
// tensors in the bundled inputs, since this is what will happen when
// such a model is executed on an Android device since the PyTorch JNI
// bindings call .cpu() in JIValue::newJIValueFromAtIValue().
module_runner.for_each_tensor_in_bundled_inputs(consume_tensor);
// If a user has bundled inputs since that api was updated to accept
// bundled inputs for multiple methods They should go down this route.
// Even if they only bundle inputs for forward they will have the new
// style bundled inputs. Since at this time in tracer.cpp we do not know
// what functions have bundled inputs we must call
// get_bundled_inputs_functions_and_info if it exists to get the set.
if (mobile_module_runner.has_new_style_bundled_inputs()) {
auto bundled_inputs_mapping =
mobile_module_runner.get_many_functions_bundled_inputs();
for (auto& entry : bundled_inputs_mapping) {
std::string function_name = entry.first;
std::vector<std::vector<at::IValue>> bundled_inputs = entry.second;
std::cout << "Got " << bundled_inputs.size() << " bundled input(s) for "
<< function_name << "\n\n";
std::vector<at::IValue> results =
mobile_module_runner.run_with_inputs(function_name, bundled_inputs);
for (auto& result : results) {
// Consume the result Tensor(s) when tracing on CPU since the
// Android/Java JNI bindings will do the same.
torch::jit::mobile::for_each_tensor_in_ivalue(result, consume_tensor);
}
}
// If get_bundled_inputs_functions_and_info does not exists we default
// to assuming they bundled before that change was made. If no bundled
// inputs are found here either an error will be thrown
} else {
std::vector<std::vector<at::IValue>> bundled_inputs =
mobile_module_runner.get_all_bundled_inputs();
std::cout << "Got " << bundled_inputs.size() << " bundled input(s)\n\n";
std::vector<at::IValue> results =
mobile_module_runner.run_with_inputs(bundled_inputs);
for (auto& result : results) {
// Consume the result Tensor(s) when tracing on CPU since the
// Android/Java JNI bindings will do the same.
torch::jit::mobile::for_each_tensor_in_ivalue(result, consume_tensor);
}
}
}
}
TracerResult trace_run(const std::string& input_module_path) {
at::globalContext().setQEngine(at::QEngine::QNNPACK);
c10::ObservedOperators::getUnobservedOperatorList().clear();
torch::jit::mobile::OperatorCallTracer op_tracer;
torch::jit::mobile::KernelDTypeTracer kdtype_tracer;
call_setup_methods();
std::set<std::string> root_ops, traced_operators, enabled_backends;
torch::jit::mobile::KernelDTypeTracer::kernel_tags_type called_kernel_tags;
using torch::jit::MobileModuleLoadOptions;
// run with QNNPACK
run_model(input_module_path, root_ops, enabled_backends, called_kernel_tags);
at::globalContext().setQEngine(at::QEngine::FBGEMM);
run_model(input_module_path, root_ops, enabled_backends, called_kernel_tags);
traced_operators = op_tracer.getCalledOperators();
called_kernel_tags.insert(
kdtype_tracer.getCalledKernelTags().begin(),
kdtype_tracer.getCalledKernelTags().end());
traced_operators.insert(
always_included_traced_ops.begin(), always_included_traced_ops.end());
TracerResult tracer_result = {
root_ops, traced_operators, called_kernel_tags, enabled_backends};
return tracer_result;
}
} // namespace mobile
} // namespace jit
} // namespace torch

View File

@ -0,0 +1,26 @@
#pragma once
#include <ATen/core/ivalue.h>
#include <torch/csrc/jit/mobile/model_tracer/KernelDTypeTracer.h>
namespace torch {
namespace jit {
namespace mobile {
const std::vector<std::string> always_included_traced_ops = {
// The following are called from setup sections.
"aten::resize_",
"aten::slice.Tensor",
};
struct TracerResult {
std::set<std::string> root_ops;
std::set<std::string> traced_operators;
KernelDTypeTracer::kernel_tags_type called_kernel_tags;
std::set<std::string> enabled_backends;
};
TracerResult trace_run(const std::string& input_module_path);
} // namespace mobile
} // namespace jit
} // namespace torch

View File

@ -23,6 +23,7 @@
#include <torch/csrc/jit/mobile/model_tracer/MobileModelRunner.h>
#include <torch/csrc/jit/mobile/model_tracer/OperatorCallTracer.h>
#include <torch/csrc/jit/mobile/model_tracer/TensorUtils.h>
#include <torch/csrc/jit/mobile/model_tracer/TracerRunner.h>
#include <torch/csrc/jit/mobile/module.h>
#include <torch/csrc/jit/mobile/parse_operators.h>
#include <torch/script.h>
@ -51,36 +52,6 @@ C10_DEFINE_string(
return 1; \
}
const std::vector<std::string> always_included_traced_ops = {
// The following are called from setup sections.
"aten::resize_",
"aten::slice.Tensor",
};
// Fetched from caffe2/aten/src/ATen/native/metal/MetalAten.mm
// Diffusion Link: https://fburl.com/diffusion/atwwmax2
const std::vector<std::string> gpu_metal_operators = {
"aten::conv2d",
"aten::add.Tensor",
"aten::add_.Tensor",
"aten::addmm",
"aten::empty.memory_format",
"aten::empty_strided",
"aten::log_softmax.int",
"aten::max_pool2d",
"aten::mul.Tensor",
"aten::relu",
"aten::relu_",
"aten::sigmoid",
"aten::sub.Tensor",
"aten::upsample_nearest2d.vec",
"aten::view",
"aten::adaptive_avg_pool2d",
"aten::hardtanh_",
"aten::reshape",
"aten::flatten.using_ints",
};
void printOpYAML(
std::ostream& out,
int indent,
@ -111,140 +82,6 @@ void printOpsYAML(
}
}
/**
* These are a collection of some common ATen methods that are usually
* called outside of the Model's forward() run, and they need to be
* traced to ensure that the used operators are included in the build.
* If/When this list becomes too long, we can consider making it a
* per-model list.
*/
void call_setup_methods() {
at::zeros({2, 2});
at::ones({2, 2});
at::Tensor t1 = at::empty({7, 7});
at::Tensor t2 = t1.fill_(3);
at::narrow(t2, 1, 0, 1);
at::eq(t1, t2);
const volatile bool nz = at::zeros({1}).is_nonzero();
(void)nz;
// Create a byte tensor and copy it
auto zb = at::zeros({10}, at::kByte);
auto zf = at::zeros({10}, at::kFloat);
zb.copy_(zf);
t2.div(1);
// Typically, failures show up in CopyKernel.cpp, so enumerating
// common dtypes that may show up.
const auto all_dtypes_for_copy = {
at::kByte,
at::kFloat,
at::kInt,
at::kChar,
at::kDouble,
at::kShort,
at::kLong};
for (const auto dtype : all_dtypes_for_copy) {
auto tensor1 = at::empty({10}, dtype);
tensor1.copy_(at::zeros({10}, at::kFloat));
}
torch::zeros({0, 0}, torch::ScalarType::Float);
std::vector<float> storage(20, 1.0);
std::vector<int64_t> sizes({2, 10});
torch::from_blob(storage.data(), at::IntArrayRef(sizes), at::kFloat);
}
/**
* Call methods on the Tensor object that we expect to be called
* in production on this Tensor.
*/
void consume_tensor(const at::Tensor& t) {
const at::Tensor& c = t;
c.copy_(t.cpu());
}
void run_model(
const std::string& input_module_path,
std::set<std::string>& root_ops,
std::set<std::string>& enabled_backends,
torch::jit::mobile::KernelDTypeTracer::kernel_tags_type&
called_kernel_tags) {
// Load the module on CPU with the flag to skip the operator exists check.
// This is needed so that we can load any TorchBind objects (custom classes)
// that this model refers to so that any operators being called from those
// TorchBind objects can be traced by the model tracer.
//
torch::jit::mobile::MobileModelRunner module_runner(input_module_path, 0);
root_ops = module_runner.get_root_operators();
std::cout << "Got " << root_ops.size() << " Root Operators." << std::endl;
if (torch::jit::mobile::MobileModelRunner::set_has_metal_gpu_operators(
root_ops)) {
std::cout << "Inferred Metal GPU Model." << std::endl;
root_ops.insert(gpu_metal_operators.begin(), gpu_metal_operators.end());
called_kernel_tags["__unused__"] = {"Float"};
enabled_backends.insert("Metal GPU");
// When we encounter a GPU model, we should call .cpu().copy_() on the
// tensors in the bundled inputs, since this is what will happen when
// such a model is executed on an iOS device (to copy the Tensor to Metal
// memory via a call to .metal()).
module_runner.for_each_tensor_in_bundled_inputs(consume_tensor);
} else {
std::cout << "Inferred CPU Model." << std::endl;
enabled_backends.insert("CPU");
torch::jit::mobile::MobileModelRunner mobile_module_runner(
input_module_path);
// When we encounter a CPU model, we should call .cpu().copy_() on the
// tensors in the bundled inputs, since this is what will happen when
// such a model is executed on an Android device since the PyTorch JNI
// bindings call .cpu() in JIValue::newJIValueFromAtIValue().
module_runner.for_each_tensor_in_bundled_inputs(consume_tensor);
// If a user has bundled inputs since that api was updated to accept
// bundled inputs for multiple methods They should go down this route.
// Even if they only bundle inputs for forward they will have the new
// style bundled inputs. Since at this time in tracer.cpp we do not know
// what functions have bundled inputs we must call
// get_bundled_inputs_functions_and_info if it exists to get the set.
if (mobile_module_runner.has_new_style_bundled_inputs()) {
auto bundled_inputs_mapping =
mobile_module_runner.get_many_functions_bundled_inputs();
for (auto& entry : bundled_inputs_mapping) {
std::string function_name = entry.first;
std::vector<std::vector<at::IValue>> bundled_inputs = entry.second;
std::cout << "Got " << bundled_inputs.size() << " bundled input(s) for "
<< function_name << "\n\n";
std::vector<at::IValue> results =
mobile_module_runner.run_with_inputs(function_name, bundled_inputs);
for (auto& result : results) {
// Consume the result Tensor(s) when tracing on CPU since the
// Android/Java JNI bindings will do the same.
torch::jit::mobile::for_each_tensor_in_ivalue(result, consume_tensor);
}
}
// If get_bundled_inputs_functions_and_info does not exists we default
// to assuming they bundled before that change was made. If no bundled
// inputs are found here either an error will be thrown
} else {
std::vector<std::vector<at::IValue>> bundled_inputs =
mobile_module_runner.get_all_bundled_inputs();
std::cout << "Got " << bundled_inputs.size() << " bundled input(s)\n\n";
std::vector<at::IValue> results =
mobile_module_runner.run_with_inputs(bundled_inputs);
for (auto& result : results) {
// Consume the result Tensor(s) when tracing on CPU since the
// Android/Java JNI bindings will do the same.
torch::jit::mobile::for_each_tensor_in_ivalue(result, consume_tensor);
}
}
}
}
/**
* Converts a pytorch model (full/lite) to lite interpreter model for
* mobile, and additionally writes out a list of root and called
@ -265,47 +102,36 @@ int main(int argc, char* argv[]) {
std::cout << "Processing: " << input_module_path << std::endl;
std::cout << "Output: " << FLAGS_build_yaml_path << std::endl;
torch::jit::mobile::TracerResult tracer_result;
try {
tracer_result = torch::jit::mobile::trace_run(FLAGS_model_input_path);
} catch (std::exception& ex) {
std::cerr
<< "ModelTracer has not been able to load the module for the following reasons:\n"
<< ex.what()
<< "\nPlease consider posting to the PyTorch with the error message."
<< std::endl;
at::globalContext().setQEngine(at::QEngine::QNNPACK);
c10::ObservedOperators::getUnobservedOperatorList().clear();
throw ex;
}
torch::jit::mobile::OperatorCallTracer op_tracer;
torch::jit::mobile::KernelDTypeTracer kdtype_tracer;
call_setup_methods();
std::set<std::string> root_ops, traced_operators, enabled_backends;
torch::jit::mobile::KernelDTypeTracer::kernel_tags_type called_kernel_tags;
using torch::jit::MobileModuleLoadOptions;
// run with QNNPACK
run_model(input_module_path, root_ops, enabled_backends, called_kernel_tags);
at::globalContext().setQEngine(at::QEngine::FBGEMM);
run_model(input_module_path, root_ops, enabled_backends, called_kernel_tags);
traced_operators = op_tracer.getCalledOperators();
called_kernel_tags.insert(
kdtype_tracer.getCalledKernelTags().begin(),
kdtype_tracer.getCalledKernelTags().end());
traced_operators.insert(
always_included_traced_ops.begin(), always_included_traced_ops.end());
if (traced_operators.size() <= always_included_traced_ops.size()) {
if (tracer_result.traced_operators.size() <=
torch::jit::mobile::always_included_traced_ops.size()) {
std::cerr
<< c10::str(
"Error traced_operators size: ",
traced_operators.size(),
tracer_result.traced_operators.size(),
". Expected the traced operator list to be bigger then the default size ",
always_included_traced_ops.size(),
torch::jit::mobile::always_included_traced_ops.size(),
". Please report a bug in PyTorch.")
<< std::endl;
}
// If the op exist in both traced_ops and root_ops, leave it in root_ops only
for (const auto& root_op : root_ops) {
if (traced_operators.find(root_op) != traced_operators.end()) {
traced_operators.erase(root_op);
for (const auto& root_op : tracer_result.root_ops) {
if (tracer_result.traced_operators.find(root_op) !=
tracer_result.traced_operators.end()) {
tracer_result.traced_operators.erase(root_op);
}
}
@ -313,13 +139,13 @@ int main(int argc, char* argv[]) {
yaml_out << "operators:" << std::endl;
printOpsYAML(
yaml_out,
root_ops,
tracer_result.root_ops,
false /* is_used_for_training */,
true /* is_root_operator */,
false /* include_all_overloads */);
printOpsYAML(
yaml_out,
traced_operators,
tracer_result.traced_operators,
false /* is_used_for_training */,
false /* is_root_operator */,
false /* include_all_overloads */);