[jit][edge] Load interface methods to corresponding ClassTypes. (#65971)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/65971

ghstack-source-id: 141842335

We should be able to load methods into their ClassTypes. Right now mobile runtime only loads data member to ClassTypes but not for methods. To support interface call, we inject methods into ClassTypes when the methods are loaded.

Test Plan: existing tests should all pass.

Reviewed By: qihqi

Differential Revision: D31326146

fbshipit-source-id: fb1dbea619910ef1f8fa26146da3ebab348fe902
This commit is contained in:
Zhengxu Chen
2021-10-29 12:47:43 -07:00
committed by Facebook GitHub Bot
parent 6259601c8a
commit d6b15bfcbd
2 changed files with 19 additions and 2 deletions

View File

@ -3,6 +3,8 @@
#include <torch/csrc/jit/mobile/parse_operators.h>
#include <ATen/core/ivalue.h>
#include <ATen/core/qualified_name.h>
#include <c10/util/Exception.h>
#include <c10/util/ScopeExit.h>
#include <c10/util/irange.h>
#include <caffe2/serialize/inline_container.h>
@ -171,6 +173,19 @@ bool isTensorInBytecodeArchive(
namespace {
void tryRegisterMethod(const std::vector<c10::Argument>& args, Function& func) {
if (args.empty() || args[0].name() != "self") {
return;
}
if (auto cls = args[0].type()->castRaw<ClassType>()) {
if (C10_UNLIKELY(cls->findMethod(func.name()))) {
return;
}
cls->addMethod(&func);
}
}
// The deserializer class which loads the bytecode package from bc files.
class BytecodeDeserializer final {
public:
@ -227,7 +242,8 @@ void BytecodeDeserializer::parseFunctionSchema(
mobile::Function* function) {
// function schema
if (schemaTable) { // (schema is optional for back compat)
auto parseArgList = [this](c10::ivalue::TupleElements&& argTables) {
auto parseArgList = [this,
function](c10::ivalue::TupleElements&& argTables) {
std::vector<c10::Argument> args;
for (auto&& argTable : std::move(argTables)) {
auto argTableElements =
@ -249,6 +265,7 @@ void BytecodeDeserializer::parseFunctionSchema(
c10::nullopt /*N*/,
std::move(default_value));
}
tryRegisterMethod(args, *function);
return args;
};
auto schemaTableElements =

View File

@ -36,7 +36,7 @@ Method Module::get_method(const std::string& name) const {
}
c10::optional<Method> Module::find_method(const std::string& basename) const {
for (auto& fn : cu_->methods()) {
for (const auto& fn : cu_->methods()) {
if (fn->name() == basename) {
return c10::make_optional<Method>(Method(this, fn.get()));
}