mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[jit][edge] Remove usage of shared_ptr<mobile::Code>. (#68037)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/68037 Right now mobile::Code doesn't outlive its enclosing Function, and all accesses to Code happens inside interpreter loop which doesn't outlive the module, so we don't need to use std::shared_ptr here. This also should saves us 1-2 KB for binary size, because shared_ptr seems to bloat on arm64 android. ghstack-source-id: 145818696 Test Plan: eyes. Reviewed By: qihqi, tugsbayasgalan Differential Revision: D32264616 fbshipit-source-id: d83f538d6604cf75fd7728a25127b4849ce7ab2a
This commit is contained in:
committed by
Facebook GitHub Bot
parent
39f65fee47
commit
d459e79500
@ -1517,8 +1517,8 @@ TEST(LiteInterpreterTest, OperatorSize1) {
|
||||
mobile::Module bc = _load_for_mobile(ss);
|
||||
const auto& func = bc.get_method("forward").function();
|
||||
ASSERT_EQ(
|
||||
func.get_code()->operator_input_sizes_.size(),
|
||||
func.get_code()->operators_.size());
|
||||
func.get_code().operator_input_sizes_.size(),
|
||||
func.get_code().operators_.size());
|
||||
}
|
||||
|
||||
TEST(LiteInterpreterTest, OperatorTest2) { // NOLINT (use =delete in gtest)
|
||||
@ -1552,8 +1552,8 @@ TEST(LiteInterpreterTest, OperatorTest2) { // NOLINT (use =delete in gtest)
|
||||
mobile::Module bc = _load_for_mobile(ss);
|
||||
const auto& func = bc.get_method("test_func").function();
|
||||
ASSERT_EQ(
|
||||
func.get_code()->operator_input_sizes_.size(),
|
||||
func.get_code()->operators_.size());
|
||||
func.get_code().operator_input_sizes_.size(),
|
||||
func.get_code().operators_.size());
|
||||
}
|
||||
}
|
||||
|
||||
@ -1590,7 +1590,7 @@ TEST(LiteInterpreterUpgraderTest, DivTensorV2) {
|
||||
*/
|
||||
mobile::Module m_module = _load_for_mobile(test_model_file);
|
||||
auto intrsuction_list =
|
||||
m_module.get_method("forward").function().get_code()->instructions_;
|
||||
m_module.get_method("forward").function().get_code().instructions_;
|
||||
uint64_t number_of_call_instruction = 0;
|
||||
for (auto& instruction : intrsuction_list) {
|
||||
number_of_call_instruction += (instruction.op == OpCode::CALL);
|
||||
@ -1629,7 +1629,7 @@ TEST(LiteInterpreterUpgraderTest, DivTensorOutV2) {
|
||||
mobile::Module m_module = _load_for_mobile(test_model_file);
|
||||
|
||||
auto intrsuction_list =
|
||||
m_module.get_method("forward").function().get_code()->instructions_;
|
||||
m_module.get_method("forward").function().get_code().instructions_;
|
||||
uint64_t number_of_call_instruction = 0;
|
||||
for (auto& instruction : intrsuction_list) {
|
||||
number_of_call_instruction += (instruction.op == OpCode::CALL);
|
||||
@ -1670,7 +1670,7 @@ TEST(LiteInterpreterUpgraderTest, DivTensorInplaceV2) {
|
||||
mobile::Module m_module = _load_for_mobile(test_model_file);
|
||||
|
||||
auto intrsuction_list =
|
||||
m_module.get_method("forward").function().get_code()->instructions_;
|
||||
m_module.get_method("forward").function().get_code().instructions_;
|
||||
uint64_t number_of_call_instruction = 0;
|
||||
for (auto& instruction : intrsuction_list) {
|
||||
number_of_call_instruction += (instruction.op == OpCode::CALL);
|
||||
@ -1710,7 +1710,7 @@ TEST(LiteInterpreterUpgraderTest, DivScalarFloatV2) {
|
||||
mobile::Module m_module = _load_for_mobile(test_model_file);
|
||||
|
||||
auto intrsuction_list =
|
||||
m_module.get_method("forward").function().get_code()->instructions_;
|
||||
m_module.get_method("forward").function().get_code().instructions_;
|
||||
uint64_t number_of_call_instruction = 0;
|
||||
for (auto& instruction : intrsuction_list) {
|
||||
number_of_call_instruction += (instruction.op == OpCode::CALL);
|
||||
@ -1750,7 +1750,7 @@ TEST(LiteInterpreterUpgraderTest, DivScalarReciprocalFloatV2) {
|
||||
mobile::Module m_module = _load_for_mobile(test_model_file);
|
||||
|
||||
auto intrsuction_list =
|
||||
m_module.get_method("forward").function().get_code()->instructions_;
|
||||
m_module.get_method("forward").function().get_code().instructions_;
|
||||
uint64_t number_of_call_instruction = 0;
|
||||
for (auto& instruction : intrsuction_list) {
|
||||
number_of_call_instruction += (instruction.op == OpCode::CALL);
|
||||
@ -1791,7 +1791,7 @@ TEST(LiteInterpreterUpgraderTest, DivScalarReciprocalIntV2) {
|
||||
mobile::Module m_module = _load_for_mobile(test_model_file);
|
||||
|
||||
auto intrsuction_list =
|
||||
m_module.get_method("forward").function().get_code()->instructions_;
|
||||
m_module.get_method("forward").function().get_code().instructions_;
|
||||
uint64_t number_of_call_instruction = 0;
|
||||
for (auto& instruction : intrsuction_list) {
|
||||
number_of_call_instruction += (instruction.op == OpCode::CALL);
|
||||
@ -1843,7 +1843,7 @@ TEST(LiteInterpreterUpgraderTest, DivScalarScalarV2) {
|
||||
*/
|
||||
mobile::Module m_module = _load_for_mobile(test_model_file);
|
||||
auto intrsuction_list =
|
||||
m_module.get_method("forward").function().get_code()->instructions_;
|
||||
m_module.get_method("forward").function().get_code().instructions_;
|
||||
uint64_t number_of_call_instruction = 0;
|
||||
for (auto& instruction : intrsuction_list) {
|
||||
number_of_call_instruction += (instruction.op == OpCode::CALL);
|
||||
@ -1884,7 +1884,7 @@ TEST(LiteInterpreterUpgraderTest, DivScalarIntV2) {
|
||||
mobile::Module m_module = _load_for_mobile(test_model_file);
|
||||
|
||||
auto intrsuction_list =
|
||||
m_module.get_method("forward").function().get_code()->instructions_;
|
||||
m_module.get_method("forward").function().get_code().instructions_;
|
||||
uint64_t number_of_call_instruction = 0;
|
||||
for (auto& instruction : intrsuction_list) {
|
||||
number_of_call_instruction += (instruction.op == OpCode::CALL);
|
||||
@ -1924,7 +1924,7 @@ TEST(LiteInterpreterUpgraderTest, DivScalarInplaceFloatV2) {
|
||||
mobile::Module m_module = _load_for_mobile(test_model_file);
|
||||
|
||||
auto intrsuction_list =
|
||||
m_module.get_method("forward").function().get_code()->instructions_;
|
||||
m_module.get_method("forward").function().get_code().instructions_;
|
||||
uint64_t number_of_call_instruction = 0;
|
||||
for (auto& instruction : intrsuction_list) {
|
||||
number_of_call_instruction += (instruction.op == OpCode::CALL);
|
||||
@ -1964,7 +1964,7 @@ TEST(LiteInterpreterUpgraderTest, DivScalarInplaceIntV2) {
|
||||
mobile::Module m_module = _load_for_mobile(test_model_file);
|
||||
|
||||
auto intrsuction_list =
|
||||
m_module.get_method("forward").function().get_code()->instructions_;
|
||||
m_module.get_method("forward").function().get_code().instructions_;
|
||||
uint64_t number_of_call_instruction = 0;
|
||||
for (auto& instruction : intrsuction_list) {
|
||||
number_of_call_instruction += (instruction.op == OpCode::CALL);
|
||||
@ -1988,9 +1988,9 @@ TEST(LiteInterpreterUpgraderTest, Upgrader) {
|
||||
|
||||
for (auto& byteCodeFunctionWithOperator : getUpgraderBytecodeList()) {
|
||||
ASSERT_EQ(
|
||||
byteCodeFunctionWithOperator.function.get_code()->operators_.size(),
|
||||
byteCodeFunctionWithOperator.function.get_code()->op_names_.size());
|
||||
if (byteCodeFunctionWithOperator.function.get_code()->operators_.empty()) {
|
||||
byteCodeFunctionWithOperator.function.get_code().operators_.size(),
|
||||
byteCodeFunctionWithOperator.function.get_code().op_names_.size());
|
||||
if (byteCodeFunctionWithOperator.function.get_code().operators_.empty()) {
|
||||
for (const auto& op : byteCodeFunctionWithOperator.operators) {
|
||||
byteCodeFunctionWithOperator.function.append_operator(
|
||||
op.name,
|
||||
|
@ -29,7 +29,7 @@ struct Code {
|
||||
// function objects, and then append referenced function pointers. This could
|
||||
// be done in parseMethods().
|
||||
std::vector<mobile::Function*> functions_;
|
||||
size_t register_size_; // Aggregated output size.
|
||||
size_t register_size_ = 0; // Aggregated output size.
|
||||
};
|
||||
|
||||
} // namespace mobile
|
||||
|
@ -13,12 +13,11 @@ namespace jit {
|
||||
|
||||
char const* toString(OpCode op);
|
||||
namespace mobile {
|
||||
Function::Function(c10::QualifiedName name)
|
||||
: name_(std::move(name)), code_(std::make_shared<Code>()) {}
|
||||
Function::Function(c10::QualifiedName name) : name_(std::move(name)) {}
|
||||
|
||||
Function::Function(
|
||||
c10::QualifiedName name,
|
||||
std::shared_ptr<Code> code,
|
||||
Code code,
|
||||
at::optional<c10::FunctionSchema> schema)
|
||||
: name_(std::move(name)),
|
||||
code_(std::move(code)),
|
||||
@ -33,8 +32,8 @@ void Function::append_instruction(OpCode op, int X, int N, int64_t dbg_handle) {
|
||||
isOpSupportedInMobile(op),
|
||||
toString(op),
|
||||
" is not supported in mobile module.");
|
||||
code_->instructions_.emplace_back(op, X, N);
|
||||
code_->debug_handles_.emplace_back(dbg_handle);
|
||||
code_.instructions_.emplace_back(op, X, N);
|
||||
code_.debug_handles_.emplace_back(dbg_handle);
|
||||
}
|
||||
|
||||
void Function::append_instruction(OpCode op, int X, int N) {
|
||||
@ -42,7 +41,7 @@ void Function::append_instruction(OpCode op, int X, int N) {
|
||||
isOpSupportedInMobile(op),
|
||||
toString(op),
|
||||
" is not supported in mobile module.");
|
||||
code_->instructions_.emplace_back(op, X, N);
|
||||
code_.instructions_.emplace_back(op, X, N);
|
||||
}
|
||||
|
||||
bool Function::append_operator(
|
||||
@ -52,39 +51,38 @@ bool Function::append_operator(
|
||||
int64_t model_version) { /* TODO: T90339189 deprecate all v3 when v3 models
|
||||
are removed */
|
||||
// Keep the original opname in code_
|
||||
code_->op_names_.emplace_back(name, overload_name);
|
||||
const auto& opname = code_->op_names_.back();
|
||||
code_->operator_input_sizes_.emplace_back(num_specified_args.value_or(-1));
|
||||
code_.op_names_.emplace_back(name, overload_name);
|
||||
const auto& opname = code_.op_names_.back();
|
||||
code_.operator_input_sizes_.emplace_back(num_specified_args.value_or(-1));
|
||||
auto func = makeOperatorFunction(opname, num_specified_args, model_version);
|
||||
if (!func.has_value()) {
|
||||
return false;
|
||||
}
|
||||
code_->operators_.emplace_back(*func);
|
||||
code_.operators_.emplace_back(*func);
|
||||
return true;
|
||||
}
|
||||
|
||||
void Function::append_constant(const c10::IValue& constant) {
|
||||
code_->constants_.push_back(constant);
|
||||
code_.constants_.push_back(constant);
|
||||
}
|
||||
|
||||
void Function::append_type(const at::TypePtr& type) {
|
||||
code_->types_.push_back(type);
|
||||
code_.types_.push_back(type);
|
||||
}
|
||||
|
||||
void Function::append_function(mobile::Function& function) {
|
||||
code_->functions_.push_back(&function);
|
||||
code_.functions_.push_back(&function);
|
||||
}
|
||||
|
||||
void Function::set_register_size(size_t size) {
|
||||
code_->register_size_ = size;
|
||||
code_.register_size_ = size;
|
||||
}
|
||||
|
||||
int64_t Function::get_debug_handle(size_t pc) const {
|
||||
TORCH_CHECK(code_, "Valid code must exist.");
|
||||
TORCH_CHECK(
|
||||
pc < code_->debug_handles_.size(),
|
||||
pc < code_.debug_handles_.size(),
|
||||
"Module debug info index out of boundary.");
|
||||
return code_->debug_handles_[pc];
|
||||
return code_.debug_handles_[pc];
|
||||
}
|
||||
|
||||
torch::jit::Function& Function::setSchema(c10::FunctionSchema schema) {
|
||||
@ -105,7 +103,7 @@ void Function::run(Stack& stack) {
|
||||
getSchema().checkAndNormalizeInputs(
|
||||
stack, std::unordered_map<std::string, IValue>{} /*kwargs*/);
|
||||
}
|
||||
InterpreterState interp_state(*code_);
|
||||
InterpreterState interp_state(code_);
|
||||
interp_state.run(stack);
|
||||
}
|
||||
|
||||
@ -119,11 +117,15 @@ size_t Function::num_inputs() const {
|
||||
}
|
||||
|
||||
bool Function::call(Stack&, c10::function_ref<void(const mobile::Code&)> f) {
|
||||
f(*code_);
|
||||
f(code_);
|
||||
return true;
|
||||
}
|
||||
|
||||
const std::shared_ptr<Code> Function::get_code() const {
|
||||
const Code& Function::get_code() const {
|
||||
return code_;
|
||||
}
|
||||
|
||||
Code& Function::get_code() {
|
||||
return code_;
|
||||
}
|
||||
|
||||
|
@ -5,23 +5,22 @@
|
||||
#include <ATen/core/function.h>
|
||||
#include <ATen/core/function_schema.h>
|
||||
#include <ATen/core/ivalue.h>
|
||||
#include <torch/csrc/jit/mobile/code.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
using Stack = std::vector<c10::IValue>;
|
||||
enum OpCode : uint8_t;
|
||||
struct Instruction;
|
||||
struct OperatorString;
|
||||
|
||||
namespace mobile {
|
||||
struct Code;
|
||||
|
||||
class TORCH_API Function : public torch::jit::Function {
|
||||
public:
|
||||
explicit Function(c10::QualifiedName name);
|
||||
Function(
|
||||
c10::QualifiedName name,
|
||||
std::shared_ptr<Code> code,
|
||||
Code code,
|
||||
at::optional<c10::FunctionSchema> schema);
|
||||
void run(Stack& stack) override;
|
||||
at::IValue operator()(Stack& stack);
|
||||
@ -48,7 +47,8 @@ class TORCH_API Function : public torch::jit::Function {
|
||||
void set_register_size(size_t size);
|
||||
|
||||
int64_t get_debug_handle(size_t pc) const;
|
||||
const std::shared_ptr<Code> get_code() const;
|
||||
const Code& get_code() const;
|
||||
Code& get_code();
|
||||
|
||||
torch::jit::Function& setSchema(c10::FunctionSchema schema) override;
|
||||
bool hasSchema() const;
|
||||
@ -67,7 +67,7 @@ class TORCH_API Function : public torch::jit::Function {
|
||||
|
||||
private:
|
||||
c10::QualifiedName name_;
|
||||
std::shared_ptr<Code> code_;
|
||||
Code code_;
|
||||
at::optional<c10::FunctionSchema> schema_; // (byte-code version 4+)
|
||||
};
|
||||
|
||||
|
@ -307,7 +307,7 @@ void BytecodeDeserializer::init_upgrader(mobile::Function* function) {
|
||||
// registerer size and etc), except operator. The operator function is also
|
||||
// static initialized and is available later. The oprator for the upgrader
|
||||
// function will be initialized when the first module is loaded.
|
||||
if (byteCodeFunctionWithOperator.function.get_code()->operators_.empty()) {
|
||||
if (byteCodeFunctionWithOperator.function.get_code().operators_.empty()) {
|
||||
for (const auto& op : byteCodeFunctionWithOperator.operators) {
|
||||
byteCodeFunctionWithOperator.function.append_operator(
|
||||
op.name,
|
||||
@ -669,12 +669,12 @@ std::set<std::string> _export_operator_list(
|
||||
std::set<std::string> operator_list;
|
||||
for (Method func : module.get_methods()) {
|
||||
const Function& function = func.function();
|
||||
const std::shared_ptr<Code> cptr = function.get_code();
|
||||
const auto& code = function.get_code();
|
||||
// op_names below isn't a list of unique operator names. In fact
|
||||
// it can contain the same operator name many many times, so we need
|
||||
// to de-dup the list by adding all the operator names into
|
||||
// an std::set<std::string>.
|
||||
std::vector<c10::OperatorName> const& op_names = cptr->op_names_;
|
||||
std::vector<c10::OperatorName> const& op_names = code.op_names_;
|
||||
for (auto& op_name : op_names) {
|
||||
operator_list.insert(toString(op_name));
|
||||
}
|
||||
|
@ -69,16 +69,16 @@ class OpCodeCache {
|
||||
} // namespace
|
||||
|
||||
void applyUpgrader(mobile::Function* function, uint64_t operator_version) {
|
||||
std::shared_ptr<Code> code = function->get_code();
|
||||
const Code& code = function->get_code();
|
||||
auto& operator_version_map = getOperatorVersionMapForMobile();
|
||||
for (size_t i = 0; i < function->get_code()->instructions_.size(); i++) {
|
||||
Instruction& inst = function->get_code()->instructions_[i];
|
||||
for (size_t i = 0; i < function->get_code().instructions_.size(); i++) {
|
||||
Instruction& inst = function->get_code().instructions_[i];
|
||||
if (inst.op == OpCode::OP) {
|
||||
std::string op_name = function->get_code()->op_names_[inst.X].name;
|
||||
std::string operator_name = function->get_code()->op_names_[inst.X].name +
|
||||
(function->get_code()->op_names_[inst.X].overload_name.empty()
|
||||
std::string op_name = function->get_code().op_names_[inst.X].name;
|
||||
std::string operator_name = function->get_code().op_names_[inst.X].name +
|
||||
(function->get_code().op_names_[inst.X].overload_name.empty()
|
||||
? ""
|
||||
: "." + function->get_code()->op_names_[inst.X].overload_name);
|
||||
: "." + function->get_code().op_names_[inst.X].overload_name);
|
||||
|
||||
auto it = operator_version_map.find(operator_name);
|
||||
// Find out if there is an upgrader for this operator
|
||||
@ -92,13 +92,13 @@ void applyUpgrader(mobile::Function* function, uint64_t operator_version) {
|
||||
if (operator_version <= upgrader.max_version &&
|
||||
operator_version >= upgrader.min_version) {
|
||||
auto func_name = function->get_code()
|
||||
->functions_[upgrader.index]
|
||||
.functions_[upgrader.index]
|
||||
->qualname()
|
||||
.qualifiedName();
|
||||
// If there exists a valid upgrader, change the instruction OP to
|
||||
// CALL, and the index will point to the according upgrader
|
||||
// function. All upgrader function are available in
|
||||
// function->get_code()->functions_. It's a vector of function
|
||||
// function->get_code().functions_. It's a vector of function
|
||||
// pointer and they are initialized in the same order as the global
|
||||
// vector kUpgraderBytecode.
|
||||
// Instruction new_inst = inst;
|
||||
|
@ -133,7 +133,7 @@ std::vector<std::unique_ptr<GraphFunction>> inlineFunctions(
|
||||
return inlined_functions;
|
||||
}
|
||||
|
||||
std::unique_ptr<mobile::Code> compileGraphToMobileCode(
|
||||
mobile::Code compileGraphToMobileCode(
|
||||
const std::string& name,
|
||||
const std::shared_ptr<Graph>& graph,
|
||||
const CompilationOptions& compilation_options,
|
||||
@ -144,9 +144,7 @@ std::unique_ptr<mobile::Code> compileGraphToMobileCode(
|
||||
compilation_options.enable_default_value_for_unspecified_arg,
|
||||
compilation_options.enable_default_args_before_out_args);
|
||||
|
||||
std::unique_ptr<mobile::Code> mobile_code_ptr =
|
||||
std::make_unique<mobile::Code>();
|
||||
mobile::Code& mobile_code = *mobile_code_ptr;
|
||||
mobile::Code mobile_code;
|
||||
|
||||
// operator names
|
||||
std::vector<std::string> method_names;
|
||||
@ -245,35 +243,35 @@ std::unique_ptr<mobile::Code> compileGraphToMobileCode(
|
||||
|
||||
mobile_code.types_ = code.type_table();
|
||||
mobile_code.register_size_ = code.register_size();
|
||||
return mobile_code_ptr;
|
||||
return mobile_code;
|
||||
}
|
||||
|
||||
std::unique_ptr<mobile::Function> convertJitFunctionToMobileFunction(
|
||||
const GraphFunction& function,
|
||||
const CompilationOptions& options) {
|
||||
BackendDebugInfoRecorder debug_handle;
|
||||
std::shared_ptr<mobile::Code> mobileCode = compileGraphToMobileCode(
|
||||
auto mobileCode = compileGraphToMobileCode(
|
||||
function.name(), function.graph(), options, debug_handle);
|
||||
const auto& schema = function.getSchema();
|
||||
return std::make_unique<mobile::Function>(
|
||||
function.qualname(), mobileCode, schema);
|
||||
function.qualname(), std::move(mobileCode), schema);
|
||||
}
|
||||
|
||||
IValue convertMobileFunctionToCodeTable(
|
||||
const mobile::Function& func,
|
||||
const CompilationOptions& compilation_options) {
|
||||
const std::shared_ptr<mobile::Code> code = func.get_code();
|
||||
auto code = func.get_code();
|
||||
std::vector<IValue> instructions;
|
||||
instructions.reserve(code->instructions_.size());
|
||||
for (Instruction ins : code->instructions_) {
|
||||
instructions.reserve(code.instructions_.size());
|
||||
for (Instruction ins : code.instructions_) {
|
||||
instructions.emplace_back(to_tuple({toString(ins.op), ins.X, ins.N}));
|
||||
}
|
||||
|
||||
std::vector<IValue> operators;
|
||||
operators.reserve(code->op_names_.size());
|
||||
for (int i = 0; i < code->op_names_.size(); ++i) {
|
||||
const auto& opname = code->op_names_[i];
|
||||
const int size = code->operator_input_sizes_[i];
|
||||
operators.reserve(code.op_names_.size());
|
||||
for (int i = 0; i < code.op_names_.size(); ++i) {
|
||||
const auto& opname = code.op_names_[i];
|
||||
const int size = code.operator_input_sizes_[i];
|
||||
if (compilation_options.enable_default_value_for_unspecified_arg) {
|
||||
operators.emplace_back(to_tuple({opname.name, opname.overload_name}));
|
||||
} else {
|
||||
@ -283,16 +281,16 @@ IValue convertMobileFunctionToCodeTable(
|
||||
}
|
||||
|
||||
std::vector<IValue> types;
|
||||
for (const TypePtr& t : code->types_) {
|
||||
for (const TypePtr& t : code.types_) {
|
||||
std::string type_str = t->annotation_str();
|
||||
types.emplace_back(type_str);
|
||||
}
|
||||
|
||||
auto register_size = static_cast<int>(code->register_size_);
|
||||
auto register_size = static_cast<int>(code.register_size_);
|
||||
auto codeTable = Table(
|
||||
{{"instructions", to_tuple(instructions)},
|
||||
{"operators", to_tuple(operators)},
|
||||
{"constants", to_tuple(code->constants_)},
|
||||
{"constants", to_tuple(code.constants_)},
|
||||
{"types", to_tuple(types)},
|
||||
{"register_size", register_size}});
|
||||
|
||||
@ -360,12 +358,12 @@ mobile::Module jitModuleToMobile(
|
||||
|
||||
for (const auto& func :
|
||||
inlineFunctions(methods_to_export, options.incl_interface_call)) {
|
||||
std::shared_ptr<mobile::Code> mobile_code_ptr = compileGraphToMobileCode(
|
||||
auto mobile_code = compileGraphToMobileCode(
|
||||
func->name(), func->graph(), options, debug_info_recorder);
|
||||
const auto& schema = func->getSchema();
|
||||
checkSchema(schema);
|
||||
auto mobile_func = std::make_unique<mobile::Function>(
|
||||
func->qualname(), mobile_code_ptr, schema);
|
||||
func->qualname(), std::move(mobile_code), schema);
|
||||
mcu->register_function(std::move(mobile_func));
|
||||
}
|
||||
|
||||
|
@ -27,7 +27,7 @@ TORCH_API mobile::Module jitModuleToMobile(
|
||||
const Module& module,
|
||||
const CompilationOptions& options);
|
||||
|
||||
std::unique_ptr<mobile::Code> compileGraphToMobileCode(
|
||||
mobile::Code compileGraphToMobileCode(
|
||||
const std::string& name,
|
||||
const std::shared_ptr<Graph>& graph,
|
||||
const CompilationOptions& compilation_options,
|
||||
|
@ -79,21 +79,21 @@ std::pair<IValue, IValue> getFunctionTuple(
|
||||
const mobile::Function& func,
|
||||
BackendDebugInfoRecorder& debug_info_recorder,
|
||||
TypeNameUniquer& type_name_uniquer_) {
|
||||
const std::shared_ptr<mobile::Code> mobile_code_ptr = func.get_code();
|
||||
const auto& mobile_code = func.get_code();
|
||||
|
||||
// instructions
|
||||
std::vector<IValue> instructions;
|
||||
instructions.reserve(mobile_code_ptr->instructions_.size());
|
||||
for (Instruction ins : mobile_code_ptr->instructions_) {
|
||||
instructions.reserve(mobile_code.instructions_.size());
|
||||
for (Instruction ins : mobile_code.instructions_) {
|
||||
instructions.emplace_back(to_tuple({toString(ins.op), ins.X, ins.N}));
|
||||
}
|
||||
|
||||
// operators
|
||||
std::vector<IValue> operators;
|
||||
operators.reserve(mobile_code_ptr->op_names_.size());
|
||||
for (int i = 0; i < mobile_code_ptr->op_names_.size(); ++i) {
|
||||
const auto& opname = mobile_code_ptr->op_names_[i];
|
||||
const int size = mobile_code_ptr->operator_input_sizes_[i];
|
||||
operators.reserve(mobile_code.op_names_.size());
|
||||
for (int i = 0; i < mobile_code.op_names_.size(); ++i) {
|
||||
const auto& opname = mobile_code.op_names_[i];
|
||||
const int size = mobile_code.operator_input_sizes_[i];
|
||||
if (BytecodeEmitMode::is_default_value_for_unspecified_arg_enabled()) {
|
||||
operators.emplace_back(to_tuple({opname.name, opname.overload_name}));
|
||||
} else {
|
||||
@ -104,11 +104,11 @@ std::pair<IValue, IValue> getFunctionTuple(
|
||||
|
||||
// types
|
||||
std::vector<IValue> types;
|
||||
types.reserve(mobile_code_ptr->types_.size());
|
||||
types.reserve(mobile_code.types_.size());
|
||||
static const std::string torch_prefix("__torch__");
|
||||
static const std::string class_prefix("__torch__.torch.classes");
|
||||
|
||||
for (const TypePtr& t : mobile_code_ptr->types_) {
|
||||
for (const TypePtr& t : mobile_code.types_) {
|
||||
std::string type_str = t->annotation_str();
|
||||
if (t->kind() == TypeKind::TupleType) {
|
||||
TORCH_CHECK(
|
||||
@ -177,12 +177,12 @@ std::pair<IValue, IValue> getFunctionTuple(
|
||||
|
||||
// since the register location is embedded into the bytecode, pass the
|
||||
// register size
|
||||
auto register_size = static_cast<int>(mobile_code_ptr->register_size_);
|
||||
auto register_size = static_cast<int>(mobile_code.register_size_);
|
||||
|
||||
auto codeTable = Table(
|
||||
{{"instructions", to_tuple(instructions)},
|
||||
{"operators", to_tuple(operators)},
|
||||
{"constants", to_tuple(mobile_code_ptr->constants_)},
|
||||
{"constants", to_tuple(mobile_code.constants_)},
|
||||
{"types", to_tuple(types)},
|
||||
{"register_size", register_size}});
|
||||
|
||||
@ -251,7 +251,7 @@ std::pair<IValue, IValue> getFunctionTuple(
|
||||
// will correspond to {source_range, inlinedCallStackPtr} which we will
|
||||
// serialize separately.
|
||||
IValue module_debug_tuple =
|
||||
c10::ivalue::Tuple::create(mobile_code_ptr->debug_handles_);
|
||||
c10::ivalue::Tuple::create(mobile_code.debug_handles_);
|
||||
auto function_debug_info =
|
||||
Table({{"function_debug_handles", module_debug_tuple}});
|
||||
debug_info_vals = to_tuple({qn, function_debug_info});
|
||||
@ -805,7 +805,7 @@ namespace {
|
||||
void export_opnames(const script::Module& m, std::set<std::string>& opnames) {
|
||||
mobile::Module mobile_m = jitModuleToMobile(m, getOptionsFromGlobal());
|
||||
for (const auto& method : mobile_m.get_methods()) {
|
||||
for (const auto& op : method.function().get_code()->op_names_) {
|
||||
for (const auto& op : method.function().get_code().op_names_) {
|
||||
// NOLINTNEXTLINE(performance-inefficient-string-concatenation)
|
||||
opnames.emplace(
|
||||
op.overload_name.empty() ? op.name
|
||||
|
Reference in New Issue
Block a user