Simplify retains grad hook implementation (#92604)

How the old retains_grad hooks was implemented:
- retains_grad hooks are stored on the autograd_meta, as entries in a vector
- upon registration, a wrapper hook CppFunctionTensorPreHook is created to wrap that vector, and then that wrapper hook is registered to the grad_fn, i.e., by appending it to a vector of retains_grad hooks on the grad_fn
- upon in-place, for the old grad_fn we set the retains_grad hook to nullptr, so that even though the old grad_fn still references the vector, the vector contains a single nullptr. For the new grad_fn, we create a new wrapper hook around the vector (storing the single retains_grad hook) on autograd_meta.

The new retains_grad hook implementation:
- we store std::function by value, and we store it on the grad_fn rather than the autograd_meta
- a single grad_fn can have multiple outputs, so it can potentially hold multiple retains_grad hooks. We use an unordered_map (previously a vector).
- on in-place we remove the hook from the old grad_fn and put it in the new grad_fn (small implication of this change is that  we we now need to have access to both the old grad_fn and new grad_fn, this isn't a problem)

Other details:
- CppFunctionTensorPreHook took a shared_ptr to vector of std::function. In our new implementation, we add a new wrapper hook CppFunctionSingleTensorPreHook, which takes a single std::function.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/92604
Approved by: https://github.com/albanD
This commit is contained in:
soulitzer
2023-01-20 22:16:54 -05:00
committed by PyTorch MergeBot
parent 71b1051230
commit a112814a7f
8 changed files with 81 additions and 70 deletions

View File

@ -48,5 +48,22 @@ variable_list CppFunctionTensorPreHook::operator()(
return results;
}
// NOLINTNEXTLINE(modernize-pass-by-value)
CppFunctionSingleTensorPreHook::CppFunctionSingleTensorPreHook(
std::function<at::TensorBase(const at::TensorBase&)> hook,
int value_idx)
: hook_(hook), value_idx_(value_idx) {}
variable_list CppFunctionSingleTensorPreHook::operator()(
const variable_list& values) {
auto value = values[value_idx_];
auto res = hook_(value);
TORCH_INTERNAL_ASSERT(
!res.defined(),
"CppFunctionSingleTensorPreHook currently only supports hooks that don't return");
variable_list results(values);
return results;
}
} // namespace autograd
} // namespace torch

View File

@ -19,5 +19,15 @@ struct CppFunctionTensorPreHook : public FunctionPreHook {
int value_idx_;
};
struct CppFunctionSingleTensorPreHook : public FunctionPreHook {
CppFunctionSingleTensorPreHook(
std::function<at::TensorBase(const at::TensorBase&)> hook,
int value_idx);
variable_list operator()(const variable_list& values) override;
std::function<at::TensorBase(const at::TensorBase&)> hook_;
int value_idx_;
};
} // namespace autograd
} // namespace torch

View File

