mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Inductor annotations (#130429)
Add NVTX annotations around training phases and buffer computations RFC/discussion: https://dev-discuss.pytorch.org/t/rfc-performance-profiling-at-scale-with-details-nvtx-annotations/2224 <img width="2160" alt="Screenshot 2024-07-10 at 11 48 04" src="https://github.com/pytorch/pytorch/assets/1175576/9ade139c-d393-473f-9b68-6c25da367dc4"> Pull Request resolved: https://github.com/pytorch/pytorch/pull/130429 Approved by: https://github.com/aorenste, https://github.com/eellison, https://github.com/albanD Co-authored-by: Cedric GESTES <cedric.gestes@flex.ai>
This commit is contained in:
committed by
PyTorch MergeBot
parent
24650c3caa
commit
539286a67b
41
test/inductor/test_inductor_annotations.py
Normal file
41
test/inductor/test_inductor_annotations.py
Normal file
@ -0,0 +1,41 @@
|
||||
# Owner(s): ["module: inductor"]
|
||||
import torch
|
||||
import torch._inductor.config as inductor_config
|
||||
from torch._inductor.test_case import run_tests, TestCase
|
||||
from torch._inductor.utils import run_and_get_code
|
||||
from torch.testing._internal.triton_utils import requires_cuda
|
||||
|
||||
|
||||
class InductorAnnotationTestCase(TestCase):
|
||||
def get_code(self):
|
||||
def f(a, b):
|
||||
return a + b, a * b
|
||||
|
||||
a = torch.randn(5, device="cuda")
|
||||
b = torch.randn(5, device="cuda")
|
||||
f_comp = torch.compile(f)
|
||||
|
||||
_, code = run_and_get_code(f_comp, a, b)
|
||||
return code[0]
|
||||
|
||||
@requires_cuda
|
||||
def test_no_annotations(self):
|
||||
code = self.get_code()
|
||||
|
||||
self.assertTrue("from torch.cuda import nvtx" not in code)
|
||||
self.assertTrue("training_annotation" not in code)
|
||||
|
||||
@inductor_config.patch(annotate_training=True)
|
||||
@requires_cuda
|
||||
def test_training_annotation(self):
|
||||
code = self.get_code()
|
||||
|
||||
self.assertTrue("from torch.cuda import nvtx" in code)
|
||||
self.assertEqual(
|
||||
code.count("training_annotation = nvtx._device_range_start('inference')"), 1
|
||||
)
|
||||
self.assertEqual(code.count("nvtx._device_range_end(training_annotation)"), 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
@ -5,3 +5,5 @@ def rangePop() -> int: ...
|
||||
def rangeStartA(message: str) -> int: ...
|
||||
def rangeEnd(int) -> None: ...
|
||||
def markA(message: str) -> None: ...
|
||||
def deviceRangeStart(message: str, stream: int) -> object: ...
|
||||
def deviceRangeEnd(range_handle: object, stream: int) -> None: ...
|
||||
|
@ -772,6 +772,8 @@ class PythonWrapperCodegen(CodeGen):
|
||||
)
|
||||
except (AttributeError, ImportError):
|
||||
pass
|
||||
if config.annotate_training:
|
||||
self.header.writeline("from torch.cuda import nvtx")
|
||||
|
||||
def include_extra_header(self, header: str):
|
||||
pass
|
||||
@ -889,6 +891,11 @@ class PythonWrapperCodegen(CodeGen):
|
||||
with self.prefix.indent():
|
||||
if config.triton.debug_sync_graph:
|
||||
self.prefix.writeline(V.graph.device_ops.synchronize())
|
||||
phase = V.graph.get_training_phase()
|
||||
if config.annotate_training:
|
||||
self.prefix.writeline(
|
||||
f"training_annotation = nvtx._device_range_start('{phase}')"
|
||||
)
|
||||
if V.graph.graph_inputs:
|
||||
lhs = ", ".join(V.graph.graph_input_names)
|
||||
if len(V.graph.graph_input_names) == 1:
|
||||
@ -1175,6 +1182,10 @@ class PythonWrapperCodegen(CodeGen):
|
||||
if config.triton.autotune_at_compile_time:
|
||||
self.generate_and_run_autotune_block()
|
||||
|
||||
if config.annotate_training:
|
||||
self.wrapper_call.writeline(
|
||||
"nvtx._device_range_end(training_annotation)"
|
||||
)
|
||||
self.generate_return(output_refs)
|
||||
|
||||
self.finalize_prefix()
|
||||
|
@ -773,6 +773,10 @@ enable_linear_binary_folding = (
|
||||
)
|
||||
|
||||
|
||||
# Adds NVTX annotations aroung training phases
|
||||
annotate_training: bool = os.environ.get("TORCHINDUCTOR_ANNOTATE_TRAINING", "0") == "1"
|
||||
|
||||
|
||||
# config specific to codegen/cpp.py
|
||||
class cpp:
|
||||
# set to torch.get_num_threads()
|
||||
|
@ -519,6 +519,13 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
finally:
|
||||
self.current_device = prior
|
||||
|
||||
def get_training_phase(self) -> str:
|
||||
if self.is_inference:
|
||||
return "inference"
|
||||
if self.is_backward:
|
||||
return "backward"
|
||||
return "forward"
|
||||
|
||||
@staticmethod
|
||||
def decide_layout_opt(gm: GraphModule, *, is_inference: bool) -> bool:
|
||||
"""
|
||||
|
@ -6,10 +6,41 @@
|
||||
#else
|
||||
#include <nvToolsExt.h>
|
||||
#endif
|
||||
#include <cuda_runtime.h>
|
||||
#include <torch/csrc/utils/pybind.h>
|
||||
|
||||
namespace torch::cuda::shared {
|
||||
|
||||
struct RangeHandle {
|
||||
nvtxRangeId_t id;
|
||||
const char* msg;
|
||||
};
|
||||
|
||||
static void device_callback_range_end(void* userData) {
|
||||
RangeHandle* handle = ((RangeHandle*)userData);
|
||||
nvtxRangeEnd(handle->id);
|
||||
free((void*)handle->msg);
|
||||
free((void*)handle);
|
||||
}
|
||||
|
||||
static void device_nvtxRangeEnd(void* handle, std::intptr_t stream) {
|
||||
cudaLaunchHostFunc((cudaStream_t)stream, device_callback_range_end, handle);
|
||||
}
|
||||
|
||||
static void device_callback_range_start(void* userData) {
|
||||
RangeHandle* handle = ((RangeHandle*)userData);
|
||||
handle->id = nvtxRangeStartA(handle->msg);
|
||||
}
|
||||
|
||||
static void* device_nvtxRangeStart(const char* msg, std::intptr_t stream) {
|
||||
RangeHandle* handle = (RangeHandle*)calloc(sizeof(RangeHandle), 1);
|
||||
handle->msg = strdup(msg);
|
||||
handle->id = 0;
|
||||
cudaLaunchHostFunc(
|
||||
(cudaStream_t)stream, device_callback_range_start, (void*)handle);
|
||||
return handle;
|
||||
}
|
||||
|
||||
void initNvtxBindings(PyObject* module) {
|
||||
auto m = py::handle(module).cast<py::module>();
|
||||
|
||||
@ -23,6 +54,8 @@ void initNvtxBindings(PyObject* module) {
|
||||
nvtx.def("rangeStartA", nvtxRangeStartA);
|
||||
nvtx.def("rangeEnd", nvtxRangeEnd);
|
||||
nvtx.def("markA", nvtxMarkA);
|
||||
nvtx.def("deviceRangeStart", device_nvtxRangeStart);
|
||||
nvtx.def("deviceRangeEnd", device_nvtxRangeEnd);
|
||||
}
|
||||
|
||||
} // namespace torch::cuda::shared
|
||||
|
@ -66,6 +66,38 @@ def range_end(range_id) -> None:
|
||||
_nvtx.rangeEnd(range_id)
|
||||
|
||||
|
||||
def _device_range_start(msg: str, stream: int = 0) -> object:
|
||||
"""
|
||||
Marks the start of a range with string message.
|
||||
It returns an opaque heap-allocated handle for this range
|
||||
to pass to the corresponding call to device_range_end().
|
||||
|
||||
A key difference between this and range_start is that the
|
||||
range_start marks the range right away, while _device_range_start
|
||||
marks the start of the range as soon as all the tasks on the
|
||||
CUDA stream are completed.
|
||||
|
||||
Returns: An opaque heap-allocated handle that should be passed to _device_range_end().
|
||||
|
||||
Args:
|
||||
msg (str): ASCII message to associate with the range.
|
||||
stream (int): CUDA stream id.
|
||||
"""
|
||||
return _nvtx.deviceRangeStart(msg, stream)
|
||||
|
||||
|
||||
def _device_range_end(range_handle: object, stream: int = 0) -> None:
|
||||
"""
|
||||
Mark the end of a range for a given range_handle as soon as all the tasks
|
||||
on the CUDA stream are completed.
|
||||
|
||||
Args:
|
||||
range_handle: an unique handle for the start range.
|
||||
stream (int): CUDA stream id.
|
||||
"""
|
||||
_nvtx.deviceRangeEnd(range_handle, stream)
|
||||
|
||||
|
||||
def mark(msg):
|
||||
"""
|
||||
Describe an instantaneous event that occurred at some point.
|
||||
|
@ -5035,6 +5035,7 @@ CUDA_IDENTIFIER_MAP = collections.OrderedDict(
|
||||
"cudaLaunchCooperativeKernel",
|
||||
("hipLaunchCooperativeKernel", CONV_EXEC, API_RUNTIME),
|
||||
),
|
||||
("cudaLaunchHostFunc", ("hipLaunchHostFunc", CONV_EXEC, API_RUNTIME, HIP_UNSUPPORTED)),
|
||||
(
|
||||
"cudaSetupArgument",
|
||||
("hipSetupArgument", CONV_EXEC, API_RUNTIME, HIP_UNSUPPORTED),
|
||||
@ -7965,6 +7966,7 @@ CUDA_IDENTIFIER_MAP = collections.OrderedDict(
|
||||
("nvtxRangePop", ("roctxRangePop", CONV_OTHER, API_ROCTX)),
|
||||
("nvtxRangeStartA", ("roctxRangeStartA", CONV_OTHER, API_ROCTX)),
|
||||
("nvtxRangeEnd", ("roctxRangeStop", CONV_OTHER, API_ROCTX)),
|
||||
("nvtxRangeId_t", ("int", CONV_OTHER, API_ROCTX)),
|
||||
("nvmlReturn_t", ("rsmi_status_t", CONV_OTHER, API_ROCMSMI)),
|
||||
("NVML_SUCCESS", ("RSMI_STATUS_SUCCESS", CONV_OTHER, API_ROCMSMI)),
|
||||
("NVML_P2P_CAPS_INDEX_READ", ("RSMI_STATUS_SUCCESS", CONV_OTHER, API_ROCMSMI)),
|
||||
|
Reference in New Issue
Block a user