mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Previous work #158352 delivered CUDAGraph memory footprint reduction with no replay-time impact, but capture time regressed (up to 20× slower) due to repeated full-graph traversals. See previous benchmark results [here](https://github.com/pytorch/pytorch/pull/158352#issuecomment-3215947565) This PR removes capture/reply overhead while preserving the memory savings: 1. **Terminals as free markers** We stop inserting empty nodes and instead record the current stream terminals as free markers. This avoids mutating the user’s graph and keeps semantics unchanged. 2. **Incremental, cached reachability** We add a **per-graph reuse context** that caches reverse-traversal state: * `graph_reuse_context[graph].visited[stream]` tracks nodes already seen from that stream’s terminal frontier. * On each allocation during capture, we resume traversal from the latest terminals and only visit unseen nodes. * A block is freed when all its recorded markers are in the visited set of its allocation stream—i.e., all markers are proven predecessors of future work. See [the performance results here](https://docs.google.com/spreadsheets/d/e/2PACX-1vRPvdd9Xa8W87ixbiA0da_qvOhrUAjUpFz0G-_j-MsDnoeRyhEa4_ut_W3rqcg1VVZVFJ-gucwov-3b/pubhtml?gid=1468302443&single=true), we sweep synthetic multi-stream CUDA Graphs built by `capture_benchmark.py` (same as before, we generate random interleaving of alloc/free/join with given probabilities, see [gist here](https://gist.github.com/eee4017/e2092d215b1d4bd46534148939af39e3)), and we compare median capture/replay times and memory. On an NVIDIA H100 PCIe across 24 configs, the optimization preserves reserved memory reduction at ~24–98%, leaves allocated memory unchanged, and brings capture time back to baseline (range 0.96–1.04× vs. baseline) with replay time unchanged (range 0.97–1.11×). Pull Request resolved: https://github.com/pytorch/pytorch/pull/162186 Approved by: https://github.com/eqy, https://github.com/ngimel