Revert D14603722: Enforce single parent for script submodules

Differential Revision:
D14603722

Original commit changeset: 63ab5d0cccf7

fbshipit-source-id: 2c4174def102eda4589e08c4dbd67ce8af975199
This commit is contained in:
Zachary DeVito
2019-04-04 10:22:27 -07:00
committed by Facebook Github Bot
parent 52a3a51490
commit f97eb8d9e4
9 changed files with 49 additions and 58 deletions

View File

@ -18,7 +18,7 @@ void check_all_parameters(
AT_ASSERT(predicate(parameter.slot()->toTensor()));
}
for (const auto& child : module.get_modules()) {
check_all_parameters(module, predicate);
check_all_parameters(*child.module, predicate);
}
}
} // namespace helpers

View File

@ -9956,7 +9956,7 @@ a")
def __init__(self):
super(OtherStrong, self).__init__()
self.weak = weak
self.weak2 = Weak()
self.weak2 = weak
@torch.jit.script_method
def forward(self, x):
@ -9973,7 +9973,7 @@ a")
other_strong_mod = OtherStrong()
self.assertIsNot(other_strong_mod.weak, other_strong_mod.weak2)
self.assertIs(other_strong_mod.weak, other_strong_mod.weak2)
with self.assertRaisesRegex(RuntimeError, "Attempted to inline a Module with param"):
strong_mod = Strong()

View File

