mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add tensor post accumulate grad hook API (#107063)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107063 Approved by: https://github.com/albanD, https://github.com/soulitzer
This commit is contained in:
committed by
PyTorch MergeBot
parent
3828cd4b79
commit
6e71ad0509
@ -523,6 +523,12 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> {
|
||||
return tensor_pre_hooks_;
|
||||
}
|
||||
|
||||
virtual std::unique_ptr<PostAccumulateGradHook>&
|
||||
tensor_post_acc_grad_hooks() noexcept {
|
||||
static std::unique_ptr<PostAccumulateGradHook> empty = nullptr;
|
||||
return empty;
|
||||
}
|
||||
|
||||
std::unordered_map<int, std::unique_ptr<FunctionPreHook>>&
|
||||
retains_grad_hooks() noexcept {
|
||||
return retains_grad_hooks_;
|
||||
|
Reference in New Issue
Block a user