From fd20889d0b7670dbb77f09aba61f1a49f2757781 Mon Sep 17 00:00:00 2001 From: Bob Ren Date: Sat, 27 Sep 2025 23:00:50 +0000 Subject: [PATCH] Add type annotations to MPS profiler utilities (#163486) ## Summary - drop the local mypy allow-untyped-defs escape hatch in the MPS profiler helpers - annotate the context managers and bool helpers so they type-check cleanly ## Testing - python -m mypy torch/mps/profiler.py --config-file mypy-strict.ini ------ https://chatgpt.com/codex/tasks/task_e_68d0ce4df2e483268d06673b65ef7745 Pull Request resolved: https://github.com/pytorch/pytorch/pull/163486 Approved by: https://github.com/Skylion007 --- torch/mps/profiler.py | 26 +++++++++++++++++--------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/torch/mps/profiler.py b/torch/mps/profiler.py index eebeea9a02a4..daaaf729ff2e 100644 --- a/torch/mps/profiler.py +++ b/torch/mps/profiler.py @@ -1,5 +1,6 @@ -# mypy: allow-untyped-defs import contextlib +from collections.abc import Iterator +from typing import Literal import torch @@ -14,7 +15,10 @@ __all__ = [ ] -def start(mode: str = "interval", wait_until_completed: bool = False) -> None: +ProfilerMode = Literal["interval", "event", "interval,event"] + + +def start(mode: ProfilerMode = "interval", wait_until_completed: bool = False) -> None: r"""Start OS Signpost tracing from MPS backend. The generated OS Signposts could be recorded and viewed in @@ -35,16 +39,20 @@ def start(mode: str = "interval", wait_until_completed: bool = False) -> None: https://developer.apple.com/documentation/os/logging/recording_performance_data """ mode_normalized = mode.lower().replace(" ", "") - torch._C._mps_profilerStartTrace(mode_normalized, wait_until_completed) + torch._C._mps_profilerStartTrace( # type: ignore[attr-defined] + mode_normalized, wait_until_completed + ) -def stop(): +def stop() -> None: r"""Stops generating OS Signpost tracing from MPS backend.""" - torch._C._mps_profilerStopTrace() + torch._C._mps_profilerStopTrace() # type: ignore[attr-defined] @contextlib.contextmanager -def profile(mode: str = "interval", wait_until_completed: bool = False): +def profile( + mode: ProfilerMode = "interval", wait_until_completed: bool = False +) -> Iterator[None]: r"""Context Manager to enabling generating OS Signpost tracing from MPS backend. Args: @@ -72,16 +80,16 @@ def is_metal_capture_enabled() -> bool: """Checks if `metal_capture` context manager is usable To enable metal capture, set MTL_CAPTURE_ENABLED envvar """ - return torch._C._mps_isCaptureEnabled() # type: ignore[attr-defined] + return torch._C._mps_isCaptureEnabled() # type: ignore[attr-defined, no-any-return] def is_capturing_metal() -> bool: """Checks if metal capture is in progress""" - return torch._C._mps_isCapturing() # type: ignore[attr-defined] + return torch._C._mps_isCapturing() # type: ignore[attr-defined, no-any-return] @contextlib.contextmanager -def metal_capture(fname: str): +def metal_capture(fname: str) -> Iterator[None]: """Context manager that enables capturing of Metal calls into gputrace""" try: torch._C._mps_startCapture(fname) # type: ignore[attr-defined]