mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Simplify retains grad hook implementation (#92604)
How the old retains_grad hooks was implemented: - retains_grad hooks are stored on the autograd_meta, as entries in a vector - upon registration, a wrapper hook CppFunctionTensorPreHook is created to wrap that vector, and then that wrapper hook is registered to the grad_fn, i.e., by appending it to a vector of retains_grad hooks on the grad_fn - upon in-place, for the old grad_fn we set the retains_grad hook to nullptr, so that even though the old grad_fn still references the vector, the vector contains a single nullptr. For the new grad_fn, we create a new wrapper hook around the vector (storing the single retains_grad hook) on autograd_meta. The new retains_grad hook implementation: - we store std::function by value, and we store it on the grad_fn rather than the autograd_meta - a single grad_fn can have multiple outputs, so it can potentially hold multiple retains_grad hooks. We use an unordered_map (previously a vector). - on in-place we remove the hook from the old grad_fn and put it in the new grad_fn (small implication of this change is that we we now need to have access to both the old grad_fn and new grad_fn, this isn't a problem) Other details: - CppFunctionTensorPreHook took a shared_ptr to vector of std::function. In our new implementation, we add a new wrapper hook CppFunctionSingleTensorPreHook, which takes a single std::function. Pull Request resolved: https://github.com/pytorch/pytorch/pull/92604 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
71b1051230
commit
a112814a7f
@ -490,8 +490,16 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> {
|
||||
tensor_pre_hooks_.push_back(std::move(pre_hook));
|
||||
}
|
||||
|
||||
void add_retains_grad_hook(std::unique_ptr<FunctionPreHook>&& pre_hook) {
|
||||
retains_grad_hooks_.push_back(std::move(pre_hook));
|
||||
void add_retains_grad_hook(
|
||||
std::unique_ptr<FunctionPreHook>&& pre_hook,
|
||||
int output_idx) {
|
||||
retains_grad_hooks_[output_idx] = std::move(pre_hook);
|
||||
}
|
||||
|
||||
std::unique_ptr<FunctionPreHook> pop_retains_grad_hook(int output_idx) {
|
||||
auto ret = std::move(retains_grad_hooks_[output_idx]);
|
||||
retains_grad_hooks_.erase(output_idx);
|
||||
return ret;
|
||||
}
|
||||
|
||||
const std::vector<std::unique_ptr<FunctionPreHook>>& pre_hooks()
|
||||
@ -508,7 +516,8 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> {
|
||||
return tensor_pre_hooks_;
|
||||
}
|
||||
|
||||
std::vector<std::unique_ptr<FunctionPreHook>>& retains_grad_hooks() noexcept {
|
||||
std::unordered_map<int, std::unique_ptr<FunctionPreHook>>&
|
||||
retains_grad_hooks() noexcept {
|
||||
return retains_grad_hooks_;
|
||||
}
|
||||
|
||||
@ -636,7 +645,7 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
|
||||
std::vector<std::unique_ptr<FunctionPreHook>> tensor_pre_hooks_;
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
|
||||
std::vector<std::unique_ptr<FunctionPreHook>> retains_grad_hooks_;
|
||||
std::unordered_map<int, std::unique_ptr<FunctionPreHook>> retains_grad_hooks_;
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
|
||||
std::vector<std::unique_ptr<FunctionPostHook>> post_hooks_;
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
|
||||
|
Reference in New Issue
Block a user