mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/33851 Rationale and context described in #33828. Script to reproduce the move: https://gist.github.com/suo/16cbefaaeb67ca5a7c6caffd49b7f6e9 ghstack-source-id: 99079645 Test Plan: Make sure CI passes Reviewed By: jamesr66a Differential Revision: D20133869 fbshipit-source-id: 390e9241a9c85366d9005c492ac31f10aa96488e
82 lines
2.2 KiB
C++
82 lines
2.2 KiB
C++
|
|
#include <test/cpp/jit/test_base.h>
|
|
#include <test/cpp/jit/test_utils.h>
|
|
|
|
#include <ATen/core/qualified_name.h>
|
|
#include <torch/csrc/jit/serialization/import.h>
|
|
#include <torch/csrc/jit/serialization/import_source.h>
|
|
#include <torch/csrc/jit/frontend/resolver.h>
|
|
#include <torch/torch.h>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
|
|
using namespace torch::jit::script;
|
|
|
|
static const std::vector<std::string> subMethodSrcs = {R"JIT(
|
|
def one(self, x: Tensor, y: Tensor) -> Tensor:
|
|
return x + y + 1
|
|
|
|
def forward(self, x: Tensor) -> Tensor:
|
|
return x
|
|
)JIT"};
|
|
static const auto parentForward = R"JIT(
|
|
def forward(self, x: Tensor) -> Tensor:
|
|
return self.subMod.forward(x)
|
|
)JIT";
|
|
|
|
static const auto moduleInterfaceSrc = R"JIT(
|
|
class OneForward(ModuleInterface):
|
|
def one(self, x: Tensor, y: Tensor) -> Tensor:
|
|
pass
|
|
def forward(self, x: Tensor) -> Tensor:
|
|
pass
|
|
)JIT";
|
|
|
|
static void import_libs(
|
|
std::shared_ptr<CompilationUnit> cu,
|
|
const std::string& class_name,
|
|
const std::shared_ptr<Source>& src,
|
|
const std::vector<at::Tensor>& tensor_table) {
|
|
SourceImporter si(
|
|
cu,
|
|
&tensor_table,
|
|
[&](const std::string& name) -> std::shared_ptr<Source> { return src; },
|
|
/*version=*/2);
|
|
si.loadNamedType(QualifiedName(class_name));
|
|
}
|
|
|
|
void testModuleInterfaceSerialization() {
|
|
auto cu = std::make_shared<CompilationUnit>();
|
|
Module parentMod("parentMod", cu);
|
|
Module subMod("subMod", cu);
|
|
|
|
std::vector<at::Tensor> constantTable;
|
|
import_libs(
|
|
cu,
|
|
"__torch__.OneForward",
|
|
std::make_shared<Source>(moduleInterfaceSrc),
|
|
constantTable);
|
|
|
|
for (const std::string& method : subMethodSrcs) {
|
|
subMod.define(method, nativeResolver());
|
|
}
|
|
parentMod.register_attribute(
|
|
"subMod",
|
|
cu->get_interface("__torch__.OneForward"),
|
|
subMod._ivalue(),
|
|
/*is_parameter=*/false);
|
|
parentMod.define(parentForward, nativeResolver());
|
|
ASSERT_TRUE(parentMod.hasattr("subMod"));
|
|
std::stringstream ss;
|
|
parentMod.save(ss);
|
|
Module reloaded_mod = jit::load(ss);
|
|
ASSERT_TRUE(reloaded_mod.hasattr("subMod"));
|
|
InterfaceTypePtr submodType =
|
|
reloaded_mod.type()->getAttribute("subMod")->cast<InterfaceType>();
|
|
ASSERT_TRUE(submodType->is_module());
|
|
}
|
|
|
|
} // namespace jit
|
|
} // namespace torch
|