Compare commits

...

2 Commits

Author SHA1 Message Date
d3c02f5869 Remove support from CUDAMallocAsyncAllocator
Async allocator does not support other config options either
2025-11-02 01:20:29 -07:00
b944315b93 Add per_process_memory_fraction to PYTORCH_CUDA_ALLOC_CONF
torch.cuda.memory.set_per_process_memory_fraction allows setting
an upper bound on how much device memory is allocated. This PR
exposes this setting to an environment variable.

For example, PYTORCH_CUDA_ALLOC_CONF="per_process_memory_fraction:0.5"
will limit the device memory to half of the available memory.
2025-11-02 01:20:29 -07:00
7 changed files with 108 additions and 37 deletions

View File

@ -106,6 +106,9 @@ void CUDAAllocatorConfig::parseArgs(const std::string& env) {
} else if (key == "graph_capture_record_stream_reuse") {
i = parseGraphCaptureRecordStreamReuse(tokenizer, i);
used_native_specific_option = true;
} else if (key == "per_process_memory_fraction") {
i = parsePerProcessMemoryFraction(tokenizer, i);
used_native_specific_option = true;
} else {
const auto& keys =
c10::CachingAllocator::AcceleratorAllocatorConfig::getKeys();
@ -146,6 +149,18 @@ size_t CUDAAllocatorConfig::parseGraphCaptureRecordStreamReuse(
return i;
}
double CUDAAllocatorConfig::parsePerProcessMemoryFraction(
const c10::CachingAllocator::ConfigTokenizer& tokenizer,
size_t i) {
tokenizer.checkToken(++i, ":");
double val_env = tokenizer.toDouble(++i);
TORCH_CHECK_VALUE(
val_env >= 0.0 && val_env <= 1.0,
"per_process_memory_fraction is invalid, set it in [0.0, 1.0]");
m_per_process_memory_fraction = val_env;
return i;
}
size_t CUDAAllocatorConfig::parsePinnedNumRegisterThreads(
const c10::CachingAllocator::ConfigTokenizer& tokenizer,
size_t i) {

View File

@ -61,6 +61,10 @@ class C10_CUDA_API CUDAAllocatorConfig {
return instance().m_graph_capture_record_stream_reuse;
}
static double per_process_memory_fraction() {
return instance().m_per_process_memory_fraction;
}
/** Pinned memory allocator settings */
static bool pinned_use_cuda_host_register() {
return instance().m_pinned_use_cuda_host_register;
@ -152,7 +156,8 @@ class C10_CUDA_API CUDAAllocatorConfig {
"pinned_use_hip_host_register",
"graph_capture_record_stream_reuse",
"pinned_reserve_segment_size_mb",
"pinned_num_register_threads"};
"pinned_num_register_threads",
"per_process_memory_fraction"};
return keys;
}
@ -177,6 +182,9 @@ class C10_CUDA_API CUDAAllocatorConfig {
size_t parseGraphCaptureRecordStreamReuse(
const c10::CachingAllocator::ConfigTokenizer& tokenizer,
size_t i);
double parsePerProcessMemoryFraction(
const c10::CachingAllocator::ConfigTokenizer& tokenizer,
size_t i);
std::atomic<size_t> m_pinned_num_register_threads{1};
std::atomic<size_t> m_pinned_reserve_segment_size_mb{0};
@ -189,6 +197,7 @@ class C10_CUDA_API CUDAAllocatorConfig {
std::atomic<bool> m_release_lock_on_cudamalloc{false};
std::atomic<bool> m_pinned_use_cuda_host_register{false};
std::atomic<bool> m_graph_capture_record_stream_reuse{false};
std::atomic<double> m_per_process_memory_fraction{1.0};
};
// Keep this for backwards compatibility

View File

@ -1100,7 +1100,7 @@ class RingBuffer {
} // anonymous namespace
} // namespace Native
static std::string reportProcessMemoryInfo(c10::DeviceIndex device) {
static std::string reportProcessMemoryInfo(const cudaDeviceProp& prop) {
#ifdef PYTORCH_C10_DRIVER_API_SUPPORTED
void* nvml_handle = DriverAPI::get_nvml_handle();
if (!nvml_handle) {
@ -1111,9 +1111,6 @@ static std::string reportProcessMemoryInfo(c10::DeviceIndex device) {
return true;
}();
cudaDeviceProp prop{};
C10_CUDA_CHECK(cudaGetDeviceProperties(&prop, device));
// NOLINTNEXTLINE(*-c-arrays)
char pci_id[80];
snprintf(
@ -1215,14 +1212,16 @@ class DeviceCachingAllocator {
// record used memory.
size_t total_allocated_memory = 0;
size_t allowed_memory_maximum = 0;
cudaDeviceProp device_prop;
// maximum amount of memory that device is allowed to
// allocate. This is set iff memory fraction is less than 1
std::optional<size_t> allowed_memory_maximum{std::nullopt};
// all live expandable segments
std::vector<ExpandableSegment*> expandable_segments_;
std::vector<c10::DeviceIndex> devices_with_peer_access_;
bool set_fraction = false;
bool record_history = false;
std::atomic<CreateContextFn> context_recorder_;
@ -1264,6 +1263,9 @@ class DeviceCachingAllocator {
: device_id(id),
large_blocks(/*small=*/false),
small_blocks(/*small=*/true) {
C10_CUDA_CHECK(cudaGetDeviceProperties(&device_prop, id));
setMemoryFraction(CUDAAllocatorConfig::per_process_memory_fraction());
stats.max_split_size =
static_cast<int64_t>(AcceleratorAllocatorConfig::max_split_size());
context_recorder_.store(nullptr);
@ -1399,7 +1401,7 @@ class DeviceCachingAllocator {
if (!block_found) {
// Do garbage collection if the flag is set.
if (C10_UNLIKELY(
set_fraction &&
allowed_memory_maximum.has_value() &&
AcceleratorAllocatorConfig::garbage_collection_threshold() >
0.0)) {
garbage_collect_cached_blocks(context);
@ -1456,11 +1458,12 @@ class DeviceCachingAllocator {
C10_CUDA_CHECK(cudaMemGetInfo(&device_free, &device_total));
std::string allowed_info;
if (set_fraction) {
allowed_info = format_size(allowed_memory_maximum) + " allowed; ";
if (allowed_memory_maximum.has_value()) {
allowed_info =
format_size(allowed_memory_maximum.value()) + " allowed; ";
}
std::string proc_info = reportProcessMemoryInfo(device_id);
std::string proc_info = reportProcessMemoryInfo(device_prop);
record_trace(
TraceEntry::OOM,
@ -1518,7 +1521,7 @@ class DeviceCachingAllocator {
for (const auto& obs : observers_local) {
obs(device_id,
alloc_size,
set_fraction ? allowed_memory_maximum : device_total,
allowed_memory_maximum.value_or(device_total),
device_free);
}
@ -2015,25 +2018,26 @@ class DeviceCachingAllocator {
/** get memory fraction limiting maximum allocated memory **/
double getMemoryFraction() {
if (!set_fraction) {
if (!allowed_memory_maximum.has_value()) {
return 1.0;
}
size_t device_free = 0;
size_t device_total = 0;
C10_CUDA_CHECK(cudaMemGetInfo(&device_free, &device_total));
return static_cast<double>(allowed_memory_maximum) /
static_cast<double>(device_total);
return static_cast<double>(allowed_memory_maximum.value()) /
static_cast<double>(device_prop.totalGlobalMem);
}
/** set memory fraction to limit maximum allocated memory **/
void setMemoryFraction(double fraction) {
size_t device_free = 0;
size_t device_total = 0;
C10_CUDA_CHECK(cudaMemGetInfo(&device_free, &device_total));
allowed_memory_maximum =
static_cast<size_t>(fraction * static_cast<double>(device_total));
set_fraction = true;
TORCH_CHECK(
0 <= fraction && fraction <= 1,
"invalid fraction:",
fraction,
". Please set within [0, 1].");
allowed_memory_maximum = std::nullopt;
if (fraction < 1.0) {
allowed_memory_maximum = static_cast<size_t>(
fraction * static_cast<double>(device_prop.totalGlobalMem));
}
}
/** get expandable segment size for all the streams on device **/
@ -3010,7 +3014,7 @@ class DeviceCachingAllocator {
BlockPool& pool = *p.pool;
if (C10_UNLIKELY(
set_fraction &&
allowed_memory_maximum.has_value() &&
AcceleratorAllocatorConfig::garbage_collection_threshold() > 0.0)) {
// Track block reuse interval only when garbage collection is enabled.
++pool.get_free_blocks_call_count;
@ -3083,7 +3087,7 @@ class DeviceCachingAllocator {
size_t gc_threshold = static_cast<size_t>(
AcceleratorAllocatorConfig::garbage_collection_threshold() *
static_cast<double>(allowed_memory_maximum));
static_cast<double>(allowed_memory_maximum.value()));
// No need to trigger GC yet
if (total_allocated_memory <= gc_threshold) {
return;
@ -3161,8 +3165,8 @@ class DeviceCachingAllocator {
bool active_pool =
p.pool->owner_PrivatePool && p.pool->owner_PrivatePool->allocator();
if (set_fraction &&
total_allocated_memory + size > allowed_memory_maximum) {
if (allowed_memory_maximum.has_value() &&
total_allocated_memory + size > allowed_memory_maximum.value()) {
p.err = cudaErrorMemoryAllocation;
return false;
// Temporarily disable checkpointing & cudagraphs internally
@ -3859,7 +3863,6 @@ class NativeCachingAllocator : public CUDAAllocator {
"Allocator not initialized for device ",
device,
": did you call init?");
C10_CUDA_CHECK(c10::cuda::SetDevice(device));
return device_allocator[device]->getMemoryFraction();
}
@ -3869,12 +3872,6 @@ class NativeCachingAllocator : public CUDAAllocator {
"Allocator not initialized for device ",
device,
": did you call init?");
TORCH_CHECK(
0 <= fraction && fraction <= 1,
"invalid fraction:",
fraction,
". Please set within [0, 1].");
C10_CUDA_CHECK(c10::cuda::SetDevice(device));
device_allocator[device]->setMemoryFraction(fraction);
}

View File

@ -2,6 +2,7 @@
#include <c10/core/AllocatorConfig.h>
#include <c10/core/CachingDeviceAllocator.h>
#include <c10/cuda/CUDAAllocatorConfig.h>
#include <c10/cuda/CUDAGraphsC10Utils.h>
#include <c10/cuda/CUDAMacros.h>
#include <c10/cuda/CUDAStream.h>

View File

@ -427,7 +427,6 @@ struct CudaMallocAsyncAllocator : public CUDAAllocator {
// on the current device each later call sees.
void init(int dev_count) override {
static bool called = [](int dev_count) {
;
// Are there external guarantees init will be called before
// any of the allocator's other functions?
// std::lock_guard<std::mutex> lk(general_mutex);

View File

@ -619,6 +619,10 @@ Available options:
and reallocate buffers across multiple streams, especially when the capture DAG frequently
reaches joined frontiers.
* ``per_process_memory_fraction`` option limits the amount of memory that can be allocated
on all the CUDA devices to a specified fraction of the available memory. This is a value
between 0 and 1. Attempting to allocate more memory will raise an out of memory error.
.. note::
Some stats reported by the

View File

@ -4633,6 +4633,52 @@ print(torch.cuda.get_allocator_backend())
rc = check_output(test_script)
self.assertEqual(rc, "cudaMallocAsync")
def test_allocator_memory_fraction_setting(self):
def make_env(fraction):
env = os.environ.copy()
var = "PYTORCH_CUDA_ALLOC_CONF"
key = "per_process_memory_fraction"
value = [
x
for x in env.get(var, "").split(",")
if len(x) > 0 and not x.startswith(f"{key}:")
]
value.append(f"{key}:{fraction}")
env[var] = ",".join(value)
return env
def run_test(value):
test_script = """\
import os
import torch
device = torch._C._cuda_getDevice()
value = torch.cuda.memory.get_per_process_memory_fraction(device)
print(value, end="")
"""
return subprocess.run(
[sys.executable, "-c", test_script],
env=make_env(value),
text=True,
check=True,
capture_output=True,
)
self.assertEqual(run_test(0.0).stdout, "0.0")
self.assertEqual(run_test(0.5).stdout, "0.5")
self.assertEqual(run_test(1.0).stdout, "1.0")
with self.assertRaises(subprocess.CalledProcessError) as e:
run_test(-0.1)
assert "per_process_memory_fraction is invalid" in e.exception.stderr, (
e.exception.stderr
)
with self.assertRaises(subprocess.CalledProcessError) as e:
run_test(1.1)
assert "per_process_memory_fraction is invalid" in e.exception.stderr, (
e.exception.stderr
)
def test_cachingAllocator_raw_alloc(self):
# Test that raw_alloc respects the setting that
# activates/deactivates the caching allocator