Remove device_id param from DeviceCachingAllocator::malloc (#164798)

The `malloc` call in DeviceCachingAllocator accepts a DeviceIndex param which
can be confusion because the allocator can only allocate memory for the device
that it corresponds to. This associated device is fixed at construction time
and the runtime param can be misleading.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164798
Approved by: https://github.com/ngimel, https://github.com/cyyever, https://github.com/eqy
This commit is contained in:
Lakshay Garg
2025-10-07 16:42:04 +00:00
committed by PyTorch MergeBot
parent ee5389d520
commit 5e47b4dd60

View File

@ -1183,6 +1183,8 @@ class DeviceCachingAllocator {
// device statistics
DeviceStats stats;
c10::DeviceIndex device_id;
// unallocated cached blocks larger than 1 MB
BlockPool large_blocks;
@ -1271,8 +1273,10 @@ class DeviceCachingAllocator {
public:
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
DeviceCachingAllocator()
: large_blocks(/*small=*/false), small_blocks(/*small=*/true) {
explicit DeviceCachingAllocator(c10::DeviceIndex id)
: device_id(id),
large_blocks(/*small=*/false),
small_blocks(/*small=*/true) {
stats.max_split_size =
static_cast<int64_t>(CUDAAllocatorConfig::max_split_size());
context_recorder_.store(nullptr);
@ -1358,10 +1362,7 @@ class DeviceCachingAllocator {
// All public methods (except the above) acquire the allocator mutex.
// Thus, do not call a public method from another public method.
Block* malloc(
c10::DeviceIndex device,
size_t orig_size,
cudaStream_t stream) {
Block* malloc(size_t orig_size, cudaStream_t stream) {
// done outside the lock because we don't know what locks the recorder needs
// to have...
auto context = maybeGatherContext(RecordContext::STATE);
@ -1389,7 +1390,7 @@ class DeviceCachingAllocator {
size_t size = round_size(orig_size);
auto& pool = get_pool(size, stream);
const size_t alloc_size = get_allocation_size(size);
AllocParams params(device, size, stream, &pool, alloc_size);
AllocParams params(device_id, size, stream, &pool, alloc_size);
params.stat_types = get_stat_types_for_pool(pool);
// First, try to get a block from the existing pool.
@ -1436,7 +1437,7 @@ class DeviceCachingAllocator {
beginAllocateToPool(mempool_id, filter);
auto& mempool = get_pool(size, stream);
AllocParams mempool_params(
device, size, stream, &mempool, alloc_size);
device_id, size, stream, &mempool, alloc_size);
mempool_params.stat_types = get_stat_types_for_pool(mempool);
block_found = get_free_block(mempool_params);
endAllocateToPool(mempool_id);
@ -1463,7 +1464,7 @@ class DeviceCachingAllocator {
allowed_info = format_size(allowed_memory_maximum) + " allowed; ";
}
std::string proc_info = reportProcessMemoryInfo(device);
std::string proc_info = reportProcessMemoryInfo(device_id);
record_trace(
TraceEntry::OOM,
@ -1481,7 +1482,7 @@ class DeviceCachingAllocator {
.current,
stats.reserved_bytes[static_cast<int64_t>(StatType::AGGREGATE)]
.current,
c10::Device(c10::DeviceType::CUDA, device));
c10::Device(c10::DeviceType::CUDA, device_id));
auto allocated_bytes =
stats.allocated_bytes[static_cast<size_t>(StatType::AGGREGATE)]
@ -1519,7 +1520,7 @@ class DeviceCachingAllocator {
lock.unlock();
for (const auto& obs : observers_local) {
obs(device,
obs(device_id,
alloc_size,
set_fraction ? allowed_memory_maximum : device_total,
device_free);
@ -1549,7 +1550,7 @@ class DeviceCachingAllocator {
"CUDA out of memory. Tried to allocate ",
format_size(alloc_size),
". GPU ",
static_cast<int>(device),
static_cast<int>(device_id),
" has a total capacity of ",
format_size(device_total),
" of which ",
@ -3809,7 +3810,8 @@ class NativeCachingAllocator : public CUDAAllocator {
if (size < device_count) {
device_allocator.resize(device_count);
for (const auto i : c10::irange(size, device_count)) {
device_allocator[i] = std::make_unique<DeviceCachingAllocator>();
device_allocator[i] =
std::make_unique<DeviceCachingAllocator>(c10::DeviceIndex(i));
}
}
}
@ -3829,7 +3831,7 @@ class NativeCachingAllocator : public CUDAAllocator {
"Allocator not initialized for device ",
device,
": did you call init?");
Block* block = device_allocator[device]->malloc(device, size, stream);
Block* block = device_allocator[device]->malloc(size, stream);
add_allocated_block(block);
*devPtr = block->ptr;
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();