diff --git a/torch/mps/profiler.py b/torch/mps/profiler.py index eebeea9a02a4..daaaf729ff2e 100644 --- a/torch/mps/profiler.py +++ b/torch/mps/profiler.py @@ -1,5 +1,6 @@ -# mypy: allow-untyped-defs import contextlib +from collections.abc import Iterator +from typing import Literal import torch @@ -14,7 +15,10 @@ __all__ = [ ] -def start(mode: str = "interval", wait_until_completed: bool = False) -> None: +ProfilerMode = Literal["interval", "event", "interval,event"] + + +def start(mode: ProfilerMode = "interval", wait_until_completed: bool = False) -> None: r"""Start OS Signpost tracing from MPS backend. The generated OS Signposts could be recorded and viewed in @@ -35,16 +39,20 @@ def start(mode: str = "interval", wait_until_completed: bool = False) -> None: https://developer.apple.com/documentation/os/logging/recording_performance_data """ mode_normalized = mode.lower().replace(" ", "") - torch._C._mps_profilerStartTrace(mode_normalized, wait_until_completed) + torch._C._mps_profilerStartTrace( # type: ignore[attr-defined] + mode_normalized, wait_until_completed + ) -def stop(): +def stop() -> None: r"""Stops generating OS Signpost tracing from MPS backend.""" - torch._C._mps_profilerStopTrace() + torch._C._mps_profilerStopTrace() # type: ignore[attr-defined] @contextlib.contextmanager -def profile(mode: str = "interval", wait_until_completed: bool = False): +def profile( + mode: ProfilerMode = "interval", wait_until_completed: bool = False +) -> Iterator[None]: r"""Context Manager to enabling generating OS Signpost tracing from MPS backend. Args: @@ -72,16 +80,16 @@ def is_metal_capture_enabled() -> bool: """Checks if `metal_capture` context manager is usable To enable metal capture, set MTL_CAPTURE_ENABLED envvar """ - return torch._C._mps_isCaptureEnabled() # type: ignore[attr-defined] + return torch._C._mps_isCaptureEnabled() # type: ignore[attr-defined, no-any-return] def is_capturing_metal() -> bool: """Checks if metal capture is in progress""" - return torch._C._mps_isCapturing() # type: ignore[attr-defined] + return torch._C._mps_isCapturing() # type: ignore[attr-defined, no-any-return] @contextlib.contextmanager -def metal_capture(fname: str): +def metal_capture(fname: str) -> Iterator[None]: """Context manager that enables capturing of Metal calls into gputrace""" try: torch._C._mps_startCapture(fname) # type: ignore[attr-defined]