Inductor annotations (#130429)

Add NVTX annotations around training phases and buffer computations

RFC/discussion: https://dev-discuss.pytorch.org/t/rfc-performance-profiling-at-scale-with-details-nvtx-annotations/2224

<img width="2160" alt="Screenshot 2024-07-10 at 11 48 04" src="https://github.com/pytorch/pytorch/assets/1175576/9ade139c-d393-473f-9b68-6c25da367dc4">

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130429
Approved by: https://github.com/aorenste, https://github.com/eellison, https://github.com/albanD

Co-authored-by: Cedric GESTES <cedric.gestes@flex.ai>
This commit is contained in:
Alex Denisov
2024-12-10 08:53:39 +00:00
committed by PyTorch MergeBot
parent 24650c3caa
commit 539286a67b
8 changed files with 132 additions and 0 deletions

View File

@ -0,0 +1,41 @@
# Owner(s): ["module: inductor"]
import torch
import torch._inductor.config as inductor_config
from torch._inductor.test_case import run_tests, TestCase
from torch._inductor.utils import run_and_get_code
from torch.testing._internal.triton_utils import requires_cuda
class InductorAnnotationTestCase(TestCase):
def get_code(self):
def f(a, b):
return a + b, a * b
a = torch.randn(5, device="cuda")
b = torch.randn(5, device="cuda")
f_comp = torch.compile(f)
_, code = run_and_get_code(f_comp, a, b)
return code[0]
@requires_cuda
def test_no_annotations(self):
code = self.get_code()
self.assertTrue("from torch.cuda import nvtx" not in code)
self.assertTrue("training_annotation" not in code)
@inductor_config.patch(annotate_training=True)
@requires_cuda
def test_training_annotation(self):
code = self.get_code()
self.assertTrue("from torch.cuda import nvtx" in code)
self.assertEqual(
code.count("training_annotation = nvtx._device_range_start('inference')"), 1
)
self.assertEqual(code.count("nvtx._device_range_end(training_annotation)"), 1)
if __name__ == "__main__":
run_tests()

View File

@ -5,3 +5,5 @@ def rangePop() -> int: ...
def rangeStartA(message: str) -> int: ...
def rangeEnd(int) -> None: ...
def markA(message: str) -> None: ...
def deviceRangeStart(message: str, stream: int) -> object: ...
def deviceRangeEnd(range_handle: object, stream: int) -> None: ...

View File

@ -772,6 +772,8 @@ class PythonWrapperCodegen(CodeGen):
)
except (AttributeError, ImportError):
pass
if config.annotate_training:
self.header.writeline("from torch.cuda import nvtx")
def include_extra_header(self, header: str):
pass
@ -889,6 +891,11 @@ class PythonWrapperCodegen(CodeGen):
with self.prefix.indent():
if config.triton.debug_sync_graph:
self.prefix.writeline(V.graph.device_ops.synchronize())
phase = V.graph.get_training_phase()
if config.annotate_training:
self.prefix.writeline(
f"training_annotation = nvtx._device_range_start('{phase}')"
)
if V.graph.graph_inputs:
lhs = ", ".join(V.graph.graph_input_names)
if len(V.graph.graph_input_names) == 1:
@ -1175,6 +1182,10 @@ class PythonWrapperCodegen(CodeGen):
if config.triton.autotune_at_compile_time:
self.generate_and_run_autotune_block()
if config.annotate_training:
self.wrapper_call.writeline(
"nvtx._device_range_end(training_annotation)"
)
self.generate_return(output_refs)
self.finalize_prefix()

View File

@ -773,6 +773,10 @@ enable_linear_binary_folding = (
)
# Adds NVTX annotations aroung training phases
annotate_training: bool = os.environ.get("TORCHINDUCTOR_ANNOTATE_TRAINING", "0") == "1"
# config specific to codegen/cpp.py
class cpp:
# set to torch.get_num_threads()

View File

