[jit][edge] Enable CALL instruction in lite interpreter. (#65964)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/65964

ghstack-source-id: 141425519

Test Plan: buck run xplat/caffe2:test_lite_interpreter

Reviewed By: cccclai

Differential Revision: D31326149

fbshipit-source-id: 8a599d92f3fa4e6c125100adb36d89592e71e547
This commit is contained in:
Zhengxu Chen
2021-10-25 14:43:08 -07:00
committed by Facebook GitHub Bot
parent b8dfb45ac2
commit 12daa4f663
6 changed files with 105 additions and 1 deletions

View File

@ -9,6 +9,7 @@
#include <torch/csrc/jit/mobile/backport.h>
#include <torch/csrc/jit/mobile/backport_manager.h>
#include <torch/csrc/jit/mobile/import.h>
#include <torch/csrc/jit/mobile/interpreter.h>
#include <torch/csrc/jit/mobile/model_compatibility.h>
#include <torch/csrc/jit/mobile/module.h>
#include <torch/csrc/jit/mobile/parse_bytecode.h>
@ -1371,5 +1372,91 @@ TEST(LiteInterpreterTest, OperatorCacheDifferentiatesDefaultArgs) {
testLiteModuleCompareResultTensors(m, inputs, "forward3");
}
TEST(RunTimeTest, RuntimeCall) {
// def call(x):
// return x + x
//
// def forward(a):
// x = a + call(a)
// y = a + call(x)
// return y
std::vector<IValue> instructionsCall{
to_tuple({"STORE", 1, 0}),
to_tuple({"LOAD", 1, 0}),
to_tuple({"MOVE", 1, 0}),
to_tuple({"LOADC", 0, 0}),
to_tuple({"OP", 0, 0}),
to_tuple({"RET", 0, 0}),
};
std::vector<IValue> instructionsFoo{
to_tuple({"STORE", 1, 0}),
to_tuple({"LOAD", 1, 0}),
to_tuple({"LOAD", 1, 0}),
to_tuple({"MOVE", 1, 0}),
to_tuple({"CALL", 0, 0}),
to_tuple({"LOADC", 0, 0}),
to_tuple({"OP", 0, 0}),
to_tuple({"CALL", 0, 0}),
to_tuple({"LOADC", 0, 0}),
to_tuple({"OP", 0, 0}),
to_tuple({"RET", 0, 0}),
};
std::vector<IValue> operatorsFoo{
to_tuple({"aten::add", "Tensor", 3}),
};
std::vector<IValue> constantsFoo{
1,
};
std::vector<IValue> operatorsCall{
to_tuple({"aten::add", "Tensor", 3}),
};
std::vector<IValue> constantsCall{
1,
};
int64_t model_version = caffe2::serialize::kProducedBytecodeVersion;
auto foo = std::make_unique<mobile::Function>(c10::QualifiedName("foo"));
c10::ivalue::TupleElements debug_handles_m_tuple;
parseInstructions(
"foo",
std::move(*c10::ivalue::Tuple::create(instructionsFoo)).elements(),
debug_handles_m_tuple,
foo.get());
parseOperators(
std::move(*c10::ivalue::Tuple::create(operatorsFoo)).elements(),
model_version,
1,
foo.get());
parseConstants(
std::move(*c10::ivalue::Tuple::create(constantsFoo)).elements(),
foo.get());
const size_t rsize = 5;
parseRegisterSize(rsize, foo.get());
auto call = std::make_unique<mobile::Function>(c10::QualifiedName("call"));
parseInstructions(
"call",
std::move(*c10::ivalue::Tuple::create(instructionsCall)).elements(),
debug_handles_m_tuple,
call.get());
parseOperators(
std::move(*c10::ivalue::Tuple::create(operatorsCall)).elements(),
model_version,
1,
call.get());
parseConstants(
std::move(*c10::ivalue::Tuple::create(constantsCall)).elements(),
call.get());
parseRegisterSize(rsize, call.get());
foo->append_function(*call);
std::vector<IValue> inputs{at::tensor(1)};
foo->run(inputs);
auto output = inputs[0];
ASSERT_EQ(output, at::tensor(7));
}
} // namespace jit
} // namespace torch

View File

@ -13,6 +13,8 @@ namespace mobile {
using Stack = std::vector<c10::IValue>;
using DebugHandle = int64_t;
class Function;
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
struct Code {
std::vector<Instruction> instructions_;
@ -21,6 +23,11 @@ struct Code {
std::vector<std::function<void(Stack&)>> operators_;
std::vector<c10::IValue> constants_;
std::vector<c10::TypePtr> types_;
// TODO After we actually export CALL instructions we can remove this.
// We may need a two-stage importing scheme, where we firstly construct all
// function objects, and then append referenced function pointers. This could
// be done in parseMethods().
std::vector<mobile::Function*> functions_;
size_t register_size_; // Aggregated output size.
};

View File

@ -141,6 +141,10 @@ void Function::append_type(const at::TypePtr& type) {
code_->types_.push_back(type);
}
void Function::append_function(mobile::Function& function) {
code_->functions_.push_back(&function);
}
void Function::set_register_size(size_t size) {
code_->register_size_ = size;
}

View File

@ -29,6 +29,7 @@ class Function {
are removed */
void append_constant(const c10::IValue& constant);
void append_type(const c10::TypePtr& type);
TORCH_API void append_function(mobile::Function& func);
void set_register_size(size_t size);

View File

@ -124,6 +124,11 @@ bool InterpreterState::run(Stack& stack) {
code.operators_[inst.X](stack);
frame.step();
} break;
case CALL: {
auto& function = frame.getCode().functions_.at(inst.X);
frame.step();
enterFrame(*function->get_code());
} break;
case INTERFACE_CALL: {
torch::jit::Function& method =
peek(stack, 0, inst.N)

View File

@ -79,7 +79,7 @@ bool isOpSupportedInMobile(OpCode op) {
OP, OPN, LOAD, MOVE, STOREN, STORE, DROP, DROPR, LOADC, JF, JMP, LOOP,
RET, GET_ATTR, SET_ATTR, LIST_CONSTRUCT, TUPLE_CONSTRUCT, WARN,
INTERFACE_CALL, LIST_UNPACK, TUPLE_SLICE, DICT_CONSTRUCT,
NAMED_TUPLE_CONSTRUCT, CREATE_OBJECT, ISINSTANCE
NAMED_TUPLE_CONSTRUCT, CREATE_OBJECT, ISINSTANCE, CALL
};
// clang-format on