[jit][edge] Remove usage of shared_ptr<mobile::Code>. (#68037)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/68037

Right now mobile::Code doesn't outlive its enclosing Function, and all accesses to Code happens inside interpreter loop which doesn't outlive the module, so we don't need to use std::shared_ptr here. This also should saves us 1-2 KB for binary size, because shared_ptr seems to bloat on arm64 android.
ghstack-source-id: 145818696

Test Plan: eyes.

Reviewed By: qihqi, tugsbayasgalan

Differential Revision: D32264616

fbshipit-source-id: d83f538d6604cf75fd7728a25127b4849ce7ab2a
This commit is contained in:
Zhengxu Chen
2021-12-16 13:06:08 -08:00
committed by Facebook GitHub Bot
parent 39f65fee47
commit d459e79500
9 changed files with 88 additions and 88 deletions

View File

@ -1517,8 +1517,8 @@ TEST(LiteInterpreterTest, OperatorSize1) {
mobile::Module bc = _load_for_mobile(ss);
const auto& func = bc.get_method("forward").function();
ASSERT_EQ(
func.get_code()->operator_input_sizes_.size(),
func.get_code()->operators_.size());
func.get_code().operator_input_sizes_.size(),
func.get_code().operators_.size());
}
TEST(LiteInterpreterTest, OperatorTest2) { // NOLINT (use =delete in gtest)
@ -1552,8 +1552,8 @@ TEST(LiteInterpreterTest, OperatorTest2) { // NOLINT (use =delete in gtest)
mobile::Module bc = _load_for_mobile(ss);
const auto& func = bc.get_method("test_func").function();
ASSERT_EQ(
func.get_code()->operator_input_sizes_.size(),
func.get_code()->operators_.size());
func.get_code().operator_input_sizes_.size(),
func.get_code().operators_.size());
}
}
@ -1590,7 +1590,7 @@ TEST(LiteInterpreterUpgraderTest, DivTensorV2) {
*/
mobile::Module m_module = _load_for_mobile(test_model_file);
auto intrsuction_list =
m_module.get_method("forward").function().get_code()->instructions_;
m_module.get_method("forward").function().get_code().instructions_;
uint64_t number_of_call_instruction = 0;
for (auto& instruction : intrsuction_list) {
number_of_call_instruction += (instruction.op == OpCode::CALL);
@ -1629,7 +1629,7 @@ TEST(LiteInterpreterUpgraderTest, DivTensorOutV2) {
mobile::Module m_module = _load_for_mobile(test_model_file);
auto intrsuction_list =
m_module.get_method("forward").function().get_code()->instructions_;
m_module.get_method("forward").function().get_code().instructions_;
uint64_t number_of_call_instruction = 0;
for (auto& instruction : intrsuction_list) {
number_of_call_instruction += (instruction.op == OpCode::CALL);
@ -1670,7 +1670,7 @@ TEST(LiteInterpreterUpgraderTest, DivTensorInplaceV2) {
mobile::Module m_module = _load_for_mobile(test_model_file);
auto intrsuction_list =
m_module.get_method("forward").function().get_code()->instructions_;
m_module.get_method("forward").function().get_code().instructions_;
uint64_t number_of_call_instruction = 0;
for (auto& instruction : intrsuction_list) {
number_of_call_instruction += (instruction.op == OpCode::CALL);
@ -1710,7 +1710,7 @@ TEST(LiteInterpreterUpgraderTest, DivScalarFloatV2) {
mobile::Module m_module = _load_for_mobile(test_model_file);
auto intrsuction_list =
m_module.get_method("forward").function().get_code()->instructions_;
m_module.get_method("forward").function().get_code().instructions_;
uint64_t number_of_call_instruction = 0;
for (auto& instruction : intrsuction_list) {
number_of_call_instruction += (instruction.op == OpCode::CALL);
@ -1750,7 +1750,7 @@ TEST(LiteInterpreterUpgraderTest, DivScalarReciprocalFloatV2) {
mobile::Module m_module = _load_for_mobile(test_model_file);
auto intrsuction_list =
m_module.get_method("forward").function().get_code()->instructions_;
m_module.get_method("forward").function().get_code().instructions_;
uint64_t number_of_call_instruction = 0;
for (auto& instruction : intrsuction_list) {
number_of_call_instruction += (instruction.op == OpCode::CALL);
@ -1791,7 +1791,7 @@ TEST(LiteInterpreterUpgraderTest, DivScalarReciprocalIntV2) {
mobile::Module m_module = _load_for_mobile(test_model_file);
auto intrsuction_list =
m_module.get_method("forward").function().get_code()->instructions_;
m_module.get_method("forward").function().get_code().instructions_;
uint64_t number_of_call_instruction = 0;
for (auto& instruction : intrsuction_list) {
number_of_call_instruction += (instruction.op == OpCode::CALL);
@ -1843,7 +1843,7 @@ TEST(LiteInterpreterUpgraderTest, DivScalarScalarV2) {
*/
mobile::Module m_module = _load_for_mobile(test_model_file);
auto intrsuction_list =
m_module.get_method("forward").function().get_code()->instructions_;
m_module.get_method("forward").function().get_code().instructions_;
uint64_t number_of_call_instruction = 0;
for (auto& instruction : intrsuction_list) {
number_of_call_instruction += (instruction.op == OpCode::CALL);
@ -1884,7 +1884,7 @@ TEST(LiteInterpreterUpgraderTest, DivScalarIntV2) {
mobile::Module m_module = _load_for_mobile(test_model_file);
auto intrsuction_list =
m_module.get_method("forward").function().get_code()->instructions_;
m_module.get_method("forward").function().get_code().instructions_;
uint64_t number_of_call_instruction = 0;
for (auto& instruction : intrsuction_list) {
number_of_call_instruction += (instruction.op == OpCode::CALL);
@ -1924,7 +1924,7 @@ TEST(LiteInterpreterUpgraderTest, DivScalarInplaceFloatV2) {
mobile::Module m_module = _load_for_mobile(test_model_file);
auto intrsuction_list =
m_module.get_method("forward").function().get_code()->instructions_;
m_module.get_method("forward").function().get_code().instructions_;
uint64_t number_of_call_instruction = 0;
for (auto& instruction : intrsuction_list) {
number_of_call_instruction += (instruction.op == OpCode::CALL);
@ -1964,7 +1964,7 @@ TEST(LiteInterpreterUpgraderTest, DivScalarInplaceIntV2) {
mobile::Module m_module = _load_for_mobile(test_model_file);
auto intrsuction_list =
m_module.get_method("forward").function().get_code()->instructions_;
m_module.get_method("forward").function().get_code().instructions_;
uint64_t number_of_call_instruction = 0;
for (auto& instruction : intrsuction_list) {
number_of_call_instruction += (instruction.op == OpCode::CALL);
@ -1988,9 +1988,9 @@ TEST(LiteInterpreterUpgraderTest, Upgrader) {
for (auto& byteCodeFunctionWithOperator : getUpgraderBytecodeList()) {
ASSERT_EQ(
byteCodeFunctionWithOperator.function.get_code()->operators_.size(),
byteCodeFunctionWithOperator.function.get_code()->op_names_.size());
if (byteCodeFunctionWithOperator.function.get_code()->operators_.empty()) {
byteCodeFunctionWithOperator.function.get_code().operators_.size(),
byteCodeFunctionWithOperator.function.get_code().op_names_.size());
if (byteCodeFunctionWithOperator.function.get_code().operators_.empty()) {
for (const auto& op : byteCodeFunctionWithOperator.operators) {
byteCodeFunctionWithOperator.function.append_operator(
op.name,

View File

@ -29,7 +29,7 @@ struct Code {
// function objects, and then append referenced function pointers. This could
// be done in parseMethods().
std::vector<mobile::Function*> functions_;
size_t register_size_; // Aggregated output size.
size_t register_size_ = 0; // Aggregated output size.
};
} // namespace mobile

View File

@ -13,12 +13,11 @@ namespace jit {
char const* toString(OpCode op);
namespace mobile {
Function::Function(c10::QualifiedName name)
: name_(std::move(name)), code_(std::make_shared<Code>()) {}
Function::Function(c10::QualifiedName name) : name_(std::move(name)) {}
Function::Function(
c10::QualifiedName name,
std::shared_ptr<Code> code,
Code code,
at::optional<c10::FunctionSchema> schema)
: name_(std::move(name)),
code_(std::move(code)),
@ -33,8 +32,8 @@ void Function::append_instruction(OpCode op, int X, int N, int64_t dbg_handle) {
isOpSupportedInMobile(op),
toString(op),
" is not supported in mobile module.");
code_->instructions_.emplace_back(op, X, N);
code_->debug_handles_.emplace_back(dbg_handle);
code_.instructions_.emplace_back(op, X, N);
code_.debug_handles_.emplace_back(dbg_handle);
}
void Function::append_instruction(OpCode op, int X, int N) {
@ -42,7 +41,7 @@ void Function::append_instruction(OpCode op, int X, int N) {
isOpSupportedInMobile(op),
toString(op),
" is not supported in mobile module.");
code_->instructions_.emplace_back(op, X, N);
code_.instructions_.emplace_back(op, X, N);
}
bool Function::append_operator(
@ -52,39 +51,38 @@ bool Function::append_operator(
int64_t model_version) { /* TODO: T90339189 deprecate all v3 when v3 models
are removed */
// Keep the original opname in code_
code_->op_names_.emplace_back(name, overload_name);
const auto& opname = code_->op_names_.back();
code_->operator_input_sizes_.emplace_back(num_specified_args.value_or(-1));
code_.op_names_.emplace_back(name, overload_name);
const auto& opname = code_.op_names_.back();
code_.operator_input_sizes_.emplace_back(num_specified_args.value_or(-1));
auto func = makeOperatorFunction(opname, num_specified_args, model_version);
if (!func.has_value()) {
return false;
}
code_->operators_.emplace_back(*func);
code_.operators_.emplace_back(*func);
return true;
}
void Function::append_constant(const c10::IValue& constant) {
code_->constants_.push_back(constant);
code_.constants_.push_back(constant);
}
void Function::append_type(const at::TypePtr& type) {
code_->types_.push_back(type);
code_.types_.push_back(type);
}
void Function::append_function(mobile::Function& function) {
code_->functions_.push_back(&function);
code_.functions_.push_back(&function);
}
void Function::set_register_size(size_t size) {
code_->register_size_ = size;
code_.register_size_ = size;
}
int64_t Function::get_debug_handle(size_t pc) const {
TORCH_CHECK(code_, "Valid code must exist.");
TORCH_CHECK(
pc < code_->debug_handles_.size(),
pc < code_.debug_handles_.size(),
"Module debug info index out of boundary.");
return code_->debug_handles_[pc];
return code_.debug_handles_[pc];
}
torch::jit::Function& Function::setSchema(c10::FunctionSchema schema) {
@ -105,7 +103,7 @@ void Function::run(Stack& stack) {
getSchema().checkAndNormalizeInputs(
stack, std::unordered_map<std::string, IValue>{} /*kwargs*/);
}
InterpreterState interp_state(*code_);
InterpreterState interp_state(code_);
interp_state.run(stack);
}
@ -119,11 +117,15 @@ size_t Function::num_inputs() const {
}
bool Function::call(Stack&, c10::function_ref<void(const mobile::Code&)> f) {
f(*code_);
f(code_);
return true;
}
const std::shared_ptr<Code> Function::get_code() const {
const Code& Function::get_code() const {
return code_;
}
Code& Function::get_code() {
return code_;
}

View File

@ -5,23 +5,22 @@
#include <ATen/core/function.h>
#include <ATen/core/function_schema.h>
#include <ATen/core/ivalue.h>
#include <torch/csrc/jit/mobile/code.h>
namespace torch {
namespace jit {
using Stack = std::vector<c10::IValue>;
enum OpCode : uint8_t;
struct Instruction;
struct OperatorString;
namespace mobile {
struct Code;
class TORCH_API Function : public torch::jit::Function {
public:
explicit Function(c10::QualifiedName name);
Function(
c10::QualifiedName name,
std::shared_ptr<Code> code,
Code code,
at::optional<c10::FunctionSchema> schema);
void run(Stack& stack) override;
at::IValue operator()(Stack& stack);
@ -48,7 +47,8 @@ class TORCH_API Function : public torch::jit::Function {
void set_register_size(size_t size);
int64_t get_debug_handle(size_t pc) const;
const std::shared_ptr<Code> get_code() const;
const Code& get_code() const;
Code& get_code();
torch::jit::Function& setSchema(c10::FunctionSchema schema) override;
bool hasSchema() const;
@ -67,7 +67,7 @@ class TORCH_API Function : public torch::jit::Function {
private:
c10::QualifiedName name_;
std::shared_ptr<Code> code_;
Code code_;
at::optional<c10::FunctionSchema> schema_; // (byte-code version 4+)
};

View File

@ -307,7 +307,7 @@ void BytecodeDeserializer::init_upgrader(mobile::Function* function) {
// registerer size and etc), except operator. The operator function is also
// static initialized and is available later. The oprator for the upgrader
// function will be initialized when the first module is loaded.
if (byteCodeFunctionWithOperator.function.get_code()->operators_.empty()) {
if (byteCodeFunctionWithOperator.function.get_code().operators_.empty()) {
for (const auto& op : byteCodeFunctionWithOperator.operators) {
byteCodeFunctionWithOperator.function.append_operator(
op.name,
@ -669,12 +669,12 @@ std::set<std::string> _export_operator_list(
std::set<std::string> operator_list;
for (Method func : module.get_methods()) {
const Function& function = func.function();
const std::shared_ptr<Code> cptr = function.get_code();
const auto& code = function.get_code();
// op_names below isn't a list of unique operator names. In fact
// it can contain the same operator name many many times, so we need
// to de-dup the list by adding all the operator names into
// an std::set<std::string>.
std::vector<c10::OperatorName> const& op_names = cptr->op_names_;
std::vector<c10::OperatorName> const& op_names = code.op_names_;
for (auto& op_name : op_names) {
operator_list.insert(toString(op_name));
}

View File

@ -69,16 +69,16 @@ class OpCodeCache {
} // namespace
void applyUpgrader(mobile::Function* function, uint64_t operator_version) {
std::shared_ptr<Code> code = function->get_code();
const Code& code = function->get_code();
auto& operator_version_map = getOperatorVersionMapForMobile();
for (size_t i = 0; i < function->get_code()->instructions_.size(); i++) {
Instruction& inst = function->get_code()->instructions_[i];
for (size_t i = 0; i < function->get_code().instructions_.size(); i++) {
Instruction& inst = function->get_code().instructions_[i];
if (inst.op == OpCode::OP) {
std::string op_name = function->get_code()->op_names_[inst.X].name;
std::string operator_name = function->get_code()->op_names_[inst.X].name +
(function->get_code()->op_names_[inst.X].overload_name.empty()
std::string op_name = function->get_code().op_names_[inst.X].name;
std::string operator_name = function->get_code().op_names_[inst.X].name +
(function->get_code().op_names_[inst.X].overload_name.empty()
? ""
: "." + function->get_code()->op_names_[inst.X].overload_name);
: "." + function->get_code().op_names_[inst.X].overload_name);
auto it = operator_version_map.find(operator_name);
// Find out if there is an upgrader for this operator
@ -92,13 +92,13 @@ void applyUpgrader(mobile::Function* function, uint64_t operator_version) {
if (operator_version <= upgrader.max_version &&
operator_version >= upgrader.min_version) {
auto func_name = function->get_code()
->functions_[upgrader.index]
.functions_[upgrader.index]
->qualname()
.qualifiedName();
// If there exists a valid upgrader, change the instruction OP to
// CALL, and the index will point to the according upgrader
// function. All upgrader function are available in
// function->get_code()->functions_. It's a vector of function
// function->get_code().functions_. It's a vector of function
// pointer and they are initialized in the same order as the global
// vector kUpgraderBytecode.
// Instruction new_inst = inst;

View File

@ -133,7 +133,7 @@ std::vector<std::unique_ptr<GraphFunction>> inlineFunctions(
return inlined_functions;
}
std::unique_ptr<mobile::Code> compileGraphToMobileCode(
mobile::Code compileGraphToMobileCode(
const std::string& name,
const std::shared_ptr<Graph>& graph,
const CompilationOptions& compilation_options,
@ -144,9 +144,7 @@ std::unique_ptr<mobile::Code> compileGraphToMobileCode(
compilation_options.enable_default_value_for_unspecified_arg,
compilation_options.enable_default_args_before_out_args);
std::unique_ptr<mobile::Code> mobile_code_ptr =
std::make_unique<mobile::Code>();
mobile::Code& mobile_code = *mobile_code_ptr;
mobile::Code mobile_code;
// operator names
std::vector<std::string> method_names;
@ -245,35 +243,35 @@ std::unique_ptr<mobile::Code> compileGraphToMobileCode(
mobile_code.types_ = code.type_table();
mobile_code.register_size_ = code.register_size();
return mobile_code_ptr;
return mobile_code;
}
std::unique_ptr<mobile::Function> convertJitFunctionToMobileFunction(
const GraphFunction& function,
const CompilationOptions& options) {
BackendDebugInfoRecorder debug_handle;
std::shared_ptr<mobile::Code> mobileCode = compileGraphToMobileCode(
auto mobileCode = compileGraphToMobileCode(
function.name(), function.graph(), options, debug_handle);
const auto& schema = function.getSchema();
return std::make_unique<mobile::Function>(
function.qualname(), mobileCode, schema);
function.qualname(), std::move(mobileCode), schema);
}
IValue convertMobileFunctionToCodeTable(
const mobile::Function& func,
const CompilationOptions& compilation_options) {
const std::shared_ptr<mobile::Code> code = func.get_code();
auto code = func.get_code();
std::vector<IValue> instructions;
instructions.reserve(code->instructions_.size());
for (Instruction ins : code->instructions_) {
instructions.reserve(code.instructions_.size());
for (Instruction ins : code.instructions_) {
instructions.emplace_back(to_tuple({toString(ins.op), ins.X, ins.N}));
}
std::vector<IValue> operators;
operators.reserve(code->op_names_.size());
for (int i = 0; i < code->op_names_.size(); ++i) {
const auto& opname = code->op_names_[i];
const int size = code->operator_input_sizes_[i];
operators.reserve(code.op_names_.size());
for (int i = 0; i < code.op_names_.size(); ++i) {
const auto& opname = code.op_names_[i];
const int size = code.operator_input_sizes_[i];
if (compilation_options.enable_default_value_for_unspecified_arg) {
operators.emplace_back(to_tuple({opname.name, opname.overload_name}));
} else {
@ -283,16 +281,16 @@ IValue convertMobileFunctionToCodeTable(
}
std::vector<IValue> types;
for (const TypePtr& t : code->types_) {
for (const TypePtr& t : code.types_) {
std::string type_str = t->annotation_str();
types.emplace_back(type_str);
}
auto register_size = static_cast<int>(code->register_size_);
auto register_size = static_cast<int>(code.register_size_);
auto codeTable = Table(
{{"instructions", to_tuple(instructions)},
{"operators", to_tuple(operators)},
{"constants", to_tuple(code->constants_)},
{"constants", to_tuple(code.constants_)},
{"types", to_tuple(types)},
{"register_size", register_size}});
@ -360,12 +358,12 @@ mobile::Module jitModuleToMobile(
for (const auto& func :
inlineFunctions(methods_to_export, options.incl_interface_call)) {
std::shared_ptr<mobile::Code> mobile_code_ptr = compileGraphToMobileCode(
auto mobile_code = compileGraphToMobileCode(
func->name(), func->graph(), options, debug_info_recorder);
const auto& schema = func->getSchema();
checkSchema(schema);
auto mobile_func = std::make_unique<mobile::Function>(
func->qualname(), mobile_code_ptr, schema);
func->qualname(), std::move(mobile_code), schema);
mcu->register_function(std::move(mobile_func));
}

View File

@ -27,7 +27,7 @@ TORCH_API mobile::Module jitModuleToMobile(
const Module& module,
const CompilationOptions& options);
std::unique_ptr<mobile::Code> compileGraphToMobileCode(
mobile::Code compileGraphToMobileCode(
const std::string& name,
const std::shared_ptr<Graph>& graph,
const CompilationOptions& compilation_options,

View File

@ -79,21 +79,21 @@ std::pair<IValue, IValue> getFunctionTuple(
const mobile::Function& func,
BackendDebugInfoRecorder& debug_info_recorder,
TypeNameUniquer& type_name_uniquer_) {
const std::shared_ptr<mobile::Code> mobile_code_ptr = func.get_code();
const auto& mobile_code = func.get_code();
// instructions
std::vector<IValue> instructions;
instructions.reserve(mobile_code_ptr->instructions_.size());
for (Instruction ins : mobile_code_ptr->instructions_) {
instructions.reserve(mobile_code.instructions_.size());
for (Instruction ins : mobile_code.instructions_) {
instructions.emplace_back(to_tuple({toString(ins.op), ins.X, ins.N}));
}
// operators
std::vector<IValue> operators;
operators.reserve(mobile_code_ptr->op_names_.size());
for (int i = 0; i < mobile_code_ptr->op_names_.size(); ++i) {
const auto& opname = mobile_code_ptr->op_names_[i];
const int size = mobile_code_ptr->operator_input_sizes_[i];
operators.reserve(mobile_code.op_names_.size());
for (int i = 0; i < mobile_code.op_names_.size(); ++i) {
const auto& opname = mobile_code.op_names_[i];
const int size = mobile_code.operator_input_sizes_[i];
if (BytecodeEmitMode::is_default_value_for_unspecified_arg_enabled()) {
operators.emplace_back(to_tuple({opname.name, opname.overload_name}));
} else {
@ -104,11 +104,11 @@ std::pair<IValue, IValue> getFunctionTuple(
// types
std::vector<IValue> types;
types.reserve(mobile_code_ptr->types_.size());
types.reserve(mobile_code.types_.size());
static const std::string torch_prefix("__torch__");
static const std::string class_prefix("__torch__.torch.classes");
for (const TypePtr& t : mobile_code_ptr->types_) {
for (const TypePtr& t : mobile_code.types_) {
std::string type_str = t->annotation_str();
if (t->kind() == TypeKind::TupleType) {
TORCH_CHECK(
@ -177,12 +177,12 @@ std::pair<IValue, IValue> getFunctionTuple(
// since the register location is embedded into the bytecode, pass the
// register size
auto register_size = static_cast<int>(mobile_code_ptr->register_size_);
auto register_size = static_cast<int>(mobile_code.register_size_);
auto codeTable = Table(
{{"instructions", to_tuple(instructions)},
{"operators", to_tuple(operators)},
{"constants", to_tuple(mobile_code_ptr->constants_)},
{"constants", to_tuple(mobile_code.constants_)},
{"types", to_tuple(types)},
{"register_size", register_size}});
@ -251,7 +251,7 @@ std::pair<IValue, IValue> getFunctionTuple(
// will correspond to {source_range, inlinedCallStackPtr} which we will
// serialize separately.
IValue module_debug_tuple =
c10::ivalue::Tuple::create(mobile_code_ptr->debug_handles_);
c10::ivalue::Tuple::create(mobile_code.debug_handles_);
auto function_debug_info =
Table({{"function_debug_handles", module_debug_tuple}});
debug_info_vals = to_tuple({qn, function_debug_info});
@ -805,7 +805,7 @@ namespace {
void export_opnames(const script::Module& m, std::set<std::string>& opnames) {
mobile::Module mobile_m = jitModuleToMobile(m, getOptionsFromGlobal());
for (const auto& method : mobile_m.get_methods()) {
for (const auto& op : method.function().get_code()->op_names_) {
for (const auto& op : method.function().get_code().op_names_) {
// NOLINTNEXTLINE(performance-inefficient-string-concatenation)
opnames.emplace(
op.overload_name.empty() ? op.name