mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-05 16:44:58 +08:00
Add clone_instance for Module (#30168)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/30168 Previous implementation of `clone` in `script::Module` copies both the module instance and the class type, after we enabled type sharing https://github.com/pytorch/pytorch/pull/26666 we also need to have a function to clone instance only and share the underlying class type. Test Plan: tbd Imported from OSS Differential Revision: D18631324 fbshipit-source-id: dbadcf19695faee0f755f45093b24618c047b9d1
This commit is contained in:
committed by
Facebook Github Bot
parent
2c1c6de122
commit
1bba0eb35b
45
test/cpp/jit/test_module_api.cpp
Normal file
45
test/cpp/jit/test_module_api.cpp
Normal file
@ -0,0 +1,45 @@
|
||||
#include <test/cpp/jit/test_base.h>
|
||||
#include <test/cpp/jit/test_utils.h>
|
||||
#include <torch/torch.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
using namespace torch::jit::script;
|
||||
|
||||
void testModuleCloneInstance() {
|
||||
auto cu = std::make_shared<CompilationUnit>();
|
||||
auto cls = ClassType::create("foo.bar", cu, true);
|
||||
auto attr_name = "attr";
|
||||
cls->addAttribute(attr_name, IntType::get());
|
||||
Module m(cu, cls);
|
||||
auto v = IValue(2);
|
||||
m.register_attribute(attr_name,
|
||||
IntType::get(),
|
||||
v,
|
||||
false);
|
||||
|
||||
Module m2 = m.clone();
|
||||
Module m3 = m.clone_instance();
|
||||
// Make sure copy works
|
||||
ASSERT_EQ(m2.attr(attr_name).toInt(), 2);
|
||||
ASSERT_EQ(m3.attr(attr_name).toInt(), 2);
|
||||
|
||||
// clone will copy both type and data, therefore we'll have a
|
||||
// different type
|
||||
ASSERT_NE(m.type(), m2.type());
|
||||
// clone_instance only copies data, type is shared
|
||||
ASSERT_EQ(m.type(), m3.type());
|
||||
|
||||
// change value of copied instance
|
||||
m3.register_attribute(attr_name,
|
||||
IntType::get(),
|
||||
IValue(3),
|
||||
false);
|
||||
// Verify value of original instance doesn't change
|
||||
ASSERT_EQ(m2.attr(attr_name).toInt(), 2);
|
||||
ASSERT_EQ(m3.attr(attr_name).toInt(), 3);
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
Reference in New Issue
Block a user