Add type annotations to MPS profiler utilities (#163486)

## Summary
- drop the local mypy allow-untyped-defs escape hatch in the MPS profiler helpers
- annotate the context managers and bool helpers so they type-check cleanly

## Testing
- python -m mypy torch/mps/profiler.py --config-file mypy-strict.ini

------
https://chatgpt.com/codex/tasks/task_e_68d0ce4df2e483268d06673b65ef7745
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163486
Approved by: https://github.com/Skylion007
This commit is contained in:
Bob Ren
2025-09-27 23:00:50 +00:00
committed by PyTorch MergeBot
parent 2ce2e48a05
commit fd20889d0b

View File

@ -1,5 +1,6 @@
# mypy: allow-untyped-defs
import contextlib
from collections.abc import Iterator
from typing import Literal
import torch
@ -14,7 +15,10 @@ __all__ = [
]
def start(mode: str = "interval", wait_until_completed: bool = False) -> None:
ProfilerMode = Literal["interval", "event", "interval,event"]
def start(mode: ProfilerMode = "interval", wait_until_completed: bool = False) -> None:
r"""Start OS Signpost tracing from MPS backend.
The generated OS Signposts could be recorded and viewed in
@ -35,16 +39,20 @@ def start(mode: str = "interval", wait_until_completed: bool = False) -> None:
https://developer.apple.com/documentation/os/logging/recording_performance_data
"""
mode_normalized = mode.lower().replace(" ", "")
torch._C._mps_profilerStartTrace(mode_normalized, wait_until_completed)
torch._C._mps_profilerStartTrace( # type: ignore[attr-defined]
mode_normalized, wait_until_completed
)
def stop():
def stop() -> None:
r"""Stops generating OS Signpost tracing from MPS backend."""
torch._C._mps_profilerStopTrace()
torch._C._mps_profilerStopTrace() # type: ignore[attr-defined]
@contextlib.contextmanager
def profile(mode: str = "interval", wait_until_completed: bool = False):
def profile(
mode: ProfilerMode = "interval", wait_until_completed: bool = False
) -> Iterator[None]:
r"""Context Manager to enabling generating OS Signpost tracing from MPS backend.
Args:
@ -72,16 +80,16 @@ def is_metal_capture_enabled() -> bool:
"""Checks if `metal_capture` context manager is usable
To enable metal capture, set MTL_CAPTURE_ENABLED envvar
"""
return torch._C._mps_isCaptureEnabled() # type: ignore[attr-defined]
return torch._C._mps_isCaptureEnabled() # type: ignore[attr-defined, no-any-return]
def is_capturing_metal() -> bool:
"""Checks if metal capture is in progress"""
return torch._C._mps_isCapturing() # type: ignore[attr-defined]
return torch._C._mps_isCapturing() # type: ignore[attr-defined, no-any-return]
@contextlib.contextmanager
def metal_capture(fname: str):
def metal_capture(fname: str) -> Iterator[None]:
"""Context manager that enables capturing of Metal calls into gputrace"""
try:
torch._C._mps_startCapture(fname) # type: ignore[attr-defined]