Enforce single parent for script submodules (#18379)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18379
ghimport-source-id: 9895ecc1ff7897e98853dc00675341f36726e7c7

Stack from [ghstack](https://github.com/ezyang/ghstack):
* **#18379 Enforce single parent for script submodules**
* #18378 Unify namespace of script::Module
* #18314 Add ability to specialize class types to ArgumentSpec
* #18226 Add Slot type to abstract the raw pointers being used for slots.

The assumption that a ScriptModule has a single parent is present in
our serialization format, and likely a few other places. It is not
enforced on creation of script module hierarchies though, meaning that
problems associated with (e.g. replicating a module twice in the output
format) will not be caught until much later in the development cycle.

This patch enforces the property when a submodule is registered.
It also removes NamedModule since it is no longer necessary in this regime.
This will also allow the easy discover of a modules fully-qualified name
without needing to traverse the Module hierarchy.

Differential Revision: D14603722

fbshipit-source-id: 63ab5d0cccf7d66c7833e0adf9023024ca9607cb
This commit is contained in:
Zachary DeVito
2019-04-03 20:21:27 -07:00
committed by Facebook Github Bot
parent b80a4fa201
commit 7e59c60454
9 changed files with 60 additions and 50 deletions

View File

@ -18,7 +18,7 @@ void check_all_parameters(
AT_ASSERT(predicate(parameter.slot()->toTensor()));
}
for (const auto& child : module.get_modules()) {
check_all_parameters(*child.module, predicate);
check_all_parameters(module, predicate);
}
}
} // namespace helpers

View File

@ -340,6 +340,7 @@ class JitTestCase(TestCase):
self.assertMultiLineEqual(main_module_code, main_module_2_code)
def getExportImportCopy(self, m, also_test_file=True, map_location=None):
buffer = io.BytesIO()
torch.jit.save(m, buffer)
@ -9929,7 +9930,7 @@ a")
def __init__(self):
super(OtherStrong, self).__init__()
self.weak = weak
self.weak2 = weak
self.weak2 = Weak()
@torch.jit.script_method
def forward(self, x):
@ -9946,7 +9947,7 @@ a")
other_strong_mod = OtherStrong()
self.assertIs(other_strong_mod.weak, other_strong_mod.weak2)
self.assertIsNot(other_strong_mod.weak, other_strong_mod.weak2)
with self.assertRaisesRegex(RuntimeError, "Attempted to inline a Module with param"):
strong_mod = Strong()

View File

