mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Enforce single parent for script submodules (#18379)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18379 ghimport-source-id: 9895ecc1ff7897e98853dc00675341f36726e7c7 Stack from [ghstack](https://github.com/ezyang/ghstack): * **#18379 Enforce single parent for script submodules** * #18378 Unify namespace of script::Module * #18314 Add ability to specialize class types to ArgumentSpec * #18226 Add Slot type to abstract the raw pointers being used for slots. The assumption that a ScriptModule has a single parent is present in our serialization format, and likely a few other places. It is not enforced on creation of script module hierarchies though, meaning that problems associated with (e.g. replicating a module twice in the output format) will not be caught until much later in the development cycle. This patch enforces the property when a submodule is registered. It also removes NamedModule since it is no longer necessary in this regime. This will also allow the easy discover of a modules fully-qualified name without needing to traverse the Module hierarchy. Differential Revision: D14603722 fbshipit-source-id: 63ab5d0cccf7d66c7833e0adf9023024ca9607cb
This commit is contained in:
committed by
Facebook Github Bot
parent
b80a4fa201
commit
7e59c60454
@ -18,7 +18,7 @@ void check_all_parameters(
|
||||
AT_ASSERT(predicate(parameter.slot()->toTensor()));
|
||||
}
|
||||
for (const auto& child : module.get_modules()) {
|
||||
check_all_parameters(*child.module, predicate);
|
||||
check_all_parameters(module, predicate);
|
||||
}
|
||||
}
|
||||
} // namespace helpers
|
||||
|
@ -340,6 +340,7 @@ class JitTestCase(TestCase):
|
||||
|
||||
self.assertMultiLineEqual(main_module_code, main_module_2_code)
|
||||
|
||||
|
||||
def getExportImportCopy(self, m, also_test_file=True, map_location=None):
|
||||
buffer = io.BytesIO()
|
||||
torch.jit.save(m, buffer)
|
||||
@ -9929,7 +9930,7 @@ a")
|
||||
def __init__(self):
|
||||
super(OtherStrong, self).__init__()
|
||||
self.weak = weak
|
||||
self.weak2 = weak
|
||||
self.weak2 = Weak()
|
||||
|
||||
@torch.jit.script_method
|
||||
def forward(self, x):
|
||||
@ -9946,7 +9947,7 @@ a")
|
||||
|
||||
other_strong_mod = OtherStrong()
|
||||
|
||||
self.assertIs(other_strong_mod.weak, other_strong_mod.weak2)
|
||||
self.assertIsNot(other_strong_mod.weak, other_strong_mod.weak2)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Attempted to inline a Module with param"):
|
||||
strong_mod = Strong()
|
||||
|
@ -51,9 +51,8 @@ void InputArchive::read(
|
||||
}
|
||||
|
||||
void InputArchive::read(const std::string& key, InputArchive& archive) {
|
||||
if (auto* named_module = module_->find_module(key)) {
|
||||
AT_ASSERT(named_module->module != nullptr);
|
||||
archive.module_ = std::move(named_module->module);
|
||||
if (auto named_module = module_->find_module(key)) {
|
||||
archive.module_ = std::move(named_module);
|
||||
} else {
|
||||
AT_ERROR("No such serialized submodule: '", key, "'");
|
||||
}
|
||||
|
@ -769,7 +769,7 @@ void ScriptModuleSerializer::convertModule(
|
||||
|
||||
for (const auto& elem : module.get_modules()) {
|
||||
torch::ModuleDef* sub_def = module_def->add_submodules();
|
||||
convertModule(*elem.module, module_name.str(), elem.name, sub_def);
|
||||
convertModule(*elem, module_name.str(), elem->name(), sub_def);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -19,8 +19,8 @@ struct ModuleAccessorValue : public SugaredValue {
|
||||
const SourceRange& loc,
|
||||
Method& m,
|
||||
const std::string& field) override {
|
||||
if (NamedModule* v = module->find_module(field)) {
|
||||
return std::make_shared<ModuleAccessorValue>(v->module);
|
||||
if (std::shared_ptr<Module> v = module->find_module(field)) {
|
||||
return std::make_shared<ModuleAccessorValue>(std::move(v));
|
||||
} else if (NamedIValue* v = module->find_parameter(field)) {
|
||||
return std::make_shared<SimpleValue>(m.get_or_add_parameter(v->slot()));
|
||||
} else if (NamedIValue* v = module->find_buffer(field)) {
|
||||
|
@ -139,7 +139,7 @@ void createTensorToParameterNameMap(
|
||||
}
|
||||
for (const auto& elem : module.get_modules()) {
|
||||
createTensorToParameterNameMap(
|
||||
*elem.module, QualifiedName::create(prefix, elem.name), result);
|
||||
*elem, QualifiedName::create(prefix, elem->name()), result);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -324,8 +324,8 @@ struct ModuleValue : public SugaredValue {
|
||||
return std::make_shared<SimpleValue>(the_bool);
|
||||
}
|
||||
|
||||
if (NamedModule* v = module->find_module(field)) {
|
||||
return std::make_shared<ModuleValue>(v->module);
|
||||
if (std::shared_ptr<Module> v = module->find_module(field)) {
|
||||
return std::make_shared<ModuleValue>(v);
|
||||
} else if (Method* v = module->find_method(field)) {
|
||||
return std::make_shared<MethodValue>(shared_from_this(), *v);
|
||||
} else if (NamedIValue* v = module->find_parameter(field)) {
|
||||
@ -614,7 +614,7 @@ static void gatherParametersAndBuffers(
|
||||
}
|
||||
}
|
||||
for (const auto& sub : m.get_modules()) {
|
||||
gatherParametersAndBuffers(values, *sub.module);
|
||||
gatherParametersAndBuffers(values, *sub);
|
||||
}
|
||||
}
|
||||
|
||||
@ -771,7 +771,7 @@ void initJitScriptBindings(PyObject* module) {
|
||||
py::tuple result(modules.size());
|
||||
for (size_t i = 0; i < modules.size(); ++i) {
|
||||
auto& item = modules[i];
|
||||
result[i] = std::make_pair(item.name, item.module);
|
||||
result[i] = std::make_pair(item->name(), item);
|
||||
}
|
||||
return result;
|
||||
})
|
||||
|
@ -123,7 +123,7 @@ void Module::to_impl(
|
||||
bool non_blocking) {
|
||||
// First call `to()` on every child module.
|
||||
for (auto& child : get_modules()) {
|
||||
child.module->to_impl(device, dtype, non_blocking);
|
||||
child->to_impl(device, dtype, non_blocking);
|
||||
}
|
||||
// Then convert every of our parameters.
|
||||
for (auto& parameter : get_parameters()) {
|
||||
|
@ -359,11 +359,6 @@ struct Method {
|
||||
|
||||
struct Module;
|
||||
|
||||
struct NamedModule {
|
||||
std::string name;
|
||||
std::shared_ptr<Module> module;
|
||||
};
|
||||
|
||||
struct NamedIValue {
|
||||
NamedIValue(std::string name, TypePtr type, IValue ivalue)
|
||||
: name_(name),
|
||||
@ -388,16 +383,20 @@ struct NamedIValue {
|
||||
|
||||
struct Module {
|
||||
TH_DISALLOW_COPY_AND_ASSIGN(Module);
|
||||
Module() : optimize(true) {}
|
||||
Module() : name_("__main__"), optimize_(true) {}
|
||||
|
||||
const std::string& name() const {
|
||||
return name_;
|
||||
}
|
||||
|
||||
// note this doesn't change the flags of existing methods just ones
|
||||
// added afterward.
|
||||
void set_optimized(bool o) {
|
||||
optimize = o;
|
||||
optimize_ = o;
|
||||
}
|
||||
|
||||
bool is_optimized() const {
|
||||
return optimize;
|
||||
return optimize_;
|
||||
}
|
||||
|
||||
IValue forward(std::vector<IValue> inputs) {
|
||||
@ -447,7 +446,15 @@ struct Module {
|
||||
void register_module(
|
||||
const std::string& name,
|
||||
std::shared_ptr<Module> module) {
|
||||
insert(name, modules_, EntityType::MODULE, {name, std::move(module)});
|
||||
if (module->parent_) {
|
||||
AT_ERROR(
|
||||
"Attempting to assign submodule '",
|
||||
name,
|
||||
"' but it is already a submodule of another ScriptModule '", module->parent_->name(), "'");
|
||||
}
|
||||
module->parent_ = this;
|
||||
module->name_ = name;
|
||||
insert(name, modules_, EntityType::MODULE, std::move(module));
|
||||
}
|
||||
|
||||
Method& create_method(
|
||||
@ -458,7 +465,7 @@ struct Module {
|
||||
std::unique_ptr<Method> method(new Method(
|
||||
this,
|
||||
name,
|
||||
optimize,
|
||||
optimize_,
|
||||
std::move(graph),
|
||||
std::move(member_inputs),
|
||||
nullptr));
|
||||
@ -471,7 +478,7 @@ struct Module {
|
||||
std::unique_ptr<Method> method(new Method(
|
||||
this,
|
||||
name,
|
||||
optimize,
|
||||
optimize_,
|
||||
std::make_shared<Graph>(),
|
||||
{},
|
||||
std::move(creator)));
|
||||
@ -505,10 +512,10 @@ struct Module {
|
||||
}
|
||||
|
||||
std::shared_ptr<Module> get_module(const std::string& name) const {
|
||||
return modules_[get_offset(name, EntityType::MODULE)].module;
|
||||
return modules_[get_offset(name, EntityType::MODULE)];
|
||||
}
|
||||
|
||||
c10::ArrayRef<NamedModule> get_modules() const {
|
||||
c10::ArrayRef<std::shared_ptr<Module>> get_modules() const {
|
||||
return modules_;
|
||||
}
|
||||
c10::ArrayRef<NamedIValue> get_parameters() const {
|
||||
@ -536,9 +543,9 @@ struct Module {
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
NamedModule* find_module(const std::string& name) {
|
||||
std::shared_ptr<Module> find_module(const std::string& name) {
|
||||
auto offset = find_offset(name, EntityType::MODULE);
|
||||
return offset ? &modules_[*offset] : nullptr;
|
||||
return offset ? modules_[*offset] : nullptr;
|
||||
}
|
||||
Method* find_method(const std::string& name) {
|
||||
auto offset = find_offset(name, EntityType::METHOD);
|
||||
@ -546,14 +553,14 @@ struct Module {
|
||||
}
|
||||
void apply(std::function<void(Module&)> fn) {
|
||||
for (auto& submod : get_modules()) {
|
||||
submod.module->apply(fn);
|
||||
submod->apply(fn);
|
||||
}
|
||||
fn(*this);
|
||||
}
|
||||
/// Enables "training" mode.
|
||||
void train(bool on = true) {
|
||||
for (auto& submod : get_modules()) {
|
||||
submod.module->train(on);
|
||||
submod->train(on);
|
||||
}
|
||||
register_buffer("training", torch::tensor(on ? 1 : 0, at::kLong));
|
||||
}
|
||||
@ -646,10 +653,10 @@ struct Module {
|
||||
parameter_remap[attr.slot()] = curr->find_buffer(attr.name())->slot();
|
||||
}
|
||||
for (auto& mod : get_modules()) {
|
||||
names.push_back(mod.name);
|
||||
names.push_back(mod->name());
|
||||
// Submodules must be translated first, otherwise parameter_remap entries
|
||||
// will not be filled in for methods of this module.
|
||||
mod.module->copy_into(module_lookup, parameter_remap, names);
|
||||
mod->copy_into(module_lookup, parameter_remap, names);
|
||||
names.pop_back();
|
||||
}
|
||||
for (auto& method : get_methods()) {
|
||||
@ -677,20 +684,6 @@ struct Module {
|
||||
const c10::optional<at::ScalarType>& dtype,
|
||||
bool non_blocking);
|
||||
|
||||
// modules have a single namespace, but spread over 4 different concepts:
|
||||
// parameters, attributes, methods, and sub-modules
|
||||
// we store individual lists of each concept, and a single map to
|
||||
// unify the namespace and ensure fast lookup
|
||||
|
||||
// invariant: to ensure initial_ivalues of Methods stay valid,
|
||||
// it is only legal to _add_ new modules and parameters.
|
||||
// removing them will allow initial_ivalues to point to invalid parameters
|
||||
// no such restriction exists for methods
|
||||
std::vector<NamedModule> modules_;
|
||||
std::vector<NamedIValue> parameters_;
|
||||
std::vector<NamedIValue> attributes_;
|
||||
std::vector<std::unique_ptr<Method>> methods_;
|
||||
|
||||
static const char* toString(EntityType t) {
|
||||
switch (t) {
|
||||
case EntityType::MODULE:
|
||||
@ -762,9 +755,26 @@ struct Module {
|
||||
return list.back();
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, Entry> dict_;
|
||||
// modules have a single namespace, but spread over 4 different concepts:
|
||||
// parameters, attributes, methods, and sub-modules
|
||||
// we store individual lists of each concept, and a single map to
|
||||
// unify the namespace and ensure fast lookup
|
||||
|
||||
bool optimize;
|
||||
// invariant: to ensure initial_ivalues of Methods stay valid,
|
||||
// it is only legal to _add_ new modules and parameters.
|
||||
// removing them will allow initial_ivalues to point to invalid parameters
|
||||
// no such restriction exists for methods
|
||||
std::vector<std::shared_ptr<Module>> modules_;
|
||||
std::vector<NamedIValue> parameters_;
|
||||
std::vector<NamedIValue> attributes_;
|
||||
std::vector<std::unique_ptr<Method>> methods_;
|
||||
|
||||
std::unordered_map<std::string, Entry> dict_;
|
||||
std::string name_;
|
||||
|
||||
// back reference to parent of this Module if present
|
||||
Module* parent_ = nullptr;
|
||||
bool optimize_;
|
||||
};
|
||||
|
||||
// returns nullptr and fills in failure_messages if the callee does not
|
||||
|
Reference in New Issue
Block a user