Add type annotations to MPS profiler utilities (#163486)

## Summary
- drop the local mypy allow-untyped-defs escape hatch in the MPS profiler helpers
- annotate the context managers and bool helpers so they type-check cleanly

## Testing
- python -m mypy torch/mps/profiler.py --config-file mypy-strict.ini

------
https://chatgpt.com/codex/tasks/task_e_68d0ce4df2e483268d06673b65ef7745
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163486
Approved by: https://github.com/Skylion007
This commit is contained in:
Bob Ren
2025-09-27 23:00:50 +00:00
committed by PyTorch MergeBot
parent 2ce2e48a05
commit fd20889d0b

View File

@ -1,5 +1,6 @@
# mypy: allow-untyped-defs
import contextlib import contextlib
from collections.abc import Iterator
from typing import Literal
import torch import torch
@ -14,7 +15,10 @@ __all__ = [
] ]
def start(mode: str = "interval", wait_until_completed: bool = False) -> None: ProfilerMode = Literal["interval", "event", "interval,event"]
def start(mode: ProfilerMode = "interval", wait_until_completed: bool = False) -> None:
r"""Start OS Signpost tracing from MPS backend. r"""Start OS Signpost tracing from MPS backend.
The generated OS Signposts could be recorded and viewed in The generated OS Signposts could be recorded and viewed in
@ -35,16 +39,20 @@ def start(mode: str = "interval", wait_until_completed: bool = False) -> None:
https://developer.apple.com/documentation/os/logging/recording_performance_data https://developer.apple.com/documentation/os/logging/recording_performance_data
""" """
mode_normalized = mode.lower().replace(" ", "") mode_normalized = mode.lower().replace(" ", "")
torch._C._mps_profilerStartTrace(mode_normalized, wait_until_completed) torch._C._mps_profilerStartTrace( # type: ignore[attr-defined]
mode_normalized, wait_until_completed
)
def stop(): def stop() -> None:
r"""Stops generating OS Signpost tracing from MPS backend.""" r"""Stops generating OS Signpost tracing from MPS backend."""
torch._C._mps_profilerStopTrace() torch._C._mps_profilerStopTrace() # type: ignore[attr-defined]
@contextlib.contextmanager @contextlib.contextmanager
def profile(mode: str = "interval", wait_until_completed: bool = False): def profile(
mode: ProfilerMode = "interval", wait_until_completed: bool = False
) -> Iterator[None]:
r"""Context Manager to enabling generating OS Signpost tracing from MPS backend. r"""Context Manager to enabling generating OS Signpost tracing from MPS backend.
Args: Args:
@ -72,16 +80,16 @@ def is_metal_capture_enabled() -> bool:
"""Checks if `metal_capture` context manager is usable """Checks if `metal_capture` context manager is usable
To enable metal capture, set MTL_CAPTURE_ENABLED envvar To enable metal capture, set MTL_CAPTURE_ENABLED envvar
""" """
return torch._C._mps_isCaptureEnabled() # type: ignore[attr-defined] return torch._C._mps_isCaptureEnabled() # type: ignore[attr-defined, no-any-return]
def is_capturing_metal() -> bool: def is_capturing_metal() -> bool:
"""Checks if metal capture is in progress""" """Checks if metal capture is in progress"""
return torch._C._mps_isCapturing() # type: ignore[attr-defined] return torch._C._mps_isCapturing() # type: ignore[attr-defined, no-any-return]
@contextlib.contextmanager @contextlib.contextmanager
def metal_capture(fname: str): def metal_capture(fname: str) -> Iterator[None]:
"""Context manager that enables capturing of Metal calls into gputrace""" """Context manager that enables capturing of Metal calls into gputrace"""
try: try:
torch._C._mps_startCapture(fname) # type: ignore[attr-defined] torch._C._mps_startCapture(fname) # type: ignore[attr-defined]