mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
This is the result of applying the ruff `UP035` check. `Callable` is imported from `collections.abc` instead of `typing`. `TypeAlias` is also imported from `typing`. This PR is the follow-up of #163947. Pull Request resolved: https://github.com/pytorch/pytorch/pull/164054 Approved by: https://github.com/ezyang, https://github.com/Skylion007
32 lines
1.0 KiB
Python
32 lines
1.0 KiB
Python
# NOTE! PLEASE KEEP THIS FILE *FREE* OF TORCH DEPS! IT SHOULD BE IMPORTABLE ANYWHERE.
|
|
# IF YOU FEEL AN OVERWHELMING URGE TO ADD A TORCH DEP, MAKE A TRAMPOLINE FILE A LA torch._dynamo.utils
|
|
# AND SCRUB AWAY TORCH NOTIONS THERE.
|
|
import collections
|
|
import functools
|
|
from collections import OrderedDict
|
|
from collections.abc import Callable
|
|
from typing import TypeVar
|
|
from typing_extensions import ParamSpec
|
|
|
|
|
|
simple_call_counter: OrderedDict[str, int] = collections.OrderedDict()
|
|
|
|
_P = ParamSpec("_P")
|
|
_R = TypeVar("_R")
|
|
|
|
|
|
def count_label(label: str) -> None:
|
|
prev = simple_call_counter.setdefault(label, 0)
|
|
simple_call_counter[label] = prev + 1
|
|
|
|
|
|
def count(fn: Callable[_P, _R]) -> Callable[_P, _R]:
|
|
@functools.wraps(fn)
|
|
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
|
|
if fn.__qualname__ not in simple_call_counter:
|
|
simple_call_counter[fn.__qualname__] = 0
|
|
simple_call_counter[fn.__qualname__] = simple_call_counter[fn.__qualname__] + 1
|
|
return fn(*args, **kwargs)
|
|
|
|
return wrapper
|