mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
## Introduction During CUDA Graph capture, the CUDA caching allocator currently defers reclaiming blocks until capture ends. This is because CUDA forbids querying events recorded during capture (the CUDA operation is not executed during the capture stage), so the allocator cannot use its normal event-based logic. However, capture records an DAG (we call it **capturing graph**) of work. We can use the capturing graph to determine when a block’s old lifetime is fully before future work, and safely reuse it within the same capture. This PR adds an experimental flag `graph_capture_record_stream_reuse: True|False (default: False)`. When enabled, the allocator inserts lightweight free markers and uses capture ordering to decide if a freed block is safe to reuse during capture. If the proof cannot be established, we fall back to the existing post-capture path. ## Terms * **Free marker**: A capture-legal no-op (created with `cudaGraphAddEmptyNode`) inserted after the last captured use of the block on each stream that used it. * **Terminal**: The set of the lastest operations of the stream (or the capturing graph). Any newly captured op on that stream will attach after all nodes in this set. For a stream currently capturing, it is the set of nodes returned in `dependencies_out` by `cudaStreamGetCaptureInfo`. ## When can we reuse a block during capture? ### Strong Rule (Graph-Wide Safety) This rule provides a universal guarantee that a block is safe for reuse by any stream in the graph. > A block is safe to reuse if every free marker is a predecessor of every terminal of all active streams in the graph. Why it's safe: This rule establishes a strict global ordering. Since any new operation on any stream must be appended after that stream's terminals, this condition guarantees that the block's new lifetime begins only after its old lifetime has completely ended everywhere. This prevents lifetime overlaps when the graph is replayed, ensuring correctness. ### Per-stream Rule (A Practical Optimization) The strong rule, while safe, is often unnecessarily restrictive. The `DeviceCachingAllocator` introduces a crucial constraint that allows for a simpler check. In `DeviceCachingAllocator`, `get_free_block` only returns blocks whose `block->stream == p.stream()`. In other words, we never reuse a block on a stream different from the allocation stream. This means we don't need to verify safety across the entire graph. We only need to confirm that the block is safe to reuse from the perspective of its own allocation stream. > Reuse a block for allocations on stream S if every free marker is a predecessor of every node in the terminal set of S. In short, a block is considered **reusable** on stream S as long as all marker marking it "free" are guaranteed to complete before any new work that might need it on stream S begins. ## Implementation * On `free(block)` during capture * For each stream in `block->stream_uses` and the allocation stream, insert a free marker (empty node) and make it that stream’s tail. * If we cannot place markers for all such streams (for example, a stream is not in capture), defer to the post-capture path. * Otherwise, store the marker handles and keep the block in the capture-private structures. * On `allocate(stream)` during capture (attempt per-stream reclaim) * Query the allocation stream S’s terminal via `cudaStreamGetCaptureInfo`. * For each deferred block, check whether it is allocated on this stream, and each of its free markers is a predecessor of the terminal. * If yes, hand the block to S for immediate reuse within the same capture. * If no, keep it deferred; it will be reconsidered as capture progresses and S’s terminal advances. * On capture end * Any still-deferred blocks follow the existing post-capture reclamation (event insertion/polling). External behavior remains unchanged if we cannot prove safety during capture. ## Examples (2 streams) <img width="641" height="801" alt="pytorch-remove-cudagraph-defer-reclaiming (6)" src="https://github.com/user-attachments/assets/41adc835-d448-483b-99ba-b4341cb7d2a2" /> * Case 0 — Unsafe The two frees are not ordered with respect to each other. For stream 1, the other stream’s free marker does not precede this stream’s terminal, so the per-stream condition fails. Counterexample intuition for the unsafe setups: imagine `f2(x)` runs for a long time. If DeviceCachingAllocator reused block `x` on a stream whose terminal is not ordered after the free markers, the new lifetime could overlap the old one on replay, risking use-after-free or data corruption. The per-stream rule prevents exactly this. * Case 1 — Reusable on stream 1 Stream 1’s terminal is after both frees, so every free marker precedes stream 1’s terminal. The block is reusable for allocations on stream 1. * Case 2 — Not reusable on stream 2, but this cannot occur in `DeviceCachingAllocator` This depicts reusing the block on stream 2 while stream 1’s free is not yet ordered before stream 2’s terminal. Though the block is not safe to reuse on stream 2, DeviceCachingAllocator will not choose that block for stream 2 anyway: `get_free_block` rejects blocks whose `stream != p.stream()`. So this case is unreachable. * Case 3 — Safe (strong rule holds) In this scenario, the terminal nodes of all streams are positioned after the block's free markers, satisfying the strong rule. This guarantees the block is safe for reuse by any stream in the capturing graph. However, since `DeviceCachingAllocator ` only reuses a block on its original allocation stream, verifying this strong condition is unnecessary. We only need to ensure the per-stream rule is met for the specific stream requesting the block. * Case 4 — Freeing after a join See the note below. ## Edge Case: Freeing after a join Our current dependency tracking has a limitation in scenarios where a block is freed after a stream join, see @galv's [comments here](https://github.com/pytorch/pytorch/pull/158352#pullrequestreview-3112565198)). In the case 4, we have a missed opportunity. Because the block's usage is not explicitly marked, we cannot determine that the block's actual last use may have occurred much earlier, long before the join. Then, we must wait for the subsequent join before the block can be reused. ## Thanks Thanks to @galv for his great idea around graph parsing and empty nodes. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158352 Approved by: https://github.com/ngimel, https://github.com/eqy Co-authored-by: Jeff Daily <jeff.daily@amd.com>