Unify namespace of script::Module (#18378)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18378
ghimport-source-id: 55c29bb436a2153d29ff2f4488d99d8863c187b1

Stack from [ghstack](https://github.com/ezyang/ghstack):
* #18379 Enforce single parent for script submodules
* **#18378 Unify namespace of script::Module**
* #18314 Add ability to specialize class types to ArgumentSpec
* #18226 Add Slot type to abstract the raw pointers being used for slots.

This removes individual OrderedDicts in favor of a single unified
namespace for all things in a script::Module. This removes a whole
class of bugs where both a method and an parameter could get the
same name, for instance.

Since we no longer have to expose OrderedDict::Item objects, a lot of
downstream code can be simplified.

We no longer now double-store names (both in the key of the dictionary,
and in the object itself).

Differential Revision: D14603723

fbshipit-source-id: b5f7551b3074679623edd6ea70269830353b4d4c
This commit is contained in:
Zachary DeVito
2019-04-03 15:58:08 -07:00
committed by Facebook Github Bot
parent 773ce4fbd0
commit 0512e4e323
10 changed files with 234 additions and 142 deletions

View File

@ -15,10 +15,10 @@ void check_all_parameters(
const torch::jit::script::Module& module,
Predicate predicate) {
for (const auto& parameter : module.get_parameters()) {
AT_ASSERT(predicate(parameter->slot()->toTensor()));
AT_ASSERT(predicate(parameter.slot()->toTensor()));
}
for (const auto& child : module.get_modules()) {
check_all_parameters(*child->module, predicate);
check_all_parameters(*child.module, predicate);
}
}
} // namespace helpers

View File

@ -8330,7 +8330,7 @@ a")
''')
def test_duplicate(self):
with self.assertRaisesRegex(RuntimeError, 'Method \'test\' already defined'):
with self.assertRaisesRegex(RuntimeError, 'method \'test\' already defined'):
cu = torch.jit.CompilationUnit('''
def test():
return 1

View File

@ -454,8 +454,7 @@ void GraphEncoder::EncodeTensor(
} else {
AT_ASSERT(t.is_contiguous());
tensor_proto->set_raw_data(std::string(
static_cast<char*>(t.data_ptr()),
t.element_size() * t.numel()));
static_cast<char*>(t.data_ptr()), t.element_size() * t.numel()));
}
}
@ -665,8 +664,7 @@ void ScriptModuleSerializer::convertAndWriteTensor(
tensor_proto->set_requires_grad(tensor.requires_grad());
uint64_t record_size =
tensor.element_size() * tensor.storage().size();
uint64_t record_size = tensor.element_size() * tensor.storage().size();
auto* key = tensor.storage().unsafeGetStorageImpl();
auto storage_it = storageMap.find(key);
@ -686,8 +684,7 @@ void ScriptModuleSerializer::convertAndWriteTensor(
/* stride = */ {1})
.cpu();
AT_ASSERT(
storage_tensor.element_size() *
storage_tensor.storage().size() ==
storage_tensor.element_size() * storage_tensor.storage().size() ==
record_size);
}
std::string name = "tensors/" + std::to_string(tensor_id);
@ -733,11 +730,10 @@ void ScriptModuleSerializer::convertModule(
module_def->set_optimize(module.is_optimized());
for (const auto& elem : module.get_parameters()) {
torch::ParameterDef* param_def = module_def->add_parameters();
convertParameter(elem.value(), param_def, /*is_buffer=*/false);
convertParameter(elem, param_def, /*is_buffer=*/false);
}
for (const auto& item : module.get_attributes()) {
auto& attribute = item.value();
for (const auto& attribute : module.get_attributes()) {
// Add attribute to ModuleDef
torch::AttributeDef* attribute_def = module_def->add_attributes();
attribute_def->set_name(attribute.name());
@ -773,7 +769,7 @@ void ScriptModuleSerializer::convertModule(
for (const auto& elem : module.get_modules()) {
torch::ModuleDef* sub_def = module_def->add_submodules();
convertModule(*elem->module, module_name.str(), elem.key(), sub_def);
convertModule(*elem.module, module_name.str(), elem.name, sub_def);
}
}

View File

@ -1,9 +1,9 @@
#include <torch/csrc/jit/passes/python_print.h>
#include <c10/util/Exception.h>
#include <torch/csrc/jit/attributes.h>
#include <torch/csrc/jit/export.h>
#include <torch/csrc/jit/ir.h>
#include <torch/csrc/jit/ir_views.h>
#include <torch/csrc/jit/passes/python_print.h>
#include <torch/csrc/jit/resource_guard.h>
#include <torch/csrc/jit/script/error_report.h>
#include <torch/csrc/jit/script/module.h>
@ -131,17 +131,15 @@ void createTensorToParameterNameMap(
const script::Module& module,
const QualifiedNamePtr& prefix,
std::unordered_map<script::Slot, QualifiedNamePtr>& result) {
for (const auto& elem : module.get_parameters()) {
const script::NamedIValue& param = elem.value();
for (const auto& param : module.get_parameters()) {
result[param.slot()] = QualifiedName::create(prefix, param.name());
}
for (const auto& elem : module.get_attributes()) {
const script::NamedIValue& param = elem.value();
for (const auto& param : module.get_attributes()) {
result[param.slot()] = QualifiedName::create(prefix, param.name());
}
for (const auto& elem : module.get_modules()) {
createTensorToParameterNameMap(
*elem->module, QualifiedName::create(prefix, elem.key()), result);
*elem.module, QualifiedName::create(prefix, elem.name), result);
}
}
@ -1122,10 +1120,12 @@ struct PythonPrintPass {
void printMethod(
script::Method& method,
bool is_class,
const std::unordered_map<script::Slot, QualifiedNamePtr>& extra_ivalue_names) {
std::vector<std::string> ivalue_names = fmap(
method.initial_ivalues(),
[&](const script::Slot& slot) { return extra_ivalue_names.at(slot)->str(); });
const std::unordered_map<script::Slot, QualifiedNamePtr>&
extra_ivalue_names) {
std::vector<std::string> ivalue_names =
fmap(method.initial_ivalues(), [&](const script::Slot& slot) {
return extra_ivalue_names.at(slot)->str();
});
const std::string& name = method.name();
Graph& graph = *method.graph();
auto defaults = fmap(
@ -1138,14 +1138,14 @@ struct PythonPrintPass {
createTensorToParameterNameMap(
module, QualifiedName::create("self"), extra_ivalue_names);
for (auto& method : module.get_methods()) {
const std::string& name = method.value()->name();
const std::string& name = method->name();
// we skip __forked_functions because they actually get inlined into their
// callers, exporting them again will lead to more code generated on each
// export
if (name.find("__forked_function") == 0) {
continue;
}
printMethod(*method.value(), /*is_class=*/false, extra_ivalue_names);
printMethod(*method, /*is_class=*/false, extra_ivalue_names);
}
}

View File

@ -1,6 +1,6 @@
#include <torch/csrc/jit/script/builtin_functions.h>
#include <torch/csrc/api/include/torch/jit.h>
#include <torch/csrc/jit/code_template.h>
#include <torch/csrc/jit/script/builtin_functions.h>
namespace torch {
namespace jit {
@ -67,8 +67,8 @@ struct BuiltinFunctionRegistry {
module, source, script::nativeResolver, /*self=*/c10::nullopt);
modules.push_back(module);
for (auto& method : module->get_methods()) {
builtins_by_name[Symbol::fromQualString("aten::" + method.key())]
.push_back(method->get());
builtins_by_name[Symbol::fromQualString("aten::" + method->name())]
.push_back(method.get());
}
}
void loadBuiltinFunctions() {

View File

@ -10,10 +10,9 @@ Method* ClassType::getMethod(const std::string& name) const {
}
std::vector<Method*> ClassType::methods() const {
const auto& methods = module_->get_methods();
std::vector<Method*> ret;
for (const auto& pr : methods.items()) {
ret.push_back(pr.value().get());
for (const auto& pr : module_->get_methods()) {
ret.push_back(pr.get());
}
return ret;
}

View File

@ -273,10 +273,10 @@ struct VISIBILITY_HIDDEN ConstantParameterList : public SugaredValue {
size_t n_binders) override {
// Add all module parameters as inputs to the graph
std::vector<Value*> params;
const auto& param_list = module_->get_parameters().items();
const auto& param_list = module_->get_parameters();
for (auto it = param_list.rbegin(); it != param_list.rend(); ++it) {
auto& param = *it;
params.push_back(caller.get_or_add_parameter(param->slot()));
params.push_back(caller.get_or_add_parameter(param.slot()));
}
auto list = caller.graph()->createList(TensorType::get(), params);
caller.graph()->insertNode(list);
@ -606,15 +606,15 @@ static void gatherParametersAndBuffers(
std::vector<Slot>& values,
const Module& m) {
for (auto& param : m.get_parameters()) {
values.push_back(param->slot());
values.push_back(param.slot());
}
for (auto& param : m.get_attributes()) {
if (param->type()->isSubtypeOf(TensorType::get())) {
values.push_back(param->slot());
if (param.type()->isSubtypeOf(TensorType::get())) {
values.push_back(param.slot());
}
}
for (const auto& sub : m.get_modules()) {
gatherParametersAndBuffers(values, *sub->module);
gatherParametersAndBuffers(values, *sub.module);
}
}
@ -767,38 +767,38 @@ void initJitScriptBindings(PyObject* module) {
.def(
"_get_modules",
[](Module& self) -> py::tuple {
auto& modules = self.get_modules();
auto modules = self.get_modules();
py::tuple result(modules.size());
for (size_t i = 0; i < modules.size(); ++i) {
auto& item = modules[i];
result[i] = std::make_pair(item.key(), item.value().module);
result[i] = std::make_pair(item.name, item.module);
}
return result;
})
.def(
"_get_parameters",
[](Module& self) -> py::tuple {
auto& parameters = self.get_parameters();
auto parameters = self.get_parameters();
py::tuple result(parameters.size());
for (size_t i = 0; i < parameters.size(); ++i) {
auto& p = parameters[i];
py::tuple r(2);
result[i] = std::make_tuple(
p.key(), autograd::as_variable_ref(p->slot()->toTensor()));
p.name(), autograd::as_variable_ref(p.slot()->toTensor()));
}
return result;
})
.def(
"_get_attributes",
[](Module& self) -> py::tuple {
auto& attributes = self.get_attributes();
auto attributes = self.get_attributes();
py::tuple result(attributes.size());
for (size_t i = 0; i < attributes.size(); ++i) {
auto& buffer = attributes[i];
py::tuple r(3);
IValue v = *buffer->slot();
IValue v = *buffer.slot();
result[i] = std::make_tuple(
buffer.key(), buffer->type(), toPyObject(std::move(v)));
buffer.name(), buffer.type(), toPyObject(std::move(v)));
}
return result;
})
@ -830,10 +830,9 @@ void initJitScriptBindings(PyObject* module) {
.def(
"_method_names",
[](Module& self) {
using Item =
torch::OrderedDict<std::string, std::unique_ptr<Method>>::Item;
return fmap(self.get_methods(), [](const Item& item) {
return (*item)->name();
return fmap(
self.get_methods(), [](const std::unique_ptr<Method>& method) {
return method->name();
});
})
.def(
@ -976,7 +975,9 @@ void initJitScriptBindings(PyObject* module) {
.def(
"propagate_and_assign_input_and_output_shapes",
&Method::propagate_and_assign_input_and_output_shapes)
.def("initial_ivalues",[](Method& m) {
.def(
"initial_ivalues",
[](Method& m) {
std::vector<at::Tensor> tensors;
for (auto& t : m.initial_ivalues()) {
tensors.push_back(t->toTensor());
@ -996,16 +997,16 @@ void initJitScriptBindings(PyObject* module) {
&Method::debugDisableAutodiffSubgraphInlining)
.def("schema", &Method::getSchema)
.def("pretty_print_schema", &Method::pretty_print_schema)
.def("python_print", [](Method& m) {
.def(
"python_print",
[](Method& m) {
std::ostringstream oss;
std::vector<at::Tensor> constants;
std::vector<ClassTypePtr> classes;
PythonPrint(oss, m, constants, classes, true);
return std::make_pair(oss.str(), std::move(constants));
})
.def_property_readonly(
"code",
[](Method& self) {
.def_property_readonly("code", [](Method& self) {
std::ostringstream ss;
std::vector<at::Tensor> tensors;
std::vector<ClassTypePtr> classes;
@ -1127,9 +1128,10 @@ void initJitScriptBindings(PyObject* module) {
py::arg("checks_file"),
py::arg("graph"));
m.def("_logging_set_logger", [](logging::LoggerBase* logger) {
return logging::setLogger(logger);
}, py::return_value_policy::reference);
m.def(
"_logging_set_logger",
[](logging::LoggerBase* logger) { return logging::setLogger(logger); },
py::return_value_policy::reference);
py::class_<logging::LoggerBase, std::shared_ptr<logging::LoggerBase>>(
m, "LoggerBase");
py::enum_<logging::LockingLogger::AggregationType>(m, "AggregationType")
@ -1148,7 +1150,6 @@ void initJitScriptBindings(PyObject* module) {
logging::LoggerBase,
std::shared_ptr<logging::NoopLogger>>(m, "NoopLogger")
.def(py::init<>());
}
} // namespace script
} // namespace jit

View File

@ -122,13 +122,13 @@ void Module::to_impl(
const c10::optional<at::ScalarType>& dtype,
bool non_blocking) {
// First call `to()` on every child module.
for (auto& child : modules) {
child->module->to_impl(device, dtype, non_blocking);
for (auto& child : get_modules()) {
child.module->to_impl(device, dtype, non_blocking);
}
// Then convert every of our parameters.
for (auto& parameter : parameters) {
for (auto& parameter : get_parameters()) {
// Need to access the `at::Tensor` as a `Variable` here.
autograd::Variable variable = parameter.value().slot()->toTensor();
autograd::Variable variable = parameter.slot()->toTensor();
at::Tensor data = variable.data();
// Use the data's original device or dtype if not supplied here.
auto new_data = data.to(

View File

@ -388,12 +388,7 @@ struct NamedIValue {
struct Module {
TH_DISALLOW_COPY_AND_ASSIGN(Module);
Module()
: modules("Module"),
parameters("Parameter"),
attributes("Attributes"),
methods("Method"),
optimize(true) {}
Module() : optimize(true) {}
// note this doesn't change the flags of existing methods just ones
// added afterward.
@ -410,12 +405,16 @@ struct Module {
}
void register_buffer(const std::string& name, autograd::Variable v) {
if (auto b = attributes.find(name)) {
if (auto b = find_attribute(name)) {
AT_ASSERT(b->type()->isSubtypeOf(TensorType::get()));
*b->slot() = v;
return;
}
attributes.insert(name, NamedIValue(name, TensorType::get(), std::move(v)));
insert(
name,
attributes_,
EntityType::ATTRIBUTE,
NamedIValue(name, TensorType::get(), std::move(v)));
}
void register_parameter(
const std::string& name,
@ -425,22 +424,30 @@ struct Module {
register_buffer(name, std::move(v));
return;
}
if (auto p = parameters.find(name)) {
if (auto p = find_parameter(name)) {
*p->slot() = v;
return;
}
parameters.insert(name, NamedIValue(name, TensorType::get(), std::move(v)));
insert(
name,
parameters_,
EntityType::PARAMETER,
NamedIValue(name, TensorType::get(), std::move(v)));
}
void register_attribute(
const std::string& name,
const TypePtr type,
IValue ivalue) {
attributes.insert(name, NamedIValue(name, type, ivalue));
insert(
name,
attributes_,
EntityType::ATTRIBUTE,
NamedIValue(name, type, ivalue));
}
void register_module(
const std::string& name,
std::shared_ptr<Module> module) {
modules.insert(name, {name, std::move(module)});
insert(name, modules_, EntityType::MODULE, {name, std::move(module)});
}
Method& create_method(
@ -455,7 +462,7 @@ struct Module {
std::move(graph),
std::move(member_inputs),
nullptr));
return *methods.insert(name, std::move(method));
return *insert(name, methods_, EntityType::METHOD, std::move(method));
}
Method& create_method(
@ -468,11 +475,11 @@ struct Module {
std::make_shared<Graph>(),
{},
std::move(creator)));
return *methods.insert(name, std::move(method));
return *insert(name, methods_, EntityType::METHOD, std::move(method));
}
Slot parameter_slot(const std::string& name) const {
return parameters[name].slot();
return parameters_[get_offset(name, EntityType::PARAMETER)].slot();
}
void set_parameter(const std::string& name, at::Tensor v) {
@ -482,69 +489,71 @@ struct Module {
autograd::Variable get_parameter(const std::string& name) const {
return autograd::as_variable_ref(parameter_slot(name)->toTensor());
}
autograd::Variable get_buffer(const std::string& name) const {
return autograd::as_variable_ref(attributes.find(name)->slot()->toTensor());
}
IValue get_attribute(const std::string& name) const {
return *attributes.find(name)->slot();
return *attributes_[get_offset(name, EntityType::ATTRIBUTE)].slot();
}
autograd::Variable get_buffer(const std::string& name) const {
return autograd::as_variable_ref(get_attribute(name).toTensor());
}
// each module owns its method. The reference returned here
// is guarenteed to stay valid until this module has been destroyed
Method& get_method(const std::string& name) const {
return *methods[name];
return *methods_[get_offset(name, EntityType::METHOD)];
}
std::shared_ptr<Module> get_module(const std::string& name) const {
return modules[name].module;
return modules_[get_offset(name, EntityType::MODULE)].module;
}
const torch::OrderedDict<std::string, NamedModule>& get_modules() const {
return modules;
c10::ArrayRef<NamedModule> get_modules() const {
return modules_;
}
const torch::OrderedDict<std::string, NamedIValue>& get_parameters() const {
return parameters;
c10::ArrayRef<NamedIValue> get_parameters() const {
return parameters_;
}
const torch::OrderedDict<std::string, NamedIValue>& get_attributes() const {
return attributes;
c10::ArrayRef<NamedIValue> get_attributes() const {
return attributes_;
}
const torch::OrderedDict<std::string, std::unique_ptr<Method>>& get_methods()
const {
return methods;
c10::ArrayRef<std::unique_ptr<Method>> get_methods() const {
return methods_;
}
NamedIValue* find_parameter(const std::string& name) {
return parameters.find(name);
auto offset = find_offset(name, EntityType::PARAMETER);
return offset ? &parameters_[*offset] : nullptr;
}
NamedIValue* find_attribute(const std::string& name) {
return attributes.find(name);
auto offset = find_offset(name, EntityType::ATTRIBUTE);
return offset ? &attributes_[*offset] : nullptr;
}
NamedIValue* find_buffer(const std::string& name) {
auto b = attributes.find(name);
if (b && b->type()->isSubtypeOf(TensorType::get())) {
return b;
auto iv = find_attribute(name);
if (iv && iv->type()->isSubtypeOf(TensorType::get())) {
return iv;
}
return nullptr;
}
NamedModule* find_module(const std::string& name) {
return modules.find(name);
auto offset = find_offset(name, EntityType::MODULE);
return offset ? &modules_[*offset] : nullptr;
}
Method* find_method(const std::string& name) {
if (auto* pm = methods.find(name)) {
return pm->get();
}
return nullptr;
auto offset = find_offset(name, EntityType::METHOD);
return offset ? methods_[*offset].get() : nullptr;
}
void apply(std::function<void(Module&)> fn) {
for (auto& submod : get_modules()) {
submod.value().module->apply(fn);
submod.module->apply(fn);
}
fn(*this);
}
/// Enables "training" mode.
void train(bool on = true) {
for (auto& submod : get_modules()) {
submod->module->train(on);
submod.module->train(on);
}
register_buffer("training", torch::tensor(on ? 1 : 0, at::kLong));
}
@ -622,51 +631,139 @@ struct Module {
std::unordered_map<Slot, Slot>& parameter_remap,
std::vector<std::string> names = {}) const {
auto curr = module_lookup(names);
for (auto& kv : parameters) {
for (auto& param : get_parameters()) {
curr->register_parameter(
kv.key(),
kv.value().slot()->toTensor(),
param.name(),
param.slot()->toTensor(),
/*is_buffer=*/false);
parameter_remap[kv.value().slot()] = curr->parameter_slot(kv.key());
parameter_remap[param.slot()] = curr->parameter_slot(param.name());
}
for (auto& kv : attributes) {
if (!kv.value().type()->isSubtypeOf(TensorType::get())) {
for (auto& attr : get_attributes()) {
if (!attr.type()->isSubtypeOf(TensorType::get())) {
continue;
}
curr->register_buffer(kv.key(), kv.value().slot()->toTensor());
parameter_remap[kv.value().slot()] = curr->find_buffer(kv.key())->slot();
curr->register_buffer(attr.name(), attr.slot()->toTensor());
parameter_remap[attr.slot()] = curr->find_buffer(attr.name())->slot();
}
for (auto& kv : modules) {
names.push_back(kv.key());
for (auto& mod : get_modules()) {
names.push_back(mod.name);
// Submodules must be translated first, otherwise parameter_remap entries
// will not be filled in for methods of this module.
kv.value().module->copy_into(module_lookup, parameter_remap, names);
mod.module->copy_into(module_lookup, parameter_remap, names);
names.pop_back();
}
for (auto& kv : methods) {
for (auto& method : get_methods()) {
std::vector<Slot> initial_ivalues;
for (auto& p : kv.value()->initial_ivalues()) {
for (auto& p : method->initial_ivalues()) {
initial_ivalues.push_back(parameter_remap.at(p));
}
curr->create_method(
kv.key(), kv.value()->graph()->copy(), initial_ivalues);
method->name(), method->graph()->copy(), initial_ivalues);
}
}
enum class EntityType { MODULE, PARAMETER, ATTRIBUTE, METHOD };
at::optional<EntityType> kind_of(const std::string& name) const {
auto it = dict_.find(name);
if (it == dict_.end())
return at::nullopt;
return it->second.type;
}
private:
void to_impl(
const c10::optional<at::Device>& device,
const c10::optional<at::ScalarType>& dtype,
bool non_blocking);
// modules have a single namespace, but spread over 4 different concepts:
// parameters, attributes, methods, and sub-modules
// we store individual lists of each concept, and a single map to
// unify the namespace and ensure fast lookup
// invariant: to ensure initial_ivalues of Methods stay valid,
// it is only legal to _add_ new modules and parameters.
// removing them will allow initial_ivalues to point to invalid parameters
// no such restriction exists for methods
torch::OrderedDict<std::string, NamedModule> modules;
torch::OrderedDict<std::string, NamedIValue> parameters;
torch::OrderedDict<std::string, NamedIValue> attributes;
torch::OrderedDict<std::string, std::unique_ptr<Method>> methods;
std::vector<NamedModule> modules_;
std::vector<NamedIValue> parameters_;
std::vector<NamedIValue> attributes_;
std::vector<std::unique_ptr<Method>> methods_;
static const char* toString(EntityType t) {
switch (t) {
case EntityType::MODULE:
return "module";
case EntityType::PARAMETER:
return "parameter";
case EntityType::ATTRIBUTE:
return "attribute";
case EntityType::METHOD:
return "method";
}
return nullptr;
}
struct Entry {
EntityType type;
size_t offset;
};
size_t get_offset(const std::string& name, EntityType expected_type) const {
auto it = dict_.find(name);
if (it == dict_.end()) {
AT_ERROR(toString(expected_type), " '", name, "' is not defined.");
}
if (it->second.type != expected_type) {
AT_ERROR(
"The field '",
name,
"' is a ",
toString(it->second.type),
" but this call is"
" trying to use it as a ",
toString(expected_type));
}
return it->second.offset;
}
at::optional<size_t> find_offset(
const std::string& name,
EntityType expected_type) const {
auto it = dict_.find(name);
if (it == dict_.end() || it->second.type != expected_type) {
return at::nullopt;
}
return it->second.offset;
}
template <typename T>
T& insert(
const std::string& name,
std::vector<T>& list,
EntityType type,
T value) {
auto it = dict_.find(name);
if (it != dict_.end()) {
if (type != it->second.type) {
AT_ERROR(
"attempting to add ",
toString(type),
" '",
name,
"' but it already exists as a ",
toString(it->second.type));
} else {
AT_ERROR(toString(type), " '", name, "' already defined.");
}
}
dict_[name] = Entry{type, list.size()};
list.emplace_back(std::move(value));
return list.back();
}
std::unordered_map<std::string, Entry> dict_;
bool optimize;
};

View File

@ -929,11 +929,10 @@ bool isHelperFunction(const std::string& method_name) {
}
void loadModule(const std::shared_ptr<script::Module>& module) {
for (const auto& method_ : module->get_methods()) {
if (isHelperFunction(method_.key()))
for (const auto& method : module->get_methods()) {
if (isHelperFunction(method->name()))
continue;
const auto& method = method_.value();
GradientPair pair;
pair.forward = method->graph();