mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
See https://github.com/pytorch/pytorch/pull/129751#issue-2380881501. Most changes are auto-generated by linter. You can review these PRs via: ```bash git diff --ignore-all-space --ignore-blank-lines HEAD~1 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/129768 Approved by: https://github.com/jansel
31 lines
645 B
Python
31 lines
645 B
Python
# mypy: allow-untyped-defs
|
|
import contextlib
|
|
from typing import Callable, List, TYPE_CHECKING
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
import torch
|
|
|
|
# Executed in the order they're registered
|
|
INTERMEDIATE_HOOKS: List[Callable[[str, "torch.Tensor"], None]] = []
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def intermediate_hook(fn):
|
|
INTERMEDIATE_HOOKS.append(fn)
|
|
try:
|
|
yield
|
|
finally:
|
|
INTERMEDIATE_HOOKS.pop()
|
|
|
|
|
|
def run_intermediate_hooks(name, val):
|
|
global INTERMEDIATE_HOOKS
|
|
hooks = INTERMEDIATE_HOOKS
|
|
INTERMEDIATE_HOOKS = []
|
|
try:
|
|
for hook in hooks:
|
|
hook(name, val)
|
|
finally:
|
|
INTERMEDIATE_HOOKS = hooks
|