mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[jit][edge] Load interface methods to corresponding ClassTypes. (#65971)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/65971 ghstack-source-id: 141842335 We should be able to load methods into their ClassTypes. Right now mobile runtime only loads data member to ClassTypes but not for methods. To support interface call, we inject methods into ClassTypes when the methods are loaded. Test Plan: existing tests should all pass. Reviewed By: qihqi Differential Revision: D31326146 fbshipit-source-id: fb1dbea619910ef1f8fa26146da3ebab348fe902
This commit is contained in:
committed by
Facebook GitHub Bot
parent
6259601c8a
commit
d6b15bfcbd
@ -3,6 +3,8 @@
|
||||
#include <torch/csrc/jit/mobile/parse_operators.h>
|
||||
|
||||
#include <ATen/core/ivalue.h>
|
||||
#include <ATen/core/qualified_name.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/ScopeExit.h>
|
||||
#include <c10/util/irange.h>
|
||||
#include <caffe2/serialize/inline_container.h>
|
||||
@ -171,6 +173,19 @@ bool isTensorInBytecodeArchive(
|
||||
|
||||
namespace {
|
||||
|
||||
void tryRegisterMethod(const std::vector<c10::Argument>& args, Function& func) {
|
||||
if (args.empty() || args[0].name() != "self") {
|
||||
return;
|
||||
}
|
||||
|
||||
if (auto cls = args[0].type()->castRaw<ClassType>()) {
|
||||
if (C10_UNLIKELY(cls->findMethod(func.name()))) {
|
||||
return;
|
||||
}
|
||||
cls->addMethod(&func);
|
||||
}
|
||||
}
|
||||
|
||||
// The deserializer class which loads the bytecode package from bc files.
|
||||
class BytecodeDeserializer final {
|
||||
public:
|
||||
@ -227,7 +242,8 @@ void BytecodeDeserializer::parseFunctionSchema(
|
||||
mobile::Function* function) {
|
||||
// function schema
|
||||
if (schemaTable) { // (schema is optional for back compat)
|
||||
auto parseArgList = [this](c10::ivalue::TupleElements&& argTables) {
|
||||
auto parseArgList = [this,
|
||||
function](c10::ivalue::TupleElements&& argTables) {
|
||||
std::vector<c10::Argument> args;
|
||||
for (auto&& argTable : std::move(argTables)) {
|
||||
auto argTableElements =
|
||||
@ -249,6 +265,7 @@ void BytecodeDeserializer::parseFunctionSchema(
|
||||
c10::nullopt /*N*/,
|
||||
std::move(default_value));
|
||||
}
|
||||
tryRegisterMethod(args, *function);
|
||||
return args;
|
||||
};
|
||||
auto schemaTableElements =
|
||||
|
@ -36,7 +36,7 @@ Method Module::get_method(const std::string& name) const {
|
||||
}
|
||||
|
||||
c10::optional<Method> Module::find_method(const std::string& basename) const {
|
||||
for (auto& fn : cu_->methods()) {
|
||||
for (const auto& fn : cu_->methods()) {
|
||||
if (fn->name() == basename) {
|
||||
return c10::make_optional<Method>(Method(this, fn.get()));
|
||||
}
|
||||
|
Reference in New Issue
Block a user