Files
pytorch/torch/csrc/cuda/Graph.cpp
Daniel Galvez cf94cadbee [CUDAGraph] Add getter for cuda graph exec (#161294)
This is far simpler than #155164 since we never destroy the cudaGraphExec_t.

The request comes from TRT-LLM specifically. The motivation is that some power users would like to mutate specific kernel parameters via APIs like `cudaGraphExec*SetParams` after a cuda graph has been instantiated. For example, a common request has been to be able to change the sequence length of attention kernels, after having captured a graph for the largest possible sequence length. It turns out that the host overhead you eliminate via cuda graphs in LLM inference ends up causing an increase in computation time when you size your kernels to the maximum possible sequence length (which I believe is done in both TRT-LLM and vLLM). Attention is the most problematic kernel because its computation time is quadratic in the sequence length, rather than linear.

This can work if your attention kernel can work for arbitrary shapes (this is not the case for all attention implementations! Many of them specialize with templates), and you have a persistent kernel that allocates only as many blocks as you have SM's (so you don't have to figure out how many blocks to allocate for a specific sequence length). Using a conditional SWITCH node is a better generic approach to this problem, but that requires more infrastructure work.

Note that this requires knowledge of the exact location of the value in your kernel's parameter buffer to mutate. It won't work with arbitrary stream capture code whose kernels you don't know before hand. So I expect this code path to be rarely used.

Testing:

```
pytest -s -k raw_graph_exec test/test_cuda.py
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161294
Approved by: https://github.com/ngimel, https://github.com/BoyuanFeng, https://github.com/eellison, https://github.com/eqy
2025-08-25 20:57:37 +00:00

117 lines
4.6 KiB
C++

#include <torch/csrc/python_headers.h>
#include <pybind11/chrono.h>
#include <torch/csrc/jit/python/pybind_utils.h>
#include <torch/csrc/utils/pybind.h>
#include <ATen/cuda/CUDAGraph.h>
#include <c10/cuda/CUDAGraphsC10Utils.h>
// Cargo culted partially from csrc/distributed/c10d/init.cpp
// and partially from csrc/cuda/Stream.cpp.
// THCPStream_init is also declared at global scope.
// Because THCPGraph_init is forward declared in the only consumer
// (csrc/Module.cpp) I don't think we need a Graph.h.
template <typename T>
using shared_ptr_class_ = py::class_<T, std::shared_ptr<T>>;
void THCPGraph_init(PyObject* module) {
// Pybind11 patch notes say "py::module_" is more up-to-date syntax,
// but CI linter and some builds prefer "module".
auto torch_C_m = py::handle(module).cast<py::module>();
torch_C_m.def("_graph_pool_handle", &::at::cuda::graph_pool_handle);
shared_ptr_class_<::at::cuda::CUDAGraph>(torch_C_m, "_CUDAGraph")
.def(py::init<bool>(), py::arg("keep_graph") = false)
.def(
"capture_begin",
[](::at::cuda::CUDAGraph& self,
std::optional<c10::cuda::MempoolId_t> pool_opt,
const std::string& capture_error_mode) {
cudaStreamCaptureMode capture_mode{};
c10::cuda::MempoolId_t pool = pool_opt.has_value()
? pool_opt.value()
: c10::cuda::MempoolId_t{0, 0};
if (capture_error_mode == "global") {
capture_mode = cudaStreamCaptureModeGlobal;
} else if (capture_error_mode == "thread_local") {
capture_mode = cudaStreamCaptureModeThreadLocal;
} else if (capture_error_mode == "relaxed") {
capture_mode = cudaStreamCaptureModeRelaxed;
} else {
TORCH_CHECK(
false,
"Unknown capture error mode. Expected `global`, `thread_local`, or `relaxed`, got ",
capture_error_mode);
}
return self.capture_begin(pool, capture_mode);
},
py::arg("pool"),
py::arg("capture_error_mode"),
py::call_guard<py::gil_scoped_release>())
.def(
"capture_end",
torch::wrap_pybind_function_no_gil(&at::cuda::CUDAGraph::capture_end))
.def(
"instantiate",
torch::wrap_pybind_function_no_gil(&at::cuda::CUDAGraph::instantiate))
.def(
"register_generator_state",
[](::at::cuda::CUDAGraph& self, py::handle raw_generator) {
auto generator = THPGenerator_Unwrap(raw_generator.ptr());
// We've unwrapped Python object to C++ object,
// so we could release GIL before calling into C++
py::gil_scoped_release release;
return self.register_generator_state(generator);
},
py::arg("generator"))
.def(
"replay",
torch::wrap_pybind_function_no_gil(&at::cuda::CUDAGraph::replay))
.def(
"reset",
torch::wrap_pybind_function_no_gil(&at::cuda::CUDAGraph::reset))
.def(
"pool",
torch::wrap_pybind_function_no_gil(&at::cuda::CUDAGraph::pool))
.def(
"debug_dump",
torch::wrap_pybind_function_no_gil(
&::at::cuda::CUDAGraph::debug_dump))
.def(
"enable_debug_mode",
torch::wrap_pybind_function_no_gil(
&::at::cuda::CUDAGraph::enable_debug_mode))
.def(
"debug_dump",
torch::wrap_pybind_function_no_gil(
&::at::cuda::CUDAGraph::debug_dump),
py::arg("debug_path"))
.def(
"raw_cuda_graph",
[](::at::cuda::CUDAGraph& self) {
cudaGraph_t graph = self.raw_cuda_graph();
// We return a raw int here, since otherwise pybind11 will
// try to return the underlying struct of cudaGraph_t
// points to, which is opaque and therefore causes a
// compile error.
return reinterpret_cast<uintptr_t>(graph);
},
py::call_guard<py::gil_scoped_release>())
.def(
"raw_cuda_graph_exec",
[](::at::cuda::CUDAGraph& self) {
cudaGraphExec_t graph_exec = self.raw_cuda_graph_exec();
// We return a raw int here, since otherwise pybind11 will
// try to return the underlying struct of cudaGraphExec_t
// points to, which is opaque and therefore causes a
// compile error.
return reinterpret_cast<uintptr_t>(graph_exec);
},
py::call_guard<py::gil_scoped_release>());
}