diff --git a/c10/cuda/CUDACachingAllocator.cpp b/c10/cuda/CUDACachingAllocator.cpp index 1a15495e5bf6..60342a33d7c4 100644 --- a/c10/cuda/CUDACachingAllocator.cpp +++ b/c10/cuda/CUDACachingAllocator.cpp @@ -1183,6 +1183,16 @@ class DeviceCachingAllocator { // ends. ska::flat_hash_map> deferred_blocks; + // Incremental reverse-traversal state cached per graph. + // We never re-traverse nodes we've already seen + struct GraphReuseContext { + ska::flat_hash_map> + visited; + }; + ska::flat_hash_map + mempool_to_capture_id; + ska::flat_hash_map graph_reuse_context; + // outstanding cuda events ska::flat_hash_map< cuda::CUDAStream, @@ -1638,44 +1648,70 @@ class DeviceCachingAllocator { return block; } - // Insert "free marker" (empty nodes) into the CUDA graph for all streams that + struct CaptureInfo { + cudaGraph_t graph{}; + CaptureId_t capture_id{0}; + const cudaGraphNode_t* terminals{nullptr}; + size_t num_terminals{0}; + cudaStreamCaptureStatus status{cudaStreamCaptureStatusNone}; + }; + + inline CaptureInfo stream_get_capture_info(cudaStream_t stream) { + CaptureInfo info{}; +#if (defined(CUDA_VERSION) && CUDA_VERSION >= 13000) + C10_CUDA_CHECK(cudaStreamGetCaptureInfo( + stream, + &info.status, + &info.capture_id, + &info.graph, + &info.terminals, + nullptr, + &info.num_terminals)); +#else + C10_CUDA_CHECK(cudaStreamGetCaptureInfo_v2( + stream, + &info.status, + &info.capture_id, + &info.graph, + &info.terminals, + &info.num_terminals)); +#endif + TORCH_INTERNAL_ASSERT( + info.status != cudaStreamCaptureStatusInvalidated, + "Invalid stream capture status"); + + return info; + } + + // Record "free marker" of the CUDA graph for all streams that // have used the block, including the allocation stream. These nodes mark the // last use of the block in the capture graph. Returns a vector of the // inserted nodes, or an empty vector if any stream is not capturing. - std::vector insert_free_marker(Block* block) { - std::vector empty_nodes; + std::vector record_free_markers(Block* block) { + // Is is possible to have the same marker recorded multiple times, so we use + // a set to avoid duplicates + ska::flat_hash_set markers; + cudaGraph_t owning_graph = nullptr; - auto try_add_empty_node = [&](cudaStream_t stream) -> bool { - cudaStreamCaptureStatus status{}; - cudaGraph_t graph{}; - const cudaGraphNode_t* deps = nullptr; - size_t num_deps = 0; -#if (defined(CUDA_VERSION) && CUDA_VERSION >= 13000) - C10_CUDA_CHECK(cudaStreamGetCaptureInfo( - stream, &status, nullptr, &graph, &deps, nullptr, &num_deps)); -#else - C10_CUDA_CHECK(cudaStreamGetCaptureInfo_v2( - stream, &status, nullptr, &graph, &deps, &num_deps)); -#endif - - TORCH_INTERNAL_ASSERT( - status != cudaStreamCaptureStatusInvalidated, - "Invalid stream capture status"); - - if (status == cudaStreamCaptureStatusNone) { - return false; + auto try_record = [&](cudaStream_t s) -> bool { + auto info = stream_get_capture_info(s); + if (info.status == cudaStreamCaptureStatusNone) { + return false; // not capturing on this stream -> must defer } - cudaGraphNode_t node{}; - C10_CUDA_CHECK(cudaGraphAddEmptyNode(&node, graph, deps, num_deps)); -#if (defined(CUDA_VERSION) && CUDA_VERSION >= 13000) - C10_CUDA_CHECK(cudaStreamUpdateCaptureDependencies( - stream, &node, nullptr, 1, cudaStreamSetCaptureDependencies)); -#else - C10_CUDA_CHECK(cudaStreamUpdateCaptureDependencies( - stream, &node, 1, cudaStreamSetCaptureDependencies)); -#endif - empty_nodes.push_back(node); + if (owning_graph == nullptr) { + owning_graph = info.graph; + } + TORCH_INTERNAL_ASSERT( + info.graph == owning_graph, + "All streams in the same capture should agree on the graph"); + + // Use current terminals as the free markers for the stream + for (size_t i = 0; i < info.num_terminals; ++i) { + auto terminal = info.terminals[i]; + markers.insert(terminal); + } + owning_graph = info.graph; // all streams in the same capture should agree return true; }; @@ -1683,81 +1719,34 @@ class DeviceCachingAllocator { // An empty vector indicates that the block should be deferred for freeing // until after capture. - // Attempt to add an empty node for the allocation stream. - if (!try_add_empty_node(block->stream)) { + // Allocation stream + if (!try_record(block->stream)) { return {}; } - // Attempt to add empty nodes for all streams that have used the block. + // Any extra streams that used this block for (const auto& s : block->stream_uses) { - if (!try_add_empty_node(s.stream())) { + if (!try_record(s.stream())) { return {}; } } - return empty_nodes; + return std::vector(markers.begin(), markers.end()); } - // Returns the current set of "terminal" nodes in the CUDA graph for a given - // stream. These represent the current endpoints of the stream, and may - // include additional nodes if the graph branches. Any new work captured will - // be attached after one or more of these terminals. - std::vector get_terminals(cudaStream_t stream) { - std::vector result; - - cudaStreamCaptureStatus status{}; - cudaGraph_t graph{}; - const cudaGraphNode_t* dependencies = nullptr; - size_t num_dependencies = 0; - -#if (defined(CUDA_VERSION) && CUDA_VERSION >= 13000) - C10_CUDA_CHECK(cudaStreamGetCaptureInfo( - stream, - &status, - nullptr, - &graph, - &dependencies, - nullptr, - &num_dependencies)); -#else - C10_CUDA_CHECK(cudaStreamGetCaptureInfo_v2( - stream, &status, nullptr, &graph, &dependencies, &num_dependencies)); -#endif - - TORCH_INTERNAL_ASSERT( - status == cudaStreamCaptureStatusActive, - "Invalid stream capture status"); - - for (size_t i = 0; i < num_dependencies; i++) { - auto node = dependencies[i]; - if (node != nullptr) { - result.push_back(node); - } - } - - return result; - } - - // Returns the set of "reusable" free markers (empty nodes) in the current + // Returns the set of "reusable" free markers in the current // CUDA graph capture. A free marker is considered reusable if it is a // predecessor of every terminal node. // This ensures that all future captured work will occur after the free // marker, making it safe to reuse. - ska::flat_hash_set get_reusable_empty_nodes( - cudaStream_t stream) { - auto terminals = get_terminals(stream); - if (terminals.empty()) { - // No terminal nodes found; nothing to free. - return {}; - } - - auto get_dependencies = [](cudaGraphNode_t node, - cudaGraphNode_t* pDependencies, - size_t* pNumDependencies) -> void { + void update_visited( + const CaptureInfo& info, + ska::flat_hash_set& visited) { + // This is the versioned cudaGraphNodeGetDependencies helper function. + auto node_get_dependencies = + [](cudaGraphNode_t n, cudaGraphNode_t* deps, size_t* count) -> void { #if (defined(CUDA_VERSION) && CUDA_VERSION >= 13000) - C10_CUDA_CHECK(cudaGraphNodeGetDependencies( - node, pDependencies, nullptr, pNumDependencies)); + C10_CUDA_CHECK(cudaGraphNodeGetDependencies(n, deps, nullptr, count)); #else - C10_CUDA_CHECK( - cudaGraphNodeGetDependencies(node, pDependencies, pNumDependencies)); + C10_CUDA_CHECK(cudaGraphNodeGetDependencies(n, deps, count)); #endif }; @@ -1765,62 +1754,43 @@ class DeviceCachingAllocator { auto get_parents = [&](cudaGraphNode_t node) -> std::vector { size_t count = 0; - get_dependencies(node, nullptr, &count); + + node_get_dependencies(node, nullptr, &count); std::vector out(count); if (count) { - get_dependencies(node, out.data(), &count); + node_get_dependencies(node, out.data(), &count); out.resize(count); } return out; }; - // Helper to determine if a node is an empty node (used as a free marker). - auto is_empty_node = [](cudaGraphNode_t n) -> bool { - cudaGraphNodeType type{}; - C10_CUDA_CHECK(cudaGraphNodeGetType(n, &type)); - return type == cudaGraphNodeTypeEmpty; - }; - - // For each terminal node, perform a reverse DFS to count, for each empty - // node, how many terminals it can reach (i.e., for how many terminals it is - // a predecessor). An empty node is reusable if it is a predecessor of all - // terminal nodes. - ska::flat_hash_map num_terminals_reachable; - - for (auto terminal : terminals) { - ska::flat_hash_set visited; - ska::flat_hash_set empty_nodes; - - std::function reverse_dfs = - [&](cudaGraphNode_t node) { - if (!visited.insert(node).second) - return; - - if (is_empty_node(node)) { - num_terminals_reachable[node]++; - empty_nodes.insert(node); - } - auto parents = get_parents(node); - for (auto p : parents) { - reverse_dfs(p); - } - }; - - reverse_dfs(terminal); + // For each terminal node, perform a reverse DFS to count, for each free + // marker, how many terminals it can reach (i.e., for how many terminals it + // is a predecessor). A free marker is reusable if it is a predecessor of + // all terminal nodes. + std::deque dfs; + for (size_t i = 0; i < info.num_terminals; ++i) { + dfs.push_back(info.terminals[i]); } - ska::flat_hash_set reusable_empty_nodes; - for (auto [node, count] : num_terminals_reachable) { - if (count == terminals.size()) { - reusable_empty_nodes.insert(node); + while (!dfs.empty()) { + auto v = dfs.back(); + dfs.pop_back(); + + if (visited.count(v)) { + continue; + } + visited.insert(v); + + auto parents = get_parents(v); + for (auto p : parents) { + dfs.push_back(p); } } - - return reusable_empty_nodes; } // A block is considered reusable during CUDA graph capture if every free - // marker (empty node) associated with the block is a predecessor of every + // marker associated with the block is a predecessor of every // terminal node. // // This ensures that any new operation added to the graph will be attached @@ -1829,36 +1799,52 @@ class DeviceCachingAllocator { // on every stream, so the block's previous lifetime ends before any new // lifetime begins. This check relies solely on the DAG topology and does not // require event queries, making it safe to use during capture. - // - // This function iterates over all deferred blocks, determines if their empty - // nodes are reusable according to the above criteria, and frees the block if - // so. void free_safe_blocks_in_capture( const std::shared_ptr& context, cudaStream_t stream) { - auto reusable_empty_nodes = get_reusable_empty_nodes(stream); + auto info = stream_get_capture_info(stream); // If there are no reusable empty nodes (e.g., not currently capturing), // there is nothing to do. - if (reusable_empty_nodes.empty()) { + if (info.status == cudaStreamCaptureStatusNone || info.num_terminals == 0) { return; } + if (graph_reuse_context.find(info.capture_id) == + graph_reuse_context.end()) { + bool found = false; + for (auto& entry : captures_underway) { + if (entry.second(stream)) { + auto graph_pool = graph_pools.find(entry.first); + TORCH_INTERNAL_ASSERT( + graph_pool != graph_pools.end(), + "Could not find graph pool for capture."); + auto mempool_id = graph_pool->first; + graph_reuse_context[info.capture_id] = GraphReuseContext{}; + mempool_to_capture_id[mempool_id] = info.capture_id; + found = true; + break; + } + } + TORCH_INTERNAL_ASSERT( + found, "Could not find memory pool id for capture."); + } + auto& graph_context = graph_reuse_context[info.capture_id]; + auto& visited = graph_context.visited[stream]; + update_visited(info, visited); std::vector blocks_to_erase; - - for (auto& [block, inserted_empty_nodes] : deferred_blocks) { - // Skip this block if it has no empty nodes, as we defer its freeing until + for (auto& [block, markers] : deferred_blocks) { + // Skip this block if it has no markers, as we defer its freeing until // after graph capture. Also skip if the block was not allocated on the // current stream; such blocks will be freed when // free_safe_blocks_in_capture is attempted on that stream. - if (inserted_empty_nodes.empty() || block->stream != stream) { + if (markers.empty() || block->stream != stream) { continue; } bool is_reusable = true; - - for (const auto& node : inserted_empty_nodes) { - if (reusable_empty_nodes.find(node) == reusable_empty_nodes.end()) { + for (auto m : markers) { + if (!visited.count(m)) { is_reusable = false; break; } @@ -1919,11 +1905,11 @@ class DeviceCachingAllocator { if (!block->stream_uses.empty()) { if (C10_UNLIKELY(!captures_underway.empty())) { if (CUDAAllocatorConfig::graph_capture_record_stream_reuse()) { - // insert_free_marker returns a vector of free markers, + // record_free_markers returns a vector of free markers, // or an empty vector if any associated stream is not currently // capturing. The empty vector means that we will defer the free until // capture is finished. - deferred_blocks.emplace(block, insert_free_marker(block)); + deferred_blocks.emplace(block, record_free_markers(block)); } else { // If graph_capture_record_stream_reuse is not enabled, always defer // the free until capture is finished. @@ -2511,6 +2497,21 @@ class DeviceCachingAllocator { // Called by CUDAGraph::capture_end void endAllocateToPool(MempoolId_t mempool_id) { std::lock_guard lock(mutex); + + if (CUDAAllocatorConfig::graph_capture_record_stream_reuse() && + !graph_reuse_context.empty()) { + auto capture_id = mempool_to_capture_id[mempool_id]; + auto graph_context = graph_reuse_context[capture_id]; + for (auto& [stream, _] : graph_context.visited) { + TORCH_INTERNAL_ASSERT( + stream_get_capture_info(stream).status == + cudaStreamCaptureStatusNone, + "This stream should not be capturing when the capture is ended"); + } + graph_reuse_context.erase(capture_id); + mempool_to_capture_id.erase(mempool_id); + } + for (auto it = captures_underway.begin(); it != captures_underway.end(); ++it) { if (it->first == mempool_id) { diff --git a/docs/source/notes/cuda.rst b/docs/source/notes/cuda.rst index 8981ac1bf6ed..86908185e996 100644 --- a/docs/source/notes/cuda.rst +++ b/docs/source/notes/cuda.rst @@ -613,8 +613,7 @@ Available options: CUDA Graph capture by using the graph topology (instead of CUDA events) to determine when a freed block is safe to reuse. This can reduce peak memory during long captures that free and reallocate buffers across multiple streams, especially when the capture DAG frequently - reaches joined frontiers. Note: Enabling this option can significantly increase the time spent - capturing the graph. + reaches joined frontiers. .. note::