mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[CUDA][CUDAGraph] Reduce capture overhead in CUDA Graph memory reuse (#162186)
Previous work #158352 delivered CUDAGraph memory footprint reduction with no replay-time impact, but capture time regressed (up to 20× slower) due to repeated full-graph traversals. See previous benchmark results [here](https://github.com/pytorch/pytorch/pull/158352#issuecomment-3215947565) This PR removes capture/reply overhead while preserving the memory savings: 1. **Terminals as free markers** We stop inserting empty nodes and instead record the current stream terminals as free markers. This avoids mutating the user’s graph and keeps semantics unchanged. 2. **Incremental, cached reachability** We add a **per-graph reuse context** that caches reverse-traversal state: * `graph_reuse_context[graph].visited[stream]` tracks nodes already seen from that stream’s terminal frontier. * On each allocation during capture, we resume traversal from the latest terminals and only visit unseen nodes. * A block is freed when all its recorded markers are in the visited set of its allocation stream—i.e., all markers are proven predecessors of future work. See [the performance results here](https://docs.google.com/spreadsheets/d/e/2PACX-1vRPvdd9Xa8W87ixbiA0da_qvOhrUAjUpFz0G-_j-MsDnoeRyhEa4_ut_W3rqcg1VVZVFJ-gucwov-3b/pubhtml?gid=1468302443&single=true), we sweep synthetic multi-stream CUDA Graphs built by `capture_benchmark.py` (same as before, we generate random interleaving of alloc/free/join with given probabilities, see [gist here](https://gist.github.com/eee4017/e2092d215b1d4bd46534148939af39e3)), and we compare median capture/replay times and memory. On an NVIDIA H100 PCIe across 24 configs, the optimization preserves reserved memory reduction at ~24–98%, leaves allocated memory unchanged, and brings capture time back to baseline (range 0.96–1.04× vs. baseline) with replay time unchanged (range 0.97–1.11×). Pull Request resolved: https://github.com/pytorch/pytorch/pull/162186 Approved by: https://github.com/eqy, https://github.com/ngimel
This commit is contained in:
committed by
PyTorch MergeBot
parent
1f1de20ba9
commit
bec6541d84
@ -1183,6 +1183,16 @@ class DeviceCachingAllocator {
|
||||
// ends.
|
||||
ska::flat_hash_map<Block*, std::vector<cudaGraphNode_t>> deferred_blocks;
|
||||
|
||||
// Incremental reverse-traversal state cached per graph.
|
||||
// We never re-traverse nodes we've already seen
|
||||
struct GraphReuseContext {
|
||||
ska::flat_hash_map<cudaStream_t, ska::flat_hash_set<cudaGraphNode_t>>
|
||||
visited;
|
||||
};
|
||||
ska::flat_hash_map<MempoolId_t, CaptureId_t, MempoolIdHash>
|
||||
mempool_to_capture_id;
|
||||
ska::flat_hash_map<CaptureId_t, GraphReuseContext> graph_reuse_context;
|
||||
|
||||
// outstanding cuda events
|
||||
ska::flat_hash_map<
|
||||
cuda::CUDAStream,
|
||||
@ -1638,44 +1648,70 @@ class DeviceCachingAllocator {
|
||||
return block;
|
||||
}
|
||||
|
||||
// Insert "free marker" (empty nodes) into the CUDA graph for all streams that
|
||||
struct CaptureInfo {
|
||||
cudaGraph_t graph{};
|
||||
CaptureId_t capture_id{0};
|
||||
const cudaGraphNode_t* terminals{nullptr};
|
||||
size_t num_terminals{0};
|
||||
cudaStreamCaptureStatus status{cudaStreamCaptureStatusNone};
|
||||
};
|
||||
|
||||
inline CaptureInfo stream_get_capture_info(cudaStream_t stream) {
|
||||
CaptureInfo info{};
|
||||
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 13000)
|
||||
C10_CUDA_CHECK(cudaStreamGetCaptureInfo(
|
||||
stream,
|
||||
&info.status,
|
||||
&info.capture_id,
|
||||
&info.graph,
|
||||
&info.terminals,
|
||||
nullptr,
|
||||
&info.num_terminals));
|
||||
#else
|
||||
C10_CUDA_CHECK(cudaStreamGetCaptureInfo_v2(
|
||||
stream,
|
||||
&info.status,
|
||||
&info.capture_id,
|
||||
&info.graph,
|
||||
&info.terminals,
|
||||
&info.num_terminals));
|
||||
#endif
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
info.status != cudaStreamCaptureStatusInvalidated,
|
||||
"Invalid stream capture status");
|
||||
|
||||
return info;
|
||||
}
|
||||
|
||||
// Record "free marker" of the CUDA graph for all streams that
|
||||
// have used the block, including the allocation stream. These nodes mark the
|
||||
// last use of the block in the capture graph. Returns a vector of the
|
||||
// inserted nodes, or an empty vector if any stream is not capturing.
|
||||
std::vector<cudaGraphNode_t> insert_free_marker(Block* block) {
|
||||
std::vector<cudaGraphNode_t> empty_nodes;
|
||||
std::vector<cudaGraphNode_t> record_free_markers(Block* block) {
|
||||
// Is is possible to have the same marker recorded multiple times, so we use
|
||||
// a set to avoid duplicates
|
||||
ska::flat_hash_set<cudaGraphNode_t> markers;
|
||||
cudaGraph_t owning_graph = nullptr;
|
||||
|
||||
auto try_add_empty_node = [&](cudaStream_t stream) -> bool {
|
||||
cudaStreamCaptureStatus status{};
|
||||
cudaGraph_t graph{};
|
||||
const cudaGraphNode_t* deps = nullptr;
|
||||
size_t num_deps = 0;
|
||||
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 13000)
|
||||
C10_CUDA_CHECK(cudaStreamGetCaptureInfo(
|
||||
stream, &status, nullptr, &graph, &deps, nullptr, &num_deps));
|
||||
#else
|
||||
C10_CUDA_CHECK(cudaStreamGetCaptureInfo_v2(
|
||||
stream, &status, nullptr, &graph, &deps, &num_deps));
|
||||
#endif
|
||||
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
status != cudaStreamCaptureStatusInvalidated,
|
||||
"Invalid stream capture status");
|
||||
|
||||
if (status == cudaStreamCaptureStatusNone) {
|
||||
return false;
|
||||
auto try_record = [&](cudaStream_t s) -> bool {
|
||||
auto info = stream_get_capture_info(s);
|
||||
if (info.status == cudaStreamCaptureStatusNone) {
|
||||
return false; // not capturing on this stream -> must defer
|
||||
}
|
||||
|
||||
cudaGraphNode_t node{};
|
||||
C10_CUDA_CHECK(cudaGraphAddEmptyNode(&node, graph, deps, num_deps));
|
||||
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 13000)
|
||||
C10_CUDA_CHECK(cudaStreamUpdateCaptureDependencies(
|
||||
stream, &node, nullptr, 1, cudaStreamSetCaptureDependencies));
|
||||
#else
|
||||
C10_CUDA_CHECK(cudaStreamUpdateCaptureDependencies(
|
||||
stream, &node, 1, cudaStreamSetCaptureDependencies));
|
||||
#endif
|
||||
empty_nodes.push_back(node);
|
||||
if (owning_graph == nullptr) {
|
||||
owning_graph = info.graph;
|
||||
}
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
info.graph == owning_graph,
|
||||
"All streams in the same capture should agree on the graph");
|
||||
|
||||
// Use current terminals as the free markers for the stream
|
||||
for (size_t i = 0; i < info.num_terminals; ++i) {
|
||||
auto terminal = info.terminals[i];
|
||||
markers.insert(terminal);
|
||||
}
|
||||
owning_graph = info.graph; // all streams in the same capture should agree
|
||||
return true;
|
||||
};
|
||||
|
||||
@ -1683,81 +1719,34 @@ class DeviceCachingAllocator {
|
||||
// An empty vector indicates that the block should be deferred for freeing
|
||||
// until after capture.
|
||||
|
||||
// Attempt to add an empty node for the allocation stream.
|
||||
if (!try_add_empty_node(block->stream)) {
|
||||
// Allocation stream
|
||||
if (!try_record(block->stream)) {
|
||||
return {};
|
||||
}
|
||||
// Attempt to add empty nodes for all streams that have used the block.
|
||||
// Any extra streams that used this block
|
||||
for (const auto& s : block->stream_uses) {
|
||||
if (!try_add_empty_node(s.stream())) {
|
||||
if (!try_record(s.stream())) {
|
||||
return {};
|
||||
}
|
||||
}
|
||||
return empty_nodes;
|
||||
return std::vector<cudaGraphNode_t>(markers.begin(), markers.end());
|
||||
}
|
||||
|
||||
// Returns the current set of "terminal" nodes in the CUDA graph for a given
|
||||
// stream. These represent the current endpoints of the stream, and may
|
||||
// include additional nodes if the graph branches. Any new work captured will
|
||||
// be attached after one or more of these terminals.
|
||||
std::vector<cudaGraphNode_t> get_terminals(cudaStream_t stream) {
|
||||
std::vector<cudaGraphNode_t> result;
|
||||
|
||||
cudaStreamCaptureStatus status{};
|
||||
cudaGraph_t graph{};
|
||||
const cudaGraphNode_t* dependencies = nullptr;
|
||||
size_t num_dependencies = 0;
|
||||
|
||||
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 13000)
|
||||
C10_CUDA_CHECK(cudaStreamGetCaptureInfo(
|
||||
stream,
|
||||
&status,
|
||||
nullptr,
|
||||
&graph,
|
||||
&dependencies,
|
||||
nullptr,
|
||||
&num_dependencies));
|
||||
#else
|
||||
C10_CUDA_CHECK(cudaStreamGetCaptureInfo_v2(
|
||||
stream, &status, nullptr, &graph, &dependencies, &num_dependencies));
|
||||
#endif
|
||||
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
status == cudaStreamCaptureStatusActive,
|
||||
"Invalid stream capture status");
|
||||
|
||||
for (size_t i = 0; i < num_dependencies; i++) {
|
||||
auto node = dependencies[i];
|
||||
if (node != nullptr) {
|
||||
result.push_back(node);
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// Returns the set of "reusable" free markers (empty nodes) in the current
|
||||
// Returns the set of "reusable" free markers in the current
|
||||
// CUDA graph capture. A free marker is considered reusable if it is a
|
||||
// predecessor of every terminal node.
|
||||
// This ensures that all future captured work will occur after the free
|
||||
// marker, making it safe to reuse.
|
||||
ska::flat_hash_set<cudaGraphNode_t> get_reusable_empty_nodes(
|
||||
cudaStream_t stream) {
|
||||
auto terminals = get_terminals(stream);
|
||||
if (terminals.empty()) {
|
||||
// No terminal nodes found; nothing to free.
|
||||
return {};
|
||||
}
|
||||
|
||||
auto get_dependencies = [](cudaGraphNode_t node,
|
||||
cudaGraphNode_t* pDependencies,
|
||||
size_t* pNumDependencies) -> void {
|
||||
void update_visited(
|
||||
const CaptureInfo& info,
|
||||
ska::flat_hash_set<cudaGraphNode_t>& visited) {
|
||||
// This is the versioned cudaGraphNodeGetDependencies helper function.
|
||||
auto node_get_dependencies =
|
||||
[](cudaGraphNode_t n, cudaGraphNode_t* deps, size_t* count) -> void {
|
||||
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 13000)
|
||||
C10_CUDA_CHECK(cudaGraphNodeGetDependencies(
|
||||
node, pDependencies, nullptr, pNumDependencies));
|
||||
C10_CUDA_CHECK(cudaGraphNodeGetDependencies(n, deps, nullptr, count));
|
||||
#else
|
||||
C10_CUDA_CHECK(
|
||||
cudaGraphNodeGetDependencies(node, pDependencies, pNumDependencies));
|
||||
C10_CUDA_CHECK(cudaGraphNodeGetDependencies(n, deps, count));
|
||||
#endif
|
||||
};
|
||||
|
||||
@ -1765,62 +1754,43 @@ class DeviceCachingAllocator {
|
||||
auto get_parents =
|
||||
[&](cudaGraphNode_t node) -> std::vector<cudaGraphNode_t> {
|
||||
size_t count = 0;
|
||||
get_dependencies(node, nullptr, &count);
|
||||
|
||||
node_get_dependencies(node, nullptr, &count);
|
||||
std::vector<cudaGraphNode_t> out(count);
|
||||
if (count) {
|
||||
get_dependencies(node, out.data(), &count);
|
||||
node_get_dependencies(node, out.data(), &count);
|
||||
out.resize(count);
|
||||
}
|
||||
return out;
|
||||
};
|
||||
|
||||
// Helper to determine if a node is an empty node (used as a free marker).
|
||||
auto is_empty_node = [](cudaGraphNode_t n) -> bool {
|
||||
cudaGraphNodeType type{};
|
||||
C10_CUDA_CHECK(cudaGraphNodeGetType(n, &type));
|
||||
return type == cudaGraphNodeTypeEmpty;
|
||||
};
|
||||
|
||||
// For each terminal node, perform a reverse DFS to count, for each empty
|
||||
// node, how many terminals it can reach (i.e., for how many terminals it is
|
||||
// a predecessor). An empty node is reusable if it is a predecessor of all
|
||||
// terminal nodes.
|
||||
ska::flat_hash_map<cudaGraphNode_t, size_t> num_terminals_reachable;
|
||||
|
||||
for (auto terminal : terminals) {
|
||||
ska::flat_hash_set<cudaGraphNode_t> visited;
|
||||
ska::flat_hash_set<cudaGraphNode_t> empty_nodes;
|
||||
|
||||
std::function<void(cudaGraphNode_t)> reverse_dfs =
|
||||
[&](cudaGraphNode_t node) {
|
||||
if (!visited.insert(node).second)
|
||||
return;
|
||||
|
||||
if (is_empty_node(node)) {
|
||||
num_terminals_reachable[node]++;
|
||||
empty_nodes.insert(node);
|
||||
}
|
||||
auto parents = get_parents(node);
|
||||
for (auto p : parents) {
|
||||
reverse_dfs(p);
|
||||
}
|
||||
};
|
||||
|
||||
reverse_dfs(terminal);
|
||||
// For each terminal node, perform a reverse DFS to count, for each free
|
||||
// marker, how many terminals it can reach (i.e., for how many terminals it
|
||||
// is a predecessor). A free marker is reusable if it is a predecessor of
|
||||
// all terminal nodes.
|
||||
std::deque<cudaGraphNode_t> dfs;
|
||||
for (size_t i = 0; i < info.num_terminals; ++i) {
|
||||
dfs.push_back(info.terminals[i]);
|
||||
}
|
||||
|
||||
ska::flat_hash_set<cudaGraphNode_t> reusable_empty_nodes;
|
||||
for (auto [node, count] : num_terminals_reachable) {
|
||||
if (count == terminals.size()) {
|
||||
reusable_empty_nodes.insert(node);
|
||||
while (!dfs.empty()) {
|
||||
auto v = dfs.back();
|
||||
dfs.pop_back();
|
||||
|
||||
if (visited.count(v)) {
|
||||
continue;
|
||||
}
|
||||
visited.insert(v);
|
||||
|
||||
auto parents = get_parents(v);
|
||||
for (auto p : parents) {
|
||||
dfs.push_back(p);
|
||||
}
|
||||
}
|
||||
|
||||
return reusable_empty_nodes;
|
||||
}
|
||||
|
||||
// A block is considered reusable during CUDA graph capture if every free
|
||||
// marker (empty node) associated with the block is a predecessor of every
|
||||
// marker associated with the block is a predecessor of every
|
||||
// terminal node.
|
||||
//
|
||||
// This ensures that any new operation added to the graph will be attached
|
||||
@ -1829,36 +1799,52 @@ class DeviceCachingAllocator {
|
||||
// on every stream, so the block's previous lifetime ends before any new
|
||||
// lifetime begins. This check relies solely on the DAG topology and does not
|
||||
// require event queries, making it safe to use during capture.
|
||||
//
|
||||
// This function iterates over all deferred blocks, determines if their empty
|
||||
// nodes are reusable according to the above criteria, and frees the block if
|
||||
// so.
|
||||
void free_safe_blocks_in_capture(
|
||||
const std::shared_ptr<GatheredContext>& context,
|
||||
cudaStream_t stream) {
|
||||
auto reusable_empty_nodes = get_reusable_empty_nodes(stream);
|
||||
auto info = stream_get_capture_info(stream);
|
||||
|
||||
// If there are no reusable empty nodes (e.g., not currently capturing),
|
||||
// there is nothing to do.
|
||||
if (reusable_empty_nodes.empty()) {
|
||||
if (info.status == cudaStreamCaptureStatusNone || info.num_terminals == 0) {
|
||||
return;
|
||||
}
|
||||
if (graph_reuse_context.find(info.capture_id) ==
|
||||
graph_reuse_context.end()) {
|
||||
bool found = false;
|
||||
for (auto& entry : captures_underway) {
|
||||
if (entry.second(stream)) {
|
||||
auto graph_pool = graph_pools.find(entry.first);
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
graph_pool != graph_pools.end(),
|
||||
"Could not find graph pool for capture.");
|
||||
auto mempool_id = graph_pool->first;
|
||||
graph_reuse_context[info.capture_id] = GraphReuseContext{};
|
||||
mempool_to_capture_id[mempool_id] = info.capture_id;
|
||||
found = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
found, "Could not find memory pool id for capture.");
|
||||
}
|
||||
auto& graph_context = graph_reuse_context[info.capture_id];
|
||||
auto& visited = graph_context.visited[stream];
|
||||
update_visited(info, visited);
|
||||
|
||||
std::vector<Block*> blocks_to_erase;
|
||||
|
||||
for (auto& [block, inserted_empty_nodes] : deferred_blocks) {
|
||||
// Skip this block if it has no empty nodes, as we defer its freeing until
|
||||
for (auto& [block, markers] : deferred_blocks) {
|
||||
// Skip this block if it has no markers, as we defer its freeing until
|
||||
// after graph capture. Also skip if the block was not allocated on the
|
||||
// current stream; such blocks will be freed when
|
||||
// free_safe_blocks_in_capture is attempted on that stream.
|
||||
if (inserted_empty_nodes.empty() || block->stream != stream) {
|
||||
if (markers.empty() || block->stream != stream) {
|
||||
continue;
|
||||
}
|
||||
|
||||
bool is_reusable = true;
|
||||
|
||||
for (const auto& node : inserted_empty_nodes) {
|
||||
if (reusable_empty_nodes.find(node) == reusable_empty_nodes.end()) {
|
||||
for (auto m : markers) {
|
||||
if (!visited.count(m)) {
|
||||
is_reusable = false;
|
||||
break;
|
||||
}
|
||||
@ -1919,11 +1905,11 @@ class DeviceCachingAllocator {
|
||||
if (!block->stream_uses.empty()) {
|
||||
if (C10_UNLIKELY(!captures_underway.empty())) {
|
||||
if (CUDAAllocatorConfig::graph_capture_record_stream_reuse()) {
|
||||
// insert_free_marker returns a vector of free markers,
|
||||
// record_free_markers returns a vector of free markers,
|
||||
// or an empty vector if any associated stream is not currently
|
||||
// capturing. The empty vector means that we will defer the free until
|
||||
// capture is finished.
|
||||
deferred_blocks.emplace(block, insert_free_marker(block));
|
||||
deferred_blocks.emplace(block, record_free_markers(block));
|
||||
} else {
|
||||
// If graph_capture_record_stream_reuse is not enabled, always defer
|
||||
// the free until capture is finished.
|
||||
@ -2511,6 +2497,21 @@ class DeviceCachingAllocator {
|
||||
// Called by CUDAGraph::capture_end
|
||||
void endAllocateToPool(MempoolId_t mempool_id) {
|
||||
std::lock_guard<std::recursive_mutex> lock(mutex);
|
||||
|
||||
if (CUDAAllocatorConfig::graph_capture_record_stream_reuse() &&
|
||||
!graph_reuse_context.empty()) {
|
||||
auto capture_id = mempool_to_capture_id[mempool_id];
|
||||
auto graph_context = graph_reuse_context[capture_id];
|
||||
for (auto& [stream, _] : graph_context.visited) {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
stream_get_capture_info(stream).status ==
|
||||
cudaStreamCaptureStatusNone,
|
||||
"This stream should not be capturing when the capture is ended");
|
||||
}
|
||||
graph_reuse_context.erase(capture_id);
|
||||
mempool_to_capture_id.erase(mempool_id);
|
||||
}
|
||||
|
||||
for (auto it = captures_underway.begin(); it != captures_underway.end();
|
||||
++it) {
|
||||
if (it->first == mempool_id) {
|
||||
|
@ -613,8 +613,7 @@ Available options:
|
||||
CUDA Graph capture by using the graph topology (instead of CUDA events) to determine
|
||||
when a freed block is safe to reuse. This can reduce peak memory during long captures that free
|
||||
and reallocate buffers across multiple streams, especially when the capture DAG frequently
|
||||
reaches joined frontiers. Note: Enabling this option can significantly increase the time spent
|
||||
capturing the graph.
|
||||
reaches joined frontiers.
|
||||
|
||||
.. note::
|
||||
|
||||
|
Reference in New Issue
Block a user