mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
## Introduction During CUDA Graph capture, the CUDA caching allocator currently defers reclaiming blocks until capture ends. This is because CUDA forbids querying events recorded during capture (the CUDA operation is not executed during the capture stage), so the allocator cannot use its normal event-based logic. However, capture records an DAG (we call it **capturing graph**) of work. We can use the capturing graph to determine when a block’s old lifetime is fully before future work, and safely reuse it within the same capture. This PR adds an experimental flag `graph_capture_record_stream_reuse: True|False (default: False)`. When enabled, the allocator inserts lightweight free markers and uses capture ordering to decide if a freed block is safe to reuse during capture. If the proof cannot be established, we fall back to the existing post-capture path. ## Terms * **Free marker**: A capture-legal no-op (created with `cudaGraphAddEmptyNode`) inserted after the last captured use of the block on each stream that used it. * **Terminal**: The set of the lastest operations of the stream (or the capturing graph). Any newly captured op on that stream will attach after all nodes in this set. For a stream currently capturing, it is the set of nodes returned in `dependencies_out` by `cudaStreamGetCaptureInfo`. ## When can we reuse a block during capture? ### Strong Rule (Graph-Wide Safety) This rule provides a universal guarantee that a block is safe for reuse by any stream in the graph. > A block is safe to reuse if every free marker is a predecessor of every terminal of all active streams in the graph. Why it's safe: This rule establishes a strict global ordering. Since any new operation on any stream must be appended after that stream's terminals, this condition guarantees that the block's new lifetime begins only after its old lifetime has completely ended everywhere. This prevents lifetime overlaps when the graph is replayed, ensuring correctness. ### Per-stream Rule (A Practical Optimization) The strong rule, while safe, is often unnecessarily restrictive. The `DeviceCachingAllocator` introduces a crucial constraint that allows for a simpler check. In `DeviceCachingAllocator`, `get_free_block` only returns blocks whose `block->stream == p.stream()`. In other words, we never reuse a block on a stream different from the allocation stream. This means we don't need to verify safety across the entire graph. We only need to confirm that the block is safe to reuse from the perspective of its own allocation stream. > Reuse a block for allocations on stream S if every free marker is a predecessor of every node in the terminal set of S. In short, a block is considered **reusable** on stream S as long as all marker marking it "free" are guaranteed to complete before any new work that might need it on stream S begins. ## Implementation * On `free(block)` during capture * For each stream in `block->stream_uses` and the allocation stream, insert a free marker (empty node) and make it that stream’s tail. * If we cannot place markers for all such streams (for example, a stream is not in capture), defer to the post-capture path. * Otherwise, store the marker handles and keep the block in the capture-private structures. * On `allocate(stream)` during capture (attempt per-stream reclaim) * Query the allocation stream S’s terminal via `cudaStreamGetCaptureInfo`. * For each deferred block, check whether it is allocated on this stream, and each of its free markers is a predecessor of the terminal. * If yes, hand the block to S for immediate reuse within the same capture. * If no, keep it deferred; it will be reconsidered as capture progresses and S’s terminal advances. * On capture end * Any still-deferred blocks follow the existing post-capture reclamation (event insertion/polling). External behavior remains unchanged if we cannot prove safety during capture. ## Examples (2 streams) <img width="641" height="801" alt="pytorch-remove-cudagraph-defer-reclaiming (6)" src="https://github.com/user-attachments/assets/41adc835-d448-483b-99ba-b4341cb7d2a2" /> * Case 0 — Unsafe The two frees are not ordered with respect to each other. For stream 1, the other stream’s free marker does not precede this stream’s terminal, so the per-stream condition fails. Counterexample intuition for the unsafe setups: imagine `f2(x)` runs for a long time. If DeviceCachingAllocator reused block `x` on a stream whose terminal is not ordered after the free markers, the new lifetime could overlap the old one on replay, risking use-after-free or data corruption. The per-stream rule prevents exactly this. * Case 1 — Reusable on stream 1 Stream 1’s terminal is after both frees, so every free marker precedes stream 1’s terminal. The block is reusable for allocations on stream 1. * Case 2 — Not reusable on stream 2, but this cannot occur in `DeviceCachingAllocator` This depicts reusing the block on stream 2 while stream 1’s free is not yet ordered before stream 2’s terminal. Though the block is not safe to reuse on stream 2, DeviceCachingAllocator will not choose that block for stream 2 anyway: `get_free_block` rejects blocks whose `stream != p.stream()`. So this case is unreachable. * Case 3 — Safe (strong rule holds) In this scenario, the terminal nodes of all streams are positioned after the block's free markers, satisfying the strong rule. This guarantees the block is safe for reuse by any stream in the capturing graph. However, since `DeviceCachingAllocator ` only reuses a block on its original allocation stream, verifying this strong condition is unnecessary. We only need to ensure the per-stream rule is met for the specific stream requesting the block. * Case 4 — Freeing after a join See the note below. ## Edge Case: Freeing after a join Our current dependency tracking has a limitation in scenarios where a block is freed after a stream join, see @galv's [comments here](https://github.com/pytorch/pytorch/pull/158352#pullrequestreview-3112565198)). In the case 4, we have a missed opportunity. Because the block's usage is not explicitly marked, we cannot determine that the block's actual last use may have occurred much earlier, long before the join. Then, we must wait for the subsequent join before the block can be reused. ## Thanks Thanks to @galv for his great idea around graph parsing and empty nodes. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158352 Approved by: https://github.com/ngimel, https://github.com/eqy Co-authored-by: Jeff Daily <jeff.daily@amd.com>
476 lines
16 KiB
C++
476 lines
16 KiB
C++
#include <c10/cuda/CUDAAllocatorConfig.h>
|
|
#include <c10/cuda/CUDACachingAllocator.h>
|
|
#include <c10/util/llvmMathExtras.h>
|
|
|
|
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
|
|
#include <c10/cuda/driver_api.h>
|
|
#endif
|
|
|
|
namespace c10::cuda::CUDACachingAllocator {
|
|
|
|
constexpr size_t kRoundUpPowerOfTwoIntervals = 16;
|
|
|
|
CUDAAllocatorConfig::CUDAAllocatorConfig()
|
|
: m_max_split_size(std::numeric_limits<size_t>::max()),
|
|
m_max_non_split_rounding_size(kLargeBuffer),
|
|
m_garbage_collection_threshold(0),
|
|
m_pinned_num_register_threads(1),
|
|
m_expandable_segments(false),
|
|
#if CUDA_VERSION >= 12030
|
|
m_expandable_segments_handle_type(
|
|
Expandable_Segments_Handle_Type::UNSPECIFIED),
|
|
#else
|
|
m_expandable_segments_handle_type(
|
|
Expandable_Segments_Handle_Type::POSIX_FD),
|
|
#endif
|
|
m_release_lock_on_cudamalloc(false),
|
|
m_pinned_use_cuda_host_register(false),
|
|
m_graph_capture_record_stream_reuse(false),
|
|
m_pinned_use_background_threads(false) {
|
|
m_roundup_power2_divisions.assign(kRoundUpPowerOfTwoIntervals, 0);
|
|
}
|
|
|
|
size_t CUDAAllocatorConfig::roundup_power2_divisions(size_t size) {
|
|
size_t log_size = (63 - llvm::countLeadingZeros(size));
|
|
|
|
// Our intervals start at 1MB and end at 64GB
|
|
const size_t interval_start =
|
|
63 - llvm::countLeadingZeros(static_cast<size_t>(1048576));
|
|
const size_t interval_end =
|
|
63 - llvm::countLeadingZeros(static_cast<size_t>(68719476736));
|
|
TORCH_CHECK(
|
|
(interval_end - interval_start == kRoundUpPowerOfTwoIntervals),
|
|
"kRoundUpPowerOfTwoIntervals mismatch");
|
|
|
|
int index = static_cast<int>(log_size) - static_cast<int>(interval_start);
|
|
|
|
index = std::max(0, index);
|
|
index = std::min(index, static_cast<int>(kRoundUpPowerOfTwoIntervals) - 1);
|
|
return instance().m_roundup_power2_divisions[index];
|
|
}
|
|
|
|
void CUDAAllocatorConfig::lexArgs(
|
|
const std::string& env,
|
|
std::vector<std::string>& config) {
|
|
std::vector<char> buf;
|
|
|
|
for (char ch : env) {
|
|
if (ch == ',' || ch == ':' || ch == '[' || ch == ']') {
|
|
if (!buf.empty()) {
|
|
config.emplace_back(buf.begin(), buf.end());
|
|
buf.clear();
|
|
}
|
|
config.emplace_back(1, ch);
|
|
} else if (ch != ' ') {
|
|
buf.emplace_back(ch);
|
|
}
|
|
}
|
|
if (!buf.empty()) {
|
|
config.emplace_back(buf.begin(), buf.end());
|
|
}
|
|
}
|
|
|
|
void CUDAAllocatorConfig::consumeToken(
|
|
const std::vector<std::string>& config,
|
|
size_t i,
|
|
const char c) {
|
|
TORCH_CHECK(
|
|
i < config.size() && config[i] == std::string(1, c),
|
|
"Error parsing CachingAllocator settings, expected ",
|
|
c,
|
|
"");
|
|
}
|
|
|
|
size_t CUDAAllocatorConfig::parseMaxSplitSize(
|
|
const std::vector<std::string>& config,
|
|
size_t i) {
|
|
consumeToken(config, ++i, ':');
|
|
constexpr int mb = 1024 * 1024;
|
|
if (++i < config.size()) {
|
|
size_t val1 = stoi(config[i]);
|
|
TORCH_CHECK(
|
|
val1 > kLargeBuffer / mb,
|
|
"CachingAllocator option max_split_size_mb too small, must be > ",
|
|
kLargeBuffer / mb,
|
|
"");
|
|
val1 = std::max(val1, kLargeBuffer / mb);
|
|
val1 = std::min(val1, (std::numeric_limits<size_t>::max() / mb));
|
|
m_max_split_size = val1 * 1024 * 1024;
|
|
} else {
|
|
TORCH_CHECK(false, "Error, expecting max_split_size_mb value", "");
|
|
}
|
|
return i;
|
|
}
|
|
|
|
size_t CUDAAllocatorConfig::parseMaxNonSplitRoundingSize(
|
|
const std::vector<std::string>& config,
|
|
size_t i) {
|
|
consumeToken(config, ++i, ':');
|
|
constexpr int mb = 1024 * 1024;
|
|
if (++i < config.size()) {
|
|
size_t val1 = stoi(config[i]);
|
|
TORCH_CHECK(
|
|
val1 > kLargeBuffer / mb,
|
|
"CachingAllocator option max_non_split_rounding_mb too small, must be > ",
|
|
kLargeBuffer / mb,
|
|
"");
|
|
val1 = std::max(val1, kLargeBuffer / mb);
|
|
val1 = std::min(val1, (std::numeric_limits<size_t>::max() / mb));
|
|
m_max_non_split_rounding_size = val1 * 1024 * 1024;
|
|
} else {
|
|
TORCH_CHECK(false, "Error, expecting max_non_split_rounding_mb value", "");
|
|
}
|
|
return i;
|
|
}
|
|
|
|
size_t CUDAAllocatorConfig::parseGarbageCollectionThreshold(
|
|
const std::vector<std::string>& config,
|
|
size_t i) {
|
|
consumeToken(config, ++i, ':');
|
|
if (++i < config.size()) {
|
|
double val1 = stod(config[i]);
|
|
TORCH_CHECK(
|
|
val1 > 0, "garbage_collect_threshold too small, set it 0.0~1.0", "");
|
|
TORCH_CHECK(
|
|
val1 < 1.0, "garbage_collect_threshold too big, set it 0.0~1.0", "");
|
|
m_garbage_collection_threshold = val1;
|
|
} else {
|
|
TORCH_CHECK(
|
|
false, "Error, expecting garbage_collection_threshold value", "");
|
|
}
|
|
return i;
|
|
}
|
|
|
|
size_t CUDAAllocatorConfig::parseRoundUpPower2Divisions(
|
|
const std::vector<std::string>& config,
|
|
size_t i) {
|
|
consumeToken(config, ++i, ':');
|
|
bool first_value = true;
|
|
|
|
if (++i < config.size()) {
|
|
if (std::string_view(config[i]) == "[") {
|
|
size_t last_index = 0;
|
|
// NOLINTNEXTLINE(bugprone-inc-dec-in-conditions)
|
|
while (++i < config.size() && std::string_view(config[i]) != "]") {
|
|
const std::string& val1 = config[i];
|
|
size_t val2 = 0;
|
|
|
|
consumeToken(config, ++i, ':');
|
|
if (++i < config.size()) {
|
|
val2 = stoi(config[i]);
|
|
} else {
|
|
TORCH_CHECK(
|
|
false, "Error parsing roundup_power2_divisions value", "");
|
|
}
|
|
TORCH_CHECK(
|
|
val2 == 0 || llvm::isPowerOf2_64(val2),
|
|
"For roundups, the divisions has to be power of 2 or 0 to disable roundup ",
|
|
"");
|
|
|
|
if (std::string_view(val1) == ">") {
|
|
std::fill(
|
|
std::next(
|
|
m_roundup_power2_divisions.begin(),
|
|
static_cast<std::vector<unsigned long>::difference_type>(
|
|
last_index)),
|
|
m_roundup_power2_divisions.end(),
|
|
val2);
|
|
} else {
|
|
size_t val1_long = stoul(val1);
|
|
TORCH_CHECK(
|
|
llvm::isPowerOf2_64(val1_long),
|
|
"For roundups, the intervals have to be power of 2 ",
|
|
"");
|
|
|
|
size_t index = 63 - llvm::countLeadingZeros(val1_long);
|
|
index = std::max((size_t)0, index);
|
|
index = std::min(index, m_roundup_power2_divisions.size() - 1);
|
|
|
|
if (first_value) {
|
|
std::fill(
|
|
m_roundup_power2_divisions.begin(),
|
|
std::next(
|
|
m_roundup_power2_divisions.begin(),
|
|
static_cast<std::vector<unsigned long>::difference_type>(
|
|
index)),
|
|
val2);
|
|
first_value = false;
|
|
}
|
|
if (index < m_roundup_power2_divisions.size()) {
|
|
m_roundup_power2_divisions[index] = val2;
|
|
}
|
|
last_index = index;
|
|
}
|
|
|
|
if (std::string_view(config[i + 1]) != "]") {
|
|
consumeToken(config, ++i, ',');
|
|
}
|
|
}
|
|
} else { // Keep this for backwards compatibility
|
|
size_t val1 = stoi(config[i]);
|
|
TORCH_CHECK(
|
|
llvm::isPowerOf2_64(val1),
|
|
"For roundups, the divisions has to be power of 2 ",
|
|
"");
|
|
std::fill(
|
|
m_roundup_power2_divisions.begin(),
|
|
m_roundup_power2_divisions.end(),
|
|
val1);
|
|
}
|
|
} else {
|
|
TORCH_CHECK(false, "Error, expecting roundup_power2_divisions value", "");
|
|
}
|
|
return i;
|
|
}
|
|
|
|
size_t CUDAAllocatorConfig::parseAllocatorConfig(
|
|
const std::vector<std::string>& config,
|
|
size_t i,
|
|
bool& used_cudaMallocAsync) {
|
|
// For ease of maintenance and understanding, the CUDA and ROCm
|
|
// implementations of this function are separated. This avoids having many
|
|
// #ifdef's throughout.
|
|
#ifdef USE_ROCM
|
|
// Ease burden on ROCm users by allowing either cuda or hip tokens.
|
|
// cuda token is broken up to prevent hipify matching it.
|
|
#define PYTORCH_TOKEN1 \
|
|
"cud" \
|
|
"aMallocAsync"
|
|
#define PYTORCH_TOKEN2 "hipMallocAsync"
|
|
consumeToken(config, ++i, ':');
|
|
if (++i < config.size()) {
|
|
TORCH_CHECK(
|
|
((config[i] == "native") || (config[i] == PYTORCH_TOKEN1) ||
|
|
(config[i] == PYTORCH_TOKEN2)),
|
|
"Unknown allocator backend, "
|
|
"options are native, " PYTORCH_TOKEN1 ", and " PYTORCH_TOKEN2);
|
|
used_cudaMallocAsync =
|
|
(config[i] == PYTORCH_TOKEN1 || config[i] == PYTORCH_TOKEN2);
|
|
TORCH_INTERNAL_ASSERT(
|
|
config[i] == get()->name() ||
|
|
(config[i] == PYTORCH_TOKEN1 && get()->name() == PYTORCH_TOKEN2),
|
|
"Allocator backend parsed at runtime != "
|
|
"allocator backend parsed at load time, ",
|
|
config[i],
|
|
" != ",
|
|
get()->name());
|
|
} else {
|
|
TORCH_CHECK(false, "Error parsing backend value", "");
|
|
}
|
|
return i;
|
|
#undef PYTORCH_TOKEN1
|
|
#undef PYTORCH_TOKEN2
|
|
#else // USE_ROCM
|
|
consumeToken(config, ++i, ':');
|
|
if (++i < config.size()) {
|
|
TORCH_CHECK(
|
|
((config[i] == "native") || (config[i] == "cudaMallocAsync")),
|
|
"Unknown allocator backend, "
|
|
"options are native and cudaMallocAsync");
|
|
used_cudaMallocAsync = (config[i] == "cudaMallocAsync");
|
|
if (used_cudaMallocAsync) {
|
|
#if CUDA_VERSION >= 11040
|
|
int version = 0;
|
|
C10_CUDA_CHECK(cudaDriverGetVersion(&version));
|
|
TORCH_CHECK(
|
|
version >= 11040,
|
|
"backend:cudaMallocAsync requires CUDA runtime "
|
|
"11.4 or newer, but cudaDriverGetVersion returned ",
|
|
version);
|
|
#else
|
|
TORCH_CHECK(
|
|
false,
|
|
"backend:cudaMallocAsync requires PyTorch to be built with "
|
|
"CUDA 11.4 or newer, but CUDA_VERSION is ",
|
|
CUDA_VERSION);
|
|
#endif
|
|
}
|
|
TORCH_INTERNAL_ASSERT(
|
|
config[i] == get()->name(),
|
|
"Allocator backend parsed at runtime != "
|
|
"allocator backend parsed at load time");
|
|
} else {
|
|
TORCH_CHECK(false, "Error parsing backend value", "");
|
|
}
|
|
return i;
|
|
#endif // USE_ROCM
|
|
}
|
|
|
|
void CUDAAllocatorConfig::parseArgs(const std::optional<std::string>& env) {
|
|
// If empty, set the default values
|
|
m_max_split_size = std::numeric_limits<size_t>::max();
|
|
m_roundup_power2_divisions.assign(kRoundUpPowerOfTwoIntervals, 0);
|
|
m_garbage_collection_threshold = 0;
|
|
bool used_cudaMallocAsync = false;
|
|
bool used_native_specific_option = false;
|
|
|
|
if (!env.has_value()) {
|
|
return;
|
|
}
|
|
{
|
|
std::lock_guard<std::mutex> lock(m_last_allocator_settings_mutex);
|
|
m_last_allocator_settings = env.value();
|
|
}
|
|
|
|
std::vector<std::string> config;
|
|
lexArgs(env.value(), config);
|
|
|
|
for (size_t i = 0; i < config.size(); i++) {
|
|
std::string_view config_item_view(config[i]);
|
|
if (config_item_view == "max_split_size_mb") {
|
|
i = parseMaxSplitSize(config, i);
|
|
used_native_specific_option = true;
|
|
} else if (config_item_view == "max_non_split_rounding_mb") {
|
|
i = parseMaxNonSplitRoundingSize(config, i);
|
|
used_native_specific_option = true;
|
|
} else if (config_item_view == "garbage_collection_threshold") {
|
|
i = parseGarbageCollectionThreshold(config, i);
|
|
used_native_specific_option = true;
|
|
} else if (config_item_view == "roundup_power2_divisions") {
|
|
i = parseRoundUpPower2Divisions(config, i);
|
|
used_native_specific_option = true;
|
|
} else if (config_item_view == "backend") {
|
|
i = parseAllocatorConfig(config, i, used_cudaMallocAsync);
|
|
} else if (config_item_view == "expandable_segments") {
|
|
used_native_specific_option = true;
|
|
consumeToken(config, ++i, ':');
|
|
++i;
|
|
TORCH_CHECK(
|
|
i < config.size() &&
|
|
(std::string_view(config[i]) == "True" ||
|
|
std::string_view(config[i]) == "False"),
|
|
"Expected a single True/False argument for expandable_segments");
|
|
config_item_view = config[i];
|
|
m_expandable_segments = (config_item_view == "True");
|
|
} else if (
|
|
// ROCm build's hipify step will change "cuda" to "hip", but for ease of
|
|
// use, accept both. We must break up the string to prevent hipify here.
|
|
config_item_view == "release_lock_on_hipmalloc" ||
|
|
config_item_view ==
|
|
"release_lock_on_c"
|
|
"udamalloc") {
|
|
used_native_specific_option = true;
|
|
consumeToken(config, ++i, ':');
|
|
++i;
|
|
TORCH_CHECK(
|
|
i < config.size() &&
|
|
(std::string_view(config[i]) == "True" ||
|
|
std::string_view(config[i]) == "False"),
|
|
"Expected a single True/False argument for release_lock_on_cudamalloc");
|
|
config_item_view = config[i];
|
|
m_release_lock_on_cudamalloc = (config_item_view == "True");
|
|
} else if (
|
|
// ROCm build's hipify step will change "cuda" to "hip", but for ease of
|
|
// use, accept both. We must break up the string to prevent hipify here.
|
|
config_item_view == "pinned_use_hip_host_register" ||
|
|
config_item_view ==
|
|
"pinned_use_c"
|
|
"uda_host_register") {
|
|
i = parsePinnedUseCudaHostRegister(config, i);
|
|
used_native_specific_option = true;
|
|
} else if (config_item_view == "pinned_num_register_threads") {
|
|
i = parsePinnedNumRegisterThreads(config, i);
|
|
used_native_specific_option = true;
|
|
} else if (config_item_view == "pinned_use_background_threads") {
|
|
i = parsePinnedUseBackgroundThreads(config, i);
|
|
used_native_specific_option = true;
|
|
} else if (config_item_view == "graph_capture_record_stream_reuse") {
|
|
i = parseGraphCaptureRecordStreamReuse(config, i);
|
|
used_native_specific_option = true;
|
|
} else {
|
|
TORCH_CHECK(
|
|
false, "Unrecognized CachingAllocator option: ", config_item_view);
|
|
}
|
|
|
|
if (i + 1 < config.size()) {
|
|
consumeToken(config, ++i, ',');
|
|
}
|
|
}
|
|
|
|
if (used_cudaMallocAsync && used_native_specific_option) {
|
|
TORCH_WARN(
|
|
"backend:cudaMallocAsync ignores max_split_size_mb,"
|
|
"roundup_power2_divisions, and garbage_collect_threshold.");
|
|
}
|
|
}
|
|
|
|
size_t CUDAAllocatorConfig::parsePinnedUseCudaHostRegister(
|
|
const std::vector<std::string>& config,
|
|
size_t i) {
|
|
consumeToken(config, ++i, ':');
|
|
if (++i < config.size()) {
|
|
TORCH_CHECK(
|
|
(config[i] == "True" || config[i] == "False"),
|
|
"Expected a single True/False argument for pinned_use_cuda_host_register");
|
|
m_pinned_use_cuda_host_register = (config[i] == "True");
|
|
} else {
|
|
TORCH_CHECK(
|
|
false, "Error, expecting pinned_use_cuda_host_register value", "");
|
|
}
|
|
return i;
|
|
}
|
|
|
|
size_t CUDAAllocatorConfig::parseGraphCaptureRecordStreamReuse(
|
|
const std::vector<std::string>& config,
|
|
size_t i) {
|
|
consumeToken(config, ++i, ':');
|
|
if (++i < config.size()) {
|
|
TORCH_CHECK(
|
|
(config[i] == "True" || config[i] == "False"),
|
|
"Expected a single True/False argument for graph_capture_record_stream_reuse");
|
|
m_graph_capture_record_stream_reuse = (config[i] == "True");
|
|
} else {
|
|
TORCH_CHECK(
|
|
false, "Error, expecting graph_capture_record_stream_reuse value", "");
|
|
}
|
|
|
|
return i;
|
|
}
|
|
|
|
size_t CUDAAllocatorConfig::parsePinnedNumRegisterThreads(
|
|
const std::vector<std::string>& config,
|
|
size_t i) {
|
|
consumeToken(config, ++i, ':');
|
|
if (++i < config.size()) {
|
|
size_t val2 = stoi(config[i]);
|
|
TORCH_CHECK(
|
|
llvm::isPowerOf2_64(val2),
|
|
"Number of register threads has to be power of 2 ",
|
|
"");
|
|
auto maxThreads = CUDAAllocatorConfig::pinned_max_register_threads();
|
|
TORCH_CHECK(
|
|
val2 <= maxThreads,
|
|
"Number of register threads should be less than or equal to " +
|
|
std::to_string(maxThreads),
|
|
"");
|
|
m_pinned_num_register_threads = val2;
|
|
} else {
|
|
TORCH_CHECK(
|
|
false, "Error, expecting pinned_num_register_threads value", "");
|
|
}
|
|
return i;
|
|
}
|
|
|
|
size_t CUDAAllocatorConfig::parsePinnedUseBackgroundThreads(
|
|
const std::vector<std::string>& config,
|
|
size_t i) {
|
|
consumeToken(config, ++i, ':');
|
|
if (++i < config.size()) {
|
|
TORCH_CHECK(
|
|
(config[i] == "True" || config[i] == "False"),
|
|
"Expected a single True/False argument for pinned_use_background_threads");
|
|
m_pinned_use_background_threads = (config[i] == "True");
|
|
} else {
|
|
TORCH_CHECK(
|
|
false, "Error, expecting pinned_use_background_threads value", "");
|
|
}
|
|
return i;
|
|
}
|
|
|
|
// General caching allocator utilities
|
|
void setAllocatorSettings(const std::string& env) {
|
|
CUDACachingAllocator::CUDAAllocatorConfig::instance().parseArgs(env.c_str());
|
|
}
|
|
|
|
} // namespace c10::cuda::CUDACachingAllocator
|