mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
[jit][edge] Enable CALL instruction in lite interpreter. (#65964)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/65964 ghstack-source-id: 141425519 Test Plan: buck run xplat/caffe2:test_lite_interpreter Reviewed By: cccclai Differential Revision: D31326149 fbshipit-source-id: 8a599d92f3fa4e6c125100adb36d89592e71e547
This commit is contained in:
committed by
Facebook GitHub Bot
parent
b8dfb45ac2
commit
12daa4f663
@ -9,6 +9,7 @@
|
||||
#include <torch/csrc/jit/mobile/backport.h>
|
||||
#include <torch/csrc/jit/mobile/backport_manager.h>
|
||||
#include <torch/csrc/jit/mobile/import.h>
|
||||
#include <torch/csrc/jit/mobile/interpreter.h>
|
||||
#include <torch/csrc/jit/mobile/model_compatibility.h>
|
||||
#include <torch/csrc/jit/mobile/module.h>
|
||||
#include <torch/csrc/jit/mobile/parse_bytecode.h>
|
||||
@ -1371,5 +1372,91 @@ TEST(LiteInterpreterTest, OperatorCacheDifferentiatesDefaultArgs) {
|
||||
testLiteModuleCompareResultTensors(m, inputs, "forward3");
|
||||
}
|
||||
|
||||
TEST(RunTimeTest, RuntimeCall) {
|
||||
// def call(x):
|
||||
// return x + x
|
||||
//
|
||||
// def forward(a):
|
||||
// x = a + call(a)
|
||||
// y = a + call(x)
|
||||
// return y
|
||||
|
||||
std::vector<IValue> instructionsCall{
|
||||
to_tuple({"STORE", 1, 0}),
|
||||
to_tuple({"LOAD", 1, 0}),
|
||||
to_tuple({"MOVE", 1, 0}),
|
||||
to_tuple({"LOADC", 0, 0}),
|
||||
to_tuple({"OP", 0, 0}),
|
||||
to_tuple({"RET", 0, 0}),
|
||||
};
|
||||
std::vector<IValue> instructionsFoo{
|
||||
to_tuple({"STORE", 1, 0}),
|
||||
to_tuple({"LOAD", 1, 0}),
|
||||
to_tuple({"LOAD", 1, 0}),
|
||||
to_tuple({"MOVE", 1, 0}),
|
||||
to_tuple({"CALL", 0, 0}),
|
||||
to_tuple({"LOADC", 0, 0}),
|
||||
to_tuple({"OP", 0, 0}),
|
||||
to_tuple({"CALL", 0, 0}),
|
||||
to_tuple({"LOADC", 0, 0}),
|
||||
to_tuple({"OP", 0, 0}),
|
||||
to_tuple({"RET", 0, 0}),
|
||||
};
|
||||
std::vector<IValue> operatorsFoo{
|
||||
to_tuple({"aten::add", "Tensor", 3}),
|
||||
};
|
||||
std::vector<IValue> constantsFoo{
|
||||
1,
|
||||
};
|
||||
std::vector<IValue> operatorsCall{
|
||||
to_tuple({"aten::add", "Tensor", 3}),
|
||||
};
|
||||
std::vector<IValue> constantsCall{
|
||||
1,
|
||||
};
|
||||
int64_t model_version = caffe2::serialize::kProducedBytecodeVersion;
|
||||
|
||||
auto foo = std::make_unique<mobile::Function>(c10::QualifiedName("foo"));
|
||||
c10::ivalue::TupleElements debug_handles_m_tuple;
|
||||
parseInstructions(
|
||||
"foo",
|
||||
std::move(*c10::ivalue::Tuple::create(instructionsFoo)).elements(),
|
||||
debug_handles_m_tuple,
|
||||
foo.get());
|
||||
parseOperators(
|
||||
std::move(*c10::ivalue::Tuple::create(operatorsFoo)).elements(),
|
||||
model_version,
|
||||
1,
|
||||
foo.get());
|
||||
parseConstants(
|
||||
std::move(*c10::ivalue::Tuple::create(constantsFoo)).elements(),
|
||||
foo.get());
|
||||
const size_t rsize = 5;
|
||||
parseRegisterSize(rsize, foo.get());
|
||||
|
||||
auto call = std::make_unique<mobile::Function>(c10::QualifiedName("call"));
|
||||
parseInstructions(
|
||||
"call",
|
||||
std::move(*c10::ivalue::Tuple::create(instructionsCall)).elements(),
|
||||
debug_handles_m_tuple,
|
||||
call.get());
|
||||
parseOperators(
|
||||
std::move(*c10::ivalue::Tuple::create(operatorsCall)).elements(),
|
||||
model_version,
|
||||
1,
|
||||
call.get());
|
||||
parseConstants(
|
||||
std::move(*c10::ivalue::Tuple::create(constantsCall)).elements(),
|
||||
call.get());
|
||||
parseRegisterSize(rsize, call.get());
|
||||
|
||||
foo->append_function(*call);
|
||||
|
||||
std::vector<IValue> inputs{at::tensor(1)};
|
||||
foo->run(inputs);
|
||||
auto output = inputs[0];
|
||||
ASSERT_EQ(output, at::tensor(7));
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
@ -13,6 +13,8 @@ namespace mobile {
|
||||
using Stack = std::vector<c10::IValue>;
|
||||
using DebugHandle = int64_t;
|
||||
|
||||
class Function;
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
||||
struct Code {
|
||||
std::vector<Instruction> instructions_;
|
||||
@ -21,6 +23,11 @@ struct Code {
|
||||
std::vector<std::function<void(Stack&)>> operators_;
|
||||
std::vector<c10::IValue> constants_;
|
||||
std::vector<c10::TypePtr> types_;
|
||||
// TODO After we actually export CALL instructions we can remove this.
|
||||
// We may need a two-stage importing scheme, where we firstly construct all
|
||||
// function objects, and then append referenced function pointers. This could
|
||||
// be done in parseMethods().
|
||||
std::vector<mobile::Function*> functions_;
|
||||
size_t register_size_; // Aggregated output size.
|
||||
};
|
||||
|
||||
|
@ -141,6 +141,10 @@ void Function::append_type(const at::TypePtr& type) {
|
||||
code_->types_.push_back(type);
|
||||
}
|
||||
|
||||
void Function::append_function(mobile::Function& function) {
|
||||
code_->functions_.push_back(&function);
|
||||
}
|
||||
|
||||
void Function::set_register_size(size_t size) {
|
||||
code_->register_size_ = size;
|
||||
}
|
||||
|
@ -29,6 +29,7 @@ class Function {
|
||||
are removed */
|
||||
void append_constant(const c10::IValue& constant);
|
||||
void append_type(const c10::TypePtr& type);
|
||||
TORCH_API void append_function(mobile::Function& func);
|
||||
|
||||
void set_register_size(size_t size);
|
||||
|
||||
|
@ -124,6 +124,11 @@ bool InterpreterState::run(Stack& stack) {
|
||||
code.operators_[inst.X](stack);
|
||||
frame.step();
|
||||
} break;
|
||||
case CALL: {
|
||||
auto& function = frame.getCode().functions_.at(inst.X);
|
||||
frame.step();
|
||||
enterFrame(*function->get_code());
|
||||
} break;
|
||||
case INTERFACE_CALL: {
|
||||
torch::jit::Function& method =
|
||||
peek(stack, 0, inst.N)
|
||||
|
@ -79,7 +79,7 @@ bool isOpSupportedInMobile(OpCode op) {
|
||||
OP, OPN, LOAD, MOVE, STOREN, STORE, DROP, DROPR, LOADC, JF, JMP, LOOP,
|
||||
RET, GET_ATTR, SET_ATTR, LIST_CONSTRUCT, TUPLE_CONSTRUCT, WARN,
|
||||
INTERFACE_CALL, LIST_UNPACK, TUPLE_SLICE, DICT_CONSTRUCT,
|
||||
NAMED_TUPLE_CONSTRUCT, CREATE_OBJECT, ISINSTANCE
|
||||
NAMED_TUPLE_CONSTRUCT, CREATE_OBJECT, ISINSTANCE, CALL
|
||||
};
|
||||
// clang-format on
|
||||
|
||||
|
Reference in New Issue
Block a user