Revert "[CUDA] Reuse blocks with record_stream during CUDA Graph capture in the CUDACachingAllocator (#158352)"

This reverts commit 190c391a28845a14df26abb228d26aa813efb20c.

Reverted https://github.com/pytorch/pytorch/pull/158352 on behalf of https://github.com/atalman due to Broke cuda 13.0 nightly builds https://github.com/pytorch/pytorch/actions/runs/17382188549/job/49341981474 ([comment](https://github.com/pytorch/pytorch/pull/158352#issuecomment-3242871629))
This commit is contained in:
PyTorch MergeBot
2025-09-01 16:27:02 +00:00
parent fefee08164
commit 63a9c23fe9
9 changed files with 10 additions and 436 deletions

View File

@ -25,7 +25,6 @@ CUDAAllocatorConfig::CUDAAllocatorConfig()
#endif
m_release_lock_on_cudamalloc(false),
m_pinned_use_cuda_host_register(false),
m_graph_capture_record_stream_reuse(false),
m_pinned_use_background_threads(false) {
m_roundup_power2_divisions.assign(kRoundUpPowerOfTwoIntervals, 0);
}
@ -374,9 +373,6 @@ void CUDAAllocatorConfig::parseArgs(const std::optional<std::string>& env) {
} else if (config_item_view == "pinned_use_background_threads") {
i = parsePinnedUseBackgroundThreads(config, i);
used_native_specific_option = true;
} else if (config_item_view == "graph_capture_record_stream_reuse") {
i = parseGraphCaptureRecordStreamReuse(config, i);
used_native_specific_option = true;
} else {
TORCH_CHECK(
false, "Unrecognized CachingAllocator option: ", config_item_view);
@ -410,23 +406,6 @@ size_t CUDAAllocatorConfig::parsePinnedUseCudaHostRegister(
return i;
}
size_t CUDAAllocatorConfig::parseGraphCaptureRecordStreamReuse(
const std::vector<std::string>& config,
size_t i) {
consumeToken(config, ++i, ':');
if (++i < config.size()) {
TORCH_CHECK(
(config[i] == "True" || config[i] == "False"),
"Expected a single True/False argument for graph_capture_record_stream_reuse");
m_graph_capture_record_stream_reuse = (config[i] == "True");
} else {
TORCH_CHECK(
false, "Error, expecting graph_capture_record_stream_reuse value", "");
}
return i;
}
size_t CUDAAllocatorConfig::parsePinnedNumRegisterThreads(
const std::vector<std::string>& config,
size_t i) {

View File

@ -53,10 +53,6 @@ class C10_CUDA_API CUDAAllocatorConfig {
return instance().m_release_lock_on_cudamalloc;
}
static bool graph_capture_record_stream_reuse() {
return instance().m_graph_capture_record_stream_reuse;
}
/** Pinned memory allocator settings */
static bool pinned_use_cuda_host_register() {
return instance().m_pinned_use_cuda_host_register;
@ -146,9 +142,6 @@ class C10_CUDA_API CUDAAllocatorConfig {
size_t parsePinnedUseBackgroundThreads(
const std::vector<std::string>& config,
size_t i);
size_t parseGraphCaptureRecordStreamReuse(
const std::vector<std::string>& config,
size_t i);
std::atomic<size_t> m_max_split_size;
std::atomic<size_t> m_max_non_split_rounding_size;
@ -160,7 +153,6 @@ class C10_CUDA_API CUDAAllocatorConfig {
m_expandable_segments_handle_type;
std::atomic<bool> m_release_lock_on_cudamalloc;
std::atomic<bool> m_pinned_use_cuda_host_register;
std::atomic<bool> m_graph_capture_record_stream_reuse;
std::atomic<bool> m_pinned_use_background_threads;
std::string m_last_allocator_settings;
std::mutex m_last_allocator_settings_mutex;

View File

@ -1167,13 +1167,8 @@ class DeviceCachingAllocator {
// tracks which pools we can use as a last resort before ooming
ska::flat_hash_set<MempoolId_t, MempoolIdHash> use_on_oom_pools;
// Map of blocks whose freeing is deferred until after CUDA graph capture.
// - Key: Block* to be freed.
// - Value: List of "empty nodes" inserted as free markers during capture.
// If the vector is empty, the block must always be deferred until capture
// ends.
ska::flat_hash_map<Block*, std::vector<cudaGraphNode_t>> deferred_blocks;
// See free() for this thing's purpose
std::vector<Block*> needs_events_deferred_until_no_capture;
// outstanding cuda events
ska::flat_hash_map<
cuda::CUDAStream,
@ -1334,11 +1329,6 @@ class DeviceCachingAllocator {
// capture. Cross-stream memory use is uncommon, so the deferral's
// effect on memory use during capture should be small.
process_events(context);
} else {
if (CUDAAllocatorConfig::graph_capture_record_stream_reuse()) {
// We check if there is some block that is safe to reuse on this stream
free_safe_blocks_in_capture(context, stream);
}
}
size_t size = round_size(orig_size);
auto& pool = get_pool(size, stream);
@ -1629,220 +1619,6 @@ class DeviceCachingAllocator {
return block;
}
// Insert "free marker" (empty nodes) into the CUDA graph for all streams that
// have used the block, including the allocation stream. These nodes mark the
// last use of the block in the capture graph. Returns a vector of the
// inserted nodes, or an empty vector if any stream is not capturing.
std::vector<cudaGraphNode_t> insert_free_marker(Block* block) {
std::vector<cudaGraphNode_t> empty_nodes;
auto try_add_empty_node = [&](cudaStream_t stream) -> bool {
cudaStreamCaptureStatus status{};
cudaGraph_t graph{};
const cudaGraphNode_t* deps = nullptr;
size_t num_deps = 0;
C10_CUDA_CHECK(cudaStreamGetCaptureInfo_v2(
stream, &status, nullptr, &graph, &deps, &num_deps));
TORCH_INTERNAL_ASSERT(
status != cudaStreamCaptureStatusInvalidated,
"Invalid stream capture status");
if (status == cudaStreamCaptureStatusNone) {
return false;
}
cudaGraphNode_t node{};
C10_CUDA_CHECK(cudaGraphAddEmptyNode(&node, graph, deps, num_deps));
C10_CUDA_CHECK(cudaStreamUpdateCaptureDependencies(
stream, &node, 1, cudaStreamSetCaptureDependencies));
empty_nodes.push_back(node);
return true;
};
// If any stream is not currently capturing, return an empty node vector.
// An empty vector indicates that the block should be deferred for freeing
// until after capture.
// Attempt to add an empty node for the allocation stream.
if (!try_add_empty_node(block->stream)) {
return {};
}
// Attempt to add empty nodes for all streams that have used the block.
for (const auto& s : block->stream_uses) {
if (!try_add_empty_node(s.stream())) {
return {};
}
}
return empty_nodes;
}
// Returns the current set of "terminal" nodes in the CUDA graph for a given
// stream. These represent the current endpoints of the stream, and may
// include additional nodes if the graph branches. Any new work captured will
// be attached after one or more of these terminals.
std::vector<cudaGraphNode_t> get_terminals(cudaStream_t stream) {
std::vector<cudaGraphNode_t> result;
cudaStreamCaptureStatus status{};
cudaGraph_t graph{};
const cudaGraphNode_t* dependencies = nullptr;
size_t num_dependencies = 0;
C10_CUDA_CHECK(cudaStreamGetCaptureInfo_v2(
stream,
&status,
/*id=*/nullptr,
&graph,
&dependencies,
&num_dependencies));
TORCH_INTERNAL_ASSERT(
status == cudaStreamCaptureStatusActive,
"Invalid stream capture status");
for (size_t i = 0; i < num_dependencies; i++) {
auto node = dependencies[i];
if (node != nullptr) {
result.push_back(node);
}
}
return result;
}
// Returns the set of "reusable" free markers (empty nodes) in the current
// CUDA graph capture. A free marker is considered reusable if it is a
// predecessor of every terminal node.
// This ensures that all future captured work will occur after the free
// marker, making it safe to reuse.
ska::flat_hash_set<cudaGraphNode_t> get_reusable_empty_nodes(
cudaStream_t stream) {
auto terminals = get_terminals(stream);
if (terminals.empty()) {
// No terminal nodes found; nothing to free.
return {};
}
// Helper to retrieve all parent nodes (dependencies) of a given node.
auto get_parents = [](cudaGraphNode_t n) -> std::vector<cudaGraphNode_t> {
size_t count = 0;
C10_CUDA_CHECK(
cudaGraphNodeGetDependencies(n, /*pDependencies=*/nullptr, &count));
std::vector<cudaGraphNode_t> out(count);
if (count) {
C10_CUDA_CHECK(cudaGraphNodeGetDependencies(n, out.data(), &count));
out.resize(count);
}
return out;
};
// Helper to determine if a node is an empty node (used as a free marker).
auto is_empty_node = [](cudaGraphNode_t n) -> bool {
cudaGraphNodeType type{};
C10_CUDA_CHECK(cudaGraphNodeGetType(n, &type));
return type == cudaGraphNodeTypeEmpty;
};
// For each terminal node, perform a reverse DFS to count, for each empty
// node, how many terminals it can reach (i.e., for how many terminals it is
// a predecessor). An empty node is reusable if it is a predecessor of all
// terminal nodes.
ska::flat_hash_map<cudaGraphNode_t, size_t> num_terminals_reachable;
for (auto terminal : terminals) {
ska::flat_hash_set<cudaGraphNode_t> visited;
ska::flat_hash_set<cudaGraphNode_t> empty_nodes;
std::function<void(cudaGraphNode_t)> reverse_dfs =
[&](cudaGraphNode_t node) {
if (!visited.insert(node).second)
return;
if (is_empty_node(node)) {
num_terminals_reachable[node]++;
empty_nodes.insert(node);
}
auto parents = get_parents(node);
for (auto p : parents) {
reverse_dfs(p);
}
};
reverse_dfs(terminal);
}
ska::flat_hash_set<cudaGraphNode_t> reusable_empty_nodes;
for (auto [node, count] : num_terminals_reachable) {
if (count == terminals.size()) {
reusable_empty_nodes.insert(node);
}
}
return reusable_empty_nodes;
}
// A block is considered reusable during CUDA graph capture if every free
// marker (empty node) associated with the block is a predecessor of every
// terminal node.
//
// This ensures that any new operation added to the graph will be attached
// after all terminal nodes, which themselves are after all free markers. As a
// result, all future work is guaranteed to occur after the block's last use
// on every stream, so the block's previous lifetime ends before any new
// lifetime begins. This check relies solely on the DAG topology and does not
// require event queries, making it safe to use during capture.
//
// This function iterates over all deferred blocks, determines if their empty
// nodes are reusable according to the above criteria, and frees the block if
// so.
void free_safe_blocks_in_capture(
const std::shared_ptr<GatheredContext>& context,
cudaStream_t stream) {
auto reusable_empty_nodes = get_reusable_empty_nodes(stream);
// If there are no reusable empty nodes (e.g., not currently capturing),
// there is nothing to do.
if (reusable_empty_nodes.empty()) {
return;
}
std::vector<Block*> blocks_to_erase;
for (auto& [block, inserted_empty_nodes] : deferred_blocks) {
// Skip this block if it has no empty nodes, as we defer its freeing until
// after graph capture. Also skip if the block was not allocated on the
// current stream; such blocks will be freed when
// free_safe_blocks_in_capture is attempted on that stream.
if (inserted_empty_nodes.empty() || block->stream != stream) {
continue;
}
bool is_reusable = true;
for (const auto& node : inserted_empty_nodes) {
if (reusable_empty_nodes.find(node) == reusable_empty_nodes.end()) {
is_reusable = false;
break;
}
}
if (is_reusable) {
// Clear stream uses since the graph ensures proper synchronization.
// No need to insert events.
block->stream_uses.clear();
free_block(block, context);
blocks_to_erase.push_back(block);
}
}
// Remove blocks that were freed from the deferred_blocks map.
for (auto* block : blocks_to_erase) {
deferred_blocks.erase(block);
}
}
void free(Block* block) {
std::shared_ptr<GatheredContext> context =
maybeGatherContext(RecordContext::ALL);
@ -1878,22 +1654,14 @@ class DeviceCachingAllocator {
if (block->size >= CUDAAllocatorConfig::max_split_size())
stats.oversize_allocations.decrease(1);
// If the block has been used on more than one stream, handle accordingly.
if (!block->stream_uses.empty()) {
if (C10_UNLIKELY(!captures_underway.empty())) {
if (CUDAAllocatorConfig::graph_capture_record_stream_reuse()) {
// insert_free_marker returns a vector of free markers,
// or an empty vector if any associated stream is not currently
// capturing. The empty vector means that we will defer the free until
// capture is finished.
deferred_blocks.emplace(block, insert_free_marker(block));
} else {
// If graph_capture_record_stream_reuse is not enabled, always defer
// the free until capture is finished.
deferred_blocks.emplace(block, std::vector<cudaGraphNode_t>{});
}
// It's forbidden to cudaEventQuery an event recorded during CUDA graph
// capture. We conservatively defer recording end-of-life events until
// the next call to process_events() (which won't happen until no
// captures are underway)
needs_events_deferred_until_no_capture.push_back(block);
} else {
// If not in a capture, insert events for the block.
insert_events(block);
}
} else {
@ -3519,8 +3287,8 @@ class DeviceCachingAllocator {
void insert_events_deferred_until_no_capture(
const std::shared_ptr<GatheredContext>& context) {
if (C10_UNLIKELY(!deferred_blocks.empty())) {
for (auto& [block, inserted_empty_nodes] : deferred_blocks) {
if (C10_UNLIKELY(!needs_events_deferred_until_no_capture.empty())) {
for (auto* block : needs_events_deferred_until_no_capture) {
TORCH_INTERNAL_ASSERT(!block->stream_uses.empty());
// only streams recorded before cudagraph will be used to insert events
// since we know all streams recorded during cudagraph must have
@ -3532,7 +3300,7 @@ class DeviceCachingAllocator {
free_block(block, context);
}
}
deferred_blocks.clear();
needs_events_deferred_until_no_capture.clear();
}
}
@ -3963,8 +3731,6 @@ class NativeCachingAllocator : public CUDAAllocator {
md.pinned_use_host_register =
CUDAAllocatorConfig::pinned_use_cuda_host_register();
md.last_allocator_settings = CUDAAllocatorConfig::last_allocator_settings();
md.graph_capture_record_stream_reuse =
CUDAAllocatorConfig::graph_capture_record_stream_reuse();
md.roundup_power2_divisions =
CUDAAllocatorConfig::roundup_power2_divisions();

View File

@ -163,7 +163,6 @@ struct AllocatorConfigInfo {
bool expandable_segments;
bool release_lock_on_malloc;
bool pinned_use_host_register;
bool graph_capture_record_stream_reuse;
std::string last_allocator_settings;
std::vector<size_t> roundup_power2_divisions;
};

View File

@ -608,14 +608,6 @@ Available options:
for processing events. This avoids any slow path associated with querying/processing of
events in the fast allocation path. This feature is disabled by default.
* ``graph_capture_record_stream_reuse`` (experimental, default: `False`)
If set to `True`, the CUDA caching allocator will attempt to reclaim device memory during
CUDA Graph capture by using the graph topology (instead of CUDA events) to determine
when a freed block is safe to reuse. This can reduce peak memory during long captures that free
and reallocate buffers across multiple streams, especially when the capture DAG frequently
reaches joined frontiers. Note: Enabling this option can significantly increase the time spent
capturing the graph.
.. note::
Some stats reported by the

View File

@ -5613,149 +5613,6 @@ class TestMemPool(TestCase):
s = p.snapshot()
self.assertEqual(len(s), 1, "Expected to have a single segment")
@unittest.skipIf(
not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
)
def test_graph_capture_reclaim_2_streams(self):
torch.cuda.memory._set_allocator_settings(
"graph_capture_record_stream_reuse:True"
)
torch.cuda.empty_cache()
s1, s2 = torch.cuda.Stream(), torch.cuda.Stream()
g = torch.cuda.CUDAGraph(keep_graph=True)
torch.cuda.synchronize()
with torch.cuda.stream(s1):
g.capture_begin()
# A sink node allocated up-front so it doesn't steal data1's block later.
sink1 = torch.empty(8, device="cuda")
# Source tensor on s1; this block is the reuse candidate.
data1 = torch.empty(8, device="cuda")
data1_ptr = data1.data_ptr()
# Fork: do real work on s2 that READS data1 and writes to its own buffer.
s2.wait_stream(s1)
with torch.cuda.stream(s2):
buf2 = torch.empty_like(data1)
torch.add(data1, 2.0, out=buf2)
data1.record_stream(s2)
del data1
# BEFORE JOIN: must NOT reuse
data2 = torch.empty(8, device="cuda")
data2_ptr = data2.data_ptr()
# Join s2 -> s1 and add a sink node on s1.
s1.wait_stream(s2)
sink1.fill_(1.0)
# AFTER JOIN: now reuse is allowed
data3 = torch.empty(8, device="cuda")
data3_ptr = data3.data_ptr()
g.capture_end()
torch.cuda.synchronize()
# No reuse before join; reuse after join.
self.assertNotEqual(data1_ptr, data2_ptr)
self.assertEqual(data1_ptr, data3_ptr)
torch.cuda.memory._set_allocator_settings(
"graph_capture_record_stream_reuse:False"
)
@unittest.skipIf(
not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
)
def test_graph_capture_reclaim_4_streams(self):
torch.cuda.memory._set_allocator_settings(
"graph_capture_record_stream_reuse:True"
)
torch.cuda.empty_cache()
s1, s2, s3, s4 = (
torch.cuda.Stream(),
torch.cuda.Stream(),
torch.cuda.Stream(),
torch.cuda.Stream(),
)
g = torch.cuda.CUDAGraph(keep_graph=True)
torch.cuda.synchronize()
with torch.cuda.stream(s1):
g.capture_begin()
# Source tensor allocated on s1. This block is the candidate for reuse.
data1 = torch.ones(8, device="cuda")
data1_ptr = data1.data_ptr()
sink1 = torch.empty_like(data1)
sink3 = torch.empty_like(data1)
s2.wait_stream(s1)
with torch.cuda.stream(s2):
buf2 = torch.empty_like(data1)
torch.add(data1, 2.0, out=buf2)
data1.record_stream(s2)
s3.wait_stream(s1)
with torch.cuda.stream(s3):
buf3 = torch.empty_like(data1)
torch.add(data1, 3.0, out=buf3)
data1.record_stream(s3)
s4.wait_stream(s1)
with torch.cuda.stream(s4):
buf4 = torch.empty_like(data1)
torch.add(data1, 4.0, out=buf4)
data1.record_stream(s4)
# Free data1 inside capture; allocator may reuse later when it's safe.
del data1
# PARTIAL JOINS: should NOT allow reuse yet
# Join s2 -> s1 and add a sink node on s1.
s1.wait_stream(s2)
sink1.fill_(1.0)
# Join s4 -> s3 and add a sink node on s3.
s3.wait_stream(s4)
with torch.cuda.stream(s3):
sink3.fill_(3.0)
sink3.record_stream(s3)
# At this point, s1 and s3 subgraphs are NOT yet joined together.
# Allocating data2 here must NOT reuse data1's block.
data2 = torch.empty(8, device="cuda")
data2_ptr = data2.data_ptr()
# FINAL JOIN: now reuse is allowed
# Join s3 -> s1 and add a sink node on s1.
s1.wait_stream(s3)
sink1.add_(sink3)
# Now allocator should safely reuse data1's block.
data3 = torch.empty(8, device="cuda")
data3_ptr = data3.data_ptr()
g.capture_end()
torch.cuda.synchronize()
# No reuse before full join; reuse after full join.
self.assertNotEqual(data1_ptr, data2_ptr)
self.assertEqual(data1_ptr, data3_ptr)
torch.cuda.memory._set_allocator_settings(
"graph_capture_record_stream_reuse:False"
)
@skipIfRocm(msg="expandable_segments mode is not supported on ROCm")
@unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Load_inline doesn't work in fbcode")
def test_mempool_expandable(self):

View File

@ -908,8 +908,6 @@ PyObject* THCPModule_memorySnapshot(PyObject* _unused, PyObject* arg) {
py::str release_lock_on_malloc_s = "release_lock_on_cudamalloc";
py::str pinned_use_host_register_s = "pinned_use_cuda_host_register";
py::str roundup_power2_divisions_s = "roundup_power2_divisions";
py::str graph_capture_record_stream_reuse_s =
"graph_capture_record_stream_reuse";
allocator_settings[last_allocator_settings_s] =
snapshot.config_metadata.last_allocator_settings;
@ -925,8 +923,6 @@ PyObject* THCPModule_memorySnapshot(PyObject* _unused, PyObject* arg) {
snapshot.config_metadata.release_lock_on_malloc;
allocator_settings[pinned_use_host_register_s] =
snapshot.config_metadata.pinned_use_host_register;
allocator_settings[graph_capture_record_stream_reuse_s] =
snapshot.config_metadata.graph_capture_record_stream_reuse;
unsigned int roundup_key = 1;
py::dict roundup_settings;
for (const auto& v : snapshot.config_metadata.roundup_power2_divisions) {

View File

@ -458,8 +458,6 @@ std::string _memory_snapshot_pickled() {
IValue release_lock_on_malloc_s = "release_lock_on_cudamalloc";
IValue pinned_use_host_register_s = "pinned_use_cuda_host_register";
IValue roundup_power2_divisions_s = "roundup_power2_divisions";
IValue graph_capture_record_stream_reuse_s =
"graph_capture_record_stream_reuse";
allocator_settings.insert(
last_allocator_settings_s,
@ -480,9 +478,6 @@ std::string _memory_snapshot_pickled() {
allocator_settings.insert(
pinned_use_host_register_s,
snapshot.config_metadata.pinned_use_host_register);
allocator_settings.insert(
graph_capture_record_stream_reuse_s,
snapshot.config_metadata.graph_capture_record_stream_reuse);
unsigned int roundup_key = 1;
auto roundup_settings = new_dict();
for (const auto& v : snapshot.config_metadata.roundup_power2_divisions) {

View File

@ -4332,8 +4332,6 @@ CUDA_IDENTIFIER_MAP = collections.OrderedDict(
("cudaStreamCaptureModeThreadLocal", ("hipStreamCaptureModeThreadLocal", CONV_TYPE, API_RUNTIME)),
("cudaStreamBeginCapture", ("hipStreamBeginCapture", CONV_TYPE, API_RUNTIME)),
("cudaStreamEndCapture", ("hipStreamEndCapture", CONV_TYPE, API_RUNTIME)),
("cudaStreamSetCaptureDependencies", ("hipStreamSetCaptureDependencies", CONV_STREAM, API_RUNTIME)),
("cudaStreamUpdateCaptureDependencies", ("hipStreamUpdateCaptureDependencies", CONV_STREAM, API_RUNTIME)),
("cudaGraphInstantiate", ("hipGraphInstantiate", CONV_TYPE, API_RUNTIME)),
("cudaGraphInstantiateWithFlags", ("hipGraphInstantiateWithFlags", CONV_TYPE, API_RUNTIME)),
(