remove allow-untyped-defs from torch/mps/event.py (#144092)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144092
Approved by: https://github.com/aorenste
This commit is contained in:
bobrenjc93
2025-01-02 11:21:28 -08:00
committed by PyTorch MergeBot
parent 496fc90965
commit 28a74fe3aa

View File

@ -1,4 +1,3 @@
# mypy: allow-untyped-defs
import torch import torch
@ -13,33 +12,33 @@ class Event:
(default: ``False``) (default: ``False``)
""" """
def __init__(self, enable_timing=False): def __init__(self, enable_timing: bool = False) -> None:
self.__eventId = torch._C._mps_acquireEvent(enable_timing) self.__eventId = torch._C._mps_acquireEvent(enable_timing)
def __del__(self): def __del__(self) -> None:
# checks if torch._C is already destroyed # checks if torch._C is already destroyed
if hasattr(torch._C, "_mps_releaseEvent") and self.__eventId > 0: if hasattr(torch._C, "_mps_releaseEvent") and self.__eventId > 0:
torch._C._mps_releaseEvent(self.__eventId) torch._C._mps_releaseEvent(self.__eventId)
def record(self): def record(self) -> None:
r"""Records the event in the default stream.""" r"""Records the event in the default stream."""
torch._C._mps_recordEvent(self.__eventId) torch._C._mps_recordEvent(self.__eventId)
def wait(self): def wait(self) -> None:
r"""Makes all future work submitted to the default stream wait for this event.""" r"""Makes all future work submitted to the default stream wait for this event."""
torch._C._mps_waitForEvent(self.__eventId) torch._C._mps_waitForEvent(self.__eventId)
def query(self): def query(self) -> bool:
r"""Returns True if all work currently captured by event has completed.""" r"""Returns True if all work currently captured by event has completed."""
return torch._C._mps_queryEvent(self.__eventId) return torch._C._mps_queryEvent(self.__eventId)
def synchronize(self): def synchronize(self) -> None:
r"""Waits until the completion of all work currently captured in this event. r"""Waits until the completion of all work currently captured in this event.
This prevents the CPU thread from proceeding until the event completes. This prevents the CPU thread from proceeding until the event completes.
""" """
torch._C._mps_synchronizeEvent(self.__eventId) torch._C._mps_synchronizeEvent(self.__eventId)
def elapsed_time(self, end_event): def elapsed_time(self, end_event: "Event") -> float:
r"""Returns the time elapsed in milliseconds after the event was r"""Returns the time elapsed in milliseconds after the event was
recorded and before the end_event was recorded. recorded and before the end_event was recorded.
""" """