mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Add type annotations to MPS profiler utilities (#163486)
## Summary - drop the local mypy allow-untyped-defs escape hatch in the MPS profiler helpers - annotate the context managers and bool helpers so they type-check cleanly ## Testing - python -m mypy torch/mps/profiler.py --config-file mypy-strict.ini ------ https://chatgpt.com/codex/tasks/task_e_68d0ce4df2e483268d06673b65ef7745 Pull Request resolved: https://github.com/pytorch/pytorch/pull/163486 Approved by: https://github.com/Skylion007
This commit is contained in:
committed by
PyTorch MergeBot
parent
2ce2e48a05
commit
fd20889d0b
@ -1,5 +1,6 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import contextlib
|
||||
from collections.abc import Iterator
|
||||
from typing import Literal
|
||||
|
||||
import torch
|
||||
|
||||
@ -14,7 +15,10 @@ __all__ = [
|
||||
]
|
||||
|
||||
|
||||
def start(mode: str = "interval", wait_until_completed: bool = False) -> None:
|
||||
ProfilerMode = Literal["interval", "event", "interval,event"]
|
||||
|
||||
|
||||
def start(mode: ProfilerMode = "interval", wait_until_completed: bool = False) -> None:
|
||||
r"""Start OS Signpost tracing from MPS backend.
|
||||
|
||||
The generated OS Signposts could be recorded and viewed in
|
||||
@ -35,16 +39,20 @@ def start(mode: str = "interval", wait_until_completed: bool = False) -> None:
|
||||
https://developer.apple.com/documentation/os/logging/recording_performance_data
|
||||
"""
|
||||
mode_normalized = mode.lower().replace(" ", "")
|
||||
torch._C._mps_profilerStartTrace(mode_normalized, wait_until_completed)
|
||||
torch._C._mps_profilerStartTrace( # type: ignore[attr-defined]
|
||||
mode_normalized, wait_until_completed
|
||||
)
|
||||
|
||||
|
||||
def stop():
|
||||
def stop() -> None:
|
||||
r"""Stops generating OS Signpost tracing from MPS backend."""
|
||||
torch._C._mps_profilerStopTrace()
|
||||
torch._C._mps_profilerStopTrace() # type: ignore[attr-defined]
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def profile(mode: str = "interval", wait_until_completed: bool = False):
|
||||
def profile(
|
||||
mode: ProfilerMode = "interval", wait_until_completed: bool = False
|
||||
) -> Iterator[None]:
|
||||
r"""Context Manager to enabling generating OS Signpost tracing from MPS backend.
|
||||
|
||||
Args:
|
||||
@ -72,16 +80,16 @@ def is_metal_capture_enabled() -> bool:
|
||||
"""Checks if `metal_capture` context manager is usable
|
||||
To enable metal capture, set MTL_CAPTURE_ENABLED envvar
|
||||
"""
|
||||
return torch._C._mps_isCaptureEnabled() # type: ignore[attr-defined]
|
||||
return torch._C._mps_isCaptureEnabled() # type: ignore[attr-defined, no-any-return]
|
||||
|
||||
|
||||
def is_capturing_metal() -> bool:
|
||||
"""Checks if metal capture is in progress"""
|
||||
return torch._C._mps_isCapturing() # type: ignore[attr-defined]
|
||||
return torch._C._mps_isCapturing() # type: ignore[attr-defined, no-any-return]
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def metal_capture(fname: str):
|
||||
def metal_capture(fname: str) -> Iterator[None]:
|
||||
"""Context manager that enables capturing of Metal calls into gputrace"""
|
||||
try:
|
||||
torch._C._mps_startCapture(fname) # type: ignore[attr-defined]
|
||||
|
Reference in New Issue
Block a user