add WaitCounter type interface and get rid of type errors (#146175)

Summary: as titled

Differential Revision: D68960123

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146175
Approved by: https://github.com/andriigrynenko, https://github.com/Skylion007
This commit is contained in:
Burak Turk
2025-02-01 23:24:52 +00:00
committed by PyTorch MergeBot
parent 3a67c0e48d
commit d89c7ea401
2 changed files with 16 additions and 3 deletions

View File

@ -2,7 +2,8 @@
import datetime
from enum import Enum
from typing import Callable
from types import TracebackType
from typing import Callable, Optional, Type
class Aggregation(Enum):
VALUE = ...
@ -42,3 +43,16 @@ class EventHandlerHandle: ...
def register_event_handler(handler: Callable[[Event], None]) -> EventHandlerHandle: ...
def unregister_event_handler(handle: EventHandlerHandle) -> None: ...
class _WaitCounterTracker:
def __enter__(self) -> None: ...
def __exit__(
self,
exec_type: Optional[Type[BaseException]] = None,
exec_value: Optional[BaseException] = None,
traceback: Optional[TracebackType] = None,
) -> None: ...
class _WaitCounter:
def __init__(self, key: str) -> None: ...
def guard(self) -> _WaitCounterTracker: ...

View File

@ -1,13 +1,12 @@
from typing import TYPE_CHECKING
from torch._C._monitor import * # noqa: F403
from torch._C._monitor import _WaitCounter # type: ignore[attr-defined]
from torch._C._monitor import _WaitCounter, _WaitCounterTracker
if TYPE_CHECKING:
from torch.utils.tensorboard import SummaryWriter
STAT_EVENT = "torch.monitor.Stat"