@ -519,6 +519,13 @@ class GraphLowering(torch.fx.Interpreter):
finally:
self.current_device = prior
def get_training_phase(self) -> str:
if self.is_inference:
return "inference"
if self.is_backward:
return "backward"
return "forward"
@staticmethod
def decide_layout_opt(gm: GraphModule, *, is_inference: bool) -> bool:
"""

View File

@ -6,10 +6,41 @@
#else
#include <nvToolsExt.h>
#endif
#include <cuda_runtime.h>
#include <torch/csrc/utils/pybind.h>
namespace torch::cuda::shared {
struct RangeHandle {
nvtxRangeId_t id;
const char* msg;
};
static void device_callback_range_end(void* userData) {
RangeHandle* handle = ((RangeHandle*)userData);
nvtxRangeEnd(handle->id);
free((void*)handle->msg);
free((void*)handle);
}
static void device_nvtxRangeEnd(void* handle, std::intptr_t stream) {
cudaLaunchHostFunc((cudaStream_t)stream, device_callback_range_end, handle);
}
static void device_callback_range_start(void* userData) {
RangeHandle* handle = ((RangeHandle*)userData);
handle->id = nvtxRangeStartA(handle->msg);
}
static void* device_nvtxRangeStart(const char* msg, std::intptr_t stream) {
RangeHandle* handle = (RangeHandle*)calloc(sizeof(RangeHandle), 1);
handle->msg = strdup(msg);
handle->id = 0;
cudaLaunchHostFunc(
(cudaStream_t)stream, device_callback_range_start, (void*)handle);
return handle;
}
void initNvtxBindings(PyObject* module) {
auto m = py::handle(module).cast<py::module>();
@ -23,6 +54,8 @@ void initNvtxBindings(PyObject* module) {
nvtx.def("rangeStartA", nvtxRangeStartA);
nvtx.def("rangeEnd", nvtxRangeEnd);
nvtx.def("markA", nvtxMarkA);
nvtx.def("deviceRangeStart", device_nvtxRangeStart);
nvtx.def("deviceRangeEnd", device_nvtxRangeEnd);
}
} // namespace torch::cuda::shared

View File

@ -66,6 +66,38 @@ def range_end(range_id) -> None:
_nvtx.rangeEnd(range_id)
def _device_range_start(msg: str, stream: int = 0) -> object:
"""
Marks the start of a range with string message.
It returns an opaque heap-allocated handle for this range
to pass to the corresponding call to device_range_end().
A key difference between this and range_start is that the
range_start marks the range right away, while _device_range_start
marks the start of the range as soon as all the tasks on the
CUDA stream are completed.
Returns: An opaque heap-allocated handle that should be passed to _device_range_end().
Args:
msg (str): ASCII message to associate with the range.
stream (int): CUDA stream id.
"""
return _nvtx.deviceRangeStart(msg, stream)
def _device_range_end(range_handle: object, stream: int = 0) -> None:
"""
Mark the end of a range for a given range_handle as soon as all the tasks
on the CUDA stream are completed.
Args:
range_handle: an unique handle for the start range.
stream (int): CUDA stream id.
"""
_nvtx.deviceRangeEnd(range_handle, stream)
def mark(msg):
"""
Describe an instantaneous event that occurred at some point.

View File

@ -5035,6 +5035,7 @@ CUDA_IDENTIFIER_MAP = collections.OrderedDict(
"cudaLaunchCooperativeKernel",
("hipLaunchCooperativeKernel", CONV_EXEC, API_RUNTIME),
),
("cudaLaunchHostFunc", ("hipLaunchHostFunc", CONV_EXEC, API_RUNTIME, HIP_UNSUPPORTED)),
(
"cudaSetupArgument",
("hipSetupArgument", CONV_EXEC, API_RUNTIME, HIP_UNSUPPORTED),
@ -7965,6 +7966,7 @@ CUDA_IDENTIFIER_MAP = collections.OrderedDict(
("nvtxRangePop", ("roctxRangePop", CONV_OTHER, API_ROCTX)),
("nvtxRangeStartA", ("roctxRangeStartA", CONV_OTHER, API_ROCTX)),
("nvtxRangeEnd", ("roctxRangeStop", CONV_OTHER, API_ROCTX)),
("nvtxRangeId_t", ("int", CONV_OTHER, API_ROCTX)),
("nvmlReturn_t", ("rsmi_status_t", CONV_OTHER, API_ROCMSMI)),
("NVML_SUCCESS", ("RSMI_STATUS_SUCCESS", CONV_OTHER, API_ROCMSMI)),
("NVML_P2P_CAPS_INDEX_READ", ("RSMI_STATUS_SUCCESS", CONV_OTHER, API_ROCMSMI)),