@ -51,9 +51,8 @@ void InputArchive::read(
}
void InputArchive::read(const std::string& key, InputArchive& archive) {
if (auto* named_module = module_->find_module(key)) {
AT_ASSERT(named_module->module != nullptr);
archive.module_ = std::move(named_module->module);
if (auto named_module = module_->find_module(key)) {
archive.module_ = std::move(named_module);
} else {
AT_ERROR("No such serialized submodule: '", key, "'");
}

View File

@ -769,7 +769,7 @@ void ScriptModuleSerializer::convertModule(
for (const auto& elem : module.get_modules()) {
torch::ModuleDef* sub_def = module_def->add_submodules();
convertModule(*elem.module, module_name.str(), elem.name, sub_def);
convertModule(*elem, module_name.str(), elem->name(), sub_def);
}
}

View File

@ -19,8 +19,8 @@ struct ModuleAccessorValue : public SugaredValue {
const SourceRange& loc,
Method& m,
const std::string& field) override {
if (NamedModule* v = module->find_module(field)) {
return std::make_shared<ModuleAccessorValue>(v->module);
if (std::shared_ptr<Module> v = module->find_module(field)) {
return std::make_shared<ModuleAccessorValue>(std::move(v));
} else if (NamedIValue* v = module->find_parameter(field)) {
return std::make_shared<SimpleValue>(m.get_or_add_parameter(v->slot()));
} else if (NamedIValue* v = module->find_buffer(field)) {

View File

@ -139,7 +139,7 @@ void createTensorToParameterNameMap(
}
for (const auto& elem : module.get_modules()) {
createTensorToParameterNameMap(
*elem.module, QualifiedName::create(prefix, elem.name), result);
*elem, QualifiedName::create(prefix, elem->name()), result);
}
}

View File

@ -324,8 +324,8 @@ struct ModuleValue : public SugaredValue {
return std::make_shared<SimpleValue>(the_bool);
}
if (NamedModule* v = module->find_module(field)) {
return std::make_shared<ModuleValue>(v->module);
if (std::shared_ptr<Module> v = module->find_module(field)) {
return std::make_shared<ModuleValue>(v);
} else if (Method* v = module->find_method(field)) {
return std::make_shared<MethodValue>(shared_from_this(), *v);
} else if (NamedIValue* v = module->find_parameter(field)) {
@ -614,7 +614,7 @@ static void gatherParametersAndBuffers(
}
}
for (const auto& sub : m.get_modules()) {
gatherParametersAndBuffers(values, *sub.module);
gatherParametersAndBuffers(values, *sub);
}
}
@ -771,7 +771,7 @@ void initJitScriptBindings(PyObject* module) {
py::tuple result(modules.size());
for (size_t i = 0; i < modules.size(); ++i) {
auto& item = modules[i];
result[i] = std::make_pair(item.name, item.module);
result[i] = std::make_pair(item->name(), item);
}
return result;
})

View File

@ -123,7 +123,7 @@ void Module::to_impl(
bool non_blocking) {
// First call `to()` on every child module.
for (auto& child : get_modules()) {
child.module->to_impl(device, dtype, non_blocking);
child->to_impl(device, dtype, non_blocking);
}
// Then convert every of our parameters.
for (auto& parameter : get_parameters()) {

View File

@ -359,11 +359,6 @@ struct Method {
struct Module;
struct NamedModule {
std::string name;
std::shared_ptr<Module> module;
};
struct NamedIValue {
NamedIValue(std::string name, TypePtr type, IValue ivalue)
: name_(name),
@ -388,16 +383,20 @@ struct NamedIValue {
struct Module {
TH_DISALLOW_COPY_AND_ASSIGN(Module);
Module() : optimize(true) {}
Module() : name_("__main__"), optimize_(true) {}
const std::string& name() const {
return name_;
}
// note this doesn't change the flags of existing methods just ones
// added afterward.
void set_optimized(bool o) {
optimize = o;
optimize_ = o;
}
bool is_optimized() const {
return optimize;
return optimize_;
}
IValue forward(std::vector<IValue> inputs) {
@ -447,7 +446,15 @@ struct Module {
void register_module(
const std::string& name,
std::shared_ptr<Module> module) {
insert(name, modules_, EntityType::MODULE, {name, std::move(module)});
if (module->parent_) {
AT_ERROR(
"Attempting to assign submodule '",
name,
"' but it is already a submodule of another ScriptModule '", module->parent_->name(), "'");
}
module->parent_ = this;
module->name_ = name;
insert(name, modules_, EntityType::MODULE, std::move(module));
}
Method& create_method(
@ -458,7 +465,7 @@ struct Module {
std::unique_ptr<Method> method(new Method(
this,
name,
optimize,
optimize_,
std::move(graph),
std::move(member_inputs),
nullptr));
@ -471,7 +478,7 @@ struct Module {
std::unique_ptr<Method> method(new Method(
this,
name,
optimize,
optimize_,
std::make_shared<Graph>(),
{},
std::move(creator)));
@ -505,10 +512,10 @@ struct Module {
}
std::shared_ptr<Module> get_module(const std::string& name) const {
return modules_[get_offset(name, EntityType::MODULE)].module;
return modules_[get_offset(name, EntityType::MODULE)];
}
c10::ArrayRef<NamedModule> get_modules() const {
c10::ArrayRef<std::shared_ptr<Module>> get_modules() const {
return modules_;
}
c10::ArrayRef<NamedIValue> get_parameters() const {
@ -536,9 +543,9 @@ struct Module {
}
return nullptr;
}
NamedModule* find_module(const std::string& name) {
std::shared_ptr<Module> find_module(const std::string& name) {
auto offset = find_offset(name, EntityType::MODULE);
return offset ? &modules_[*offset] : nullptr;
return offset ? modules_[*offset] : nullptr;
}
Method* find_method(const std::string& name) {
auto offset = find_offset(name, EntityType::METHOD);
@ -546,14 +553,14 @@ struct Module {
}
void apply(std::function<void(Module&)> fn) {
for (auto& submod : get_modules()) {
submod.module->apply(fn);
submod->apply(fn);
}
fn(*this);
}
/// Enables "training" mode.
void train(bool on = true) {
for (auto& submod : get_modules()) {
submod.module->train(on);
submod->train(on);
}
register_buffer("training", torch::tensor(on ? 1 : 0, at::kLong));
}
@ -646,10 +653,10 @@ struct Module {
parameter_remap[attr.slot()] = curr->find_buffer(attr.name())->slot();
}
for (auto& mod : get_modules()) {
names.push_back(mod.name);
names.push_back(mod->name());
// Submodules must be translated first, otherwise parameter_remap entries
// will not be filled in for methods of this module.
mod.module->copy_into(module_lookup, parameter_remap, names);
mod->copy_into(module_lookup, parameter_remap, names);
names.pop_back();
}
for (auto& method : get_methods()) {
@ -677,20 +684,6 @@ struct Module {
const c10::optional<at::ScalarType>& dtype,
bool non_blocking);
// modules have a single namespace, but spread over 4 different concepts:
// parameters, attributes, methods, and sub-modules
// we store individual lists of each concept, and a single map to
// unify the namespace and ensure fast lookup
// invariant: to ensure initial_ivalues of Methods stay valid,
// it is only legal to _add_ new modules and parameters.
// removing them will allow initial_ivalues to point to invalid parameters
// no such restriction exists for methods
std::vector<NamedModule> modules_;
std::vector<NamedIValue> parameters_;
std::vector<NamedIValue> attributes_;
std::vector<std::unique_ptr<Method>> methods_;
static const char* toString(EntityType t) {
switch (t) {
case EntityType::MODULE:
@ -762,9 +755,26 @@ struct Module {
return list.back();
}
std::unordered_map<std::string, Entry> dict_;
// modules have a single namespace, but spread over 4 different concepts:
// parameters, attributes, methods, and sub-modules
// we store individual lists of each concept, and a single map to
// unify the namespace and ensure fast lookup
bool optimize;
// invariant: to ensure initial_ivalues of Methods stay valid,
// it is only legal to _add_ new modules and parameters.
// removing them will allow initial_ivalues to point to invalid parameters
// no such restriction exists for methods
std::vector<std::shared_ptr<Module>> modules_;
std::vector<NamedIValue> parameters_;
std::vector<NamedIValue> attributes_;
std::vector<std::unique_ptr<Method>> methods_;
std::unordered_map<std::string, Entry> dict_;
std::string name_;
// back reference to parent of this Module if present
Module* parent_ = nullptr;
bool optimize_;
};
// returns nullptr and fills in failure_messages if the callee does not