mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
Implement more of of the nn.Module API (#28828)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/28828 This updates torch::script::Module to more closely match the behavior of nn.Module. In particular, it implements the (optionally recurisive) iterators that retrieve submodules, parameters, and buffers and makes their names match the python versions. This also removes the individual accessors for Parameter, Module, Buffer, etc. and replaces them with a single `attr` function which is equivalent to writing `a.foo` in Python (`setattr` emulates `a.foo = v`). As we build out the user-facing API for TorchScript values this will end up matching how an attribute is accessed on general objects. This PR preservers the python bindings for script::Module by emulating the old API at the binding level. A followup will clean up the usage to more directly match the C++ API. Test Plan: Imported from OSS Differential Revision: D18197611 Pulled By: zdevito fbshipit-source-id: 7ee4dcbb258605d1c988314b05d938423f1ccee5
This commit is contained in:
committed by
Facebook Github Bot
parent
509d9630ca
commit
796363147f
@ -67,11 +67,11 @@ void testModuleInterfaceSerialization() {
|
||||
subMod.module_object(),
|
||||
/*is_parameter=*/false);
|
||||
parentMod.define(parentForward, nativeResolver());
|
||||
ASSERT_TRUE(parentMod.find_module("subMod").has_value());
|
||||
ASSERT_TRUE(parentMod.hasattr("subMod"));
|
||||
std::stringstream ss;
|
||||
parentMod.save(ss);
|
||||
Module reloaded_mod = jit::load(ss);
|
||||
ASSERT_TRUE(reloaded_mod.find_module("subMod").has_value());
|
||||
ASSERT_TRUE(reloaded_mod.hasattr("subMod"));
|
||||
InterfaceTypePtr submodType =
|
||||
reloaded_mod.type()->getAttribute("subMod")->cast<InterfaceType>();
|
||||
ASSERT_TRUE(submodType->is_module());
|
||||
|
Reference in New Issue
Block a user