@ -51,8 +51,9 @@ void InputArchive::read(
}
void InputArchive::read(const std::string& key, InputArchive& archive) {
if (auto named_module = module_->find_module(key)) {
archive.module_ = std::move(named_module);
if (auto* named_module = module_->find_module(key)) {
AT_ASSERT(named_module->module != nullptr);
archive.module_ = std::move(named_module->module);
} else {
AT_ERROR("No such serialized submodule: '", key, "'");
}

View File

@ -769,7 +769,7 @@ void ScriptModuleSerializer::convertModule(
for (const auto& elem : module.get_modules()) {
torch::ModuleDef* sub_def = module_def->add_submodules();
convertModule(*elem, module_name.str(), elem->name(), sub_def);
convertModule(*elem.module, module_name.str(), elem.name, sub_def);
}
}

View File

@ -19,8 +19,8 @@ struct ModuleAccessorValue : public SugaredValue {
const SourceRange& loc,
Method& m,
const std::string& field) override {
if (std::shared_ptr<Module> v = module->find_module(field)) {
return std::make_shared<ModuleAccessorValue>(std::move(v));
if (NamedModule* v = module->find_module(field)) {
return std::make_shared<ModuleAccessorValue>(v->module);
} else if (NamedIValue* v = module->find_parameter(field)) {
return std::make_shared<SimpleValue>(m.get_or_add_parameter(v->slot()));
} else if (NamedIValue* v = module->find_buffer(field)) {

View File

@ -139,7 +139,7 @@ void createTensorToParameterNameMap(
}
for (const auto& elem : module.get_modules()) {
createTensorToParameterNameMap(
*elem, QualifiedName::create(prefix, elem->name()), result);
*elem.module, QualifiedName::create(prefix, elem.name), result);
}
}

View File

@ -324,8 +324,8 @@ struct ModuleValue : public SugaredValue {
return std::make_shared<SimpleValue>(the_bool);
}
if (std::shared_ptr<Module> v = module->find_module(field)) {
return std::make_shared<ModuleValue>(v);
if (NamedModule* v = module->find_module(field)) {
return std::make_shared<ModuleValue>(v->module);
} else if (Method* v = module->find_method(field)) {
return std::make_shared<MethodValue>(shared_from_this(), *v);
} else if (NamedIValue* v = module->find_parameter(field)) {
@ -614,7 +614,7 @@ static void gatherParametersAndBuffers(
}
}
for (const auto& sub : m.get_modules()) {
gatherParametersAndBuffers(values, *sub);
gatherParametersAndBuffers(values, *sub.module);
}
}
@ -771,7 +771,7 @@ void initJitScriptBindings(PyObject* module) {
py::tuple result(modules.size());
for (size_t i = 0; i < modules.size(); ++i) {
auto& item = modules[i];
result[i] = std::make_pair(item->name(), item);
result[i] = std::make_pair(item.name, item.module);
}
return result;
})

View File

@ -123,7 +123,7 @@ void Module::to_impl(
bool non_blocking) {
// First call `to()` on every child module.
for (auto& child : get_modules()) {
child->to_impl(device, dtype, non_blocking);
child.module->to_impl(device, dtype, non_blocking);
}
// Then convert every of our parameters.
for (auto& parameter : get_parameters()) {

View File

@ -359,6 +359,11 @@ struct Method {
struct Module;
struct NamedModule {
std::string name;
std::shared_ptr<Module> module;
};
struct NamedIValue {
NamedIValue(std::string name, TypePtr type, IValue ivalue)
: name_(name),
@ -383,20 +388,16 @@ struct NamedIValue {
struct Module {
TH_DISALLOW_COPY_AND_ASSIGN(Module);
Module() : name_("__main__"), optimize_(true) {}
const std::string& name() const {
return name_;
}
Module() : optimize(true) {}
// note this doesn't change the flags of existing methods just ones
// added afterward.
void set_optimized(bool o) {
optimize_ = o;
optimize = o;
}
bool is_optimized() const {
return optimize_;
return optimize;
}
IValue forward(std::vector<IValue> inputs) {
@ -446,15 +447,7 @@ struct Module {
void register_module(
const std::string& name,
std::shared_ptr<Module> module) {
if (module->parent_) {
AT_ERROR(
"Attempting to assign submodule '",
name,
"' but it is already a submodule of another ScriptModule '", module->parent_->name(), "'");
}
module->parent_ = this;
module->name_ = name;
insert(name, modules_, EntityType::MODULE, std::move(module));
insert(name, modules_, EntityType::MODULE, {name, std::move(module)});
}
Method& create_method(
@ -465,7 +458,7 @@ struct Module {
std::unique_ptr<Method> method(new Method(
this,
name,
optimize_,
optimize,
std::move(graph),
std::move(member_inputs),
nullptr));
@ -478,7 +471,7 @@ struct Module {
std::unique_ptr<Method> method(new Method(
this,
name,
optimize_,
optimize,
std::make_shared<Graph>(),
{},
std::move(creator)));
@ -512,10 +505,10 @@ struct Module {
}
std::shared_ptr<Module> get_module(const std::string& name) const {
return modules_[get_offset(name, EntityType::MODULE)];
return modules_[get_offset(name, EntityType::MODULE)].module;
}
c10::ArrayRef<std::shared_ptr<Module>> get_modules() const {
c10::ArrayRef<NamedModule> get_modules() const {
return modules_;
}
c10::ArrayRef<NamedIValue> get_parameters() const {
@ -543,9 +536,9 @@ struct Module {
}
return nullptr;
}
std::shared_ptr<Module> find_module(const std::string& name) {
NamedModule* find_module(const std::string& name) {
auto offset = find_offset(name, EntityType::MODULE);
return offset ? modules_[*offset] : nullptr;
return offset ? &modules_[*offset] : nullptr;
}
Method* find_method(const std::string& name) {
auto offset = find_offset(name, EntityType::METHOD);
@ -553,14 +546,14 @@ struct Module {
}
void apply(std::function<void(Module&)> fn) {
for (auto& submod : get_modules()) {
submod->apply(fn);
submod.module->apply(fn);
}
fn(*this);
}
/// Enables "training" mode.
void train(bool on = true) {
for (auto& submod : get_modules()) {
submod->train(on);
submod.module->train(on);
}
register_buffer("training", torch::tensor(on ? 1 : 0, at::kLong));
}
@ -653,10 +646,10 @@ struct Module {
parameter_remap[attr.slot()] = curr->find_buffer(attr.name())->slot();
}
for (auto& mod : get_modules()) {
names.push_back(mod->name());
names.push_back(mod.name);
// Submodules must be translated first, otherwise parameter_remap entries
// will not be filled in for methods of this module.
mod->copy_into(module_lookup, parameter_remap, names);
mod.module->copy_into(module_lookup, parameter_remap, names);
names.pop_back();
}
for (auto& method : get_methods()) {
@ -684,6 +677,20 @@ struct Module {
const c10::optional<at::ScalarType>& dtype,
bool non_blocking);
// modules have a single namespace, but spread over 4 different concepts:
// parameters, attributes, methods, and sub-modules
// we store individual lists of each concept, and a single map to
// unify the namespace and ensure fast lookup
// invariant: to ensure initial_ivalues of Methods stay valid,
// it is only legal to _add_ new modules and parameters.
// removing them will allow initial_ivalues to point to invalid parameters
// no such restriction exists for methods
std::vector<NamedModule> modules_;
std::vector<NamedIValue> parameters_;
std::vector<NamedIValue> attributes_;
std::vector<std::unique_ptr<Method>> methods_;
static const char* toString(EntityType t) {
switch (t) {
case EntityType::MODULE:
@ -755,26 +762,9 @@ struct Module {
return list.back();
}
// modules have a single namespace, but spread over 4 different concepts:
// parameters, attributes, methods, and sub-modules
// we store individual lists of each concept, and a single map to
// unify the namespace and ensure fast lookup
// invariant: to ensure initial_ivalues of Methods stay valid,
// it is only legal to _add_ new modules and parameters.
// removing them will allow initial_ivalues to point to invalid parameters
// no such restriction exists for methods
std::vector<std::shared_ptr<Module>> modules_;
std::vector<NamedIValue> parameters_;
std::vector<NamedIValue> attributes_;
std::vector<std::unique_ptr<Method>> methods_;
std::unordered_map<std::string, Entry> dict_;
std::string name_;
// back reference to parent of this Module if present
Module* parent_ = nullptr;
bool optimize_;
bool optimize;
};
// returns nullptr and fills in failure_messages if the callee does not