Revert "Refactor CUDAAllocatorConfig to reuse AcceleratorAllocatorConfig (#150312)"

This reverts commit dfacf11f66d6512396382bdf5088f0ba9de00406.

Reverted https://github.com/pytorch/pytorch/pull/150312 on behalf of https://github.com/guangyey due to Static initialization order issue impact the downstream repo ([comment](https://github.com/pytorch/pytorch/pull/150312#issuecomment-3142035444))
This commit is contained in:
PyTorch MergeBot
2025-08-01 03:24:54 +00:00
parent 90f13f3b2a
commit 5cc6a0abc1
4 changed files with 496 additions and 159 deletions

View File

@ -1,119 +1,389 @@
#include <c10/cuda/CUDAAllocatorConfig.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include <c10/util/llvmMathExtras.h>
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
#include <c10/cuda/driver_api.h>
#endif
#include <cuda_runtime_api.h>
namespace c10::cuda::CUDACachingAllocator {
size_t CUDAAllocatorConfig::parseAllocatorConfig(
const c10::CachingAllocator::ConfigTokenizer& tokenizer,
constexpr size_t kRoundUpPowerOfTwoIntervals = 16;
CUDAAllocatorConfig::CUDAAllocatorConfig()
: m_max_split_size(std::numeric_limits<size_t>::max()),
m_max_non_split_rounding_size(kLargeBuffer),
m_garbage_collection_threshold(0),
m_pinned_num_register_threads(1),
m_expandable_segments(false),
#if CUDA_VERSION >= 12030
m_expandable_segments_handle_type(
Expandable_Segments_Handle_Type::UNSPECIFIED),
#else
m_expandable_segments_handle_type(
Expandable_Segments_Handle_Type::POSIX_FD),
#endif
m_release_lock_on_cudamalloc(false),
m_pinned_use_cuda_host_register(false),
m_pinned_use_background_threads(false) {
m_roundup_power2_divisions.assign(kRoundUpPowerOfTwoIntervals, 0);
}
size_t CUDAAllocatorConfig::roundup_power2_divisions(size_t size) {
size_t log_size = (63 - llvm::countLeadingZeros(size));
// Our intervals start at 1MB and end at 64GB
const size_t interval_start =
63 - llvm::countLeadingZeros(static_cast<size_t>(1048576));
const size_t interval_end =
63 - llvm::countLeadingZeros(static_cast<size_t>(68719476736));
TORCH_CHECK(
(interval_end - interval_start == kRoundUpPowerOfTwoIntervals),
"kRoundUpPowerOfTwoIntervals mismatch");
int index = static_cast<int>(log_size) - static_cast<int>(interval_start);
index = std::max(0, index);
index = std::min(index, static_cast<int>(kRoundUpPowerOfTwoIntervals) - 1);
return instance().m_roundup_power2_divisions[index];
}
void CUDAAllocatorConfig::lexArgs(
const std::string& env,
std::vector<std::string>& config) {
std::vector<char> buf;
for (char ch : env) {
if (ch == ',' || ch == ':' || ch == '[' || ch == ']') {
if (!buf.empty()) {
config.emplace_back(buf.begin(), buf.end());
buf.clear();
}
config.emplace_back(1, ch);
} else if (ch != ' ') {
buf.emplace_back(ch);
}
}
if (!buf.empty()) {
config.emplace_back(buf.begin(), buf.end());
}
}
void CUDAAllocatorConfig::consumeToken(
const std::vector<std::string>& config,
size_t i,
const char c) {
TORCH_CHECK(
i < config.size() && config[i] == std::string(1, c),
"Error parsing CachingAllocator settings, expected ",
c,
"");
}
size_t CUDAAllocatorConfig::parseMaxSplitSize(
const std::vector<std::string>& config,
size_t i) {
consumeToken(config, ++i, ':');
constexpr int mb = 1024 * 1024;
if (++i < config.size()) {
size_t val1 = stoi(config[i]);
TORCH_CHECK(
val1 > kLargeBuffer / mb,
"CachingAllocator option max_split_size_mb too small, must be > ",
kLargeBuffer / mb,
"");
val1 = std::max(val1, kLargeBuffer / mb);
val1 = std::min(val1, (std::numeric_limits<size_t>::max() / mb));
m_max_split_size = val1 * 1024 * 1024;
} else {
TORCH_CHECK(false, "Error, expecting max_split_size_mb value", "");
}
return i;
}
size_t CUDAAllocatorConfig::parseMaxNonSplitRoundingSize(
const std::vector<std::string>& config,
size_t i) {
consumeToken(config, ++i, ':');
constexpr int mb = 1024 * 1024;
if (++i < config.size()) {
size_t val1 = stoi(config[i]);
TORCH_CHECK(
val1 > kLargeBuffer / mb,
"CachingAllocator option max_non_split_rounding_mb too small, must be > ",
kLargeBuffer / mb,
"");
val1 = std::max(val1, kLargeBuffer / mb);
val1 = std::min(val1, (std::numeric_limits<size_t>::max() / mb));
m_max_non_split_rounding_size = val1 * 1024 * 1024;
} else {
TORCH_CHECK(false, "Error, expecting max_non_split_rounding_mb value", "");
}
return i;
}
size_t CUDAAllocatorConfig::parseGarbageCollectionThreshold(
const std::vector<std::string>& config,
size_t i) {
consumeToken(config, ++i, ':');
if (++i < config.size()) {
double val1 = stod(config[i]);
TORCH_CHECK(
val1 > 0, "garbage_collect_threshold too small, set it 0.0~1.0", "");
TORCH_CHECK(
val1 < 1.0, "garbage_collect_threshold too big, set it 0.0~1.0", "");
m_garbage_collection_threshold = val1;
} else {
TORCH_CHECK(
false, "Error, expecting garbage_collection_threshold value", "");
}
return i;
}
size_t CUDAAllocatorConfig::parseRoundUpPower2Divisions(
const std::vector<std::string>& config,
size_t i) {
consumeToken(config, ++i, ':');
bool first_value = true;
if (++i < config.size()) {
if (std::string_view(config[i]) == "[") {
size_t last_index = 0;
// NOLINTNEXTLINE(bugprone-inc-dec-in-conditions)
while (++i < config.size() && std::string_view(config[i]) != "]") {
const std::string& val1 = config[i];
size_t val2 = 0;
consumeToken(config, ++i, ':');
if (++i < config.size()) {
val2 = stoi(config[i]);
} else {
TORCH_CHECK(
false, "Error parsing roundup_power2_divisions value", "");
}
TORCH_CHECK(
val2 == 0 || llvm::isPowerOf2_64(val2),
"For roundups, the divisions has to be power of 2 or 0 to disable roundup ",
"");
if (std::string_view(val1) == ">") {
std::fill(
std::next(
m_roundup_power2_divisions.begin(),
static_cast<std::vector<unsigned long>::difference_type>(
last_index)),
m_roundup_power2_divisions.end(),
val2);
} else {
size_t val1_long = stoul(val1);
TORCH_CHECK(
llvm::isPowerOf2_64(val1_long),
"For roundups, the intervals have to be power of 2 ",
"");
size_t index = 63 - llvm::countLeadingZeros(val1_long);
index = std::max((size_t)0, index);
index = std::min(index, m_roundup_power2_divisions.size() - 1);
if (first_value) {
std::fill(
m_roundup_power2_divisions.begin(),
std::next(
m_roundup_power2_divisions.begin(),
static_cast<std::vector<unsigned long>::difference_type>(
index)),
val2);
first_value = false;
}
if (index < m_roundup_power2_divisions.size()) {
m_roundup_power2_divisions[index] = val2;
}
last_index = index;
}
if (std::string_view(config[i + 1]) != "]") {
consumeToken(config, ++i, ',');
}
}
} else { // Keep this for backwards compatibility
size_t val1 = stoi(config[i]);
TORCH_CHECK(
llvm::isPowerOf2_64(val1),
"For roundups, the divisions has to be power of 2 ",
"");
std::fill(
m_roundup_power2_divisions.begin(),
m_roundup_power2_divisions.end(),
val1);
}
} else {
TORCH_CHECK(false, "Error, expecting roundup_power2_divisions value", "");
}
return i;
}
size_t CUDAAllocatorConfig::parseAllocatorConfig(
const std::vector<std::string>& config,
size_t i,
bool& used_cudaMallocAsync) {
// For ease of maintenance and understanding, the CUDA and ROCm
// implementations of this function are separated. This avoids having many
// #ifdef's throughout.
#ifdef USE_ROCM
// Ease burden on ROCm users by allowing either cuda or hip tokens.
// cuda token is broken up to prevent hipify matching it.
#define PYTORCH_TOKEN1 \
"cud" \
"aMallocAsync"
#define PYTORCH_TOKEN2 "hipMallocAsync"
tokenizer.checkToken(++i, ":");
i++; // Move to the value after the colon
TORCH_CHECK(
((tokenizer[i] == "native") || (tokenizer[i] == PYTORCH_TOKEN1) ||
(tokenizer[i] == PYTORCH_TOKEN2)),
"Unknown allocator backend, "
"options are native, " PYTORCH_TOKEN1 ", and " PYTORCH_TOKEN2);
if (m_is_allocator_loaded) {
bool aync_allocator_at_runtime = (tokenizer[i] != "native");
consumeToken(config, ++i, ':');
if (++i < config.size()) {
TORCH_CHECK(
aync_allocator_at_runtime == m_use_async_allocator,
"Allocator async backend parsed at runtime != allocator async backend parsed at load time, ",
aync_allocator_at_runtime,
((config[i] == "native") || (config[i] == PYTORCH_TOKEN1) ||
(config[i] == PYTORCH_TOKEN2)),
"Unknown allocator backend, "
"options are native, " PYTORCH_TOKEN1 ", and " PYTORCH_TOKEN2);
used_cudaMallocAsync =
(config[i] == PYTORCH_TOKEN1 || config[i] == PYTORCH_TOKEN2);
TORCH_INTERNAL_ASSERT(
config[i] == get()->name() ||
(config[i] == PYTORCH_TOKEN1 && get()->name() == PYTORCH_TOKEN2),
"Allocator backend parsed at runtime != "
"allocator backend parsed at load time, ",
config[i],
" != ",
m_use_async_allocator);
get()->name());
} else {
TORCH_CHECK(false, "Error parsing backend value", "");
}
m_use_async_allocator =
(tokenizer[i] == PYTORCH_TOKEN1 || tokenizer[i] == PYTORCH_TOKEN2);
// CUDA allocator is always loaded at the start of the program
m_is_allocator_loaded = true;
#if defined(CUDA_VERSION)
if (m_use_async_allocator) {
#if CUDA_VERSION >= 11040
int version = 0;
C10_CUDA_CHECK(cudaDriverGetVersion(&version));
TORCH_CHECK(
version >= 11040,
"backend:cudaMallocAsync requires CUDA runtime "
"11.4 or newer, but cudaDriverGetVersion returned ",
version);
#else
TORCH_CHECK(
false,
"backend:cudaMallocAsync requires PyTorch to be built with "
"CUDA 11.4 or newer, but CUDA_VERSION is ",
CUDA_VERSION);
#endif
}
#endif
return i;
#undef PYTORCH_TOKEN1
#undef PYTORCH_TOKEN2
#else // USE_ROCM
consumeToken(config, ++i, ':');
if (++i < config.size()) {
TORCH_CHECK(
((config[i] == "native") || (config[i] == "cudaMallocAsync")),
"Unknown allocator backend, "
"options are native and cudaMallocAsync");
used_cudaMallocAsync = (config[i] == "cudaMallocAsync");
if (used_cudaMallocAsync) {
#if CUDA_VERSION >= 11040
int version = 0;
C10_CUDA_CHECK(cudaDriverGetVersion(&version));
TORCH_CHECK(
version >= 11040,
"backend:cudaMallocAsync requires CUDA runtime "
"11.4 or newer, but cudaDriverGetVersion returned ",
version);
#else
TORCH_CHECK(
false,
"backend:cudaMallocAsync requires PyTorch to be built with "
"CUDA 11.4 or newer, but CUDA_VERSION is ",
CUDA_VERSION);
#endif
}
TORCH_INTERNAL_ASSERT(
config[i] == get()->name(),
"Allocator backend parsed at runtime != "
"allocator backend parsed at load time");
} else {
TORCH_CHECK(false, "Error parsing backend value", "");
}
return i;
#endif // USE_ROCM
}
void CUDAAllocatorConfig::parseArgs(const std::string& env) {
void CUDAAllocatorConfig::parseArgs(const std::optional<std::string>& env) {
// If empty, set the default values
m_max_split_size = std::numeric_limits<size_t>::max();
m_roundup_power2_divisions.assign(kRoundUpPowerOfTwoIntervals, 0);
m_garbage_collection_threshold = 0;
bool used_cudaMallocAsync = false;
bool used_native_specific_option = false;
c10::CachingAllocator::ConfigTokenizer tokenizer(env);
for (size_t i = 0; i < tokenizer.size(); i++) {
const auto& key = tokenizer[i];
if (key == "backend") {
i = parseAllocatorConfig(tokenizer, i);
if (!env.has_value()) {
return;
}
{
std::lock_guard<std::mutex> lock(m_last_allocator_settings_mutex);
m_last_allocator_settings = env.value();
}
std::vector<std::string> config;
lexArgs(env.value(), config);
for (size_t i = 0; i < config.size(); i++) {
std::string_view config_item_view(config[i]);
if (config_item_view == "max_split_size_mb") {
i = parseMaxSplitSize(config, i);
used_native_specific_option = true;
} else if (config_item_view == "max_non_split_rounding_mb") {
i = parseMaxNonSplitRoundingSize(config, i);
used_native_specific_option = true;
} else if (config_item_view == "garbage_collection_threshold") {
i = parseGarbageCollectionThreshold(config, i);
used_native_specific_option = true;
} else if (config_item_view == "roundup_power2_divisions") {
i = parseRoundUpPower2Divisions(config, i);
used_native_specific_option = true;
} else if (config_item_view == "backend") {
i = parseAllocatorConfig(config, i, used_cudaMallocAsync);
} else if (config_item_view == "expandable_segments") {
used_native_specific_option = true;
consumeToken(config, ++i, ':');
++i;
TORCH_CHECK(
i < config.size() &&
(std::string_view(config[i]) == "True" ||
std::string_view(config[i]) == "False"),
"Expected a single True/False argument for expandable_segments");
config_item_view = config[i];
m_expandable_segments = (config_item_view == "True");
} else if (
// ROCm build's hipify step will change "cuda" to "hip", but for ease of
// use, accept both. We must break up the string to prevent hipify here.
key == "release_lock_on_hipmalloc" ||
key ==
config_item_view == "release_lock_on_hipmalloc" ||
config_item_view ==
"release_lock_on_c"
"udamalloc") {
used_native_specific_option = true;
tokenizer.checkToken(++i, ":");
m_release_lock_on_cudamalloc = tokenizer.toBool(++i);
consumeToken(config, ++i, ':');
++i;
TORCH_CHECK(
i < config.size() &&
(std::string_view(config[i]) == "True" ||
std::string_view(config[i]) == "False"),
"Expected a single True/False argument for release_lock_on_cudamalloc");
config_item_view = config[i];
m_release_lock_on_cudamalloc = (config_item_view == "True");
} else if (
// ROCm build's hipify step will change "cuda" to "hip", but for ease of
// use, accept both. We must break up the string to prevent hipify here.
key == "pinned_use_hip_host_register" ||
key ==
config_item_view == "pinned_use_hip_host_register" ||
config_item_view ==
"pinned_use_c"
"uda_host_register") {
i = parsePinnedUseCudaHostRegister(tokenizer, i);
i = parsePinnedUseCudaHostRegister(config, i);
used_native_specific_option = true;
} else if (key == "pinned_num_register_threads") {
i = parsePinnedNumRegisterThreads(tokenizer, i);
} else if (config_item_view == "pinned_num_register_threads") {
i = parsePinnedNumRegisterThreads(config, i);
used_native_specific_option = true;
} else if (config_item_view == "pinned_use_background_threads") {
i = parsePinnedUseBackgroundThreads(config, i);
used_native_specific_option = true;
} else {
const auto& keys =
c10::CachingAllocator::AcceleratorAllocatorConfig::getKeys();
TORCH_CHECK(
keys.find(key) != keys.end(),
"Unrecognized key '",
key,
"' in Accelerator allocator config.");
i = tokenizer.skipKey(i);
false, "Unrecognized CachingAllocator option: ", config_item_view);
}
if (i + 1 < tokenizer.size()) {
tokenizer.checkToken(++i, ",");
if (i + 1 < config.size()) {
consumeToken(config, ++i, ',');
}
}
if (m_use_async_allocator && used_native_specific_option) {
if (used_cudaMallocAsync && used_native_specific_option) {
TORCH_WARN(
"backend:cudaMallocAsync ignores max_split_size_mb,"
"roundup_power2_divisions, and garbage_collect_threshold.");
@ -121,33 +391,64 @@ void CUDAAllocatorConfig::parseArgs(const std::string& env) {
}
size_t CUDAAllocatorConfig::parsePinnedUseCudaHostRegister(
const c10::CachingAllocator::ConfigTokenizer& tokenizer,
const std::vector<std::string>& config,
size_t i) {
tokenizer.checkToken(++i, ":");
m_pinned_use_cuda_host_register = tokenizer.toBool(++i);
consumeToken(config, ++i, ':');
if (++i < config.size()) {
TORCH_CHECK(
(config[i] == "True" || config[i] == "False"),
"Expected a single True/False argument for pinned_use_cuda_host_register");
m_pinned_use_cuda_host_register = (config[i] == "True");
} else {
TORCH_CHECK(
false, "Error, expecting pinned_use_cuda_host_register value", "");
}
return i;
}
size_t CUDAAllocatorConfig::parsePinnedNumRegisterThreads(
const c10::CachingAllocator::ConfigTokenizer& tokenizer,
const std::vector<std::string>& config,
size_t i) {
tokenizer.checkToken(++i, ":");
size_t val2 = tokenizer.toSizeT(++i);
TORCH_CHECK(
llvm::isPowerOf2_64(val2),
"Number of register threads has to be power of 2 ",
"");
auto maxThreads = CUDAAllocatorConfig::pinned_max_register_threads();
TORCH_CHECK(
val2 <= maxThreads,
"Number of register threads should be less than or equal to " +
std::to_string(maxThreads),
"");
m_pinned_num_register_threads = val2;
consumeToken(config, ++i, ':');
if (++i < config.size()) {
size_t val2 = stoi(config[i]);
TORCH_CHECK(
llvm::isPowerOf2_64(val2),
"Number of register threads has to be power of 2 ",
"");
auto maxThreads = CUDAAllocatorConfig::pinned_max_register_threads();
TORCH_CHECK(
val2 <= maxThreads,
"Number of register threads should be less than or equal to " +
std::to_string(maxThreads),
"");
m_pinned_num_register_threads = val2;
} else {
TORCH_CHECK(
false, "Error, expecting pinned_num_register_threads value", "");
}
return i;
}
REGISTER_ALLOCATOR_CONFIG_PARSE_HOOK(CUDAAllocatorConfig)
size_t CUDAAllocatorConfig::parsePinnedUseBackgroundThreads(
const std::vector<std::string>& config,
size_t i) {
consumeToken(config, ++i, ':');
if (++i < config.size()) {
TORCH_CHECK(
(config[i] == "True" || config[i] == "False"),
"Expected a single True/False argument for pinned_use_background_threads");
m_pinned_use_background_threads = (config[i] == "True");
} else {
TORCH_CHECK(
false, "Error, expecting pinned_use_background_threads value", "");
}
return i;
}
// General caching allocator utilities
void setAllocatorSettings(const std::string& env) {
CUDACachingAllocator::CUDAAllocatorConfig::instance().parseArgs(env.c_str());
}
} // namespace c10::cuda::CUDACachingAllocator

View File

@ -1,11 +1,16 @@
#pragma once
#include <c10/core/AllocatorConfig.h>
#include <c10/cuda/CUDAException.h>
#include <c10/cuda/CUDAMacros.h>
#include <c10/util/Exception.h>
#include <c10/util/env.h>
#include <atomic>
#include <cstddef>
#include <cstdlib>
#include <mutex>
#include <string>
#include <vector>
namespace c10::cuda::CUDACachingAllocator {
enum class Expandable_Segments_Handle_Type : int {
@ -18,23 +23,20 @@ enum class Expandable_Segments_Handle_Type : int {
class C10_CUDA_API CUDAAllocatorConfig {
public:
static size_t max_split_size() {
return c10::CachingAllocator::AcceleratorAllocatorConfig::max_split_size();
return instance().m_max_split_size;
}
static double garbage_collection_threshold() {
return c10::CachingAllocator::AcceleratorAllocatorConfig::
garbage_collection_threshold();
return instance().m_garbage_collection_threshold;
}
static bool expandable_segments() {
bool enabled = c10::CachingAllocator::AcceleratorAllocatorConfig::
use_expandable_segments();
#ifndef PYTORCH_C10_DRIVER_API_SUPPORTED
if (enabled) {
if (instance().m_expandable_segments) {
TORCH_WARN_ONCE("expandable_segments not supported on this platform")
}
return false;
#else
return enabled;
return instance().m_expandable_segments;
#endif
}
@ -61,8 +63,7 @@ class C10_CUDA_API CUDAAllocatorConfig {
}
static bool pinned_use_background_threads() {
return c10::CachingAllocator::AcceleratorAllocatorConfig::
pinned_use_background_threads();
return instance().m_pinned_use_background_threads;
}
static size_t pinned_max_register_threads() {
@ -76,97 +77,88 @@ class C10_CUDA_API CUDAAllocatorConfig {
// More description below in function roundup_power2_next_division
// As an example, if we want 4 divisions between 2's power, this can be done
// using env variable: PYTORCH_CUDA_ALLOC_CONF=roundup_power2_divisions:4
static size_t roundup_power2_divisions(size_t size) {
return c10::CachingAllocator::AcceleratorAllocatorConfig::
roundup_power2_divisions(size);
}
static size_t roundup_power2_divisions(size_t size);
static std::vector<size_t> roundup_power2_divisions() {
return c10::CachingAllocator::AcceleratorAllocatorConfig::
roundup_power2_divisions();
return instance().m_roundup_power2_divisions;
}
static size_t max_non_split_rounding_size() {
return c10::CachingAllocator::AcceleratorAllocatorConfig::
max_non_split_rounding_size();
return instance().m_max_non_split_rounding_size;
}
static std::string last_allocator_settings() {
return c10::CachingAllocator::getAllocatorSettings();
}
static bool use_async_allocator() {
return instance().m_use_async_allocator;
}
static const std::unordered_set<std::string>& getKeys() {
return keys_;
std::lock_guard<std::mutex> lock(
instance().m_last_allocator_settings_mutex);
return instance().m_last_allocator_settings;
}
static CUDAAllocatorConfig& instance() {
static CUDAAllocatorConfig* s_instance = ([]() {
auto inst = new CUDAAllocatorConfig();
auto env = c10::utils::get_env("PYTORCH_ALLOC_CONF");
if (!env.has_value()) {
// For backward compatibility, check for the old environment variable
// PYTORCH_CUDA_ALLOC_CONF.
env = c10::utils::get_env("PYTORCH_CUDA_ALLOC_CONF");
}
auto env = c10::utils::get_env("PYTORCH_CUDA_ALLOC_CONF");
#ifdef USE_ROCM
// convenience for ROCm users, allow alternative HIP token
if (!env.has_value()) {
env = c10::utils::get_env("PYTORCH_HIP_ALLOC_CONF");
}
#endif
if (env.has_value()) {
inst->parseArgs(env.value());
}
inst->parseArgs(env);
return inst;
})();
return *s_instance;
}
void parseArgs(const std::string& env);
void parseArgs(const std::optional<std::string>& env);
private:
CUDAAllocatorConfig() = default;
CUDAAllocatorConfig();
size_t parseAllocatorConfig(
const c10::CachingAllocator::ConfigTokenizer& tokenizer,
static void lexArgs(const std::string& env, std::vector<std::string>& config);
static void consumeToken(
const std::vector<std::string>& config,
size_t i,
const char c);
size_t parseMaxSplitSize(const std::vector<std::string>& config, size_t i);
size_t parseMaxNonSplitRoundingSize(
const std::vector<std::string>& config,
size_t i);
size_t parseGarbageCollectionThreshold(
const std::vector<std::string>& config,
size_t i);
size_t parseRoundUpPower2Divisions(
const std::vector<std::string>& config,
size_t i);
size_t parseAllocatorConfig(
const std::vector<std::string>& config,
size_t i,
bool& used_cudaMallocAsync);
size_t parsePinnedUseCudaHostRegister(
const c10::CachingAllocator::ConfigTokenizer& tokenizer,
const std::vector<std::string>& config,
size_t i);
size_t parsePinnedNumRegisterThreads(
const c10::CachingAllocator::ConfigTokenizer& tokenizer,
const std::vector<std::string>& config,
size_t i);
size_t parsePinnedUseBackgroundThreads(
const std::vector<std::string>& config,
size_t i);
std::atomic<size_t> m_pinned_num_register_threads{1};
std::atomic<Expandable_Segments_Handle_Type> m_expandable_segments_handle_type
#if CUDA_VERSION >= 12030
{Expandable_Segments_Handle_Type::UNSPECIFIED};
#else
{Expandable_Segments_Handle_Type::POSIX_FD};
#endif
std::atomic<bool> m_release_lock_on_cudamalloc{false};
std::atomic<bool> m_pinned_use_cuda_host_register{false};
std::atomic<bool> m_use_async_allocator{false};
std::atomic<bool> m_is_allocator_loaded{false};
inline static std::unordered_set<std::string> keys_{
"backend",
// keep BC for Rocm: `cuda` -> `cud` `a`, to avoid hipify issues
// NOLINTBEGIN(bugprone-suspicious-missing-comma,-warnings-as-errors)
"release_lock_on_cud"
"amalloc",
"pinned_use_cud"
"a_host_register",
// NOLINTEND(bugprone-suspicious-missing-comma,-warnings-as-errors)
"release_lock_on_hipmalloc",
"pinned_use_hip_host_register",
"pinned_num_register_threads"};
std::atomic<size_t> m_max_split_size;
std::atomic<size_t> m_max_non_split_rounding_size;
std::vector<size_t> m_roundup_power2_divisions;
std::atomic<double> m_garbage_collection_threshold;
std::atomic<size_t> m_pinned_num_register_threads;
std::atomic<bool> m_expandable_segments;
std::atomic<Expandable_Segments_Handle_Type>
m_expandable_segments_handle_type;
std::atomic<bool> m_release_lock_on_cudamalloc;
std::atomic<bool> m_pinned_use_cuda_host_register;
std::atomic<bool> m_pinned_use_background_threads;
std::string m_last_allocator_settings;
std::mutex m_last_allocator_settings_mutex;
};
// Keep this for backwards compatibility
using c10::CachingAllocator::setAllocatorSettings;
// General caching allocator utilities
C10_CUDA_API void setAllocatorSettings(const std::string& env);
} // namespace c10::cuda::CUDACachingAllocator

View File

@ -1,6 +1,7 @@
#include <c10/cuda/CUDACachingAllocator.h>
#include <c10/core/impl/GPUTrace.h>
#include <c10/cuda/CUDAAllocatorConfig.h>
#include <c10/cuda/CUDAException.h>
#include <c10/cuda/CUDAFunctions.h>
#include <c10/cuda/CUDAGuard.h>
@ -63,6 +64,10 @@ namespace cuda::CUDACachingAllocator {
using namespace c10::CachingAllocator;
using namespace c10::CachingDeviceAllocator;
// Included here as this is externally used in CUDAAllocatorConfig
const size_t kLargeBuffer =
20971520; // "large" allocations may be packed in 20 MiB blocks
namespace Native {
//
@ -4123,10 +4128,49 @@ CUDAAllocator* allocator();
} // namespace CudaMallocAsync
struct BackendStaticInitializer {
// Parses env for backend at load time, duplicating some logic from
// CUDAAllocatorConfig. CUDAAllocatorConfig double-checks it later (at
// runtime). Defers verbose exceptions and error checks, including Cuda
// version checks, to CUDAAllocatorConfig's runtime doublecheck. If this
// works, maybe we should move all of CUDAAllocatorConfig here?
CUDAAllocator* parseEnvForBackend() {
// If the environment variable is set, we use the CudaMallocAsync allocator.
if (CUDAAllocatorConfig::use_async_allocator()) {
return CudaMallocAsync::allocator();
auto val = c10::utils::get_env("PYTORCH_CUDA_ALLOC_CONF");
#ifdef USE_ROCM
// convenience for ROCm users to allow either CUDA or HIP env var
if (!val.has_value()) {
val = c10::utils::get_env("PYTORCH_HIP_ALLOC_CONF");
}
#endif
if (val.has_value()) {
const std::string& config = val.value();
std::regex exp("[\\s,]+");
std::sregex_token_iterator it(config.begin(), config.end(), exp, -1);
std::sregex_token_iterator end;
std::vector<std::string> options(it, end);
for (auto option : options) {
std::regex exp2("[:]+");
std::sregex_token_iterator it2(option.begin(), option.end(), exp2, -1);
std::sregex_token_iterator end2;
std::vector<std::string> kv(it2, end2);
if (kv.size() >= 2) {
if (kv[0] == "backend") {
#ifdef USE_ROCM
// convenience for ROCm users to allow either CUDA or HIP env var
if (kv[1] ==
"cud"
"aMallocAsync" ||
kv[1] == "hipMallocAsync")
#else
if (kv[1] == "cudaMallocAsync")
#endif
return CudaMallocAsync::allocator();
if (kv[1] == "native")
return &Native::allocator;
}
}
}
}
return &Native::allocator;
}

View File

@ -1,7 +1,6 @@
#pragma once
#include <c10/core/CachingDeviceAllocator.h>
#include <c10/cuda/CUDAAllocatorConfig.h>
#include <c10/cuda/CUDAGraphsC10Utils.h>
#include <c10/cuda/CUDAMacros.h>
#include <c10/cuda/CUDAStream.h>
@ -50,9 +49,10 @@ namespace c10::cuda::CUDACachingAllocator {
// Preserved only for BC reasons
// NOLINTNEXTLINE(misc-unused-using-decls)
using c10::CachingAllocator::kLargeBuffer;
using c10::CachingDeviceAllocator::DeviceStats;
extern const size_t kLargeBuffer;
typedef std::shared_ptr<GatheredContext> (*CreateContextFn)();
// Struct containing info of an allocation block (i.e. a fractional part of a