mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-24 23:54:56 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/29219 We added class constant in previous PRs, this PR allows access to class constant in the object API Test Plan: build/bin/test_jit python test/test_jit.py Imported from OSS Differential Revision: D18846851 fbshipit-source-id: 888a6517d5f747d1f8ced283c0c2c30b2f6c72c6
65 lines
1.9 KiB
C++
65 lines
1.9 KiB
C++
#include <test/cpp/jit/test_base.h>
|
|
#include <test/cpp/jit/test_utils.h>
|
|
#include <torch/torch.h>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
|
|
using namespace torch::jit::script;
|
|
|
|
void testModuleCloneInstance() {
|
|
auto cu = std::make_shared<CompilationUnit>();
|
|
auto cls = ClassType::create("foo.bar", cu, true);
|
|
auto attr_name = "attr";
|
|
cls->addAttribute(attr_name, IntType::get());
|
|
Module m(cu, cls);
|
|
auto v = IValue(2);
|
|
m.register_attribute(attr_name,
|
|
IntType::get(),
|
|
v,
|
|
false);
|
|
|
|
Module m2 = m.clone();
|
|
Module m3 = m.clone_instance();
|
|
// Make sure copy works
|
|
ASSERT_EQ(m2.attr(attr_name).toInt(), 2);
|
|
ASSERT_EQ(m3.attr(attr_name).toInt(), 2);
|
|
|
|
// clone will copy both type and data, therefore we'll have a
|
|
// different type
|
|
ASSERT_NE(m.type(), m2.type());
|
|
// clone_instance only copies data, type is shared
|
|
ASSERT_EQ(m.type(), m3.type());
|
|
|
|
// change value of copied instance
|
|
m3.register_attribute(attr_name,
|
|
IntType::get(),
|
|
IValue(3),
|
|
false);
|
|
// Verify value of original instance doesn't change
|
|
ASSERT_EQ(m2.attr(attr_name).toInt(), 2);
|
|
ASSERT_EQ(m3.attr(attr_name).toInt(), 3);
|
|
}
|
|
|
|
void testModuleConstant() {
|
|
auto cu = std::make_shared<CompilationUnit>();
|
|
auto cls = ClassType::create("foo.bar", cu, true);
|
|
auto attr_name = "attr";
|
|
auto const_name = "const";
|
|
cls->addAttribute(attr_name, IntType::get());
|
|
cls->addConstant(const_name, IValue(3));
|
|
Module m(cu, cls);
|
|
auto v = IValue(2);
|
|
m.register_attribute(attr_name,
|
|
IntType::get(),
|
|
v,
|
|
false);
|
|
ASSERT_TRUE(m.hasattr(attr_name));
|
|
ASSERT_TRUE(m.hasattr(const_name));
|
|
ASSERT_EQ(m.attr(attr_name).toInt(), 2);
|
|
ASSERT_EQ(m.attr(const_name).toInt(), 3);
|
|
}
|
|
|
|
} // namespace jit
|
|
} // namespace torch
|