mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[CUDA] Reuse blocks with record_stream during CUDA Graph capture in the CUDACachingAllocator (#158352)
## Introduction During CUDA Graph capture, the CUDA caching allocator currently defers reclaiming blocks until capture ends. This is because CUDA forbids querying events recorded during capture (the CUDA operation is not executed during the capture stage), so the allocator cannot use its normal event-based logic. However, capture records an DAG (we call it **capturing graph**) of work. We can use the capturing graph to determine when a block’s old lifetime is fully before future work, and safely reuse it within the same capture. This PR adds an experimental flag `graph_capture_record_stream_reuse: True|False (default: False)`. When enabled, the allocator inserts lightweight free markers and uses capture ordering to decide if a freed block is safe to reuse during capture. If the proof cannot be established, we fall back to the existing post-capture path. ## Terms * **Free marker**: A capture-legal no-op (created with `cudaGraphAddEmptyNode`) inserted after the last captured use of the block on each stream that used it. * **Terminal**: The set of the lastest operations of the stream (or the capturing graph). Any newly captured op on that stream will attach after all nodes in this set. For a stream currently capturing, it is the set of nodes returned in `dependencies_out` by `cudaStreamGetCaptureInfo`. ## When can we reuse a block during capture? ### Strong Rule (Graph-Wide Safety) This rule provides a universal guarantee that a block is safe for reuse by any stream in the graph. > A block is safe to reuse if every free marker is a predecessor of every terminal of all active streams in the graph. Why it's safe: This rule establishes a strict global ordering. Since any new operation on any stream must be appended after that stream's terminals, this condition guarantees that the block's new lifetime begins only after its old lifetime has completely ended everywhere. This prevents lifetime overlaps when the graph is replayed, ensuring correctness. ### Per-stream Rule (A Practical Optimization) The strong rule, while safe, is often unnecessarily restrictive. The `DeviceCachingAllocator` introduces a crucial constraint that allows for a simpler check. In `DeviceCachingAllocator`, `get_free_block` only returns blocks whose `block->stream == p.stream()`. In other words, we never reuse a block on a stream different from the allocation stream. This means we don't need to verify safety across the entire graph. We only need to confirm that the block is safe to reuse from the perspective of its own allocation stream. > Reuse a block for allocations on stream S if every free marker is a predecessor of every node in the terminal set of S. In short, a block is considered **reusable** on stream S as long as all marker marking it "free" are guaranteed to complete before any new work that might need it on stream S begins. ## Implementation * On `free(block)` during capture * For each stream in `block->stream_uses` and the allocation stream, insert a free marker (empty node) and make it that stream’s tail. * If we cannot place markers for all such streams (for example, a stream is not in capture), defer to the post-capture path. * Otherwise, store the marker handles and keep the block in the capture-private structures. * On `allocate(stream)` during capture (attempt per-stream reclaim) * Query the allocation stream S’s terminal via `cudaStreamGetCaptureInfo`. * For each deferred block, check whether it is allocated on this stream, and each of its free markers is a predecessor of the terminal. * If yes, hand the block to S for immediate reuse within the same capture. * If no, keep it deferred; it will be reconsidered as capture progresses and S’s terminal advances. * On capture end * Any still-deferred blocks follow the existing post-capture reclamation (event insertion/polling). External behavior remains unchanged if we cannot prove safety during capture. ## Examples (2 streams) <img width="641" height="801" alt="pytorch-remove-cudagraph-defer-reclaiming (6)" src="https://github.com/user-attachments/assets/41adc835-d448-483b-99ba-b4341cb7d2a2" /> * Case 0 — Unsafe The two frees are not ordered with respect to each other. For stream 1, the other stream’s free marker does not precede this stream’s terminal, so the per-stream condition fails. Counterexample intuition for the unsafe setups: imagine `f2(x)` runs for a long time. If DeviceCachingAllocator reused block `x` on a stream whose terminal is not ordered after the free markers, the new lifetime could overlap the old one on replay, risking use-after-free or data corruption. The per-stream rule prevents exactly this. * Case 1 — Reusable on stream 1 Stream 1’s terminal is after both frees, so every free marker precedes stream 1’s terminal. The block is reusable for allocations on stream 1. * Case 2 — Not reusable on stream 2, but this cannot occur in `DeviceCachingAllocator` This depicts reusing the block on stream 2 while stream 1’s free is not yet ordered before stream 2’s terminal. Though the block is not safe to reuse on stream 2, DeviceCachingAllocator will not choose that block for stream 2 anyway: `get_free_block` rejects blocks whose `stream != p.stream()`. So this case is unreachable. * Case 3 — Safe (strong rule holds) In this scenario, the terminal nodes of all streams are positioned after the block's free markers, satisfying the strong rule. This guarantees the block is safe for reuse by any stream in the capturing graph. However, since `DeviceCachingAllocator ` only reuses a block on its original allocation stream, verifying this strong condition is unnecessary. We only need to ensure the per-stream rule is met for the specific stream requesting the block. * Case 4 — Freeing after a join See the note below. ## Edge Case: Freeing after a join Our current dependency tracking has a limitation in scenarios where a block is freed after a stream join, see @galv's [comments here](https://github.com/pytorch/pytorch/pull/158352#pullrequestreview-3112565198)). In the case 4, we have a missed opportunity. Because the block's usage is not explicitly marked, we cannot determine that the block's actual last use may have occurred much earlier, long before the join. Then, we must wait for the subsequent join before the block can be reused. ## Thanks Thanks to @galv for his great idea around graph parsing and empty nodes. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158352 Approved by: https://github.com/ngimel, https://github.com/eqy Co-authored-by: Jeff Daily <jeff.daily@amd.com>
This commit is contained in:
committed by
PyTorch MergeBot
parent
f36f285953
commit
0c0e056a9e
@ -25,6 +25,7 @@ CUDAAllocatorConfig::CUDAAllocatorConfig()
|
||||
#endif
|
||||
m_release_lock_on_cudamalloc(false),
|
||||
m_pinned_use_cuda_host_register(false),
|
||||
m_graph_capture_record_stream_reuse(false),
|
||||
m_pinned_use_background_threads(false) {
|
||||
m_roundup_power2_divisions.assign(kRoundUpPowerOfTwoIntervals, 0);
|
||||
}
|
||||
@ -373,6 +374,9 @@ void CUDAAllocatorConfig::parseArgs(const std::optional<std::string>& env) {
|
||||
} else if (config_item_view == "pinned_use_background_threads") {
|
||||
i = parsePinnedUseBackgroundThreads(config, i);
|
||||
used_native_specific_option = true;
|
||||
} else if (config_item_view == "graph_capture_record_stream_reuse") {
|
||||
i = parseGraphCaptureRecordStreamReuse(config, i);
|
||||
used_native_specific_option = true;
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
false, "Unrecognized CachingAllocator option: ", config_item_view);
|
||||
@ -406,6 +410,23 @@ size_t CUDAAllocatorConfig::parsePinnedUseCudaHostRegister(
|
||||
return i;
|
||||
}
|
||||
|
||||
size_t CUDAAllocatorConfig::parseGraphCaptureRecordStreamReuse(
|
||||
const std::vector<std::string>& config,
|
||||
size_t i) {
|
||||
consumeToken(config, ++i, ':');
|
||||
if (++i < config.size()) {
|
||||
TORCH_CHECK(
|
||||
(config[i] == "True" || config[i] == "False"),
|
||||
"Expected a single True/False argument for graph_capture_record_stream_reuse");
|
||||
m_graph_capture_record_stream_reuse = (config[i] == "True");
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
false, "Error, expecting graph_capture_record_stream_reuse value", "");
|
||||
}
|
||||
|
||||
return i;
|
||||
}
|
||||
|
||||
size_t CUDAAllocatorConfig::parsePinnedNumRegisterThreads(
|
||||
const std::vector<std::string>& config,
|
||||
size_t i) {
|
||||
|
@ -53,6 +53,10 @@ class C10_CUDA_API CUDAAllocatorConfig {
|
||||
return instance().m_release_lock_on_cudamalloc;
|
||||
}
|
||||
|
||||
static bool graph_capture_record_stream_reuse() {
|
||||
return instance().m_graph_capture_record_stream_reuse;
|
||||
}
|
||||
|
||||
/** Pinned memory allocator settings */
|
||||
static bool pinned_use_cuda_host_register() {
|
||||
return instance().m_pinned_use_cuda_host_register;
|
||||
@ -142,6 +146,9 @@ class C10_CUDA_API CUDAAllocatorConfig {
|
||||
size_t parsePinnedUseBackgroundThreads(
|
||||
const std::vector<std::string>& config,
|
||||
size_t i);
|
||||
size_t parseGraphCaptureRecordStreamReuse(
|
||||
const std::vector<std::string>& config,
|
||||
size_t i);
|
||||
|
||||
std::atomic<size_t> m_max_split_size;
|
||||
std::atomic<size_t> m_max_non_split_rounding_size;
|
||||
@ -153,6 +160,7 @@ class C10_CUDA_API CUDAAllocatorConfig {
|
||||
m_expandable_segments_handle_type;
|
||||
std::atomic<bool> m_release_lock_on_cudamalloc;
|
||||
std::atomic<bool> m_pinned_use_cuda_host_register;
|
||||
std::atomic<bool> m_graph_capture_record_stream_reuse;
|
||||
std::atomic<bool> m_pinned_use_background_threads;
|
||||
std::string m_last_allocator_settings;
|
||||
std::mutex m_last_allocator_settings_mutex;
|
||||
|
@ -1167,8 +1167,13 @@ class DeviceCachingAllocator {
|
||||
// tracks which pools we can use as a last resort before ooming
|
||||
ska::flat_hash_set<MempoolId_t, MempoolIdHash> use_on_oom_pools;
|
||||
|
||||
// See free() for this thing's purpose
|
||||
std::vector<Block*> needs_events_deferred_until_no_capture;
|
||||
// Map of blocks whose freeing is deferred until after CUDA graph capture.
|
||||
// - Key: Block* to be freed.
|
||||
// - Value: List of "empty nodes" inserted as free markers during capture.
|
||||
// If the vector is empty, the block must always be deferred until capture
|
||||
// ends.
|
||||
ska::flat_hash_map<Block*, std::vector<cudaGraphNode_t>> deferred_blocks;
|
||||
|
||||
// outstanding cuda events
|
||||
ska::flat_hash_map<
|
||||
cuda::CUDAStream,
|
||||
@ -1329,6 +1334,11 @@ class DeviceCachingAllocator {
|
||||
// capture. Cross-stream memory use is uncommon, so the deferral's
|
||||
// effect on memory use during capture should be small.
|
||||
process_events(context);
|
||||
} else {
|
||||
if (CUDAAllocatorConfig::graph_capture_record_stream_reuse()) {
|
||||
// We check if there is some block that is safe to reuse on this stream
|
||||
free_safe_blocks_in_capture(context, stream);
|
||||
}
|
||||
}
|
||||
size_t size = round_size(orig_size);
|
||||
auto& pool = get_pool(size, stream);
|
||||
@ -1619,6 +1629,248 @@ class DeviceCachingAllocator {
|
||||
return block;
|
||||
}
|
||||
|
||||
// Insert "free marker" (empty nodes) into the CUDA graph for all streams that
|
||||
// have used the block, including the allocation stream. These nodes mark the
|
||||
// last use of the block in the capture graph. Returns a vector of the
|
||||
// inserted nodes, or an empty vector if any stream is not capturing.
|
||||
std::vector<cudaGraphNode_t> insert_free_marker(Block* block) {
|
||||
std::vector<cudaGraphNode_t> empty_nodes;
|
||||
|
||||
auto try_add_empty_node = [&](cudaStream_t stream) -> bool {
|
||||
cudaStreamCaptureStatus status{};
|
||||
cudaGraph_t graph{};
|
||||
const cudaGraphNode_t* deps = nullptr;
|
||||
size_t num_deps = 0;
|
||||
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 13000)
|
||||
C10_CUDA_CHECK(cudaStreamGetCaptureInfo(
|
||||
stream, &status, nullptr, &graph, &deps, nullptr, &num_deps));
|
||||
#else
|
||||
C10_CUDA_CHECK(cudaStreamGetCaptureInfo_v2(
|
||||
stream, &status, nullptr, &graph, &deps, &num_deps));
|
||||
#endif
|
||||
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
status != cudaStreamCaptureStatusInvalidated,
|
||||
"Invalid stream capture status");
|
||||
|
||||
if (status == cudaStreamCaptureStatusNone) {
|
||||
return false;
|
||||
}
|
||||
|
||||
cudaGraphNode_t node{};
|
||||
C10_CUDA_CHECK(cudaGraphAddEmptyNode(&node, graph, deps, num_deps));
|
||||
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 13000)
|
||||
C10_CUDA_CHECK(cudaStreamUpdateCaptureDependencies(
|
||||
stream, &node, nullptr, 1, cudaStreamSetCaptureDependencies));
|
||||
#else
|
||||
C10_CUDA_CHECK(cudaStreamUpdateCaptureDependencies(
|
||||
stream, &node, 1, cudaStreamSetCaptureDependencies));
|
||||
#endif
|
||||
empty_nodes.push_back(node);
|
||||
return true;
|
||||
};
|
||||
|
||||
// If any stream is not currently capturing, return an empty node vector.
|
||||
// An empty vector indicates that the block should be deferred for freeing
|
||||
// until after capture.
|
||||
|
||||
// Attempt to add an empty node for the allocation stream.
|
||||
if (!try_add_empty_node(block->stream)) {
|
||||
return {};
|
||||
}
|
||||
// Attempt to add empty nodes for all streams that have used the block.
|
||||
for (const auto& s : block->stream_uses) {
|
||||
if (!try_add_empty_node(s.stream())) {
|
||||
return {};
|
||||
}
|
||||
}
|
||||
return empty_nodes;
|
||||
}
|
||||
|
||||
// Returns the current set of "terminal" nodes in the CUDA graph for a given
|
||||
// stream. These represent the current endpoints of the stream, and may
|
||||
// include additional nodes if the graph branches. Any new work captured will
|
||||
// be attached after one or more of these terminals.
|
||||
std::vector<cudaGraphNode_t> get_terminals(cudaStream_t stream) {
|
||||
std::vector<cudaGraphNode_t> result;
|
||||
|
||||
cudaStreamCaptureStatus status{};
|
||||
cudaGraph_t graph{};
|
||||
const cudaGraphNode_t* dependencies = nullptr;
|
||||
size_t num_dependencies = 0;
|
||||
|
||||
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 13000)
|
||||
C10_CUDA_CHECK(cudaStreamGetCaptureInfo(
|
||||
stream,
|
||||
&status,
|
||||
nullptr,
|
||||
&graph,
|
||||
&dependencies,
|
||||
nullptr,
|
||||
&num_dependencies));
|
||||
#else
|
||||
C10_CUDA_CHECK(cudaStreamGetCaptureInfo_v2(
|
||||
stream, &status, nullptr, &graph, &dependencies, &num_dependencies));
|
||||
#endif
|
||||
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
status == cudaStreamCaptureStatusActive,
|
||||
"Invalid stream capture status");
|
||||
|
||||
for (size_t i = 0; i < num_dependencies; i++) {
|
||||
auto node = dependencies[i];
|
||||
if (node != nullptr) {
|
||||
result.push_back(node);
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// Returns the set of "reusable" free markers (empty nodes) in the current
|
||||
// CUDA graph capture. A free marker is considered reusable if it is a
|
||||
// predecessor of every terminal node.
|
||||
// This ensures that all future captured work will occur after the free
|
||||
// marker, making it safe to reuse.
|
||||
ska::flat_hash_set<cudaGraphNode_t> get_reusable_empty_nodes(
|
||||
cudaStream_t stream) {
|
||||
auto terminals = get_terminals(stream);
|
||||
if (terminals.empty()) {
|
||||
// No terminal nodes found; nothing to free.
|
||||
return {};
|
||||
}
|
||||
|
||||
auto get_dependencies = [](cudaGraphNode_t node,
|
||||
cudaGraphNode_t* pDependencies,
|
||||
size_t* pNumDependencies) -> void {
|
||||
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 13000)
|
||||
C10_CUDA_CHECK(cudaGraphNodeGetDependencies(
|
||||
node, pDependencies, nullptr, pNumDependencies));
|
||||
#else
|
||||
C10_CUDA_CHECK(
|
||||
cudaGraphNodeGetDependencies(node, pDependencies, pNumDependencies));
|
||||
#endif
|
||||
};
|
||||
|
||||
// Helper to retrieve all parent nodes (dependencies) of a given node.
|
||||
auto get_parents =
|
||||
[&](cudaGraphNode_t node) -> std::vector<cudaGraphNode_t> {
|
||||
size_t count = 0;
|
||||
get_dependencies(node, nullptr, &count);
|
||||
std::vector<cudaGraphNode_t> out(count);
|
||||
if (count) {
|
||||
get_dependencies(node, out.data(), &count);
|
||||
out.resize(count);
|
||||
}
|
||||
return out;
|
||||
};
|
||||
|
||||
// Helper to determine if a node is an empty node (used as a free marker).
|
||||
auto is_empty_node = [](cudaGraphNode_t n) -> bool {
|
||||
cudaGraphNodeType type{};
|
||||
C10_CUDA_CHECK(cudaGraphNodeGetType(n, &type));
|
||||
return type == cudaGraphNodeTypeEmpty;
|
||||
};
|
||||
|
||||
// For each terminal node, perform a reverse DFS to count, for each empty
|
||||
// node, how many terminals it can reach (i.e., for how many terminals it is
|
||||
// a predecessor). An empty node is reusable if it is a predecessor of all
|
||||
// terminal nodes.
|
||||
ska::flat_hash_map<cudaGraphNode_t, size_t> num_terminals_reachable;
|
||||
|
||||
for (auto terminal : terminals) {
|
||||
ska::flat_hash_set<cudaGraphNode_t> visited;
|
||||
ska::flat_hash_set<cudaGraphNode_t> empty_nodes;
|
||||
|
||||
std::function<void(cudaGraphNode_t)> reverse_dfs =
|
||||
[&](cudaGraphNode_t node) {
|
||||
if (!visited.insert(node).second)
|
||||
return;
|
||||
|
||||
if (is_empty_node(node)) {
|
||||
num_terminals_reachable[node]++;
|
||||
empty_nodes.insert(node);
|
||||
}
|
||||
auto parents = get_parents(node);
|
||||
for (auto p : parents) {
|
||||
reverse_dfs(p);
|
||||
}
|
||||
};
|
||||
|
||||
reverse_dfs(terminal);
|
||||
}
|
||||
|
||||
ska::flat_hash_set<cudaGraphNode_t> reusable_empty_nodes;
|
||||
for (auto [node, count] : num_terminals_reachable) {
|
||||
if (count == terminals.size()) {
|
||||
reusable_empty_nodes.insert(node);
|
||||
}
|
||||
}
|
||||
|
||||
return reusable_empty_nodes;
|
||||
}
|
||||
|
||||
// A block is considered reusable during CUDA graph capture if every free
|
||||
// marker (empty node) associated with the block is a predecessor of every
|
||||
// terminal node.
|
||||
//
|
||||
// This ensures that any new operation added to the graph will be attached
|
||||
// after all terminal nodes, which themselves are after all free markers. As a
|
||||
// result, all future work is guaranteed to occur after the block's last use
|
||||
// on every stream, so the block's previous lifetime ends before any new
|
||||
// lifetime begins. This check relies solely on the DAG topology and does not
|
||||
// require event queries, making it safe to use during capture.
|
||||
//
|
||||
// This function iterates over all deferred blocks, determines if their empty
|
||||
// nodes are reusable according to the above criteria, and frees the block if
|
||||
// so.
|
||||
void free_safe_blocks_in_capture(
|
||||
const std::shared_ptr<GatheredContext>& context,
|
||||
cudaStream_t stream) {
|
||||
auto reusable_empty_nodes = get_reusable_empty_nodes(stream);
|
||||
|
||||
// If there are no reusable empty nodes (e.g., not currently capturing),
|
||||
// there is nothing to do.
|
||||
if (reusable_empty_nodes.empty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<Block*> blocks_to_erase;
|
||||
|
||||
for (auto& [block, inserted_empty_nodes] : deferred_blocks) {
|
||||
// Skip this block if it has no empty nodes, as we defer its freeing until
|
||||
// after graph capture. Also skip if the block was not allocated on the
|
||||
// current stream; such blocks will be freed when
|
||||
// free_safe_blocks_in_capture is attempted on that stream.
|
||||
if (inserted_empty_nodes.empty() || block->stream != stream) {
|
||||
continue;
|
||||
}
|
||||
|
||||
bool is_reusable = true;
|
||||
|
||||
for (const auto& node : inserted_empty_nodes) {
|
||||
if (reusable_empty_nodes.find(node) == reusable_empty_nodes.end()) {
|
||||
is_reusable = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (is_reusable) {
|
||||
// Clear stream uses since the graph ensures proper synchronization.
|
||||
// No need to insert events.
|
||||
block->stream_uses.clear();
|
||||
|
||||
free_block(block, context);
|
||||
blocks_to_erase.push_back(block);
|
||||
}
|
||||
}
|
||||
|
||||
// Remove blocks that were freed from the deferred_blocks map.
|
||||
for (auto* block : blocks_to_erase) {
|
||||
deferred_blocks.erase(block);
|
||||
}
|
||||
}
|
||||
|
||||
void free(Block* block) {
|
||||
std::shared_ptr<GatheredContext> context =
|
||||
maybeGatherContext(RecordContext::ALL);
|
||||
@ -1654,14 +1906,22 @@ class DeviceCachingAllocator {
|
||||
if (block->size >= CUDAAllocatorConfig::max_split_size())
|
||||
stats.oversize_allocations.decrease(1);
|
||||
|
||||
// If the block has been used on more than one stream, handle accordingly.
|
||||
if (!block->stream_uses.empty()) {
|
||||
if (C10_UNLIKELY(!captures_underway.empty())) {
|
||||
// It's forbidden to cudaEventQuery an event recorded during CUDA graph
|
||||
// capture. We conservatively defer recording end-of-life events until
|
||||
// the next call to process_events() (which won't happen until no
|
||||
// captures are underway)
|
||||
needs_events_deferred_until_no_capture.push_back(block);
|
||||
if (CUDAAllocatorConfig::graph_capture_record_stream_reuse()) {
|
||||
// insert_free_marker returns a vector of free markers,
|
||||
// or an empty vector if any associated stream is not currently
|
||||
// capturing. The empty vector means that we will defer the free until
|
||||
// capture is finished.
|
||||
deferred_blocks.emplace(block, insert_free_marker(block));
|
||||
} else {
|
||||
// If graph_capture_record_stream_reuse is not enabled, always defer
|
||||
// the free until capture is finished.
|
||||
deferred_blocks.emplace(block, std::vector<cudaGraphNode_t>{});
|
||||
}
|
||||
} else {
|
||||
// If not in a capture, insert events for the block.
|
||||
insert_events(block);
|
||||
}
|
||||
} else {
|
||||
@ -3287,8 +3547,8 @@ class DeviceCachingAllocator {
|
||||
|
||||
void insert_events_deferred_until_no_capture(
|
||||
const std::shared_ptr<GatheredContext>& context) {
|
||||
if (C10_UNLIKELY(!needs_events_deferred_until_no_capture.empty())) {
|
||||
for (auto* block : needs_events_deferred_until_no_capture) {
|
||||
if (C10_UNLIKELY(!deferred_blocks.empty())) {
|
||||
for (auto& [block, inserted_empty_nodes] : deferred_blocks) {
|
||||
TORCH_INTERNAL_ASSERT(!block->stream_uses.empty());
|
||||
// only streams recorded before cudagraph will be used to insert events
|
||||
// since we know all streams recorded during cudagraph must have
|
||||
@ -3300,7 +3560,7 @@ class DeviceCachingAllocator {
|
||||
free_block(block, context);
|
||||
}
|
||||
}
|
||||
needs_events_deferred_until_no_capture.clear();
|
||||
deferred_blocks.clear();
|
||||
}
|
||||
}
|
||||
|
||||
@ -3731,6 +3991,8 @@ class NativeCachingAllocator : public CUDAAllocator {
|
||||
md.pinned_use_host_register =
|
||||
CUDAAllocatorConfig::pinned_use_cuda_host_register();
|
||||
md.last_allocator_settings = CUDAAllocatorConfig::last_allocator_settings();
|
||||
md.graph_capture_record_stream_reuse =
|
||||
CUDAAllocatorConfig::graph_capture_record_stream_reuse();
|
||||
md.roundup_power2_divisions =
|
||||
CUDAAllocatorConfig::roundup_power2_divisions();
|
||||
|
||||
|
@ -163,6 +163,7 @@ struct AllocatorConfigInfo {
|
||||
bool expandable_segments;
|
||||
bool release_lock_on_malloc;
|
||||
bool pinned_use_host_register;
|
||||
bool graph_capture_record_stream_reuse;
|
||||
std::string last_allocator_settings;
|
||||
std::vector<size_t> roundup_power2_divisions;
|
||||
};
|
||||
|
@ -608,6 +608,14 @@ Available options:
|
||||
for processing events. This avoids any slow path associated with querying/processing of
|
||||
events in the fast allocation path. This feature is disabled by default.
|
||||
|
||||
* ``graph_capture_record_stream_reuse`` (experimental, default: `False`)
|
||||
If set to `True`, the CUDA caching allocator will attempt to reclaim device memory during
|
||||
CUDA Graph capture by using the graph topology (instead of CUDA events) to determine
|
||||
when a freed block is safe to reuse. This can reduce peak memory during long captures that free
|
||||
and reallocate buffers across multiple streams, especially when the capture DAG frequently
|
||||
reaches joined frontiers. Note: Enabling this option can significantly increase the time spent
|
||||
capturing the graph.
|
||||
|
||||
.. note::
|
||||
|
||||
Some stats reported by the
|
||||
|
@ -5613,6 +5613,149 @@ class TestMemPool(TestCase):
|
||||
s = p.snapshot()
|
||||
self.assertEqual(len(s), 1, "Expected to have a single segment")
|
||||
|
||||
@unittest.skipIf(
|
||||
not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
|
||||
)
|
||||
def test_graph_capture_reclaim_2_streams(self):
|
||||
torch.cuda.memory._set_allocator_settings(
|
||||
"graph_capture_record_stream_reuse:True"
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
s1, s2 = torch.cuda.Stream(), torch.cuda.Stream()
|
||||
g = torch.cuda.CUDAGraph(keep_graph=True)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
with torch.cuda.stream(s1):
|
||||
g.capture_begin()
|
||||
|
||||
# A sink node allocated up-front so it doesn't steal data1's block later.
|
||||
sink1 = torch.empty(8, device="cuda")
|
||||
|
||||
# Source tensor on s1; this block is the reuse candidate.
|
||||
data1 = torch.empty(8, device="cuda")
|
||||
data1_ptr = data1.data_ptr()
|
||||
|
||||
# Fork: do real work on s2 that READS data1 and writes to its own buffer.
|
||||
s2.wait_stream(s1)
|
||||
with torch.cuda.stream(s2):
|
||||
buf2 = torch.empty_like(data1)
|
||||
torch.add(data1, 2.0, out=buf2)
|
||||
data1.record_stream(s2)
|
||||
|
||||
del data1
|
||||
|
||||
# BEFORE JOIN: must NOT reuse
|
||||
data2 = torch.empty(8, device="cuda")
|
||||
data2_ptr = data2.data_ptr()
|
||||
|
||||
# Join s2 -> s1 and add a sink node on s1.
|
||||
s1.wait_stream(s2)
|
||||
sink1.fill_(1.0)
|
||||
|
||||
# AFTER JOIN: now reuse is allowed
|
||||
data3 = torch.empty(8, device="cuda")
|
||||
data3_ptr = data3.data_ptr()
|
||||
|
||||
g.capture_end()
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# No reuse before join; reuse after join.
|
||||
self.assertNotEqual(data1_ptr, data2_ptr)
|
||||
self.assertEqual(data1_ptr, data3_ptr)
|
||||
|
||||
torch.cuda.memory._set_allocator_settings(
|
||||
"graph_capture_record_stream_reuse:False"
|
||||
)
|
||||
|
||||
@unittest.skipIf(
|
||||
not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
|
||||
)
|
||||
def test_graph_capture_reclaim_4_streams(self):
|
||||
torch.cuda.memory._set_allocator_settings(
|
||||
"graph_capture_record_stream_reuse:True"
|
||||
)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
s1, s2, s3, s4 = (
|
||||
torch.cuda.Stream(),
|
||||
torch.cuda.Stream(),
|
||||
torch.cuda.Stream(),
|
||||
torch.cuda.Stream(),
|
||||
)
|
||||
g = torch.cuda.CUDAGraph(keep_graph=True)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
with torch.cuda.stream(s1):
|
||||
g.capture_begin()
|
||||
|
||||
# Source tensor allocated on s1. This block is the candidate for reuse.
|
||||
data1 = torch.ones(8, device="cuda")
|
||||
data1_ptr = data1.data_ptr()
|
||||
sink1 = torch.empty_like(data1)
|
||||
sink3 = torch.empty_like(data1)
|
||||
|
||||
s2.wait_stream(s1)
|
||||
with torch.cuda.stream(s2):
|
||||
buf2 = torch.empty_like(data1)
|
||||
torch.add(data1, 2.0, out=buf2)
|
||||
data1.record_stream(s2)
|
||||
|
||||
s3.wait_stream(s1)
|
||||
with torch.cuda.stream(s3):
|
||||
buf3 = torch.empty_like(data1)
|
||||
torch.add(data1, 3.0, out=buf3)
|
||||
data1.record_stream(s3)
|
||||
|
||||
s4.wait_stream(s1)
|
||||
with torch.cuda.stream(s4):
|
||||
buf4 = torch.empty_like(data1)
|
||||
torch.add(data1, 4.0, out=buf4)
|
||||
data1.record_stream(s4)
|
||||
|
||||
# Free data1 inside capture; allocator may reuse later when it's safe.
|
||||
del data1
|
||||
|
||||
# PARTIAL JOINS: should NOT allow reuse yet
|
||||
# Join s2 -> s1 and add a sink node on s1.
|
||||
s1.wait_stream(s2)
|
||||
sink1.fill_(1.0)
|
||||
|
||||
# Join s4 -> s3 and add a sink node on s3.
|
||||
s3.wait_stream(s4)
|
||||
with torch.cuda.stream(s3):
|
||||
sink3.fill_(3.0)
|
||||
sink3.record_stream(s3)
|
||||
|
||||
# At this point, s1 and s3 subgraphs are NOT yet joined together.
|
||||
# Allocating data2 here must NOT reuse data1's block.
|
||||
data2 = torch.empty(8, device="cuda")
|
||||
data2_ptr = data2.data_ptr()
|
||||
|
||||
# FINAL JOIN: now reuse is allowed
|
||||
# Join s3 -> s1 and add a sink node on s1.
|
||||
s1.wait_stream(s3)
|
||||
sink1.add_(sink3)
|
||||
|
||||
# Now allocator should safely reuse data1's block.
|
||||
data3 = torch.empty(8, device="cuda")
|
||||
data3_ptr = data3.data_ptr()
|
||||
|
||||
g.capture_end()
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# No reuse before full join; reuse after full join.
|
||||
self.assertNotEqual(data1_ptr, data2_ptr)
|
||||
self.assertEqual(data1_ptr, data3_ptr)
|
||||
|
||||
torch.cuda.memory._set_allocator_settings(
|
||||
"graph_capture_record_stream_reuse:False"
|
||||
)
|
||||
|
||||
@skipIfRocm(msg="expandable_segments mode is not supported on ROCm")
|
||||
@unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Load_inline doesn't work in fbcode")
|
||||
def test_mempool_expandable(self):
|
||||
|
@ -907,6 +907,8 @@ PyObject* THCPModule_memorySnapshot(PyObject* _unused, PyObject* arg) {
|
||||
py::str release_lock_on_malloc_s = "release_lock_on_cudamalloc";
|
||||
py::str pinned_use_host_register_s = "pinned_use_cuda_host_register";
|
||||
py::str roundup_power2_divisions_s = "roundup_power2_divisions";
|
||||
py::str graph_capture_record_stream_reuse_s =
|
||||
"graph_capture_record_stream_reuse";
|
||||
|
||||
allocator_settings[last_allocator_settings_s] =
|
||||
snapshot.config_metadata.last_allocator_settings;
|
||||
@ -922,6 +924,8 @@ PyObject* THCPModule_memorySnapshot(PyObject* _unused, PyObject* arg) {
|
||||
snapshot.config_metadata.release_lock_on_malloc;
|
||||
allocator_settings[pinned_use_host_register_s] =
|
||||
snapshot.config_metadata.pinned_use_host_register;
|
||||
allocator_settings[graph_capture_record_stream_reuse_s] =
|
||||
snapshot.config_metadata.graph_capture_record_stream_reuse;
|
||||
unsigned int roundup_key = 1;
|
||||
py::dict roundup_settings;
|
||||
for (const auto& v : snapshot.config_metadata.roundup_power2_divisions) {
|
||||
|
@ -458,6 +458,8 @@ std::string _memory_snapshot_pickled() {
|
||||
IValue release_lock_on_malloc_s = "release_lock_on_cudamalloc";
|
||||
IValue pinned_use_host_register_s = "pinned_use_cuda_host_register";
|
||||
IValue roundup_power2_divisions_s = "roundup_power2_divisions";
|
||||
IValue graph_capture_record_stream_reuse_s =
|
||||
"graph_capture_record_stream_reuse";
|
||||
|
||||
allocator_settings.insert(
|
||||
last_allocator_settings_s,
|
||||
@ -478,6 +480,9 @@ std::string _memory_snapshot_pickled() {
|
||||
allocator_settings.insert(
|
||||
pinned_use_host_register_s,
|
||||
snapshot.config_metadata.pinned_use_host_register);
|
||||
allocator_settings.insert(
|
||||
graph_capture_record_stream_reuse_s,
|
||||
snapshot.config_metadata.graph_capture_record_stream_reuse);
|
||||
unsigned int roundup_key = 1;
|
||||
auto roundup_settings = new_dict();
|
||||
for (const auto& v : snapshot.config_metadata.roundup_power2_divisions) {
|
||||
|
@ -4333,6 +4333,8 @@ CUDA_IDENTIFIER_MAP = collections.OrderedDict(
|
||||
("cudaStreamCaptureModeThreadLocal", ("hipStreamCaptureModeThreadLocal", CONV_TYPE, API_RUNTIME)),
|
||||
("cudaStreamBeginCapture", ("hipStreamBeginCapture", CONV_TYPE, API_RUNTIME)),
|
||||
("cudaStreamEndCapture", ("hipStreamEndCapture", CONV_TYPE, API_RUNTIME)),
|
||||
("cudaStreamSetCaptureDependencies", ("hipStreamSetCaptureDependencies", CONV_STREAM, API_RUNTIME)),
|
||||
("cudaStreamUpdateCaptureDependencies", ("hipStreamUpdateCaptureDependencies", CONV_STREAM, API_RUNTIME)),
|
||||
("cudaGraphInstantiate", ("hipGraphInstantiate", CONV_TYPE, API_RUNTIME)),
|
||||
("cudaGraphInstantiateWithFlags", ("hipGraphInstantiateWithFlags", CONV_TYPE, API_RUNTIME)),
|
||||
(
|
||||
|
Reference in New Issue
Block a user