diff --git a/torch/mps/event.py b/torch/mps/event.py index d619c027480c..3f597c66a41f 100644 --- a/torch/mps/event.py +++ b/torch/mps/event.py @@ -1,4 +1,3 @@ -# mypy: allow-untyped-defs import torch @@ -13,33 +12,33 @@ class Event: (default: ``False``) """ - def __init__(self, enable_timing=False): + def __init__(self, enable_timing: bool = False) -> None: self.__eventId = torch._C._mps_acquireEvent(enable_timing) - def __del__(self): + def __del__(self) -> None: # checks if torch._C is already destroyed if hasattr(torch._C, "_mps_releaseEvent") and self.__eventId > 0: torch._C._mps_releaseEvent(self.__eventId) - def record(self): + def record(self) -> None: r"""Records the event in the default stream.""" torch._C._mps_recordEvent(self.__eventId) - def wait(self): + def wait(self) -> None: r"""Makes all future work submitted to the default stream wait for this event.""" torch._C._mps_waitForEvent(self.__eventId) - def query(self): + def query(self) -> bool: r"""Returns True if all work currently captured by event has completed.""" return torch._C._mps_queryEvent(self.__eventId) - def synchronize(self): + def synchronize(self) -> None: r"""Waits until the completion of all work currently captured in this event. This prevents the CPU thread from proceeding until the event completes. """ torch._C._mps_synchronizeEvent(self.__eventId) - def elapsed_time(self, end_event): + def elapsed_time(self, end_event: "Event") -> float: r"""Returns the time elapsed in milliseconds after the event was recorded and before the end_event was recorded. """