[MPS] Expose MPSProfiler::start/stopCapture to Python (#144561)

I.e. when `MTL_CAPTURE_ENABLED` environment variable is set to 1, one should be able to invoke wrap the code with `torch.mps.profiler.capture_metal` to generate gputrace for shaders invoked inside the context manager.

For example, code below:
```python
import torch
import os

def foo(x):
   return x[:,::2].sin() + x[:, 1::2].cos()

if __name__ == "__main__":
    os.environ["MTL_CAPTURE_ENABLED"] = "1"
    x = torch.rand(32, 1024, device="mps")

    with torch.mps.profiler.metal_capture("compiled_shader"):
        torch.compile(foo)(x)
```
should capture the execution of a `torch.compile` generated shader
<img width="734" alt="image" src="https://github.com/user-attachments/assets/718ff64e-103b-4b11-b66c-c89cfc770b5d" />

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144561
Approved by: https://github.com/manuelcandales
ghstack dependencies: #144559, #144560
This commit is contained in:
Nikita Shulga
2025-01-10 16:20:54 -08:00
committed by PyTorch MergeBot
parent c7dbee5106
commit 92ddb3d3d3
6 changed files with 71 additions and 1 deletions

View File

@ -152,6 +152,7 @@ jobs:
set -e
${CONDA_RUN} python3 test/run_test.py --mps --verbose
MTL_CAPTURE_ENABLED=1 ${CONDA_RUN} python3 test/test_mps.py --verbose -k test_metal_capture
- name: Print remaining test logs
shell: bash

View File

@ -29,6 +29,10 @@ MPS Profiler
profiler.stop
profiler.profile
profiler.is_capturing_metal
profiler.is_metal_capture_enabled
profiler.metal_capture
MPS Event
------------
.. autosummary::

View File

@ -6,6 +6,7 @@ import math
import random
import unittest
import warnings
import shutil
import subprocess
import tempfile
import os
@ -12696,6 +12697,24 @@ class TestMetalLibrary(TestCaseMPS):
# Passing no tensors asserts
self.assertRaises(RuntimeError, lambda: lib.full(12))
@unittest.skipIf(not torch.mps.profiler.is_metal_capture_enabled(), "Set MTL_CAPTURE_ENABLED and try again")
def test_metal_capture(self):
lib = torch.mps._compile_shader("kernel void full(device float* x, uint idx [[thread_position_in_grid]]) { x[idx] = 1.0; }")
mps_tensor = torch.rand(32, device="mps")
capture_name = f"lib_full{''.join(random.choice('0123456789') for i in range(5))}"
capture_dirname = f"0000-{capture_name}.gputrace"
if os.path.exists(capture_dirname):
shutil.rmtree(capture_dirname)
with torch.mps.profiler.metal_capture(capture_name):
self.assertTrue(torch.mps.profiler.is_capturing_metal())
lib.full(mps_tensor)
self.assertEqual(mps_tensor.sum().item(), 32.0)
self.assertTrue(os.path.exists(capture_dirname), f"Capture file {capture_dirname} has not been generated")
capture_listdir = os.listdir(capture_dirname)
shutil.rmtree(capture_dirname)
self.assertGreater(len(capture_listdir), 3,
f"Capture file {capture_dirname} contains only metadata, i.e. {capture_listdir}")
# TODO: Actually instantiate that test for the "mps" device to better reflect what it is doing.
# This requires mps to be properly registered in the device generic test framework which is not the

View File

@ -1834,6 +1834,10 @@ def _mps_waitForEvent(event_id: _int) -> None: ...
def _mps_synchronizeEvent(event_id: _int) -> None: ...
def _mps_queryEvent(event_id: _int) -> _bool: ...
def _mps_elapsedTimeOfEvents(start_event_id: _int, end_event_id: _int) -> _float: ...
def _mps_isCaptureEnabled() -> _bool: ...
def _mps_isCapturing() -> _bool: ...
def _mps_startCapture(name: str) -> None: ...
def _mps_stopCapture() -> None: ...
# Defined in torch/csrc/cuda/Module.cpp

View File

@ -17,6 +17,7 @@
#endif
#ifdef USE_MPS
#include <ATen/mps/MPSProfiler.h>
#include <ATen/native/mps/MetalShaderLibrary.h>
#endif
@ -504,6 +505,16 @@ void initModule(PyObject* module) {
m.def("_mps_compileShader", [](const std::string& source) {
return std::make_shared<DynamicMetalShaderLibrary>(source);
});
m.def("_mps_isCaptureEnabled", []() {
return at::mps::getMPSProfiler().isCaptureEnabled();
});
m.def("_mps_isCapturing", []() {
return at::mps::getMPSProfiler().isCapturing();
});
m.def("_mps_startCapture", [](const std::string& fileName) {
at::mps::getMPSProfiler().startCapture(fileName);
});
m.def("_mps_stopCapture", []() { at::mps::getMPSProfiler().stopCapture(); });
}
#endif /* USE_MPS */

View File

@ -4,7 +4,14 @@ import contextlib
import torch
__all__ = ["start", "stop", "profile"]
__all__ = [
"start",
"stop",
"profile",
"metal_capture",
"is_metal_capture_enabled",
"is_capturing_metal",
]
def start(mode: str = "interval", wait_until_completed: bool = False) -> None:
@ -59,3 +66,27 @@ def profile(mode: str = "interval", wait_until_completed: bool = False):
yield
finally:
stop()
def is_metal_capture_enabled() -> bool:
"""Checks if `metal_capture` context manager is usable
To enable metal capture, set MTL_CAPTURE_ENABLED envvar
"""
return torch._C._mps_isCaptureEnabled()
def is_capturing_metal() -> bool:
"""Cheks if metal capture is in progress"""
return torch._C._mps_isCapturing()
@contextlib.contextmanager
def metal_capture(fname: str):
"""Conext manager that enables capturing of Metal calls into gputrace"""
try:
torch._C._mps_startCapture(fname)
yield
# Drain all the work that were enqueued during the context call
torch.mps.synchronize()
finally:
torch._C._mps_stopCapture()