mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[MPS] Expose MPSProfiler::start/stopCapture
to Python (#144561)
I.e. when `MTL_CAPTURE_ENABLED` environment variable is set to 1, one should be able to invoke wrap the code with `torch.mps.profiler.capture_metal` to generate gputrace for shaders invoked inside the context manager. For example, code below: ```python import torch import os def foo(x): return x[:,::2].sin() + x[:, 1::2].cos() if __name__ == "__main__": os.environ["MTL_CAPTURE_ENABLED"] = "1" x = torch.rand(32, 1024, device="mps") with torch.mps.profiler.metal_capture("compiled_shader"): torch.compile(foo)(x) ``` should capture the execution of a `torch.compile` generated shader <img width="734" alt="image" src="https://github.com/user-attachments/assets/718ff64e-103b-4b11-b66c-c89cfc770b5d" /> Pull Request resolved: https://github.com/pytorch/pytorch/pull/144561 Approved by: https://github.com/manuelcandales ghstack dependencies: #144559, #144560
This commit is contained in:
committed by
PyTorch MergeBot
parent
c7dbee5106
commit
92ddb3d3d3
1
.github/workflows/_mac-test-mps.yml
vendored
1
.github/workflows/_mac-test-mps.yml
vendored
@ -152,6 +152,7 @@ jobs:
|
||||
set -e
|
||||
|
||||
${CONDA_RUN} python3 test/run_test.py --mps --verbose
|
||||
MTL_CAPTURE_ENABLED=1 ${CONDA_RUN} python3 test/test_mps.py --verbose -k test_metal_capture
|
||||
|
||||
- name: Print remaining test logs
|
||||
shell: bash
|
||||
|
@ -29,6 +29,10 @@ MPS Profiler
|
||||
profiler.stop
|
||||
profiler.profile
|
||||
|
||||
profiler.is_capturing_metal
|
||||
profiler.is_metal_capture_enabled
|
||||
profiler.metal_capture
|
||||
|
||||
MPS Event
|
||||
------------
|
||||
.. autosummary::
|
||||
|
@ -6,6 +6,7 @@ import math
|
||||
import random
|
||||
import unittest
|
||||
import warnings
|
||||
import shutil
|
||||
import subprocess
|
||||
import tempfile
|
||||
import os
|
||||
@ -12696,6 +12697,24 @@ class TestMetalLibrary(TestCaseMPS):
|
||||
# Passing no tensors asserts
|
||||
self.assertRaises(RuntimeError, lambda: lib.full(12))
|
||||
|
||||
@unittest.skipIf(not torch.mps.profiler.is_metal_capture_enabled(), "Set MTL_CAPTURE_ENABLED and try again")
|
||||
def test_metal_capture(self):
|
||||
lib = torch.mps._compile_shader("kernel void full(device float* x, uint idx [[thread_position_in_grid]]) { x[idx] = 1.0; }")
|
||||
mps_tensor = torch.rand(32, device="mps")
|
||||
capture_name = f"lib_full{''.join(random.choice('0123456789') for i in range(5))}"
|
||||
capture_dirname = f"0000-{capture_name}.gputrace"
|
||||
if os.path.exists(capture_dirname):
|
||||
shutil.rmtree(capture_dirname)
|
||||
with torch.mps.profiler.metal_capture(capture_name):
|
||||
self.assertTrue(torch.mps.profiler.is_capturing_metal())
|
||||
lib.full(mps_tensor)
|
||||
self.assertEqual(mps_tensor.sum().item(), 32.0)
|
||||
self.assertTrue(os.path.exists(capture_dirname), f"Capture file {capture_dirname} has not been generated")
|
||||
capture_listdir = os.listdir(capture_dirname)
|
||||
shutil.rmtree(capture_dirname)
|
||||
self.assertGreater(len(capture_listdir), 3,
|
||||
f"Capture file {capture_dirname} contains only metadata, i.e. {capture_listdir}")
|
||||
|
||||
|
||||
# TODO: Actually instantiate that test for the "mps" device to better reflect what it is doing.
|
||||
# This requires mps to be properly registered in the device generic test framework which is not the
|
||||
|
@ -1834,6 +1834,10 @@ def _mps_waitForEvent(event_id: _int) -> None: ...
|
||||
def _mps_synchronizeEvent(event_id: _int) -> None: ...
|
||||
def _mps_queryEvent(event_id: _int) -> _bool: ...
|
||||
def _mps_elapsedTimeOfEvents(start_event_id: _int, end_event_id: _int) -> _float: ...
|
||||
def _mps_isCaptureEnabled() -> _bool: ...
|
||||
def _mps_isCapturing() -> _bool: ...
|
||||
def _mps_startCapture(name: str) -> None: ...
|
||||
def _mps_stopCapture() -> None: ...
|
||||
|
||||
|
||||
# Defined in torch/csrc/cuda/Module.cpp
|
||||
|
@ -17,6 +17,7 @@
|
||||
#endif
|
||||
|
||||
#ifdef USE_MPS
|
||||
#include <ATen/mps/MPSProfiler.h>
|
||||
#include <ATen/native/mps/MetalShaderLibrary.h>
|
||||
#endif
|
||||
|
||||
@ -504,6 +505,16 @@ void initModule(PyObject* module) {
|
||||
m.def("_mps_compileShader", [](const std::string& source) {
|
||||
return std::make_shared<DynamicMetalShaderLibrary>(source);
|
||||
});
|
||||
m.def("_mps_isCaptureEnabled", []() {
|
||||
return at::mps::getMPSProfiler().isCaptureEnabled();
|
||||
});
|
||||
m.def("_mps_isCapturing", []() {
|
||||
return at::mps::getMPSProfiler().isCapturing();
|
||||
});
|
||||
m.def("_mps_startCapture", [](const std::string& fileName) {
|
||||
at::mps::getMPSProfiler().startCapture(fileName);
|
||||
});
|
||||
m.def("_mps_stopCapture", []() { at::mps::getMPSProfiler().stopCapture(); });
|
||||
}
|
||||
#endif /* USE_MPS */
|
||||
|
||||
|
@ -4,7 +4,14 @@ import contextlib
|
||||
import torch
|
||||
|
||||
|
||||
__all__ = ["start", "stop", "profile"]
|
||||
__all__ = [
|
||||
"start",
|
||||
"stop",
|
||||
"profile",
|
||||
"metal_capture",
|
||||
"is_metal_capture_enabled",
|
||||
"is_capturing_metal",
|
||||
]
|
||||
|
||||
|
||||
def start(mode: str = "interval", wait_until_completed: bool = False) -> None:
|
||||
@ -59,3 +66,27 @@ def profile(mode: str = "interval", wait_until_completed: bool = False):
|
||||
yield
|
||||
finally:
|
||||
stop()
|
||||
|
||||
|
||||
def is_metal_capture_enabled() -> bool:
|
||||
"""Checks if `metal_capture` context manager is usable
|
||||
To enable metal capture, set MTL_CAPTURE_ENABLED envvar
|
||||
"""
|
||||
return torch._C._mps_isCaptureEnabled()
|
||||
|
||||
|
||||
def is_capturing_metal() -> bool:
|
||||
"""Cheks if metal capture is in progress"""
|
||||
return torch._C._mps_isCapturing()
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def metal_capture(fname: str):
|
||||
"""Conext manager that enables capturing of Metal calls into gputrace"""
|
||||
try:
|
||||
torch._C._mps_startCapture(fname)
|
||||
yield
|
||||
# Drain all the work that were enqueued during the context call
|
||||
torch.mps.synchronize()
|
||||
finally:
|
||||
torch._C._mps_stopCapture()
|
||||
|
Reference in New Issue
Block a user