Remove device_id param from DeviceCachingAllocator::malloc (#164798)

The `malloc` call in DeviceCachingAllocator accepts a DeviceIndex param which
can be confusion because the allocator can only allocate memory for the device
that it corresponds to. This associated device is fixed at construction time
and the runtime param can be misleading.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164798
Approved by: https://github.com/ngimel, https://github.com/cyyever, https://github.com/eqy
This commit is contained in:
Lakshay Garg
2025-10-07 16:42:04 +00:00
committed by PyTorch MergeBot
parent ee5389d520
commit 5e47b4dd60

View File

@ -1183,6 +1183,8 @@ class DeviceCachingAllocator {
// device statistics // device statistics
DeviceStats stats; DeviceStats stats;
c10::DeviceIndex device_id;
// unallocated cached blocks larger than 1 MB // unallocated cached blocks larger than 1 MB
BlockPool large_blocks; BlockPool large_blocks;
@ -1271,8 +1273,10 @@ class DeviceCachingAllocator {
public: public:
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
DeviceCachingAllocator() explicit DeviceCachingAllocator(c10::DeviceIndex id)
: large_blocks(/*small=*/false), small_blocks(/*small=*/true) { : device_id(id),
large_blocks(/*small=*/false),
small_blocks(/*small=*/true) {
stats.max_split_size = stats.max_split_size =
static_cast<int64_t>(CUDAAllocatorConfig::max_split_size()); static_cast<int64_t>(CUDAAllocatorConfig::max_split_size());
context_recorder_.store(nullptr); context_recorder_.store(nullptr);
@ -1358,10 +1362,7 @@ class DeviceCachingAllocator {
// All public methods (except the above) acquire the allocator mutex. // All public methods (except the above) acquire the allocator mutex.
// Thus, do not call a public method from another public method. // Thus, do not call a public method from another public method.
Block* malloc( Block* malloc(size_t orig_size, cudaStream_t stream) {
c10::DeviceIndex device,
size_t orig_size,
cudaStream_t stream) {
// done outside the lock because we don't know what locks the recorder needs // done outside the lock because we don't know what locks the recorder needs
// to have... // to have...
auto context = maybeGatherContext(RecordContext::STATE); auto context = maybeGatherContext(RecordContext::STATE);
@ -1389,7 +1390,7 @@ class DeviceCachingAllocator {
size_t size = round_size(orig_size); size_t size = round_size(orig_size);
auto& pool = get_pool(size, stream); auto& pool = get_pool(size, stream);
const size_t alloc_size = get_allocation_size(size); const size_t alloc_size = get_allocation_size(size);
AllocParams params(device, size, stream, &pool, alloc_size); AllocParams params(device_id, size, stream, &pool, alloc_size);
params.stat_types = get_stat_types_for_pool(pool); params.stat_types = get_stat_types_for_pool(pool);
// First, try to get a block from the existing pool. // First, try to get a block from the existing pool.
@ -1436,7 +1437,7 @@ class DeviceCachingAllocator {
beginAllocateToPool(mempool_id, filter); beginAllocateToPool(mempool_id, filter);
auto& mempool = get_pool(size, stream); auto& mempool = get_pool(size, stream);
AllocParams mempool_params( AllocParams mempool_params(
device, size, stream, &mempool, alloc_size); device_id, size, stream, &mempool, alloc_size);
mempool_params.stat_types = get_stat_types_for_pool(mempool); mempool_params.stat_types = get_stat_types_for_pool(mempool);
block_found = get_free_block(mempool_params); block_found = get_free_block(mempool_params);
endAllocateToPool(mempool_id); endAllocateToPool(mempool_id);
@ -1463,7 +1464,7 @@ class DeviceCachingAllocator {
allowed_info = format_size(allowed_memory_maximum) + " allowed; "; allowed_info = format_size(allowed_memory_maximum) + " allowed; ";
} }
std::string proc_info = reportProcessMemoryInfo(device); std::string proc_info = reportProcessMemoryInfo(device_id);
record_trace( record_trace(
TraceEntry::OOM, TraceEntry::OOM,
@ -1481,7 +1482,7 @@ class DeviceCachingAllocator {
.current, .current,
stats.reserved_bytes[static_cast<int64_t>(StatType::AGGREGATE)] stats.reserved_bytes[static_cast<int64_t>(StatType::AGGREGATE)]
.current, .current,
c10::Device(c10::DeviceType::CUDA, device)); c10::Device(c10::DeviceType::CUDA, device_id));
auto allocated_bytes = auto allocated_bytes =
stats.allocated_bytes[static_cast<size_t>(StatType::AGGREGATE)] stats.allocated_bytes[static_cast<size_t>(StatType::AGGREGATE)]
@ -1519,7 +1520,7 @@ class DeviceCachingAllocator {
lock.unlock(); lock.unlock();
for (const auto& obs : observers_local) { for (const auto& obs : observers_local) {
obs(device, obs(device_id,
alloc_size, alloc_size,
set_fraction ? allowed_memory_maximum : device_total, set_fraction ? allowed_memory_maximum : device_total,
device_free); device_free);
@ -1549,7 +1550,7 @@ class DeviceCachingAllocator {
"CUDA out of memory. Tried to allocate ", "CUDA out of memory. Tried to allocate ",
format_size(alloc_size), format_size(alloc_size),
". GPU ", ". GPU ",
static_cast<int>(device), static_cast<int>(device_id),
" has a total capacity of ", " has a total capacity of ",
format_size(device_total), format_size(device_total),
" of which ", " of which ",
@ -3809,7 +3810,8 @@ class NativeCachingAllocator : public CUDAAllocator {
if (size < device_count) { if (size < device_count) {
device_allocator.resize(device_count); device_allocator.resize(device_count);
for (const auto i : c10::irange(size, device_count)) { for (const auto i : c10::irange(size, device_count)) {
device_allocator[i] = std::make_unique<DeviceCachingAllocator>(); device_allocator[i] =
std::make_unique<DeviceCachingAllocator>(c10::DeviceIndex(i));
} }
} }
} }
@ -3829,7 +3831,7 @@ class NativeCachingAllocator : public CUDAAllocator {
"Allocator not initialized for device ", "Allocator not initialized for device ",
device, device,
": did you call init?"); ": did you call init?");
Block* block = device_allocator[device]->malloc(device, size, stream); Block* block = device_allocator[device]->malloc(size, stream);
add_allocated_block(block); add_allocated_block(block);
*devPtr = block->ptr; *devPtr = block->ptr;
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();