@ -757,8 +757,8 @@ static variable_list call_tensor_pre_hooks(Node& fn, variable_list inputs) {
for (const auto& hook : fn.tensor_pre_hooks()) {
inputs = (*hook)(inputs);
}
for (const auto& hook : fn.retains_grad_hooks()) {
inputs = (*hook)(inputs);
for (const auto& pair : fn.retains_grad_hooks()) {
inputs = (*pair.second)(inputs);
}
return inputs;
}

View File

@ -490,8 +490,16 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> {
tensor_pre_hooks_.push_back(std::move(pre_hook));
}
void add_retains_grad_hook(std::unique_ptr<FunctionPreHook>&& pre_hook) {
retains_grad_hooks_.push_back(std::move(pre_hook));
void add_retains_grad_hook(
std::unique_ptr<FunctionPreHook>&& pre_hook,
int output_idx) {
retains_grad_hooks_[output_idx] = std::move(pre_hook);
}
std::unique_ptr<FunctionPreHook> pop_retains_grad_hook(int output_idx) {
auto ret = std::move(retains_grad_hooks_[output_idx]);
retains_grad_hooks_.erase(output_idx);
return ret;
}
const std::vector<std::unique_ptr<FunctionPreHook>>& pre_hooks()
@ -508,7 +516,8 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> {
return tensor_pre_hooks_;
}
std::vector<std::unique_ptr<FunctionPreHook>>& retains_grad_hooks() noexcept {
std::unordered_map<int, std::unique_ptr<FunctionPreHook>>&
retains_grad_hooks() noexcept {
return retains_grad_hooks_;
}
@ -636,7 +645,7 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> {
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
std::vector<std::unique_ptr<FunctionPreHook>> tensor_pre_hooks_;
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
std::vector<std::unique_ptr<FunctionPreHook>> retains_grad_hooks_;
std::unordered_map<int, std::unique_ptr<FunctionPreHook>> retains_grad_hooks_;
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
std::vector<std::unique_ptr<FunctionPostHook>> post_hooks_;
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)

View File

@ -86,8 +86,9 @@ int THPCppFunction_traverse(PyObject* self, visitproc visit, void* arg) {
// In theory this shouldn't be necessary, because retains_grad_hooks should
// not contain any PyFunctionTensorPreHooks. The alternative is to have a
// check that actually guarantees this.
for (const auto& hook : fn.retains_grad_hooks()) {
if (auto pyhook = dynamic_cast<PyFunctionTensorPreHook*>(hook.get())) {
for (const auto& pair : fn.retains_grad_hooks()) {
if (auto pyhook =
dynamic_cast<PyFunctionTensorPreHook*>(pair.second.get())) {
Py_VISIT(pyhook->dict);
}
}

View File

@ -219,8 +219,9 @@ static int THPFunction_traverse(THPFunction* self, visitproc visit, void* arg) {
}
}
// See NOTE [retains_grad_hook PyObject traversal]
for (const auto& hook : cdata->retains_grad_hooks()) {
if (auto pyhook = dynamic_cast<PyFunctionTensorPreHook*>(hook.get())) {
for (const auto& pair : cdata->retains_grad_hooks()) {
if (auto pyhook =
dynamic_cast<PyFunctionTensorPreHook*>(pair.second.get())) {
Py_VISIT(pyhook->dict);
}
}

View File

@ -158,42 +158,41 @@ AutogradMeta* materialize_autograd_meta(const at::TensorBase& self) {
void update_tensor_hooks_on_new_gradfn(
const at::TensorBase& self,
const std::shared_ptr<torch::autograd::Node>& old_fn,
const std::shared_ptr<torch::autograd::Node>& new_fn) {
// This function is called whenever the grad_fn of the tensor is
// changed. We assume here that new_fn does not yet have hooks of
// its own
// its own.
//
// This function does two things:
const auto& meta = impl::get_autograd_meta(self);
TORCH_INTERNAL_ASSERT(meta);
TORCH_INTERNAL_ASSERT(new_fn);
// (1) reset the list when grad_fn is updated, so new hooks don't
// get erroneously registered to the old grad_fn.
// Note that the old cpp_hooks_list_ is still kept alive by the
// old grad_fn so hooks registered to the older version of the tensor
// will continue to be active.
// (2) If there is a retains_grad hook registered, move that from the
// old cpp_hooks_list_ to the new one
const auto& meta = impl::get_autograd_meta(self);
TORCH_INTERNAL_ASSERT(meta);
TORCH_INTERNAL_ASSERT(new_fn);
meta->cpp_hooks_list_ = nullptr;
const c10::impl::PyInterpreter* interp =
self.unsafeGetTensorImpl()->pyobj_slot()->pyobj_interpreter();
if (interp) {
(*interp)->reset_backward_hooks(self.unsafeGetTensorImpl());
}
// (2) If there is a retains_grad hook registered, move that from the
// old cpp_hooks_list_ to the new one
if (self.retains_grad()) {
auto new_list = std::make_shared<hooks_list>();
new_list->push_back(std::move((*meta->retains_grad_hooks_list_)[0]));
(*meta->retains_grad_hooks_list_)[0] = nullptr;
meta->retains_grad_hooks_list_ = new_list;
std::unique_ptr<FunctionPreHook> hook_ptr =
std::make_unique<CppFunctionTensorPreHook>(
meta->retains_grad_hooks_list_, self.output_nr());
new_fn->add_retains_grad_hook(std::move(hook_ptr));
TORCH_INTERNAL_ASSERT(old_fn);
auto out = old_fn->pop_retains_grad_hook(self.output_nr());
TORCH_INTERNAL_ASSERT(out != nullptr);
new_fn->add_retains_grad_hook(std::move(out), self.output_nr());
}
}
void rebase_history(const Variable& self, Edge gradient_edge) {
TORCH_INTERNAL_ASSERT(gradient_edge.function != nullptr);
const auto& meta = impl::get_autograd_meta(self);
auto old_fn = meta != nullptr ? meta->grad_fn_ : nullptr;
auto diff_view_meta = get_view_autograd_meta(self);
if (diff_view_meta && diff_view_meta->has_bw_view()) {
// See NOTE [ View + Inplace detection ]
@ -221,21 +220,11 @@ void rebase_history(const Variable& self, Edge gradient_edge) {
set_gradient_edge(self, std::move(gradient_edge));
// Pass both self and its grad_fn to avoid calling into grad_fn reentrantly
torch::autograd::impl::update_tensor_hooks_on_new_gradfn(
self, self.grad_fn());
self, old_fn, self.grad_fn());
}
void create_cpp_hook(const at::TensorBase& self, bool is_retains_grad_hook) {
const auto& fn = self.grad_fn();
if (is_retains_grad_hook) {
std::shared_ptr<hooks_list>& list =
materialize_autograd_meta(self)->retains_grad_hooks_list_;
// NOLINTNEXTLINE(modernize-make-shared)
list.reset(new hooks_list());
std::unique_ptr<FunctionPreHook> hook_ptr{
new CppFunctionTensorPreHook(list, self.output_nr())};
TORCH_INTERNAL_ASSERT(fn, "Expect grad_fn to be defined for retains_grad");
fn->add_retains_grad_hook(std::move(hook_ptr));
} else {
std::shared_ptr<hooks_list>& list =
materialize_autograd_meta(self)->cpp_hooks_list_;
// NOLINTNEXTLINE(modernize-make-shared)
@ -251,7 +240,6 @@ void create_cpp_hook(const at::TensorBase& self, bool is_retains_grad_hook) {
fn->add_tensor_pre_hook(std::move(hook_ptr));
}
}
}
void set_grad_accumulator(
const Variable& self,
@ -529,24 +517,6 @@ int64_t VariableHooks::_version(const at::TensorBase& self) const {
return self.unsafeGetTensorImpl()->version_counter().current_version();
}
unsigned register_retains_grad_hook(
const at::TensorBase& self,
std::function<at::TensorBase(const at::TensorBase&)> hook) {
TORCH_CHECK(
self.requires_grad(),
"cannot retain grad on a variable that "
"doesn't require gradient");
// NB: materialize_autograd_meta unnecessary due to requires grad check
auto& list =
torch::autograd::impl::get_autograd_meta(self)->retains_grad_hooks_list_;
if (!list) {
torch::autograd::impl::create_cpp_hook(self, /*is_retains_grad_hook=*/true);
}
unsigned idx = list->size();
list->push_back(hook);
return idx;
}
void VariableHooks::retain_grad(const at::TensorBase& self) const {
TORCH_CHECK(
self.requires_grad(),
@ -583,7 +553,10 @@ void VariableHooks::retain_grad(const at::TensorBase& self) const {
return at::TensorBase{};
};
register_retains_grad_hook(self, retain_grad_hook);
const auto& fn = self.grad_fn();
std::unique_ptr<FunctionPreHook> hook_ptr{new CppFunctionSingleTensorPreHook(
std::move(retain_grad_hook), self.output_nr())};
fn->add_retains_grad_hook(std::move(hook_ptr), self.output_nr());
impl::get_autograd_meta(self)->retains_grad_ = true;
}
@ -674,6 +647,7 @@ const std::shared_ptr<torch::autograd::Node>& VariableHooks::grad_fn(
return diff_view_meta->grad_fn_;
}
auto current_version = self._version();
auto old_fn = diff_view_meta->grad_fn_;
if (diff_view_meta->get_attr_version() != current_version) {
// This is an indirect rebase_history due to another view or the base
// being modified inplace
@ -735,7 +709,7 @@ const std::shared_ptr<torch::autograd::Node>& VariableHooks::grad_fn(
diff_view_meta->set_attr_version(current_version);
torch::autograd::impl::update_tensor_hooks_on_new_gradfn(
self, diff_view_meta->grad_fn_);
self, old_fn, diff_view_meta->grad_fn_);
}
return diff_view_meta->grad_fn_;
}

View File

@ -229,7 +229,6 @@ struct TORCH_API AutogradMeta : public c10::AutogradMetaInterface {
// each other, so using both is not defined behavior.
std::vector<std::unique_ptr<FunctionPreHook>> hooks_;
std::shared_ptr<hooks_list> cpp_hooks_list_;
std::shared_ptr<hooks_list> retains_grad_hooks_list_;
// Only meaningful on leaf variables (must be false otherwise)
bool requires_grad_;