[CUDA][CUDAGraph] Reduce capture overhead in CUDA Graph memory reuse (#162186)

Previous work #158352 delivered CUDAGraph memory footprint reduction with no replay-time impact, but capture time regressed (up to 20× slower) due to repeated full-graph traversals. See previous benchmark results [here](https://github.com/pytorch/pytorch/pull/158352#issuecomment-3215947565)

This PR removes capture/reply overhead while preserving the memory savings:

1. **Terminals as free markers**
   We stop inserting empty nodes and instead record the current stream terminals as free markers. This avoids mutating the user’s graph and keeps semantics unchanged.

2. **Incremental, cached reachability**
   We add a **per-graph reuse context** that caches reverse-traversal state:

   * `graph_reuse_context[graph].visited[stream]` tracks nodes already seen from that stream’s terminal frontier.
   * On each allocation during capture, we resume traversal from the latest terminals and only visit unseen nodes.
   * A block is freed when all its recorded markers are in the visited set of its allocation stream—i.e., all markers are proven predecessors of future work.

See [the performance results here](https://docs.google.com/spreadsheets/d/e/2PACX-1vRPvdd9Xa8W87ixbiA0da_qvOhrUAjUpFz0G-_j-MsDnoeRyhEa4_ut_W3rqcg1VVZVFJ-gucwov-3b/pubhtml?gid=1468302443&single=true), we sweep synthetic multi-stream CUDA Graphs built by `capture_benchmark.py` (same as before, we generate random interleaving of alloc/free/join with given probabilities, see [gist here](https://gist.github.com/eee4017/e2092d215b1d4bd46534148939af39e3)), and we compare median capture/replay times and memory. On an NVIDIA H100 PCIe across 24 configs, the optimization preserves reserved memory reduction at ~24–98%, leaves allocated memory unchanged, and brings capture time back to baseline (range 0.96–1.04× vs. baseline) with replay time unchanged (range 0.97–1.11×).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162186
Approved by: https://github.com/eqy, https://github.com/ngimel
This commit is contained in:
Frank Lin
2025-09-30 22:28:42 +00:00
committed by PyTorch MergeBot
parent 1f1de20ba9
commit bec6541d84
2 changed files with 152 additions and 152 deletions

View File

@ -1183,6 +1183,16 @@ class DeviceCachingAllocator {
// ends.
ska::flat_hash_map<Block*, std::vector<cudaGraphNode_t>> deferred_blocks;
// Incremental reverse-traversal state cached per graph.
// We never re-traverse nodes we've already seen
struct GraphReuseContext {
ska::flat_hash_map<cudaStream_t, ska::flat_hash_set<cudaGraphNode_t>>
visited;
};
ska::flat_hash_map<MempoolId_t, CaptureId_t, MempoolIdHash>
mempool_to_capture_id;
ska::flat_hash_map<CaptureId_t, GraphReuseContext> graph_reuse_context;
// outstanding cuda events
ska::flat_hash_map<
cuda::CUDAStream,
@ -1638,44 +1648,70 @@ class DeviceCachingAllocator {
return block;
}
// Insert "free marker" (empty nodes) into the CUDA graph for all streams that
struct CaptureInfo {
cudaGraph_t graph{};
CaptureId_t capture_id{0};
const cudaGraphNode_t* terminals{nullptr};
size_t num_terminals{0};
cudaStreamCaptureStatus status{cudaStreamCaptureStatusNone};
};
inline CaptureInfo stream_get_capture_info(cudaStream_t stream) {
CaptureInfo info{};
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 13000)
C10_CUDA_CHECK(cudaStreamGetCaptureInfo(
stream,
&info.status,
&info.capture_id,
&info.graph,
&info.terminals,
nullptr,
&info.num_terminals));
#else
C10_CUDA_CHECK(cudaStreamGetCaptureInfo_v2(
stream,
&info.status,
&info.capture_id,
&info.graph,
&info.terminals,
&info.num_terminals));
#endif
TORCH_INTERNAL_ASSERT(
info.status != cudaStreamCaptureStatusInvalidated,
"Invalid stream capture status");
return info;
}
// Record "free marker" of the CUDA graph for all streams that
// have used the block, including the allocation stream. These nodes mark the
// last use of the block in the capture graph. Returns a vector of the
// inserted nodes, or an empty vector if any stream is not capturing.
std::vector<cudaGraphNode_t> insert_free_marker(Block* block) {
std::vector<cudaGraphNode_t> empty_nodes;
std::vector<cudaGraphNode_t> record_free_markers(Block* block) {
// Is is possible to have the same marker recorded multiple times, so we use
// a set to avoid duplicates
ska::flat_hash_set<cudaGraphNode_t> markers;
cudaGraph_t owning_graph = nullptr;
auto try_add_empty_node = [&](cudaStream_t stream) -> bool {
cudaStreamCaptureStatus status{};
cudaGraph_t graph{};
const cudaGraphNode_t* deps = nullptr;
size_t num_deps = 0;
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 13000)
C10_CUDA_CHECK(cudaStreamGetCaptureInfo(
stream, &status, nullptr, &graph, &deps, nullptr, &num_deps));
#else
C10_CUDA_CHECK(cudaStreamGetCaptureInfo_v2(
stream, &status, nullptr, &graph, &deps, &num_deps));
#endif
TORCH_INTERNAL_ASSERT(
status != cudaStreamCaptureStatusInvalidated,
"Invalid stream capture status");
if (status == cudaStreamCaptureStatusNone) {
return false;
auto try_record = [&](cudaStream_t s) -> bool {
auto info = stream_get_capture_info(s);
if (info.status == cudaStreamCaptureStatusNone) {
return false; // not capturing on this stream -> must defer
}
cudaGraphNode_t node{};
C10_CUDA_CHECK(cudaGraphAddEmptyNode(&node, graph, deps, num_deps));
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 13000)
C10_CUDA_CHECK(cudaStreamUpdateCaptureDependencies(
stream, &node, nullptr, 1, cudaStreamSetCaptureDependencies));
#else
C10_CUDA_CHECK(cudaStreamUpdateCaptureDependencies(
stream, &node, 1, cudaStreamSetCaptureDependencies));
#endif
empty_nodes.push_back(node);
if (owning_graph == nullptr) {
owning_graph = info.graph;
}
TORCH_INTERNAL_ASSERT(
info.graph == owning_graph,
"All streams in the same capture should agree on the graph");
// Use current terminals as the free markers for the stream
for (size_t i = 0; i < info.num_terminals; ++i) {
auto terminal = info.terminals[i];
markers.insert(terminal);
}
owning_graph = info.graph; // all streams in the same capture should agree
return true;
};
@ -1683,81 +1719,34 @@ class DeviceCachingAllocator {
// An empty vector indicates that the block should be deferred for freeing
// until after capture.
// Attempt to add an empty node for the allocation stream.
if (!try_add_empty_node(block->stream)) {
// Allocation stream
if (!try_record(block->stream)) {
return {};
}
// Attempt to add empty nodes for all streams that have used the block.
// Any extra streams that used this block
for (const auto& s : block->stream_uses) {
if (!try_add_empty_node(s.stream())) {
if (!try_record(s.stream())) {
return {};
}
}
return empty_nodes;
return std::vector<cudaGraphNode_t>(markers.begin(), markers.end());
}
// Returns the current set of "terminal" nodes in the CUDA graph for a given
// stream. These represent the current endpoints of the stream, and may
// include additional nodes if the graph branches. Any new work captured will
// be attached after one or more of these terminals.
std::vector<cudaGraphNode_t> get_terminals(cudaStream_t stream) {
std::vector<cudaGraphNode_t> result;
cudaStreamCaptureStatus status{};
cudaGraph_t graph{};
const cudaGraphNode_t* dependencies = nullptr;
size_t num_dependencies = 0;
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 13000)
C10_CUDA_CHECK(cudaStreamGetCaptureInfo(
stream,
&status,
nullptr,
&graph,
&dependencies,
nullptr,
&num_dependencies));
#else
C10_CUDA_CHECK(cudaStreamGetCaptureInfo_v2(
stream, &status, nullptr, &graph, &dependencies, &num_dependencies));
#endif
TORCH_INTERNAL_ASSERT(
status == cudaStreamCaptureStatusActive,
"Invalid stream capture status");
for (size_t i = 0; i < num_dependencies; i++) {
auto node = dependencies[i];
if (node != nullptr) {
result.push_back(node);
}
}
return result;
}
// Returns the set of "reusable" free markers (empty nodes) in the current
// Returns the set of "reusable" free markers in the current
// CUDA graph capture. A free marker is considered reusable if it is a
// predecessor of every terminal node.
// This ensures that all future captured work will occur after the free
// marker, making it safe to reuse.
ska::flat_hash_set<cudaGraphNode_t> get_reusable_empty_nodes(
cudaStream_t stream) {
auto terminals = get_terminals(stream);
if (terminals.empty()) {
// No terminal nodes found; nothing to free.
return {};
}
auto get_dependencies = [](cudaGraphNode_t node,
cudaGraphNode_t* pDependencies,
size_t* pNumDependencies) -> void {
void update_visited(
const CaptureInfo& info,
ska::flat_hash_set<cudaGraphNode_t>& visited) {
// This is the versioned cudaGraphNodeGetDependencies helper function.
auto node_get_dependencies =
[](cudaGraphNode_t n, cudaGraphNode_t* deps, size_t* count) -> void {
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 13000)
C10_CUDA_CHECK(cudaGraphNodeGetDependencies(
node, pDependencies, nullptr, pNumDependencies));
C10_CUDA_CHECK(cudaGraphNodeGetDependencies(n, deps, nullptr, count));
#else
C10_CUDA_CHECK(
cudaGraphNodeGetDependencies(node, pDependencies, pNumDependencies));
C10_CUDA_CHECK(cudaGraphNodeGetDependencies(n, deps, count));
#endif
};
@ -1765,62 +1754,43 @@ class DeviceCachingAllocator {
auto get_parents =
[&](cudaGraphNode_t node) -> std::vector<cudaGraphNode_t> {
size_t count = 0;
get_dependencies(node, nullptr, &count);
node_get_dependencies(node, nullptr, &count);
std::vector<cudaGraphNode_t> out(count);
if (count) {
get_dependencies(node, out.data(), &count);
node_get_dependencies(node, out.data(), &count);
out.resize(count);
}
return out;
};
// Helper to determine if a node is an empty node (used as a free marker).
auto is_empty_node = [](cudaGraphNode_t n) -> bool {
cudaGraphNodeType type{};
C10_CUDA_CHECK(cudaGraphNodeGetType(n, &type));
return type == cudaGraphNodeTypeEmpty;
};
// For each terminal node, perform a reverse DFS to count, for each empty
// node, how many terminals it can reach (i.e., for how many terminals it is
// a predecessor). An empty node is reusable if it is a predecessor of all
// terminal nodes.
ska::flat_hash_map<cudaGraphNode_t, size_t> num_terminals_reachable;
for (auto terminal : terminals) {
ska::flat_hash_set<cudaGraphNode_t> visited;
ska::flat_hash_set<cudaGraphNode_t> empty_nodes;
std::function<void(cudaGraphNode_t)> reverse_dfs =
[&](cudaGraphNode_t node) {
if (!visited.insert(node).second)
return;
if (is_empty_node(node)) {
num_terminals_reachable[node]++;
empty_nodes.insert(node);
}
auto parents = get_parents(node);
for (auto p : parents) {
reverse_dfs(p);
}
};
reverse_dfs(terminal);
// For each terminal node, perform a reverse DFS to count, for each free
// marker, how many terminals it can reach (i.e., for how many terminals it
// is a predecessor). A free marker is reusable if it is a predecessor of
// all terminal nodes.
std::deque<cudaGraphNode_t> dfs;
for (size_t i = 0; i < info.num_terminals; ++i) {
dfs.push_back(info.terminals[i]);
}
ska::flat_hash_set<cudaGraphNode_t> reusable_empty_nodes;
for (auto [node, count] : num_terminals_reachable) {
if (count == terminals.size()) {
reusable_empty_nodes.insert(node);
while (!dfs.empty()) {
auto v = dfs.back();
dfs.pop_back();
if (visited.count(v)) {
continue;
}
visited.insert(v);
auto parents = get_parents(v);
for (auto p : parents) {
dfs.push_back(p);
}
}
return reusable_empty_nodes;
}
// A block is considered reusable during CUDA graph capture if every free
// marker (empty node) associated with the block is a predecessor of every
// marker associated with the block is a predecessor of every
// terminal node.
//
// This ensures that any new operation added to the graph will be attached
@ -1829,36 +1799,52 @@ class DeviceCachingAllocator {
// on every stream, so the block's previous lifetime ends before any new
// lifetime begins. This check relies solely on the DAG topology and does not
// require event queries, making it safe to use during capture.
//
// This function iterates over all deferred blocks, determines if their empty
// nodes are reusable according to the above criteria, and frees the block if
// so.
void free_safe_blocks_in_capture(
const std::shared_ptr<GatheredContext>& context,
cudaStream_t stream) {
auto reusable_empty_nodes = get_reusable_empty_nodes(stream);
auto info = stream_get_capture_info(stream);
// If there are no reusable empty nodes (e.g., not currently capturing),
// there is nothing to do.
if (reusable_empty_nodes.empty()) {
if (info.status == cudaStreamCaptureStatusNone || info.num_terminals == 0) {
return;
}
if (graph_reuse_context.find(info.capture_id) ==
graph_reuse_context.end()) {
bool found = false;
for (auto& entry : captures_underway) {
if (entry.second(stream)) {
auto graph_pool = graph_pools.find(entry.first);
TORCH_INTERNAL_ASSERT(
graph_pool != graph_pools.end(),
"Could not find graph pool for capture.");
auto mempool_id = graph_pool->first;
graph_reuse_context[info.capture_id] = GraphReuseContext{};
mempool_to_capture_id[mempool_id] = info.capture_id;
found = true;
break;
}
}
TORCH_INTERNAL_ASSERT(
found, "Could not find memory pool id for capture.");
}
auto& graph_context = graph_reuse_context[info.capture_id];
auto& visited = graph_context.visited[stream];
update_visited(info, visited);
std::vector<Block*> blocks_to_erase;
for (auto& [block, inserted_empty_nodes] : deferred_blocks) {
// Skip this block if it has no empty nodes, as we defer its freeing until
for (auto& [block, markers] : deferred_blocks) {
// Skip this block if it has no markers, as we defer its freeing until
// after graph capture. Also skip if the block was not allocated on the
// current stream; such blocks will be freed when
// free_safe_blocks_in_capture is attempted on that stream.
if (inserted_empty_nodes.empty() || block->stream != stream) {
if (markers.empty() || block->stream != stream) {
continue;
}
bool is_reusable = true;
for (const auto& node : inserted_empty_nodes) {
if (reusable_empty_nodes.find(node) == reusable_empty_nodes.end()) {
for (auto m : markers) {
if (!visited.count(m)) {
is_reusable = false;
break;
}
@ -1919,11 +1905,11 @@ class DeviceCachingAllocator {
if (!block->stream_uses.empty()) {
if (C10_UNLIKELY(!captures_underway.empty())) {
if (CUDAAllocatorConfig::graph_capture_record_stream_reuse()) {
// insert_free_marker returns a vector of free markers,
// record_free_markers returns a vector of free markers,
// or an empty vector if any associated stream is not currently
// capturing. The empty vector means that we will defer the free until
// capture is finished.
deferred_blocks.emplace(block, insert_free_marker(block));
deferred_blocks.emplace(block, record_free_markers(block));
} else {
// If graph_capture_record_stream_reuse is not enabled, always defer
// the free until capture is finished.
@ -2511,6 +2497,21 @@ class DeviceCachingAllocator {
// Called by CUDAGraph::capture_end
void endAllocateToPool(MempoolId_t mempool_id) {
std::lock_guard<std::recursive_mutex> lock(mutex);
if (CUDAAllocatorConfig::graph_capture_record_stream_reuse() &&
!graph_reuse_context.empty()) {
auto capture_id = mempool_to_capture_id[mempool_id];
auto graph_context = graph_reuse_context[capture_id];
for (auto& [stream, _] : graph_context.visited) {
TORCH_INTERNAL_ASSERT(
stream_get_capture_info(stream).status ==
cudaStreamCaptureStatusNone,
"This stream should not be capturing when the capture is ended");
}
graph_reuse_context.erase(capture_id);
mempool_to_capture_id.erase(mempool_id);
}
for (auto it = captures_underway.begin(); it != captures_underway.end();
++it) {
if (it->first == mempool_id) {

View File

@ -613,8 +613,7 @@ Available options:
CUDA Graph capture by using the graph topology (instead of CUDA events) to determine
when a freed block is safe to reuse. This can reduce peak memory during long captures that free
and reallocate buffers across multiple streams, especially when the capture DAG frequently
reaches joined frontiers. Note: Enabling this option can significantly increase the time spent
capturing the graph.
reaches joined frontiers.
.. note::