Refine CUDA BackendStaticInitializer for allocator select (#165298)

* #165288
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165298
Approved by: https://github.com/albanD
ghstack dependencies: #165288, #165289, #165291
This commit is contained in:
Yu, Guangye
2025-10-17 17:16:46 +00:00
committed by PyTorch MergeBot
parent b2f5c25b27
commit 1ba808dd97
2 changed files with 46 additions and 26 deletions

View File

@ -4453,11 +4453,12 @@ CUDAAllocator* allocator();
} // namespace CudaMallocAsync } // namespace CudaMallocAsync
struct BackendStaticInitializer { struct BackendStaticInitializer {
// Parses env for backend at load time, duplicating some logic from // Parses the environment configuration for CUDA/ROCm allocator backend at
// CUDAAllocatorConfig. CUDAAllocatorConfig double-checks it later (at // load time. This duplicates some logic from CUDAAllocatorConfig to ensure
// runtime). Defers verbose exceptions and error checks, including Cuda // lazy initialization without triggering global static constructors. The
// version checks, to CUDAAllocatorConfig's runtime doublecheck. If this // function looks for the key "backend" and returns the appropriate allocator
// works, maybe we should move all of CUDAAllocatorConfig here? // instance based on its value. If no valid configuration is found, it falls
// back to the default Native allocator.
CUDAAllocator* parseEnvForBackend() { CUDAAllocator* parseEnvForBackend() {
auto val = c10::utils::get_env("PYTORCH_CUDA_ALLOC_CONF"); auto val = c10::utils::get_env("PYTORCH_CUDA_ALLOC_CONF");
#ifdef USE_ROCM #ifdef USE_ROCM
@ -4466,34 +4467,35 @@ struct BackendStaticInitializer {
val = c10::utils::get_env("PYTORCH_HIP_ALLOC_CONF"); val = c10::utils::get_env("PYTORCH_HIP_ALLOC_CONF");
} }
#endif #endif
if (!val.has_value()) {
val = c10::utils::get_env("PYTORCH_ALLOC_CONF");
}
if (val.has_value()) { if (val.has_value()) {
const std::string& config = val.value(); c10::CachingAllocator::ConfigTokenizer tokenizer(val.value());
for (size_t i = 0; i < tokenizer.size(); i++) {
std::regex exp("[\\s,]+"); const auto& key = tokenizer[i];
std::sregex_token_iterator it(config.begin(), config.end(), exp, -1); if (key == "backend") {
std::sregex_token_iterator end; tokenizer.checkToken(++i, ":");
std::vector<std::string> options(it, end); i++; // Move to the value after the colon
if (tokenizer[i] == "cudaMallocAsync"
for (auto option : options) {
std::regex exp2("[:]+");
std::sregex_token_iterator it2(option.begin(), option.end(), exp2, -1);
std::sregex_token_iterator end2;
std::vector<std::string> kv(it2, end2);
if (kv.size() >= 2) {
if (kv[0] == "backend") {
#ifdef USE_ROCM #ifdef USE_ROCM
// convenience for ROCm users to allow either CUDA or HIP env var // convenience for ROCm users to allow either CUDA or HIP env var
if (kv[1] == "cudaMallocAsync" || kv[1] == "hipMallocAsync") || tokenizer[i] == "hipMallocAsync"
#else
if (kv[1] == "cudaMallocAsync")
#endif #endif
) {
return CudaMallocAsync::allocator(); return CudaMallocAsync::allocator();
if (kv[1] == "native") }
return &Native::allocator; break;
} } else {
// Skip the key and its value
i = tokenizer.skipKey(i);
}
if (i + 1 < tokenizer.size()) {
tokenizer.checkToken(++i, ",");
} }
} }
} }
// Default fallback allocator.
return &Native::allocator; return &Native::allocator;
} }

View File

@ -4613,6 +4613,24 @@ class TestCudaMallocAsync(TestCase):
"pinned_num_register_threads:1024" "pinned_num_register_threads:1024"
) )
def test_allocator_backend(self):
def check_output(script: str) -> str:
return (
subprocess.check_output([sys.executable, "-c", script])
.decode("ascii")
.strip()
)
test_script = """\
import os
os.environ["PYTORCH_ALLOC_CONF"] = "max_split_size_mb:20,backend:cudaMallocAsync,release_lock_on_cudamalloc:none"
import torch
torch.cuda.init()
print(torch.cuda.get_allocator_backend())
"""
rc = check_output(test_script)
self.assertEqual(rc, "cudaMallocAsync")
def test_cachingAllocator_raw_alloc(self): def test_cachingAllocator_raw_alloc(self):
# Test that raw_alloc respects the setting that # Test that raw_alloc respects the setting that
# activates/deactivates the caching allocator # activates/deactivates the caching allocator