Simplify retains grad hook implementation (#92604)

How the old retains_grad hooks was implemented:
- retains_grad hooks are stored on the autograd_meta, as entries in a vector
- upon registration, a wrapper hook CppFunctionTensorPreHook is created to wrap that vector, and then that wrapper hook is registered to the grad_fn, i.e., by appending it to a vector of retains_grad hooks on the grad_fn
- upon in-place, for the old grad_fn we set the retains_grad hook to nullptr, so that even though the old grad_fn still references the vector, the vector contains a single nullptr. For the new grad_fn, we create a new wrapper hook around the vector (storing the single retains_grad hook) on autograd_meta.

The new retains_grad hook implementation:
- we store std::function by value, and we store it on the grad_fn rather than the autograd_meta
- a single grad_fn can have multiple outputs, so it can potentially hold multiple retains_grad hooks. We use an unordered_map (previously a vector).
- on in-place we remove the hook from the old grad_fn and put it in the new grad_fn (small implication of this change is that  we we now need to have access to both the old grad_fn and new grad_fn, this isn't a problem)

Other details:
- CppFunctionTensorPreHook took a shared_ptr to vector of std::function. In our new implementation, we add a new wrapper hook CppFunctionSingleTensorPreHook, which takes a single std::function.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/92604
Approved by: https://github.com/albanD
This commit is contained in:
soulitzer
2023-01-20 22:16:54 -05:00
committed by PyTorch MergeBot
parent 71b1051230
commit a112814a7f
8 changed files with 81 additions and 70 deletions

View File

@ -490,8 +490,16 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> {
tensor_pre_hooks_.push_back(std::move(pre_hook));
}
void add_retains_grad_hook(std::unique_ptr<FunctionPreHook>&& pre_hook) {
retains_grad_hooks_.push_back(std::move(pre_hook));
void add_retains_grad_hook(
std::unique_ptr<FunctionPreHook>&& pre_hook,
int output_idx) {
retains_grad_hooks_[output_idx] = std::move(pre_hook);
}
std::unique_ptr<FunctionPreHook> pop_retains_grad_hook(int output_idx) {
auto ret = std::move(retains_grad_hooks_[output_idx]);
retains_grad_hooks_.erase(output_idx);
return ret;
}
const std::vector<std::unique_ptr<FunctionPreHook>>& pre_hooks()
@ -508,7 +516,8 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> {
return tensor_pre_hooks_;
}
std::vector<std::unique_ptr<FunctionPreHook>>& retains_grad_hooks() noexcept {
std::unordered_map<int, std::unique_ptr<FunctionPreHook>>&
retains_grad_hooks() noexcept {
return retains_grad_hooks_;
}
@ -636,7 +645,7 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> {
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
std::vector<std::unique_ptr<FunctionPreHook>> tensor_pre_hooks_;
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
std::vector<std::unique_ptr<FunctionPreHook>> retains_grad_hooks_;
std::unordered_map<int, std::unique_ptr<FunctionPreHook>> retains_grad_hooks_;
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
std::vector<std::unique_ptr<FunctionPostHook>> post_hooks_;
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)