mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Unify namespace of script::Module (#18378)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18378 ghimport-source-id: 55c29bb436a2153d29ff2f4488d99d8863c187b1 Stack from [ghstack](https://github.com/ezyang/ghstack): * #18379 Enforce single parent for script submodules * **#18378 Unify namespace of script::Module** * #18314 Add ability to specialize class types to ArgumentSpec * #18226 Add Slot type to abstract the raw pointers being used for slots. This removes individual OrderedDicts in favor of a single unified namespace for all things in a script::Module. This removes a whole class of bugs where both a method and an parameter could get the same name, for instance. Since we no longer have to expose OrderedDict::Item objects, a lot of downstream code can be simplified. We no longer now double-store names (both in the key of the dictionary, and in the object itself). Differential Revision: D14603723 fbshipit-source-id: b5f7551b3074679623edd6ea70269830353b4d4c
This commit is contained in:
committed by
Facebook Github Bot
parent
773ce4fbd0
commit
0512e4e323
@ -15,10 +15,10 @@ void check_all_parameters(
|
||||
const torch::jit::script::Module& module,
|
||||
Predicate predicate) {
|
||||
for (const auto& parameter : module.get_parameters()) {
|
||||
AT_ASSERT(predicate(parameter->slot()->toTensor()));
|
||||
AT_ASSERT(predicate(parameter.slot()->toTensor()));
|
||||
}
|
||||
for (const auto& child : module.get_modules()) {
|
||||
check_all_parameters(*child->module, predicate);
|
||||
check_all_parameters(*child.module, predicate);
|
||||
}
|
||||
}
|
||||
} // namespace helpers
|
||||
|
@ -8330,7 +8330,7 @@ a")
|
||||
''')
|
||||
|
||||
def test_duplicate(self):
|
||||
with self.assertRaisesRegex(RuntimeError, 'Method \'test\' already defined'):
|
||||
with self.assertRaisesRegex(RuntimeError, 'method \'test\' already defined'):
|
||||
cu = torch.jit.CompilationUnit('''
|
||||
def test():
|
||||
return 1
|
||||
|
@ -454,8 +454,7 @@ void GraphEncoder::EncodeTensor(
|
||||
} else {
|
||||
AT_ASSERT(t.is_contiguous());
|
||||
tensor_proto->set_raw_data(std::string(
|
||||
static_cast<char*>(t.data_ptr()),
|
||||
t.element_size() * t.numel()));
|
||||
static_cast<char*>(t.data_ptr()), t.element_size() * t.numel()));
|
||||
}
|
||||
}
|
||||
|
||||
@ -665,8 +664,7 @@ void ScriptModuleSerializer::convertAndWriteTensor(
|
||||
|
||||
tensor_proto->set_requires_grad(tensor.requires_grad());
|
||||
|
||||
uint64_t record_size =
|
||||
tensor.element_size() * tensor.storage().size();
|
||||
uint64_t record_size = tensor.element_size() * tensor.storage().size();
|
||||
auto* key = tensor.storage().unsafeGetStorageImpl();
|
||||
|
||||
auto storage_it = storageMap.find(key);
|
||||
@ -686,8 +684,7 @@ void ScriptModuleSerializer::convertAndWriteTensor(
|
||||
/* stride = */ {1})
|
||||
.cpu();
|
||||
AT_ASSERT(
|
||||
storage_tensor.element_size() *
|
||||
storage_tensor.storage().size() ==
|
||||
storage_tensor.element_size() * storage_tensor.storage().size() ==
|
||||
record_size);
|
||||
}
|
||||
std::string name = "tensors/" + std::to_string(tensor_id);
|
||||
@ -733,11 +730,10 @@ void ScriptModuleSerializer::convertModule(
|
||||
module_def->set_optimize(module.is_optimized());
|
||||
for (const auto& elem : module.get_parameters()) {
|
||||
torch::ParameterDef* param_def = module_def->add_parameters();
|
||||
convertParameter(elem.value(), param_def, /*is_buffer=*/false);
|
||||
convertParameter(elem, param_def, /*is_buffer=*/false);
|
||||
}
|
||||
|
||||
for (const auto& item : module.get_attributes()) {
|
||||
auto& attribute = item.value();
|
||||
for (const auto& attribute : module.get_attributes()) {
|
||||
// Add attribute to ModuleDef
|
||||
torch::AttributeDef* attribute_def = module_def->add_attributes();
|
||||
attribute_def->set_name(attribute.name());
|
||||
@ -773,7 +769,7 @@ void ScriptModuleSerializer::convertModule(
|
||||
|
||||
for (const auto& elem : module.get_modules()) {
|
||||
torch::ModuleDef* sub_def = module_def->add_submodules();
|
||||
convertModule(*elem->module, module_name.str(), elem.key(), sub_def);
|
||||
convertModule(*elem.module, module_name.str(), elem.name, sub_def);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1,9 +1,9 @@
|
||||
#include <torch/csrc/jit/passes/python_print.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <torch/csrc/jit/attributes.h>
|
||||
#include <torch/csrc/jit/export.h>
|
||||
#include <torch/csrc/jit/ir.h>
|
||||
#include <torch/csrc/jit/ir_views.h>
|
||||
#include <torch/csrc/jit/passes/python_print.h>
|
||||
#include <torch/csrc/jit/resource_guard.h>
|
||||
#include <torch/csrc/jit/script/error_report.h>
|
||||
#include <torch/csrc/jit/script/module.h>
|
||||
@ -131,17 +131,15 @@ void createTensorToParameterNameMap(
|
||||
const script::Module& module,
|
||||
const QualifiedNamePtr& prefix,
|
||||
std::unordered_map<script::Slot, QualifiedNamePtr>& result) {
|
||||
for (const auto& elem : module.get_parameters()) {
|
||||
const script::NamedIValue& param = elem.value();
|
||||
for (const auto& param : module.get_parameters()) {
|
||||
result[param.slot()] = QualifiedName::create(prefix, param.name());
|
||||
}
|
||||
for (const auto& elem : module.get_attributes()) {
|
||||
const script::NamedIValue& param = elem.value();
|
||||
for (const auto& param : module.get_attributes()) {
|
||||
result[param.slot()] = QualifiedName::create(prefix, param.name());
|
||||
}
|
||||
for (const auto& elem : module.get_modules()) {
|
||||
createTensorToParameterNameMap(
|
||||
*elem->module, QualifiedName::create(prefix, elem.key()), result);
|
||||
*elem.module, QualifiedName::create(prefix, elem.name), result);
|
||||
}
|
||||
}
|
||||
|
||||
@ -1122,10 +1120,12 @@ struct PythonPrintPass {
|
||||
void printMethod(
|
||||
script::Method& method,
|
||||
bool is_class,
|
||||
const std::unordered_map<script::Slot, QualifiedNamePtr>& extra_ivalue_names) {
|
||||
std::vector<std::string> ivalue_names = fmap(
|
||||
method.initial_ivalues(),
|
||||
[&](const script::Slot& slot) { return extra_ivalue_names.at(slot)->str(); });
|
||||
const std::unordered_map<script::Slot, QualifiedNamePtr>&
|
||||
extra_ivalue_names) {
|
||||
std::vector<std::string> ivalue_names =
|
||||
fmap(method.initial_ivalues(), [&](const script::Slot& slot) {
|
||||
return extra_ivalue_names.at(slot)->str();
|
||||
});
|
||||
const std::string& name = method.name();
|
||||
Graph& graph = *method.graph();
|
||||
auto defaults = fmap(
|
||||
@ -1138,14 +1138,14 @@ struct PythonPrintPass {
|
||||
createTensorToParameterNameMap(
|
||||
module, QualifiedName::create("self"), extra_ivalue_names);
|
||||
for (auto& method : module.get_methods()) {
|
||||
const std::string& name = method.value()->name();
|
||||
const std::string& name = method->name();
|
||||
// we skip __forked_functions because they actually get inlined into their
|
||||
// callers, exporting them again will lead to more code generated on each
|
||||
// export
|
||||
if (name.find("__forked_function") == 0) {
|
||||
continue;
|
||||
}
|
||||
printMethod(*method.value(), /*is_class=*/false, extra_ivalue_names);
|
||||
printMethod(*method, /*is_class=*/false, extra_ivalue_names);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
#include <torch/csrc/jit/script/builtin_functions.h>
|
||||
#include <torch/csrc/api/include/torch/jit.h>
|
||||
#include <torch/csrc/jit/code_template.h>
|
||||
#include <torch/csrc/jit/script/builtin_functions.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
@ -67,8 +67,8 @@ struct BuiltinFunctionRegistry {
|
||||
module, source, script::nativeResolver, /*self=*/c10::nullopt);
|
||||
modules.push_back(module);
|
||||
for (auto& method : module->get_methods()) {
|
||||
builtins_by_name[Symbol::fromQualString("aten::" + method.key())]
|
||||
.push_back(method->get());
|
||||
builtins_by_name[Symbol::fromQualString("aten::" + method->name())]
|
||||
.push_back(method.get());
|
||||
}
|
||||
}
|
||||
void loadBuiltinFunctions() {
|
||||
|
@ -10,10 +10,9 @@ Method* ClassType::getMethod(const std::string& name) const {
|
||||
}
|
||||
|
||||
std::vector<Method*> ClassType::methods() const {
|
||||
const auto& methods = module_->get_methods();
|
||||
std::vector<Method*> ret;
|
||||
for (const auto& pr : methods.items()) {
|
||||
ret.push_back(pr.value().get());
|
||||
for (const auto& pr : module_->get_methods()) {
|
||||
ret.push_back(pr.get());
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
@ -273,10 +273,10 @@ struct VISIBILITY_HIDDEN ConstantParameterList : public SugaredValue {
|
||||
size_t n_binders) override {
|
||||
// Add all module parameters as inputs to the graph
|
||||
std::vector<Value*> params;
|
||||
const auto& param_list = module_->get_parameters().items();
|
||||
const auto& param_list = module_->get_parameters();
|
||||
for (auto it = param_list.rbegin(); it != param_list.rend(); ++it) {
|
||||
auto& param = *it;
|
||||
params.push_back(caller.get_or_add_parameter(param->slot()));
|
||||
params.push_back(caller.get_or_add_parameter(param.slot()));
|
||||
}
|
||||
auto list = caller.graph()->createList(TensorType::get(), params);
|
||||
caller.graph()->insertNode(list);
|
||||
@ -606,15 +606,15 @@ static void gatherParametersAndBuffers(
|
||||
std::vector<Slot>& values,
|
||||
const Module& m) {
|
||||
for (auto& param : m.get_parameters()) {
|
||||
values.push_back(param->slot());
|
||||
values.push_back(param.slot());
|
||||
}
|
||||
for (auto& param : m.get_attributes()) {
|
||||
if (param->type()->isSubtypeOf(TensorType::get())) {
|
||||
values.push_back(param->slot());
|
||||
if (param.type()->isSubtypeOf(TensorType::get())) {
|
||||
values.push_back(param.slot());
|
||||
}
|
||||
}
|
||||
for (const auto& sub : m.get_modules()) {
|
||||
gatherParametersAndBuffers(values, *sub->module);
|
||||
gatherParametersAndBuffers(values, *sub.module);
|
||||
}
|
||||
}
|
||||
|
||||
@ -767,38 +767,38 @@ void initJitScriptBindings(PyObject* module) {
|
||||
.def(
|
||||
"_get_modules",
|
||||
[](Module& self) -> py::tuple {
|
||||
auto& modules = self.get_modules();
|
||||
auto modules = self.get_modules();
|
||||
py::tuple result(modules.size());
|
||||
for (size_t i = 0; i < modules.size(); ++i) {
|
||||
auto& item = modules[i];
|
||||
result[i] = std::make_pair(item.key(), item.value().module);
|
||||
result[i] = std::make_pair(item.name, item.module);
|
||||
}
|
||||
return result;
|
||||
})
|
||||
.def(
|
||||
"_get_parameters",
|
||||
[](Module& self) -> py::tuple {
|
||||
auto& parameters = self.get_parameters();
|
||||
auto parameters = self.get_parameters();
|
||||
py::tuple result(parameters.size());
|
||||
for (size_t i = 0; i < parameters.size(); ++i) {
|
||||
auto& p = parameters[i];
|
||||
py::tuple r(2);
|
||||
result[i] = std::make_tuple(
|
||||
p.key(), autograd::as_variable_ref(p->slot()->toTensor()));
|
||||
p.name(), autograd::as_variable_ref(p.slot()->toTensor()));
|
||||
}
|
||||
return result;
|
||||
})
|
||||
.def(
|
||||
"_get_attributes",
|
||||
[](Module& self) -> py::tuple {
|
||||
auto& attributes = self.get_attributes();
|
||||
auto attributes = self.get_attributes();
|
||||
py::tuple result(attributes.size());
|
||||
for (size_t i = 0; i < attributes.size(); ++i) {
|
||||
auto& buffer = attributes[i];
|
||||
py::tuple r(3);
|
||||
IValue v = *buffer->slot();
|
||||
IValue v = *buffer.slot();
|
||||
result[i] = std::make_tuple(
|
||||
buffer.key(), buffer->type(), toPyObject(std::move(v)));
|
||||
buffer.name(), buffer.type(), toPyObject(std::move(v)));
|
||||
}
|
||||
return result;
|
||||
})
|
||||
@ -830,10 +830,9 @@ void initJitScriptBindings(PyObject* module) {
|
||||
.def(
|
||||
"_method_names",
|
||||
[](Module& self) {
|
||||
using Item =
|
||||
torch::OrderedDict<std::string, std::unique_ptr<Method>>::Item;
|
||||
return fmap(self.get_methods(), [](const Item& item) {
|
||||
return (*item)->name();
|
||||
return fmap(
|
||||
self.get_methods(), [](const std::unique_ptr<Method>& method) {
|
||||
return method->name();
|
||||
});
|
||||
})
|
||||
.def(
|
||||
@ -976,7 +975,9 @@ void initJitScriptBindings(PyObject* module) {
|
||||
.def(
|
||||
"propagate_and_assign_input_and_output_shapes",
|
||||
&Method::propagate_and_assign_input_and_output_shapes)
|
||||
.def("initial_ivalues",[](Method& m) {
|
||||
.def(
|
||||
"initial_ivalues",
|
||||
[](Method& m) {
|
||||
std::vector<at::Tensor> tensors;
|
||||
for (auto& t : m.initial_ivalues()) {
|
||||
tensors.push_back(t->toTensor());
|
||||
@ -996,16 +997,16 @@ void initJitScriptBindings(PyObject* module) {
|
||||
&Method::debugDisableAutodiffSubgraphInlining)
|
||||
.def("schema", &Method::getSchema)
|
||||
.def("pretty_print_schema", &Method::pretty_print_schema)
|
||||
.def("python_print", [](Method& m) {
|
||||
.def(
|
||||
"python_print",
|
||||
[](Method& m) {
|
||||
std::ostringstream oss;
|
||||
std::vector<at::Tensor> constants;
|
||||
std::vector<ClassTypePtr> classes;
|
||||
PythonPrint(oss, m, constants, classes, true);
|
||||
return std::make_pair(oss.str(), std::move(constants));
|
||||
})
|
||||
.def_property_readonly(
|
||||
"code",
|
||||
[](Method& self) {
|
||||
.def_property_readonly("code", [](Method& self) {
|
||||
std::ostringstream ss;
|
||||
std::vector<at::Tensor> tensors;
|
||||
std::vector<ClassTypePtr> classes;
|
||||
@ -1127,9 +1128,10 @@ void initJitScriptBindings(PyObject* module) {
|
||||
py::arg("checks_file"),
|
||||
py::arg("graph"));
|
||||
|
||||
m.def("_logging_set_logger", [](logging::LoggerBase* logger) {
|
||||
return logging::setLogger(logger);
|
||||
}, py::return_value_policy::reference);
|
||||
m.def(
|
||||
"_logging_set_logger",
|
||||
[](logging::LoggerBase* logger) { return logging::setLogger(logger); },
|
||||
py::return_value_policy::reference);
|
||||
py::class_<logging::LoggerBase, std::shared_ptr<logging::LoggerBase>>(
|
||||
m, "LoggerBase");
|
||||
py::enum_<logging::LockingLogger::AggregationType>(m, "AggregationType")
|
||||
@ -1148,7 +1150,6 @@ void initJitScriptBindings(PyObject* module) {
|
||||
logging::LoggerBase,
|
||||
std::shared_ptr<logging::NoopLogger>>(m, "NoopLogger")
|
||||
.def(py::init<>());
|
||||
|
||||
}
|
||||
} // namespace script
|
||||
} // namespace jit
|
||||
|
@ -122,13 +122,13 @@ void Module::to_impl(
|
||||
const c10::optional<at::ScalarType>& dtype,
|
||||
bool non_blocking) {
|
||||
// First call `to()` on every child module.
|
||||
for (auto& child : modules) {
|
||||
child->module->to_impl(device, dtype, non_blocking);
|
||||
for (auto& child : get_modules()) {
|
||||
child.module->to_impl(device, dtype, non_blocking);
|
||||
}
|
||||
// Then convert every of our parameters.
|
||||
for (auto& parameter : parameters) {
|
||||
for (auto& parameter : get_parameters()) {
|
||||
// Need to access the `at::Tensor` as a `Variable` here.
|
||||
autograd::Variable variable = parameter.value().slot()->toTensor();
|
||||
autograd::Variable variable = parameter.slot()->toTensor();
|
||||
at::Tensor data = variable.data();
|
||||
// Use the data's original device or dtype if not supplied here.
|
||||
auto new_data = data.to(
|
||||
|
@ -388,12 +388,7 @@ struct NamedIValue {
|
||||
|
||||
struct Module {
|
||||
TH_DISALLOW_COPY_AND_ASSIGN(Module);
|
||||
Module()
|
||||
: modules("Module"),
|
||||
parameters("Parameter"),
|
||||
attributes("Attributes"),
|
||||
methods("Method"),
|
||||
optimize(true) {}
|
||||
Module() : optimize(true) {}
|
||||
|
||||
// note this doesn't change the flags of existing methods just ones
|
||||
// added afterward.
|
||||
@ -410,12 +405,16 @@ struct Module {
|
||||
}
|
||||
|
||||
void register_buffer(const std::string& name, autograd::Variable v) {
|
||||
if (auto b = attributes.find(name)) {
|
||||
if (auto b = find_attribute(name)) {
|
||||
AT_ASSERT(b->type()->isSubtypeOf(TensorType::get()));
|
||||
*b->slot() = v;
|
||||
return;
|
||||
}
|
||||
attributes.insert(name, NamedIValue(name, TensorType::get(), std::move(v)));
|
||||
insert(
|
||||
name,
|
||||
attributes_,
|
||||
EntityType::ATTRIBUTE,
|
||||
NamedIValue(name, TensorType::get(), std::move(v)));
|
||||
}
|
||||
void register_parameter(
|
||||
const std::string& name,
|
||||
@ -425,22 +424,30 @@ struct Module {
|
||||
register_buffer(name, std::move(v));
|
||||
return;
|
||||
}
|
||||
if (auto p = parameters.find(name)) {
|
||||
if (auto p = find_parameter(name)) {
|
||||
*p->slot() = v;
|
||||
return;
|
||||
}
|
||||
parameters.insert(name, NamedIValue(name, TensorType::get(), std::move(v)));
|
||||
insert(
|
||||
name,
|
||||
parameters_,
|
||||
EntityType::PARAMETER,
|
||||
NamedIValue(name, TensorType::get(), std::move(v)));
|
||||
}
|
||||
void register_attribute(
|
||||
const std::string& name,
|
||||
const TypePtr type,
|
||||
IValue ivalue) {
|
||||
attributes.insert(name, NamedIValue(name, type, ivalue));
|
||||
insert(
|
||||
name,
|
||||
attributes_,
|
||||
EntityType::ATTRIBUTE,
|
||||
NamedIValue(name, type, ivalue));
|
||||
}
|
||||
void register_module(
|
||||
const std::string& name,
|
||||
std::shared_ptr<Module> module) {
|
||||
modules.insert(name, {name, std::move(module)});
|
||||
insert(name, modules_, EntityType::MODULE, {name, std::move(module)});
|
||||
}
|
||||
|
||||
Method& create_method(
|
||||
@ -455,7 +462,7 @@ struct Module {
|
||||
std::move(graph),
|
||||
std::move(member_inputs),
|
||||
nullptr));
|
||||
return *methods.insert(name, std::move(method));
|
||||
return *insert(name, methods_, EntityType::METHOD, std::move(method));
|
||||
}
|
||||
|
||||
Method& create_method(
|
||||
@ -468,11 +475,11 @@ struct Module {
|
||||
std::make_shared<Graph>(),
|
||||
{},
|
||||
std::move(creator)));
|
||||
return *methods.insert(name, std::move(method));
|
||||
return *insert(name, methods_, EntityType::METHOD, std::move(method));
|
||||
}
|
||||
|
||||
Slot parameter_slot(const std::string& name) const {
|
||||
return parameters[name].slot();
|
||||
return parameters_[get_offset(name, EntityType::PARAMETER)].slot();
|
||||
}
|
||||
|
||||
void set_parameter(const std::string& name, at::Tensor v) {
|
||||
@ -482,69 +489,71 @@ struct Module {
|
||||
autograd::Variable get_parameter(const std::string& name) const {
|
||||
return autograd::as_variable_ref(parameter_slot(name)->toTensor());
|
||||
}
|
||||
autograd::Variable get_buffer(const std::string& name) const {
|
||||
return autograd::as_variable_ref(attributes.find(name)->slot()->toTensor());
|
||||
}
|
||||
|
||||
IValue get_attribute(const std::string& name) const {
|
||||
return *attributes.find(name)->slot();
|
||||
return *attributes_[get_offset(name, EntityType::ATTRIBUTE)].slot();
|
||||
}
|
||||
|
||||
autograd::Variable get_buffer(const std::string& name) const {
|
||||
return autograd::as_variable_ref(get_attribute(name).toTensor());
|
||||
}
|
||||
|
||||
// each module owns its method. The reference returned here
|
||||
// is guarenteed to stay valid until this module has been destroyed
|
||||
Method& get_method(const std::string& name) const {
|
||||
return *methods[name];
|
||||
return *methods_[get_offset(name, EntityType::METHOD)];
|
||||
}
|
||||
|
||||
std::shared_ptr<Module> get_module(const std::string& name) const {
|
||||
return modules[name].module;
|
||||
return modules_[get_offset(name, EntityType::MODULE)].module;
|
||||
}
|
||||
|
||||
const torch::OrderedDict<std::string, NamedModule>& get_modules() const {
|
||||
return modules;
|
||||
c10::ArrayRef<NamedModule> get_modules() const {
|
||||
return modules_;
|
||||
}
|
||||
const torch::OrderedDict<std::string, NamedIValue>& get_parameters() const {
|
||||
return parameters;
|
||||
c10::ArrayRef<NamedIValue> get_parameters() const {
|
||||
return parameters_;
|
||||
}
|
||||
const torch::OrderedDict<std::string, NamedIValue>& get_attributes() const {
|
||||
return attributes;
|
||||
c10::ArrayRef<NamedIValue> get_attributes() const {
|
||||
return attributes_;
|
||||
}
|
||||
const torch::OrderedDict<std::string, std::unique_ptr<Method>>& get_methods()
|
||||
const {
|
||||
return methods;
|
||||
c10::ArrayRef<std::unique_ptr<Method>> get_methods() const {
|
||||
return methods_;
|
||||
}
|
||||
|
||||
NamedIValue* find_parameter(const std::string& name) {
|
||||
return parameters.find(name);
|
||||
auto offset = find_offset(name, EntityType::PARAMETER);
|
||||
return offset ? ¶meters_[*offset] : nullptr;
|
||||
}
|
||||
NamedIValue* find_attribute(const std::string& name) {
|
||||
return attributes.find(name);
|
||||
auto offset = find_offset(name, EntityType::ATTRIBUTE);
|
||||
return offset ? &attributes_[*offset] : nullptr;
|
||||
}
|
||||
NamedIValue* find_buffer(const std::string& name) {
|
||||
auto b = attributes.find(name);
|
||||
if (b && b->type()->isSubtypeOf(TensorType::get())) {
|
||||
return b;
|
||||
auto iv = find_attribute(name);
|
||||
if (iv && iv->type()->isSubtypeOf(TensorType::get())) {
|
||||
return iv;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
NamedModule* find_module(const std::string& name) {
|
||||
return modules.find(name);
|
||||
auto offset = find_offset(name, EntityType::MODULE);
|
||||
return offset ? &modules_[*offset] : nullptr;
|
||||
}
|
||||
Method* find_method(const std::string& name) {
|
||||
if (auto* pm = methods.find(name)) {
|
||||
return pm->get();
|
||||
}
|
||||
return nullptr;
|
||||
auto offset = find_offset(name, EntityType::METHOD);
|
||||
return offset ? methods_[*offset].get() : nullptr;
|
||||
}
|
||||
void apply(std::function<void(Module&)> fn) {
|
||||
for (auto& submod : get_modules()) {
|
||||
submod.value().module->apply(fn);
|
||||
submod.module->apply(fn);
|
||||
}
|
||||
fn(*this);
|
||||
}
|
||||
/// Enables "training" mode.
|
||||
void train(bool on = true) {
|
||||
for (auto& submod : get_modules()) {
|
||||
submod->module->train(on);
|
||||
submod.module->train(on);
|
||||
}
|
||||
register_buffer("training", torch::tensor(on ? 1 : 0, at::kLong));
|
||||
}
|
||||
@ -622,51 +631,139 @@ struct Module {
|
||||
std::unordered_map<Slot, Slot>& parameter_remap,
|
||||
std::vector<std::string> names = {}) const {
|
||||
auto curr = module_lookup(names);
|
||||
for (auto& kv : parameters) {
|
||||
for (auto& param : get_parameters()) {
|
||||
curr->register_parameter(
|
||||
kv.key(),
|
||||
kv.value().slot()->toTensor(),
|
||||
param.name(),
|
||||
param.slot()->toTensor(),
|
||||
/*is_buffer=*/false);
|
||||
parameter_remap[kv.value().slot()] = curr->parameter_slot(kv.key());
|
||||
parameter_remap[param.slot()] = curr->parameter_slot(param.name());
|
||||
}
|
||||
for (auto& kv : attributes) {
|
||||
if (!kv.value().type()->isSubtypeOf(TensorType::get())) {
|
||||
for (auto& attr : get_attributes()) {
|
||||
if (!attr.type()->isSubtypeOf(TensorType::get())) {
|
||||
continue;
|
||||
}
|
||||
curr->register_buffer(kv.key(), kv.value().slot()->toTensor());
|
||||
parameter_remap[kv.value().slot()] = curr->find_buffer(kv.key())->slot();
|
||||
curr->register_buffer(attr.name(), attr.slot()->toTensor());
|
||||
parameter_remap[attr.slot()] = curr->find_buffer(attr.name())->slot();
|
||||
}
|
||||
for (auto& kv : modules) {
|
||||
names.push_back(kv.key());
|
||||
for (auto& mod : get_modules()) {
|
||||
names.push_back(mod.name);
|
||||
// Submodules must be translated first, otherwise parameter_remap entries
|
||||
// will not be filled in for methods of this module.
|
||||
kv.value().module->copy_into(module_lookup, parameter_remap, names);
|
||||
mod.module->copy_into(module_lookup, parameter_remap, names);
|
||||
names.pop_back();
|
||||
}
|
||||
for (auto& kv : methods) {
|
||||
for (auto& method : get_methods()) {
|
||||
std::vector<Slot> initial_ivalues;
|
||||
for (auto& p : kv.value()->initial_ivalues()) {
|
||||
for (auto& p : method->initial_ivalues()) {
|
||||
initial_ivalues.push_back(parameter_remap.at(p));
|
||||
}
|
||||
curr->create_method(
|
||||
kv.key(), kv.value()->graph()->copy(), initial_ivalues);
|
||||
method->name(), method->graph()->copy(), initial_ivalues);
|
||||
}
|
||||
}
|
||||
|
||||
enum class EntityType { MODULE, PARAMETER, ATTRIBUTE, METHOD };
|
||||
|
||||
at::optional<EntityType> kind_of(const std::string& name) const {
|
||||
auto it = dict_.find(name);
|
||||
if (it == dict_.end())
|
||||
return at::nullopt;
|
||||
return it->second.type;
|
||||
}
|
||||
|
||||
private:
|
||||
void to_impl(
|
||||
const c10::optional<at::Device>& device,
|
||||
const c10::optional<at::ScalarType>& dtype,
|
||||
bool non_blocking);
|
||||
|
||||
// modules have a single namespace, but spread over 4 different concepts:
|
||||
// parameters, attributes, methods, and sub-modules
|
||||
// we store individual lists of each concept, and a single map to
|
||||
// unify the namespace and ensure fast lookup
|
||||
|
||||
// invariant: to ensure initial_ivalues of Methods stay valid,
|
||||
// it is only legal to _add_ new modules and parameters.
|
||||
// removing them will allow initial_ivalues to point to invalid parameters
|
||||
// no such restriction exists for methods
|
||||
torch::OrderedDict<std::string, NamedModule> modules;
|
||||
torch::OrderedDict<std::string, NamedIValue> parameters;
|
||||
torch::OrderedDict<std::string, NamedIValue> attributes;
|
||||
torch::OrderedDict<std::string, std::unique_ptr<Method>> methods;
|
||||
std::vector<NamedModule> modules_;
|
||||
std::vector<NamedIValue> parameters_;
|
||||
std::vector<NamedIValue> attributes_;
|
||||
std::vector<std::unique_ptr<Method>> methods_;
|
||||
|
||||
static const char* toString(EntityType t) {
|
||||
switch (t) {
|
||||
case EntityType::MODULE:
|
||||
return "module";
|
||||
case EntityType::PARAMETER:
|
||||
return "parameter";
|
||||
case EntityType::ATTRIBUTE:
|
||||
return "attribute";
|
||||
case EntityType::METHOD:
|
||||
return "method";
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
struct Entry {
|
||||
EntityType type;
|
||||
size_t offset;
|
||||
};
|
||||
|
||||
size_t get_offset(const std::string& name, EntityType expected_type) const {
|
||||
auto it = dict_.find(name);
|
||||
if (it == dict_.end()) {
|
||||
AT_ERROR(toString(expected_type), " '", name, "' is not defined.");
|
||||
}
|
||||
if (it->second.type != expected_type) {
|
||||
AT_ERROR(
|
||||
"The field '",
|
||||
name,
|
||||
"' is a ",
|
||||
toString(it->second.type),
|
||||
" but this call is"
|
||||
" trying to use it as a ",
|
||||
toString(expected_type));
|
||||
}
|
||||
return it->second.offset;
|
||||
}
|
||||
at::optional<size_t> find_offset(
|
||||
const std::string& name,
|
||||
EntityType expected_type) const {
|
||||
auto it = dict_.find(name);
|
||||
if (it == dict_.end() || it->second.type != expected_type) {
|
||||
return at::nullopt;
|
||||
}
|
||||
return it->second.offset;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T& insert(
|
||||
const std::string& name,
|
||||
std::vector<T>& list,
|
||||
EntityType type,
|
||||
T value) {
|
||||
auto it = dict_.find(name);
|
||||
if (it != dict_.end()) {
|
||||
if (type != it->second.type) {
|
||||
AT_ERROR(
|
||||
"attempting to add ",
|
||||
toString(type),
|
||||
" '",
|
||||
name,
|
||||
"' but it already exists as a ",
|
||||
toString(it->second.type));
|
||||
} else {
|
||||
AT_ERROR(toString(type), " '", name, "' already defined.");
|
||||
}
|
||||
}
|
||||
dict_[name] = Entry{type, list.size()};
|
||||
list.emplace_back(std::move(value));
|
||||
return list.back();
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, Entry> dict_;
|
||||
|
||||
bool optimize;
|
||||
};
|
||||
|
||||
|
@ -929,11 +929,10 @@ bool isHelperFunction(const std::string& method_name) {
|
||||
}
|
||||
|
||||
void loadModule(const std::shared_ptr<script::Module>& module) {
|
||||
for (const auto& method_ : module->get_methods()) {
|
||||
if (isHelperFunction(method_.key()))
|
||||
for (const auto& method : module->get_methods()) {
|
||||
if (isHelperFunction(method->name()))
|
||||
continue;
|
||||
|
||||
const auto& method = method_.value();
|
||||
GradientPair pair;
|
||||
pair.forward = method->graph();
|
||||
|
||||
|
Reference in New Issue
Block a user