mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-23 14:59:34 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/42121 This PR changes the Module API to allow register a module with module interface type, and therefore allows Module::clone works on the case where there's a module interface type being shared by two submodules. interface type will be shared by the new cloned instance in the same compilation unit bc it only contains a list of functionSchema, which does not involve any attributes compared to classType. fixes https://github.com/pytorch/pytorch/issues/41882 Test Plan: Imported from OSS Reviewed By: suo Differential Revision: D22781205 Pulled By: wanchaol fbshipit-source-id: f97f4b75970f0b434e38b5a1f778eda2c4e5109b
296 lines
9.3 KiB
C++
296 lines
9.3 KiB
C++
#include <test/cpp/jit/test_base.h>
|
|
#include <test/cpp/jit/test_utils.h>
|
|
|
|
#include <ATen/core/qualified_name.h>
|
|
#include <torch/csrc/jit/frontend/resolver.h>
|
|
#include <torch/csrc/jit/serialization/import.h>
|
|
#include <torch/csrc/jit/serialization/import_source.h>
|
|
#include <torch/torch.h>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
|
|
static const auto moduleInterfaceSrc = R"JIT(
|
|
class OneInterface(ModuleInterface):
|
|
def one(self, x: Tensor, y: Tensor) -> Tensor:
|
|
pass
|
|
)JIT";
|
|
|
|
static const std::vector<std::string> subModuleMethodsSrc = {R"JIT(
|
|
def one(self, x: Tensor, y: Tensor) -> Tensor:
|
|
return self.attr * x + y + 1
|
|
|
|
def forward(self, x: Tensor) -> Tensor:
|
|
return self.attr + x
|
|
)JIT"};
|
|
|
|
static const auto parentForward = R"JIT(
|
|
def forward(self, x: Tensor) -> Tensor:
|
|
return self.subMod1.one(x, x) + self.subMod2.one(x, x)
|
|
)JIT";
|
|
|
|
static void import_libs(
|
|
std::shared_ptr<CompilationUnit> cu,
|
|
const std::string& class_name,
|
|
const std::shared_ptr<Source>& src,
|
|
const std::vector<at::IValue>& tensor_table) {
|
|
SourceImporter si(
|
|
cu,
|
|
&tensor_table,
|
|
[&](const std::string& name) -> std::shared_ptr<Source> { return src; },
|
|
/*version=*/2);
|
|
si.loadType(QualifiedName(class_name));
|
|
}
|
|
|
|
void testModuleClone() {
|
|
auto cu = std::make_shared<CompilationUnit>();
|
|
// creating child module
|
|
auto child = ClassType::create("child", cu, true);
|
|
auto attr_name = "attr";
|
|
child->addAttribute(attr_name, IntType::get());
|
|
Module c1(cu, child);
|
|
auto v1 = IValue(2);
|
|
c1.register_attribute(attr_name, IntType::get(), v1, false);
|
|
Module c2(cu, child);
|
|
auto v2 = IValue(3);
|
|
c2.register_attribute(attr_name, IntType::get(), v2, false);
|
|
|
|
// attach two child module instance to parent that shares
|
|
// ClassType
|
|
auto parent = ClassType::create("parent", cu, true);
|
|
Module p(cu, parent);
|
|
p.register_attribute("c1", c1.type(), c1._ivalue(), false);
|
|
p.register_attribute("c2", c2.type(), c2._ivalue(), false);
|
|
|
|
// clone parent
|
|
Module p2 = p.clone();
|
|
// check the two child module has the same ClassType
|
|
ASSERT_EQ(p2.attr("c1").type(), p2.attr("c2").type());
|
|
// but different instances
|
|
ASSERT_EQ(Module(p2.attr("c1").toObject()).attr(attr_name).toInt(), 2);
|
|
ASSERT_EQ(Module(p2.attr("c2").toObject()).attr(attr_name).toInt(), 3);
|
|
}
|
|
|
|
void testModuleCloneWithModuleInterface() {
|
|
auto cu = std::make_shared<CompilationUnit>();
|
|
|
|
// define a initial module with two submods share same interface
|
|
Module parentMod("parentMod", cu);
|
|
Module subMod1("subMod1", cu);
|
|
Module subMod2("subMod2", cu);
|
|
|
|
std::vector<at::IValue> constantTable;
|
|
import_libs(
|
|
cu,
|
|
"__torch__.OneInterface",
|
|
std::make_shared<Source>(moduleInterfaceSrc),
|
|
constantTable);
|
|
|
|
auto v1 = IValue(2);
|
|
subMod1.register_attribute("attr", IntType::get(), v1, false);
|
|
|
|
auto v2 = IValue(4);
|
|
subMod2.register_attribute("attr", IntType::get(), v2, false);
|
|
|
|
for (const std::string& method : subModuleMethodsSrc) {
|
|
subMod1.define(method, nativeResolver());
|
|
subMod2.define(method, nativeResolver());
|
|
}
|
|
|
|
parentMod.register_attribute(
|
|
"subMod1",
|
|
cu->get_interface("__torch__.OneInterface"),
|
|
subMod1._ivalue());
|
|
parentMod.register_attribute(
|
|
"subMod2",
|
|
cu->get_interface("__torch__.OneInterface"),
|
|
subMod2._ivalue());
|
|
|
|
parentMod.define(parentForward, nativeResolver());
|
|
|
|
Module clonedMod = parentMod.clone();
|
|
|
|
// clone will copy both type and data, therefore we'll have a
|
|
// different type
|
|
ASSERT_NE(clonedMod.type(), parentMod.type());
|
|
}
|
|
|
|
void testModuleCopy() {
|
|
auto cu = std::make_shared<CompilationUnit>();
|
|
auto cls = ClassType::create("foo.bar", cu, true);
|
|
auto attr_name = "attr";
|
|
cls->addAttribute(attr_name, IntType::get());
|
|
Module m(cu, cls);
|
|
auto v = IValue(2);
|
|
m.register_attribute(attr_name, IntType::get(), v, false);
|
|
|
|
Module m2 = m.clone();
|
|
Module m3 = m.copy();
|
|
|
|
// Make sure copy works
|
|
ASSERT_EQ(m2.attr(attr_name).toInt(), 2);
|
|
ASSERT_EQ(m3.attr(attr_name).toInt(), 2);
|
|
|
|
// clone will copy both type and data, therefore we'll have a
|
|
// different type
|
|
ASSERT_NE(m.type(), m2.type());
|
|
// copy only copies data, type is shared
|
|
ASSERT_EQ(m.type(), m3.type());
|
|
|
|
// change value of copied instance
|
|
m3.register_attribute(attr_name, IntType::get(), IValue(3), false);
|
|
// Verify value of original instance doesn't change
|
|
ASSERT_EQ(m2.attr(attr_name).toInt(), 2);
|
|
ASSERT_EQ(m3.attr(attr_name).toInt(), 3);
|
|
}
|
|
|
|
void testModuleDeepcopy() {
|
|
auto cu = std::make_shared<CompilationUnit>();
|
|
auto cls = ClassType::create("foo.bar", cu, true);
|
|
auto str_attr = "str_attr";
|
|
auto int_attr = "int_attr";
|
|
auto tensor_attr = "tensor_attr";
|
|
auto tensor_list_attr = "tensor_list_attr";
|
|
cls->addAttribute(int_attr, IntType::get());
|
|
cls->addAttribute(str_attr, StringType::get());
|
|
cls->addAttribute(tensor_attr, TensorType::get());
|
|
cls->addAttribute(tensor_list_attr, ListType::ofTensors());
|
|
Module m(cu, cls);
|
|
c10::List<at::Tensor> list({at::rand(5), at::rand(5)});
|
|
m.setattr(int_attr, IValue(2));
|
|
m.setattr(str_attr, IValue("str"));
|
|
m.setattr(tensor_attr, at::randn(5));
|
|
m.setattr(tensor_list_attr, list);
|
|
|
|
Module m2 = m.deepcopy();
|
|
Module m3 = m.copy();
|
|
// Make sure copy works
|
|
ASSERT_EQ(m2.attr(int_attr).toInt(), 2);
|
|
ASSERT_EQ(m3.attr(int_attr).toInt(), 2);
|
|
|
|
// Test overlaps
|
|
ASSERT_TRUE(!IValue(m2._ivalue()).overlaps(IValue(m._ivalue())));
|
|
ASSERT_TRUE(IValue(m3._ivalue()).overlaps(IValue(m._ivalue())));
|
|
|
|
// Both deepcopy and copy will preserve the type
|
|
ASSERT_EQ(m.type(), m2.type());
|
|
ASSERT_EQ(m.type(), m3.type());
|
|
|
|
// change int value of copied instances
|
|
m2.setattr(int_attr, IValue(3));
|
|
m3.setattr(int_attr, IValue(4));
|
|
|
|
// Verify value of original instance doesn't change
|
|
ASSERT_EQ(m.attr(int_attr).toInt(), 2);
|
|
ASSERT_EQ(m2.attr(int_attr).toInt(), 3);
|
|
ASSERT_EQ(m3.attr(int_attr).toInt(), 4);
|
|
|
|
// change Tensor value of copied instances
|
|
at::Tensor t1 = m.attr(tensor_attr).toTensor();
|
|
at::Tensor t2 =
|
|
m2.attr(tensor_attr).toTensor(); // deepcopy will copy the Tensor
|
|
at::Tensor t3 =
|
|
m3.attr(tensor_attr).toTensor(); // copy will not copy the Tensor
|
|
// check copy works
|
|
ASSERT_TRUE(t1.equal(t2));
|
|
ASSERT_TRUE(t1.equal(t3));
|
|
|
|
// zero out t1
|
|
t1.zero_();
|
|
// check that t2 is not affected because it is a deep copy
|
|
ASSERT_TRUE(!t1.equal(t2));
|
|
// check that t3 is the same as t1 since it is a shallow copy
|
|
ASSERT_TRUE(t1.equal(t3));
|
|
}
|
|
|
|
void testModuleDeepcopyString() {
|
|
auto cu = std::make_shared<CompilationUnit>();
|
|
auto cls = ClassType::create("foo.bar", cu, true);
|
|
auto attr1 = "attr1";
|
|
cls->addAttribute(attr1, StringType::get());
|
|
std::string str = "str";
|
|
Module m(cu, cls);
|
|
m.setattr(attr1, str);
|
|
auto copied = m.deepcopy();
|
|
auto original_str = str;
|
|
ASSERT_EQ(copied.attr(attr1).toString()->string(), original_str);
|
|
// check string mutation is not reflected in the copied module
|
|
str += "str";
|
|
ASSERT_EQ(copied.attr(attr1).toString()->string(), original_str);
|
|
}
|
|
|
|
void testModuleDeepcopyAliasing() {
|
|
// check deepcopy preserves aliasing
|
|
auto cu = std::make_shared<CompilationUnit>();
|
|
auto cls = ClassType::create("foo.bar", cu, true);
|
|
auto attr1 = "attr1";
|
|
auto attr2 = "attr2";
|
|
auto attr3 = "attr3";
|
|
auto attr4 = "attr4";
|
|
cls->addAttribute(attr1, ListType::ofTensors());
|
|
cls->addAttribute(attr2, ListType::ofTensors());
|
|
cls->addAttribute(attr3, TensorType::get());
|
|
cls->addAttribute(attr4, TensorType::get());
|
|
Module m(cu, cls);
|
|
auto t1 = at::rand(5);
|
|
auto t2 = at::rand(5);
|
|
auto t3 = at::rand(5);
|
|
auto t4 = at::rand({5, 2});
|
|
c10::List<at::Tensor> list1({t1, t2});
|
|
c10::List<at::Tensor> list2({t1, t3});
|
|
// first element of attr1 and attr2 are aliased
|
|
m.setattr(attr1, list1);
|
|
m.setattr(attr2, list2);
|
|
m.setattr(attr3, t4);
|
|
m.setattr(attr4, t4.view(-1));
|
|
|
|
auto copied = m.deepcopy();
|
|
// test tensor aliasing
|
|
auto copied_attr1_t1 = copied.attr(attr1).toList().get(0);
|
|
auto copied_attr2_t1 = copied.attr(attr2).toList().get(0);
|
|
ASSERT_TRUE(copied_attr1_t1.isAliasOf(copied_attr2_t1));
|
|
|
|
// test aliasing from view
|
|
auto copied_attr3 = copied.attr(attr3);
|
|
auto copied_attr4 = copied.attr(attr3);
|
|
ASSERT_TRUE(copied_attr3.isAliasOf(copied_attr4));
|
|
}
|
|
|
|
void testModuleConstant() {
|
|
auto cu = std::make_shared<CompilationUnit>();
|
|
auto cls = ClassType::create("foo.bar", cu, true);
|
|
auto attr_name = "attr";
|
|
auto const_name = "const";
|
|
cls->addAttribute(attr_name, IntType::get());
|
|
cls->addConstant(const_name, IValue(3));
|
|
Module m(cu, cls);
|
|
auto v = IValue(2);
|
|
m.register_attribute(attr_name, IntType::get(), v, false);
|
|
ASSERT_TRUE(m.hasattr(attr_name));
|
|
ASSERT_TRUE(m.hasattr(const_name));
|
|
ASSERT_EQ(m.attr(attr_name).toInt(), 2);
|
|
ASSERT_EQ(m.attr(const_name).toInt(), 3);
|
|
}
|
|
|
|
void testModuleParameter() {
|
|
auto cu = std::make_shared<CompilationUnit>();
|
|
auto cls = ClassType::create("foo.bar", cu, true);
|
|
Module m(cu, cls);
|
|
// Tensor parameter
|
|
m.register_parameter(
|
|
"tensor_param", at::empty({3}, at::kFloat), /* is_buffer */ false);
|
|
// None parameter
|
|
m.register_attribute(
|
|
"none_param", NoneType::get(), IValue(), /* is_param */ true);
|
|
m.register_attribute(
|
|
"none_param2", NoneType::get(), IValue(), /* is_param */ true);
|
|
auto param_list = m.parameters();
|
|
ASSERT_EQ(param_list.size(), 1);
|
|
ASSERT_TRUE(m.hasattr("tensor_param"));
|
|
ASSERT_TRUE(m.hasattr("none_param"));
|
|
ASSERT_TRUE(m.hasattr("none_param2"));
|
|
}
|
|
|
|
} // namespace jit
|
|
} // namespace torch
|