Compare commits

...

1 Commits

Author SHA1 Message Date
a610f7b83e Fix CUDA graph memory leak when record_stream used 2025-06-10 23:56:25 -04:00
2 changed files with 33 additions and 0 deletions

View File

@ -2162,11 +2162,15 @@ class DeviceCachingAllocator {
// Called by CUDAGraph::capture_end
void endAllocateToPool(MempoolId_t mempool_id) {
auto context = maybeGatherContext(RecordContext::STATE);
std::lock_guard<std::recursive_mutex> lock(mutex);
for (auto it = captures_underway.begin(); it != captures_underway.end();
++it) {
if (it->first == mempool_id) {
captures_underway.erase(it);
if (captures_underway.empty()) {
process_events(context);
}
return;
}
}

View File

@ -3012,6 +3012,35 @@ exit(2)
# dummy allocation triggers process_events, Hopefully successfully processes b's end-of-life event.
torch.zeros((3,), device="cuda")
@unittest.skipIf(
not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
)
def test_graph_record_stream_no_leak(self):
torch.cuda.empty_cache()
torch.cuda.synchronize()
base_mem = torch.cuda.memory_allocated()
g = torch.cuda.CUDAGraph()
s_main = torch.cuda.Stream()
s_other = torch.cuda.Stream()
with torch.cuda.stream(s_main):
g.capture_begin()
x = torch.empty(64 * 1024 * 1024 // 4, device="cuda")
s_other.wait_stream(s_main)
with torch.cuda.stream(s_other):
y = x + 1
y.record_stream(s_other)
del y
s_main.wait_stream(s_other)
del x
g.capture_end()
torch.cuda.synchronize()
self.assertLessEqual(
torch.cuda.memory_allocated() - base_mem, 1 * 1024 * 1024
)
@skipIfRocm
@unittest.skipIf(
not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"