Add tensor post accumulate grad hook API (#107063)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107063
Approved by: https://github.com/albanD, https://github.com/soulitzer
This commit is contained in:
Jane Xu
2023-08-21 15:17:14 -07:00
committed by PyTorch MergeBot
parent 3828cd4b79
commit 6e71ad0509
14 changed files with 378 additions and 8 deletions

View File

@ -523,6 +523,12 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> {
return tensor_pre_hooks_;
}
virtual std::unique_ptr<PostAccumulateGradHook>&
tensor_post_acc_grad_hooks() noexcept {
static std::unique_ptr<PostAccumulateGradHook> empty = nullptr;
return empty;
}
std::unordered_map<int, std::unique_ptr<FunctionPreHook>>&
retains_grad_hooks() noexcept {
return retains_grad_hooks_;