mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
remove allow-untyped-defs from torch/mps/event.py (#144092)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144092 Approved by: https://github.com/aorenste
This commit is contained in:
committed by
PyTorch MergeBot
parent
496fc90965
commit
28a74fe3aa
@ -1,4 +1,3 @@
|
|||||||
# mypy: allow-untyped-defs
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
@ -13,33 +12,33 @@ class Event:
|
|||||||
(default: ``False``)
|
(default: ``False``)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, enable_timing=False):
|
def __init__(self, enable_timing: bool = False) -> None:
|
||||||
self.__eventId = torch._C._mps_acquireEvent(enable_timing)
|
self.__eventId = torch._C._mps_acquireEvent(enable_timing)
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self) -> None:
|
||||||
# checks if torch._C is already destroyed
|
# checks if torch._C is already destroyed
|
||||||
if hasattr(torch._C, "_mps_releaseEvent") and self.__eventId > 0:
|
if hasattr(torch._C, "_mps_releaseEvent") and self.__eventId > 0:
|
||||||
torch._C._mps_releaseEvent(self.__eventId)
|
torch._C._mps_releaseEvent(self.__eventId)
|
||||||
|
|
||||||
def record(self):
|
def record(self) -> None:
|
||||||
r"""Records the event in the default stream."""
|
r"""Records the event in the default stream."""
|
||||||
torch._C._mps_recordEvent(self.__eventId)
|
torch._C._mps_recordEvent(self.__eventId)
|
||||||
|
|
||||||
def wait(self):
|
def wait(self) -> None:
|
||||||
r"""Makes all future work submitted to the default stream wait for this event."""
|
r"""Makes all future work submitted to the default stream wait for this event."""
|
||||||
torch._C._mps_waitForEvent(self.__eventId)
|
torch._C._mps_waitForEvent(self.__eventId)
|
||||||
|
|
||||||
def query(self):
|
def query(self) -> bool:
|
||||||
r"""Returns True if all work currently captured by event has completed."""
|
r"""Returns True if all work currently captured by event has completed."""
|
||||||
return torch._C._mps_queryEvent(self.__eventId)
|
return torch._C._mps_queryEvent(self.__eventId)
|
||||||
|
|
||||||
def synchronize(self):
|
def synchronize(self) -> None:
|
||||||
r"""Waits until the completion of all work currently captured in this event.
|
r"""Waits until the completion of all work currently captured in this event.
|
||||||
This prevents the CPU thread from proceeding until the event completes.
|
This prevents the CPU thread from proceeding until the event completes.
|
||||||
"""
|
"""
|
||||||
torch._C._mps_synchronizeEvent(self.__eventId)
|
torch._C._mps_synchronizeEvent(self.__eventId)
|
||||||
|
|
||||||
def elapsed_time(self, end_event):
|
def elapsed_time(self, end_event: "Event") -> float:
|
||||||
r"""Returns the time elapsed in milliseconds after the event was
|
r"""Returns the time elapsed in milliseconds after the event was
|
||||||
recorded and before the end_event was recorded.
|
recorded and before the end_event was recorded.
|
||||||
"""
|
"""
|
||||||
|
Reference in New Issue
Block a user