mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
introduce module interface declaration (#28408)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/28408 This enable interface to defined on a nn.Module, and the InterfaceType now have a field of is_module_ to distinguish if it's a module interface or a normal interface (This is similar to what ClassType distinguish on module and torchscript classes). The module interface can be assigned with any ScriptModule that has the compatible signatures on schemas. A normal object that is not a ScriptModule will not be able to assigned to an module interface and will error out when user explicitly doing so. Assigning a ScriptModule to class interface will make it only available in attribute_list, not module_list. More details on subtyping relationship documented in the jit_type.h If you declare an module interface inside an nn.Module that is being compiled to a ScriptModule, behavior to our internal compilation will be: 1. ConcreteModuleType will record it as an module attribute and add to the attributes_ list. 2. JitType that is created from the ConcreteModuleType will record it as an attribute and pre-genenerate the slot. The slot will be marked as EntityType::MODULE still to make sure JitType record it as a Module slot 3. cpp_module will also register it as a Module as the Slot type is the source of truth Since JitType will record it as attribute as store its type, it will behave normally as the class interface attribute behave now. This means the submodule assigned to this module interface is not getting inlined into the graph as the normal `Module::attr` behave, it will generate interface callMethod and allow us to later swap this with another ScriptModule that implicitly implements this module interface. Test Plan: Imported from OSS Differential Revision: D18284311 fbshipit-source-id: e0b8f6e8c34b2087fab337a969e5ea3fb37ec209
This commit is contained in:
committed by
Facebook Github Bot
parent
1e904049ca
commit
e95dc9814e
81
test/cpp/jit/test_interface.cpp
Normal file
81
test/cpp/jit/test_interface.cpp
Normal file
@ -0,0 +1,81 @@
|
||||
|
||||
#include <test/cpp/jit/test_base.h>
|
||||
#include <test/cpp/jit/test_utils.h>
|
||||
|
||||
#include <ATen/core/qualified_name.h>
|
||||
#include <torch/csrc/jit/import.h>
|
||||
#include <torch/csrc/jit/import_source.h>
|
||||
#include <torch/csrc/jit/script/resolver.h>
|
||||
#include <torch/torch.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
using namespace torch::jit::script;
|
||||
|
||||
static const std::vector<std::string> subMethodSrcs = {R"JIT(
|
||||
def one(self, x: Tensor, y: Tensor) -> Tensor:
|
||||
return x + y + 1
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return x
|
||||
)JIT"};
|
||||
static const auto parentForward = R"JIT(
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return self.subMod.forward(x)
|
||||
)JIT";
|
||||
|
||||
static const auto moduleInterfaceSrc = R"JIT(
|
||||
class OneForward(ModuleInterface):
|
||||
def one(self, x: Tensor, y: Tensor) -> Tensor:
|
||||
pass
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
pass
|
||||
)JIT";
|
||||
|
||||
static void import_libs(
|
||||
std::shared_ptr<CompilationUnit> cu,
|
||||
const std::string& class_name,
|
||||
const std::shared_ptr<Source>& src,
|
||||
const std::vector<at::Tensor>& tensor_table) {
|
||||
SourceImporter si(
|
||||
cu,
|
||||
&tensor_table,
|
||||
[&](const std::string& name) -> std::shared_ptr<Source> { return src; },
|
||||
/*version=*/2);
|
||||
si.loadNamedType(QualifiedName(class_name));
|
||||
}
|
||||
|
||||
void testModuleInterfaceSerialization() {
|
||||
auto cu = std::make_shared<CompilationUnit>();
|
||||
Module parentMod("parentMod", cu);
|
||||
Module subMod("subMod", cu);
|
||||
|
||||
std::vector<at::Tensor> constantTable;
|
||||
import_libs(
|
||||
cu,
|
||||
"__torch__.OneForward",
|
||||
std::make_shared<Source>(moduleInterfaceSrc),
|
||||
constantTable);
|
||||
|
||||
for (const std::string& method : subMethodSrcs) {
|
||||
subMod.define(method, nativeResolver());
|
||||
}
|
||||
parentMod.register_attribute(
|
||||
"subMod",
|
||||
cu->get_interface("__torch__.OneForward"),
|
||||
subMod.module_object(),
|
||||
/*is_parameter=*/false);
|
||||
parentMod.define(parentForward, nativeResolver());
|
||||
ASSERT_TRUE(parentMod.find_module("subMod").has_value());
|
||||
std::stringstream ss;
|
||||
parentMod.save(ss);
|
||||
Module reloaded_mod = jit::load(ss);
|
||||
ASSERT_TRUE(reloaded_mod.find_module("subMod").has_value());
|
||||
InterfaceTypePtr submodType =
|
||||
reloaded_mod.type()->getAttribute("subMod")->cast<InterfaceType>();
|
||||
ASSERT_TRUE(submodType->is_module());
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
Reference in New Issue
Block a user