Resubmit Remove MemPoolContext (#154042) (#154746)

Summary: Per title

Test Plan: Added tests + existing tests

Differential Revision: D75695030

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154746
Approved by: https://github.com/malfet
This commit is contained in:
Natalia Gimelshein
2025-05-31 01:21:54 +00:00
committed by PyTorch MergeBot
parent 932733e0e6
commit f01e628e3b
14 changed files with 138 additions and 243 deletions

View File

@ -833,8 +833,9 @@ class EventPool {
// CUDA graphs helper // CUDA graphs helper
struct PrivatePool { struct PrivatePool {
PrivatePool(MempoolId_t id) PrivatePool(MempoolId_t id, CUDAAllocator* allocator = nullptr)
: id(std::move(id)), : id(std::move(id)),
allocator_(allocator),
large_blocks(/*small=*/false, this), large_blocks(/*small=*/false, this),
small_blocks(/*small=*/true, this) {} small_blocks(/*small=*/true, this) {}
PrivatePool(const PrivatePool&) = delete; PrivatePool(const PrivatePool&) = delete;
@ -855,8 +856,14 @@ struct PrivatePool {
// distinguish private blocks by adding a "pool id" check above the stream // distinguish private blocks by adding a "pool id" check above the stream
// check in BlockComparator. BlockComparator is performance- critical though, // check in BlockComparator. BlockComparator is performance- critical though,
// I'd rather not add more logic to it. // I'd rather not add more logic to it.
CUDAAllocator* allocator_;
BlockPool large_blocks; BlockPool large_blocks;
BlockPool small_blocks; BlockPool small_blocks;
public:
CUDAAllocator* allocator() {
return allocator_;
}
}; };
MempoolId_t BlockPool::owner_MempoolId() const { MempoolId_t BlockPool::owner_MempoolId() const {
@ -905,9 +912,8 @@ struct MempoolIdHash {
}; };
cudaError_t allocPrimitive(void** ptr, size_t size, AllocParams& p) { cudaError_t allocPrimitive(void** ptr, size_t size, AllocParams& p) {
auto active_pool = MemPoolContext::getActiveMemPool(); if (p.pool->owner_PrivatePool && p.pool->owner_PrivatePool->allocator()) {
if (active_pool && active_pool->allocator() && p.pool->owner_PrivatePool) { *ptr = p.pool->owner_PrivatePool->allocator()->raw_alloc(size);
*ptr = active_pool->allocator()->raw_alloc(size);
return *ptr ? cudaSuccess : cudaErrorMemoryAllocation; return *ptr ? cudaSuccess : cudaErrorMemoryAllocation;
} else { } else {
return C10_CUDA_ERROR_HANDLED(cudaMalloc(ptr, size)); return C10_CUDA_ERROR_HANDLED(cudaMalloc(ptr, size));
@ -1277,14 +1283,14 @@ class DeviceCachingAllocator {
alloc_block(params, false, context, lock)) alloc_block(params, false, context, lock))
// Free all non-split cached blocks and retry alloc. // Free all non-split cached blocks and retry alloc.
|| (C10_LIKELY(captures_underway.empty()) && || (C10_LIKELY(captures_underway.empty()) &&
release_cached_blocks(context) && release_cached_blocks(context, {0, 0}) &&
alloc_block(params, true, context, lock)); alloc_block(params, true, context, lock));
} }
// we are about to oom, try to use existing mempools as a last resort // we are about to oom, try to use existing mempools as a last resort
if (!block_found && params.err == cudaErrorMemoryAllocation) { if (!block_found && params.err == cudaErrorMemoryAllocation) {
// if already trying to use a mempool, then just oom // if already trying to use a mempool, then just oom
auto active_pool = MemPoolContext::getActiveMemPool(); bool active_pool = params.pool->owner_PrivatePool;
if (!active_pool) { if (!active_pool) {
for (MempoolId_t mempool_id : use_on_oom_pools) { for (MempoolId_t mempool_id : use_on_oom_pools) {
auto tid = std::this_thread::get_id(); auto tid = std::this_thread::get_id();
@ -1671,10 +1677,10 @@ class DeviceCachingAllocator {
} }
/** returns cached blocks to the system allocator **/ /** returns cached blocks to the system allocator **/
void emptyCache() { void emptyCache(MempoolId_t mempool_id) {
auto context = maybeGatherContext(RecordContext::ALL); auto context = maybeGatherContext(RecordContext::ALL);
std::lock_guard<std::recursive_mutex> lock(mutex); std::lock_guard<std::recursive_mutex> lock(mutex);
release_cached_blocks(context); release_cached_blocks(context, mempool_id);
} }
/** Retrieves size of largest unused block held by the memory cache **/ /** Retrieves size of largest unused block held by the memory cache **/
@ -1992,16 +1998,10 @@ class DeviceCachingAllocator {
/** Dump a complete snapshot of the memory held by the allocator. Potentially /** Dump a complete snapshot of the memory held by the allocator. Potentially
* VERY expensive. **/ * VERY expensive. **/
std::vector<SegmentInfo> snapshot() { std::vector<SegmentInfo> snapshot(MempoolId_t mempool_id) {
std::lock_guard<std::recursive_mutex> lock(mutex); std::lock_guard<std::recursive_mutex> lock(mutex);
std::vector<Block*> all_blocks; std::vector<Block*> all_blocks;
MempoolId_t mempool_id = {0, 0};
auto active_mempool = MemPoolContext::getActiveMemPool();
if (active_mempool) {
mempool_id = active_mempool->id();
}
if (mempool_id.first != 0 || mempool_id.second != 0) { if (mempool_id.first != 0 || mempool_id.second != 0) {
// If there is an active mempool, we find the corresponding PrivatePool // If there is an active mempool, we find the corresponding PrivatePool
@ -2011,7 +2011,7 @@ class DeviceCachingAllocator {
all_blocks = get_private_pool_head_blocks(pool->second.get()); all_blocks = get_private_pool_head_blocks(pool->second.get());
} }
} else { } else {
// When snapshot is called outside a MemPoolContext, we return // When snapshot is called with non-default mempool_id, we return
// all the blocks in the CUDACachingAllocator (as returned by // all the blocks in the CUDACachingAllocator (as returned by
// get_all_blocks). // get_all_blocks).
all_blocks = get_all_blocks(); all_blocks = get_all_blocks();
@ -2130,11 +2130,11 @@ class DeviceCachingAllocator {
} }
} }
void ensureExistsAndIncrefPool(MempoolId_t mempool_id) { void createOrIncrefPool(MempoolId_t mempool_id, CUDAAllocator* allocator) {
// Create a PrivatePool object if it does not exist yet // Create a PrivatePool object if it does not exist yet
// and increment its use_count // and increment its use_count
std::lock_guard<std::recursive_mutex> lock(mutex); std::lock_guard<std::recursive_mutex> lock(mutex);
ensure_exists_and_incref_pool(mempool_id); create_or_incref_pool(mempool_id, allocator);
} }
void setUseOnOOM(MempoolId_t mempool_id) { void setUseOnOOM(MempoolId_t mempool_id) {
@ -2150,7 +2150,7 @@ class DeviceCachingAllocator {
MempoolId_t mempool_id, MempoolId_t mempool_id,
std::function<bool(cudaStream_t)> filter) { std::function<bool(cudaStream_t)> filter) {
std::lock_guard<std::recursive_mutex> lock(mutex); std::lock_guard<std::recursive_mutex> lock(mutex);
ensure_exists_and_incref_pool(mempool_id); create_or_incref_pool(mempool_id);
for (auto it2 = captures_underway.begin(); it2 != captures_underway.end(); for (auto it2 = captures_underway.begin(); it2 != captures_underway.end();
++it2) { ++it2) {
TORCH_CHECK( TORCH_CHECK(
@ -2272,21 +2272,24 @@ class DeviceCachingAllocator {
return blocks; return blocks;
} }
void ensure_exists_and_incref_pool(MempoolId_t mempool_id) { void create_or_incref_pool(
MempoolId_t mempool_id,
CUDAAllocator* allocator = nullptr) {
auto it = graph_pools.find(mempool_id); auto it = graph_pools.find(mempool_id);
if (it == graph_pools.end()) { if (it == graph_pools.end()) {
// mempool_id does not reference an existing pool. // mempool_id does not reference an existing pool.
// Make a new pool for CUDAGraph capture or torch.cuda.use_mem_pool // Make a new pool for CUDAGraph capture or torch.cuda.use_mem_pool
// usage. use_count is initially 1, which means the pool is // usage. use_count is initially 1, which means the pool is
// being used since somebody called ensureExistsAndIncrefPool. // being used since somebody called createOrIncrefPool.
graph_pools.emplace( graph_pools.emplace(
mempool_id, std::make_unique<PrivatePool>(mempool_id)); mempool_id, std::make_unique<PrivatePool>(mempool_id, allocator));
} else { } else {
// mempool_id references an existing pool, which the current CUDAGraph // mempool_id references an existing pool, which the current CUDAGraph
// capture or torch.cuda.use_mem_pool will // capture or torch.cuda.use_mem_pool will
// share. Check this pool is live (at least one other capture already // share. Check this pool is live (at least one other capture already
// references it). Increment it to establish the usage. // references it). Increment it to establish the usage.
TORCH_INTERNAL_ASSERT(it->second->use_count > 0); TORCH_INTERNAL_ASSERT(it->second->use_count > 0);
TORCH_INTERNAL_ASSERT(allocator == nullptr);
it->second->use_count++; it->second->use_count++;
} }
} }
@ -2776,7 +2779,8 @@ class DeviceCachingAllocator {
bool in_fbcode = false; bool in_fbcode = false;
#endif #endif
auto active_pool = MemPoolContext::getActiveMemPool(); bool active_pool =
p.pool->owner_PrivatePool && p.pool->owner_PrivatePool->allocator();
if (set_fraction && if (set_fraction &&
total_allocated_memory + size > allowed_memory_maximum) { total_allocated_memory + size > allowed_memory_maximum) {
p.err = cudaErrorMemoryAllocation; p.err = cudaErrorMemoryAllocation;
@ -2801,12 +2805,6 @@ class DeviceCachingAllocator {
} }
return bool(p.block); return bool(p.block);
} else { } else {
if (active_pool && active_pool->allocator() &&
p.pool->owner_PrivatePool) {
// Ensure that active_pool and p.pool are the same
auto pp = get_private_pool(active_pool->id());
TORCH_INTERNAL_ASSERT(pp == p.pool->owner_PrivatePool);
}
if (CUDAAllocatorConfig::release_lock_on_cudamalloc()) { if (CUDAAllocatorConfig::release_lock_on_cudamalloc()) {
// At scope exit, acquire the lock again. This provides safety against // At scope exit, acquire the lock again. This provides safety against
// any potential exceptions in the cudaMallocMaybeCapturing function. // any potential exceptions in the cudaMallocMaybeCapturing function.
@ -2926,13 +2924,9 @@ class DeviceCachingAllocator {
return true; return true;
} }
bool release_cached_blocks(const std::shared_ptr<GatheredContext>& context) { bool release_cached_blocks(
MempoolId_t mempool_id = {0, 0}; const std::shared_ptr<GatheredContext>& context,
auto active_mempool = MemPoolContext::getActiveMemPool(); MempoolId_t mempool_id) {
if (active_mempool) {
mempool_id = active_mempool->id();
}
if (mempool_id.first == 0 && mempool_id.second == 0) { if (mempool_id.first == 0 && mempool_id.second == 0) {
// If there is no active mempool, we work on releasing *all* blocks. // If there is no active mempool, we work on releasing *all* blocks.
@ -3005,15 +2999,10 @@ class DeviceCachingAllocator {
context ? context : block->context_when_segment_allocated); context ? context : block->context_when_segment_allocated);
auto* pool = block->pool; auto* pool = block->pool;
auto active_pool = MemPoolContext::getActiveMemPool(); if (pool->owner_PrivatePool && pool->owner_PrivatePool->allocator()) {
if (active_pool && active_pool->allocator() && pool->owner_PrivatePool) {
// Ensure that active_pool and pool are the same
auto pp = get_private_pool(active_pool->id());
TORCH_INTERNAL_ASSERT(pp == pool->owner_PrivatePool);
// If there is an active mempool with a given allocator, // If there is an active mempool with a given allocator,
// we use the given allocator's delete function. // we use the given allocator's delete function.
active_pool->allocator()->raw_delete((void*)block->ptr); pool->owner_PrivatePool->allocator()->raw_delete((void*)block->ptr);
} else { } else {
C10_CUDA_CHECK(cudaFree((void*)block->ptr)); C10_CUDA_CHECK(cudaFree((void*)block->ptr));
} }
@ -3589,9 +3578,9 @@ class NativeCachingAllocator : public CUDAAllocator {
} }
} }
void emptyCache() override { void emptyCache(MempoolId_t mempool_id) override {
for (auto& da : device_allocator) for (auto& da : device_allocator)
da->emptyCache(); da->emptyCache(mempool_id);
} }
void enable(bool value) override { void enable(bool value) override {
@ -3639,7 +3628,7 @@ class NativeCachingAllocator : public CUDAAllocator {
device_allocator[block->device]->recordStream(block, stream); device_allocator[block->device]->recordStream(block, stream);
} }
SnapshotInfo snapshot() override { SnapshotInfo snapshot(MempoolId_t mempool_id) override {
// Set-up converter to convert timestamps from tsc to microseconds. // Set-up converter to convert timestamps from tsc to microseconds.
auto tsc_to_ns = clock_converter.makeConverter(); auto tsc_to_ns = clock_converter.makeConverter();
auto tsc_to_us = [=](approx_time_t t_approx) { auto tsc_to_us = [=](approx_time_t t_approx) {
@ -3657,7 +3646,7 @@ class NativeCachingAllocator : public CUDAAllocator {
// Get the device_traces' TraceEntry lists. // Get the device_traces' TraceEntry lists.
for (auto& da : device_allocator) { for (auto& da : device_allocator) {
result.device_traces.emplace_back(da->trace(tsc_to_us)); result.device_traces.emplace_back(da->trace(tsc_to_us));
auto snap = da->snapshot(); auto snap = da->snapshot(mempool_id);
result.segments.insert(result.segments.end(), snap.begin(), snap.end()); result.segments.insert(result.segments.end(), snap.begin(), snap.end());
} }
@ -3785,11 +3774,13 @@ class NativeCachingAllocator : public CUDAAllocator {
device_allocator[device]->resetPeakStats(); device_allocator[device]->resetPeakStats();
} }
void ensureExistsAndIncrefPool( void createOrIncrefPool(
c10::DeviceIndex device, c10::DeviceIndex device,
MempoolId_t mempool_id) override { MempoolId_t mempool_id,
CUDAAllocator* allocator) override {
assertValidDevice(device); assertValidDevice(device);
device_allocator[device]->ensureExistsAndIncrefPool(std::move(mempool_id)); device_allocator[device]->createOrIncrefPool(
std::move(mempool_id), allocator);
} }
void setUseOnOOM(c10::DeviceIndex device, MempoolId_t mempool_id) override { void setUseOnOOM(c10::DeviceIndex device, MempoolId_t mempool_id) override {
@ -4134,7 +4125,7 @@ MemPool::MemPool(
id_ = {uuid_++, 0}; id_ = {uuid_++, 0};
} }
device_ = c10::cuda::current_device(); device_ = c10::cuda::current_device();
CUDACachingAllocator::ensureExistsAndIncrefPool(device_, id_); CUDACachingAllocator::createOrIncrefPool(device_, id_, allocator);
if (use_on_oom) { if (use_on_oom) {
CUDACachingAllocator::setUseOnOOM(device_, id_); CUDACachingAllocator::setUseOnOOM(device_, id_);
} }
@ -4143,8 +4134,7 @@ MemPool::MemPool(
MemPool::~MemPool() { MemPool::~MemPool() {
TORCH_INTERNAL_ASSERT(use_count() == 1); TORCH_INTERNAL_ASSERT(use_count() == 1);
CUDACachingAllocator::releasePool(device_, id_); CUDACachingAllocator::releasePool(device_, id_);
auto ctx = MemPoolContext(this); c10::cuda::CUDACachingAllocator::emptyCache(id_);
c10::cuda::CUDACachingAllocator::emptyCache();
} }
MempoolId_t MemPool::id() { MempoolId_t MemPool::id() {
@ -4170,23 +4160,4 @@ MempoolId_t MemPool::graph_pool_handle(bool is_user_created) {
return {uuid_++, 0}; return {uuid_++, 0};
} }
// Note that active_mempool_ is a global variable here
// and not inside MemPoolContext class, because in windows we
// can't use __declspec(dllexport) and __declspec(thread)
// together: https://stackoverflow.com/a/50967977
static thread_local MemPool* active_mempool_ = nullptr;
MemPoolContext::MemPoolContext(MemPool* mempool)
: prev_mempool_(active_mempool_) {
active_mempool_ = mempool;
}
MemPoolContext::~MemPoolContext() {
active_mempool_ = prev_mempool_;
}
MemPool* MemPoolContext::getActiveMemPool() {
return active_mempool_;
}
} // namespace c10::cuda } // namespace c10::cuda

View File

@ -211,7 +211,7 @@ class CUDAAllocator : public Allocator {
virtual bool initialized() = 0; virtual bool initialized() = 0;
virtual double getMemoryFraction(c10::DeviceIndex device) = 0; virtual double getMemoryFraction(c10::DeviceIndex device) = 0;
virtual void setMemoryFraction(double fraction, c10::DeviceIndex device) = 0; virtual void setMemoryFraction(double fraction, c10::DeviceIndex device) = 0;
virtual void emptyCache() = 0; virtual void emptyCache(MempoolId_t mempool_id = {0, 0}) = 0;
virtual void enable(bool value) = 0; virtual void enable(bool value) = 0;
virtual bool isEnabled() const = 0; virtual bool isEnabled() const = 0;
virtual void cacheInfo(c10::DeviceIndex device, size_t* largestBlock) = 0; virtual void cacheInfo(c10::DeviceIndex device, size_t* largestBlock) = 0;
@ -221,7 +221,7 @@ class CUDAAllocator : public Allocator {
c10::DeviceIndex device) = 0; c10::DeviceIndex device) = 0;
virtual void resetAccumulatedStats(c10::DeviceIndex device) = 0; virtual void resetAccumulatedStats(c10::DeviceIndex device) = 0;
virtual void resetPeakStats(c10::DeviceIndex device) = 0; virtual void resetPeakStats(c10::DeviceIndex device) = 0;
virtual SnapshotInfo snapshot() = 0; virtual SnapshotInfo snapshot(MempoolId_t mempool_id = {0, 0}) = 0;
virtual void beginAllocateToPool( virtual void beginAllocateToPool(
c10::DeviceIndex device, c10::DeviceIndex device,
MempoolId_t mempool_id, MempoolId_t mempool_id,
@ -239,13 +239,14 @@ class CUDAAllocator : public Allocator {
" does not yet support getPoolUseCount. " " does not yet support getPoolUseCount. "
"If you need it, please file an issue describing your use case."); "If you need it, please file an issue describing your use case.");
} }
virtual void ensureExistsAndIncrefPool( virtual void createOrIncrefPool(
c10::DeviceIndex /*device*/, c10::DeviceIndex /*device*/,
MempoolId_t /*mempool_id*/) { MempoolId_t /*mempool_id*/,
CUDAAllocator* allocator = nullptr) {
TORCH_CHECK( TORCH_CHECK(
false, false,
name(), name(),
" does not yet support ensureExistsAndIncrefPool. " " does not yet support createOrIncrefPool. "
"If you need it, please file an issue describing your use case."); "If you need it, please file an issue describing your use case.");
} }
virtual void setUseOnOOM(c10::DeviceIndex device, MempoolId_t mempool_id) { virtual void setUseOnOOM(c10::DeviceIndex device, MempoolId_t mempool_id) {
@ -364,7 +365,7 @@ inline void setMemoryFraction(double fraction, c10::DeviceIndex device) {
return get()->setMemoryFraction(fraction, device); return get()->setMemoryFraction(fraction, device);
} }
inline void emptyCache() { inline void emptyCache(MempoolId_t mempool_id = {0, 0}) {
return get()->emptyCache(); return get()->emptyCache();
} }
@ -401,8 +402,8 @@ inline void resetPeakStats(c10::DeviceIndex device) {
return get()->resetPeakStats(device); return get()->resetPeakStats(device);
} }
inline SnapshotInfo snapshot() { inline SnapshotInfo snapshot(MempoolId_t mempool_id = {0, 0}) {
return get()->snapshot(); return get()->snapshot(mempool_id);
} }
inline std::shared_ptr<AllocatorState> getCheckpointState( inline std::shared_ptr<AllocatorState> getCheckpointState(
@ -475,10 +476,11 @@ inline void attachAllocatorTraceTracker(AllocatorTraceTracker tracker) {
inline void releasePool(c10::DeviceIndex device, MempoolId_t mempool_id) { inline void releasePool(c10::DeviceIndex device, MempoolId_t mempool_id) {
return get()->releasePool(device, mempool_id); return get()->releasePool(device, mempool_id);
} }
inline void ensureExistsAndIncrefPool( inline void createOrIncrefPool(
c10::DeviceIndex device, c10::DeviceIndex device,
MempoolId_t mempool_id) { MempoolId_t mempool_id,
get()->ensureExistsAndIncrefPool(device, mempool_id); CUDAAllocator* allocator_ptr = nullptr) {
get()->createOrIncrefPool(device, mempool_id, allocator_ptr);
} }
inline void setUseOnOOM(c10::DeviceIndex device, MempoolId_t mempool_id) { inline void setUseOnOOM(c10::DeviceIndex device, MempoolId_t mempool_id) {
get()->setUseOnOOM(device, mempool_id); get()->setUseOnOOM(device, mempool_id);
@ -555,26 +557,4 @@ struct C10_CUDA_API MemPool {
c10::DeviceIndex device_; c10::DeviceIndex device_;
}; };
// MemPoolContext holds the currently active pool and stashes the previous
// pool. On deletion it makes the previous pool active.
struct C10_CUDA_API MemPoolContext {
MemPoolContext(MemPool* mempool);
~MemPoolContext();
// getActiveMemPool() can be used to get the currently active pool.
// For instance: in CUDACachingAllocator, we can route allocations
// to a user provided allocator, by doing:
//
// auto active_pool = MemPoolContext::getActiveMemPool();
// if (active_pool && active_pool->allocator()) {
// ptr = active_pool->allocator()->raw_alloc(size);
// }
//
static MemPool* getActiveMemPool();
private:
MemPool* prev_mempool_;
};
} // namespace c10::cuda } // namespace c10::cuda

View File

@ -496,7 +496,7 @@ struct CudaMallocAsyncAllocator : public CUDAAllocator {
// introduces performance nondeterminism. // introduces performance nondeterminism.
} }
void emptyCache() override { void emptyCache(/*unused*/ MempoolId_t mempool_id) override {
std::lock_guard<std::mutex> lk(general_mutex); std::lock_guard<std::mutex> lk(general_mutex);
for (int dev = 0; dev < device_count; dev++) { for (int dev = 0; dev < device_count; dev++) {
@ -778,7 +778,7 @@ struct CudaMallocAsyncAllocator : public CUDAAllocator {
cudaMemPoolSetAttribute(mempool, cudaMemPoolAttrUsedMemHigh, &zero)); cudaMemPoolSetAttribute(mempool, cudaMemPoolAttrUsedMemHigh, &zero));
} }
SnapshotInfo snapshot() override { SnapshotInfo snapshot(MempoolId_t mempool_id) override {
TORCH_CHECK( TORCH_CHECK(
false, false,
"Calling snapshot with backend:cudaMallocAsync is not meaningful. " "Calling snapshot with backend:cudaMallocAsync is not meaningful. "

View File

@ -2282,7 +2282,6 @@ coverage_ignore_classes = [
"UnsynchronizedAccessError", "UnsynchronizedAccessError",
# torch.cuda.memory # torch.cuda.memory
"MemPool", "MemPool",
"MemPoolContext",
# torch.distributed.elastic.multiprocessing.errors # torch.distributed.elastic.multiprocessing.errors
"ChildFailedError", "ChildFailedError",
"ProcessFailure", "ProcessFailure",

View File

@ -128,7 +128,6 @@ Memory management
CUDAPluggableAllocator CUDAPluggableAllocator
change_current_allocator change_current_allocator
MemPool MemPool
MemPoolContext
.. currentmodule:: torch.cuda.memory .. currentmodule:: torch.cuda.memory

View File

@ -5049,41 +5049,57 @@ class TestMemPool(TestCase):
# increments the id # increments the id
self.assertTrue(abs(pool2[1] - pool1[1]) > 0) self.assertTrue(abs(pool2[1] - pool1[1]) > 0)
def test_mempool_with_allocator(self): def get_dummy_allocator(self, check_vars):
pool = torch.cuda.MemPool() dummy_allocator_source_vars = """
# MemPool doesn't have an allocator by default
self.assertEqual(pool.allocator, None)
from torch.utils.cpp_extension import load_inline
dummy_allocator_source = """
#include <torch/extension.h> #include <torch/extension.h>
#include <ATen/cuda/Exceptions.h> #include <ATen/cuda/Exceptions.h>
#include <cuda_runtime_api.h> #include <cuda_runtime_api.h>
extern "C" { extern "C" {
C10_EXPORT int called_dummy_alloc = 0; C10_EXPORT int called_dummy_alloc = 0;
C10_EXPORT int called_dummy_free = 0; C10_EXPORT int called_dummy_free = 0;
// Note that windows needs __declspec(dllexport): https://stackoverflow.com/a/24575865
C10_EXPORT void* dummy_alloc(size_t size, int device, void* stream) {
called_dummy_alloc = 123;
void* ptr;
C10_CUDA_CHECK(cudaMallocManaged(&ptr, size));
return ptr;
}
C10_EXPORT void dummy_free(void* ptr, size_t size, int device, void* stream) {
called_dummy_free = 321;
C10_CUDA_CHECK(cudaFree(ptr));
}
}
"""
dummy_allocator_source_no_vars = """
#include <torch/extension.h>
#include <ATen/cuda/Exceptions.h>
#include <cuda_runtime_api.h>
extern "C" {
// Note that windows needs __declspec(dllexport): https://stackoverflow.com/a/24575865 // Note that windows needs __declspec(dllexport): https://stackoverflow.com/a/24575865
C10_EXPORT void* dummy_alloc(size_t size, int device, void* stream) { C10_EXPORT void* dummy_alloc(size_t size, int device, void* stream) {
called_dummy_alloc = 123;
void* ptr; void* ptr;
C10_CUDA_CHECK(cudaMallocManaged(&ptr, size)); C10_CUDA_CHECK(cudaMallocManaged(&ptr, size));
return ptr; return ptr;
} }
C10_EXPORT void dummy_free(void* ptr, size_t size, int device, void* stream) { C10_EXPORT void dummy_free(void* ptr, size_t size, int device, void* stream) {
called_dummy_free = 321;
C10_CUDA_CHECK(cudaFree(ptr)); C10_CUDA_CHECK(cudaFree(ptr));
} }
} }
""" """
from torch.utils.cpp_extension import load_inline
dummy_allocator_libname = "dummy_allocator" dummy_allocator_libname = "dummy_allocator"
dummy_allocator = load_inline( dummy_allocator = load_inline(
name=dummy_allocator_libname, name=dummy_allocator_libname,
cpp_sources=dummy_allocator_source, cpp_sources=dummy_allocator_source_vars
if check_vars
else dummy_allocator_source_no_vars,
is_python_module=False, is_python_module=False,
keep_intermediates=False, keep_intermediates=False,
verbose=True, verbose=True,
@ -5094,6 +5110,15 @@ class TestMemPool(TestCase):
"dummy_alloc", "dummy_alloc",
"dummy_free", "dummy_free",
) )
return allocator, dummy_allocator
def test_mempool_with_allocator(self):
pool = torch.cuda.MemPool()
# MemPool doesn't have an allocator by default
self.assertEqual(pool.allocator, None)
allocator, dummy_allocator = self.get_dummy_allocator(check_vars=True)
pool = torch.cuda.MemPool(allocator.allocator()) pool = torch.cuda.MemPool(allocator.allocator())
# pool should point to the same allocator as the one passed into it # pool should point to the same allocator as the one passed into it
@ -5128,6 +5153,8 @@ class TestMemPool(TestCase):
# out tensor # out tensor
self.assertEqual(called_dummy_alloc.value, 123) self.assertEqual(called_dummy_alloc.value, 123)
out_non_pool = torch.empty(nelem_1mb, device="cuda")
with torch.cuda.use_mem_pool(pool): with torch.cuda.use_mem_pool(pool):
# pool should have 1 segment since we made a small allocation (1 MB) # pool should have 1 segment since we made a small allocation (1 MB)
# above and so the CUDACachingAllocator packed it into a 2 MB buffer # above and so the CUDACachingAllocator packed it into a 2 MB buffer
@ -5145,6 +5172,8 @@ class TestMemPool(TestCase):
# to make a new 2 MB buffer to accomodate out_2 # to make a new 2 MB buffer to accomodate out_2
self.assertEqual(len(pool.snapshot()), 2) self.assertEqual(len(pool.snapshot()), 2)
self.assertEqual(len(pool.snapshot()), 2)
del out_0, out_1, out_2 del out_0, out_1, out_2
# pool's destructor calls emptyCache() # pool's destructor calls emptyCache()
@ -5156,40 +5185,7 @@ class TestMemPool(TestCase):
@serialTest() @serialTest()
def test_mempool_limited_memory_with_allocator(self): def test_mempool_limited_memory_with_allocator(self):
from torch.utils.cpp_extension import load_inline allocator, _ = self.get_dummy_allocator(check_vars=False)
dummy_allocator_source = """
#include <torch/extension.h>
#include <ATen/cuda/Exceptions.h>
#include <cuda_runtime_api.h>
extern "C" {
// Note that windows needs __declspec(dllexport): https://stackoverflow.com/a/24575865
C10_EXPORT void* dummy_alloc(size_t size, int device, void* stream) {
void* ptr;
C10_CUDA_CHECK(cudaMallocManaged(&ptr, size));
return ptr;
}
C10_EXPORT void dummy_free(void* ptr, size_t size, int device, void* stream) {
C10_CUDA_CHECK(cudaFree(ptr));
}
}
"""
dummy_allocator_libname = "dummy_allocator"
dummy_allocator = load_inline(
name=dummy_allocator_libname,
cpp_sources=dummy_allocator_source,
is_python_module=False,
keep_intermediates=False,
verbose=True,
with_cuda=True,
)
allocator = torch.cuda.memory.CUDAPluggableAllocator(
dummy_allocator,
"dummy_alloc",
"dummy_free",
)
pool_do_not_use = torch.cuda.MemPool(allocator.allocator()) pool_do_not_use = torch.cuda.MemPool(allocator.allocator())
pool_use = torch.cuda.MemPool(allocator.allocator(), use_on_oom=True) pool_use = torch.cuda.MemPool(allocator.allocator(), use_on_oom=True)
@ -5258,38 +5254,13 @@ class TestMemPool(TestCase):
self._teardown_mempool_limited_memory_test() self._teardown_mempool_limited_memory_test()
def test_mempool_context(self):
active_pool = torch.cuda.MemPoolContext.active_pool()
# there is no active pool if none was made active
self.assertEqual(active_pool, None)
pool = torch.cuda.MemPool()
ctx = torch.cuda.MemPoolContext(pool)
active_pool = torch.cuda.MemPoolContext.active_pool()
# pool was made active
self.assertEqual(active_pool, pool)
del ctx
active_pool = torch.cuda.MemPoolContext.active_pool()
# ctx was deleted, so active pool is the previous one
self.assertEqual(active_pool, None)
def test_mempool_multithread(self): def test_mempool_multithread(self):
pool_ids = [] pool_ids = []
active_pool_ids = []
def create_mempool_and_make_active(): def create_mempool_and_make_active():
pool = torch.cuda.MemPool() pool = torch.cuda.MemPool()
pool_ids.extend([pool.id]) pool_ids.extend([pool.id])
ctx = torch.cuda.MemPoolContext(pool)
active_pool = torch.cuda.MemPoolContext.active_pool()
active_pool_ids.extend([active_pool.id])
del ctx
num_threads = 4 num_threads = 4
threads = [ threads = [
threading.Thread(target=create_mempool_and_make_active) threading.Thread(target=create_mempool_and_make_active)
@ -5304,14 +5275,12 @@ class TestMemPool(TestCase):
# mempool id creation is atomic # mempool id creation is atomic
self.assertEqual(len(set(pool_ids)), 4) self.assertEqual(len(set(pool_ids)), 4)
# each thread should have different active mempool, since
# the pointer to the mempool is thread local
self.assertEqual(len(set(active_pool_ids)), 4)
@skipIfRocm(msg="expandable_segments mode is not supported on ROCm") @skipIfRocm(msg="expandable_segments mode is not supported on ROCm")
@unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Load_inline doesn't work in fbcode")
def test_mempool_expandable(self): def test_mempool_expandable(self):
torch.cuda.memory._set_allocator_settings("expandable_segments:True") torch.cuda.memory._set_allocator_settings("expandable_segments:True")
pool = torch.cuda.MemPool() allocator, _ = self.get_dummy_allocator(check_vars=False)
pool = torch.cuda.MemPool(allocator.allocator())
# torch.cuda.MemPool doesn't work with expandable segments # torch.cuda.MemPool doesn't work with expandable segments
with self.assertRaises(RuntimeError): with self.assertRaises(RuntimeError):

View File

@ -2027,7 +2027,7 @@ def _cuda_resetPeakMemoryStats(device: _int) -> None: ...
def _cuda_hostMemoryStats() -> dict[str, Any]: ... def _cuda_hostMemoryStats() -> dict[str, Any]: ...
def _cuda_resetAccumulatedHostMemoryStats() -> None: ... def _cuda_resetAccumulatedHostMemoryStats() -> None: ...
def _cuda_resetPeakHostMemoryStats() -> None: ... def _cuda_resetPeakHostMemoryStats() -> None: ...
def _cuda_memorySnapshot() -> dict[str, Any]: ... def _cuda_memorySnapshot(mempool_id: tuple[_int, _int] | None) -> dict[str, Any]: ...
def _cuda_record_memory_history_legacy( def _cuda_record_memory_history_legacy(
enabled: _bool, enabled: _bool,
record_context: _bool, record_context: _bool,
@ -2304,11 +2304,6 @@ class _MemPool:
def allocator(self) -> _cuda_CUDAAllocator | None: ... def allocator(self) -> _cuda_CUDAAllocator | None: ...
def use_count(self) -> _int: ... def use_count(self) -> _int: ...
class _MemPoolContext:
def __init__(self, pool: _MemPool) -> None: ...
@staticmethod
def active_pool() -> _MemPool | None: ...
def _cuda_isCurrentStreamCapturing() -> _bool: ... def _cuda_isCurrentStreamCapturing() -> _bool: ...
def _graph_pool_handle() -> tuple[_int, _int]: ... def _graph_pool_handle() -> tuple[_int, _int]: ...

View File

@ -184,7 +184,8 @@ void CUDAPluggableAllocator::setMemoryFraction(
} }
} }
void CUDAPluggableAllocator::emptyCache() { void CUDAPluggableAllocator::emptyCache(
/*unused*/ c10::cuda::MempoolId_t mempool_id) {
if (reset_fn_) { if (reset_fn_) {
return reset_fn_(); return reset_fn_();
} }
@ -237,8 +238,8 @@ void CUDAPluggableAllocator::resetPeakStats(c10::DeviceIndex device) {
"If you need it, please file an issue describing your use case."); "If you need it, please file an issue describing your use case.");
} }
c10::cuda::CUDACachingAllocator::SnapshotInfo CUDAPluggableAllocator:: c10::cuda::CUDACachingAllocator::SnapshotInfo CUDAPluggableAllocator::snapshot(
snapshot() { c10::cuda::MempoolId_t mempool_id) {
TORCH_CHECK( TORCH_CHECK(
false, false,
"CUDAPluggableAllocator does not yet support snapshot. " "CUDAPluggableAllocator does not yet support snapshot. "

View File

@ -114,7 +114,7 @@ struct TORCH_CUDA_CPP_API CUDAPluggableAllocator
bool initialized() override; bool initialized() override;
double getMemoryFraction(c10::DeviceIndex device) override; double getMemoryFraction(c10::DeviceIndex device) override;
void setMemoryFraction(double fraction, c10::DeviceIndex device) override; void setMemoryFraction(double fraction, c10::DeviceIndex device) override;
void emptyCache() override; void emptyCache(c10::cuda::MempoolId_t mempool_id = {0, 0}) override;
void enable(bool) override {} void enable(bool) override {}
bool isEnabled() const override { bool isEnabled() const override {
return true; return true;
@ -128,7 +128,8 @@ struct TORCH_CUDA_CPP_API CUDAPluggableAllocator
c10::DeviceIndex device) override; c10::DeviceIndex device) override;
void resetAccumulatedStats(c10::DeviceIndex device) override; void resetAccumulatedStats(c10::DeviceIndex device) override;
void resetPeakStats(c10::DeviceIndex device) override; void resetPeakStats(c10::DeviceIndex device) override;
c10::cuda::CUDACachingAllocator::SnapshotInfo snapshot() override; c10::cuda::CUDACachingAllocator::SnapshotInfo snapshot(
c10::cuda::MempoolId_t mempool) override;
void beginAllocateToPool( void beginAllocateToPool(
c10::DeviceIndex device, c10::DeviceIndex device,
c10::cuda::MempoolId_t mempool_id, c10::cuda::MempoolId_t mempool_id,

View File

@ -24,8 +24,4 @@ void THCPMemPool_init(PyObject* module) {
.def_property_readonly("id", &::c10::cuda::MemPool::id) .def_property_readonly("id", &::c10::cuda::MemPool::id)
.def_property_readonly("allocator", &::c10::cuda::MemPool::allocator) .def_property_readonly("allocator", &::c10::cuda::MemPool::allocator)
.def("use_count", &::c10::cuda::MemPool::use_count); .def("use_count", &::c10::cuda::MemPool::use_count);
shared_ptr_class_<::c10::cuda::MemPoolContext>(torch_C_m, "_MemPoolContext")
.def(py::init<c10::cuda::MemPool*>())
.def_static(
"active_pool", &::c10::cuda::MemPoolContext::getActiveMemPool);
} }

View File

@ -721,8 +721,24 @@ CapturedTraceback* getFromContext(
"attempting to gather stack context from the wrong StackContext type."); "attempting to gather stack context from the wrong StackContext type.");
} }
PyObject* THCPModule_memorySnapshot(PyObject* _unused, PyObject* noargs) { PyObject* THCPModule_memorySnapshot(PyObject* _unused, PyObject* arg) {
HANDLE_TH_ERRORS HANDLE_TH_ERRORS
c10::cuda::MempoolId_t mempool_id = {0, 0};
if (arg && arg != Py_None) {
TORCH_CHECK(PyTuple_Check(arg), "mempool_id must be a tuple");
Py_ssize_t size = PyTuple_Size(arg);
TORCH_CHECK(size == 2, "mempool_id must be a tuple of 2 integers");
auto id1 = THPObjectPtr(PyTuple_GetItem(arg, 0));
auto id2 = THPObjectPtr(PyTuple_GetItem(arg, 1));
TORCH_CHECK(
THPUtils_checkLong(id1) && THPUtils_checkLong(id2),
"mempool_id elements must be integers");
mempool_id = c10::cuda::MempoolId_t(
static_cast<int64_t>(THPUtils_unpackLong(id1)),
static_cast<int64_t>(THPUtils_unpackLong(id2)));
}
using c10::cuda::CUDACachingAllocator::BlockInfo; using c10::cuda::CUDACachingAllocator::BlockInfo;
using c10::cuda::CUDACachingAllocator::SegmentInfo; using c10::cuda::CUDACachingAllocator::SegmentInfo;
@ -802,7 +818,7 @@ PyObject* THCPModule_memorySnapshot(PyObject* _unused, PyObject* noargs) {
return segmentDict; return segmentDict;
}; };
auto snapshot = c10::cuda::CUDACachingAllocator::snapshot(); auto snapshot = c10::cuda::CUDACachingAllocator::snapshot(mempool_id);
py::list segments; py::list segments;
@ -2011,7 +2027,7 @@ static struct PyMethodDef _THCPModule_methods[] = {
THCPModule_resetPeakMemoryStats, THCPModule_resetPeakMemoryStats,
METH_O, METH_O,
nullptr}, nullptr},
{"_cuda_memorySnapshot", THCPModule_memorySnapshot, METH_NOARGS, nullptr}, {"_cuda_memorySnapshot", THCPModule_memorySnapshot, METH_O, nullptr},
{"_cuda_attach_out_of_memory_observer", {"_cuda_attach_out_of_memory_observer",
THCPModule_attachOutOfMemoryObserver, THCPModule_attachOutOfMemoryObserver,
METH_O, METH_O,

View File

@ -1186,8 +1186,7 @@ void ProcessGroupNCCL::registerMemPool(c10::cuda::MemPool* pool) {
// We must ensure we're listening for allocator trace events in order to // We must ensure we're listening for allocator trace events in order to
// register future segments allocated in this pool (this call is idempotent). // register future segments allocated in this pool (this call is idempotent).
attachAllocatorHooks(); attachAllocatorHooks();
auto ctx = c10::cuda::MemPoolContext(pool); auto snapshot = c10::cuda::CUDACachingAllocator::snapshot(pool->id());
auto snapshot = c10::cuda::CUDACachingAllocator::snapshot();
for (const auto& segmentInfo : snapshot.segments) { for (const auto& segmentInfo : snapshot.segments) {
TORCH_INTERNAL_ASSERT( TORCH_INTERNAL_ASSERT(
segmentInfo.device == pool->device(), segmentInfo.device == pool->device(),
@ -1221,8 +1220,7 @@ void ProcessGroupNCCL::deregisterMemPool(c10::cuda::MemPool* pool) {
auto iter = ncclCommMemPoolMap.find(ncclComm); auto iter = ncclCommMemPoolMap.find(ncclComm);
iter->second.erase(pool->id()); iter->second.erase(pool->id());
} }
auto ctx = c10::cuda::MemPoolContext(pool); auto snapshot = c10::cuda::CUDACachingAllocator::snapshot(pool->id());
auto snapshot = c10::cuda::CUDACachingAllocator::snapshot();
for (const auto& segmentInfo : snapshot.segments) { for (const auto& segmentInfo : snapshot.segments) {
TORCH_INTERNAL_ASSERT( TORCH_INTERNAL_ASSERT(
segmentInfo.device == pool->device(), segmentInfo.device == pool->device(),
@ -5572,7 +5570,6 @@ at::Tensor ProcessGroupNCCL::allocateTensor(
} }
// Allocate tensor under this MemPool's context // Allocate tensor under this MemPool's context
auto ctx = c10::cuda::MemPoolContext(memPool_.get());
auto tid = std::this_thread::get_id(); auto tid = std::this_thread::get_id();
c10::cuda::CUDACachingAllocator::beginAllocateToPool( c10::cuda::CUDACachingAllocator::beginAllocateToPool(
memPool_->device(), memPool_->id(), [=](cudaStream_t) { memPool_->device(), memPool_->id(), [=](cudaStream_t) {

View File

@ -1872,7 +1872,6 @@ __all__ = [
"memory_summary", "memory_summary",
"memory_usage", "memory_usage",
"MemPool", "MemPool",
"MemPoolContext",
"use_mem_pool", "use_mem_pool",
"temperature", "temperature",
"power_draw", "power_draw",

View File

@ -60,7 +60,6 @@ __all__ = [
"CUDAPluggableAllocator", "CUDAPluggableAllocator",
"change_current_allocator", "change_current_allocator",
"MemPool", "MemPool",
"MemPoolContext",
"use_mem_pool", "use_mem_pool",
] ]
@ -73,7 +72,6 @@ if not hasattr(torch._C, "_cuda_CUDAAllocator"):
if not hasattr(torch._C, "_MemPool"): if not hasattr(torch._C, "_MemPool"):
# Define dummy base classes # Define dummy base classes
torch._C.__dict__["_MemPool"] = _dummy_type("_MemPool") torch._C.__dict__["_MemPool"] = _dummy_type("_MemPool")
torch._C.__dict__["_MemPoolContext"] = _dummy_type("_MemPoolContext")
torch._C.__dict__["_cuda_beginAllocateToPool"] = _dummy_type( torch._C.__dict__["_cuda_beginAllocateToPool"] = _dummy_type(
"_cuda_beginAllocateToPool" "_cuda_beginAllocateToPool"
) )
@ -92,7 +90,6 @@ from torch._C import ( # noqa: F401
_cuda_endAllocateToPool, _cuda_endAllocateToPool,
_cuda_releasePool, _cuda_releasePool,
_MemPool, _MemPool,
_MemPoolContext,
) )
@ -617,7 +614,7 @@ def max_memory_cached(device: "Device" = None) -> int:
return max_memory_reserved(device=device) return max_memory_reserved(device=device)
def memory_snapshot(): def memory_snapshot(mempool_id=None):
r"""Return a snapshot of the CUDA memory allocator state across all devices. r"""Return a snapshot of the CUDA memory allocator state across all devices.
Interpreting the output of this function requires familiarity with the Interpreting the output of this function requires familiarity with the
@ -627,7 +624,7 @@ def memory_snapshot():
See :ref:`cuda-memory-management` for more details about GPU memory See :ref:`cuda-memory-management` for more details about GPU memory
management. management.
""" """
return torch._C._cuda_memorySnapshot()["segments"] return torch._C._cuda_memorySnapshot(mempool_id)["segments"]
def memory_summary(device: "Device" = None, abbreviated: bool = False) -> str: def memory_summary(device: "Device" = None, abbreviated: bool = False) -> str:
@ -998,7 +995,7 @@ def _snapshot(device: "Device" = None):
Returns: Returns:
The Snapshot dictionary object The Snapshot dictionary object
""" """
return _C._cuda_memorySnapshot() return _C._cuda_memorySnapshot(None)
def _dump_snapshot(filename="dump_snapshot.pickle"): def _dump_snapshot(filename="dump_snapshot.pickle"):
@ -1110,25 +1107,6 @@ def _get_current_allocator() -> _CUDAAllocator:
return _CUDAAllocator(torch._C._cuda_getAllocator()) return _CUDAAllocator(torch._C._cuda_getAllocator())
class MemPoolContext(_MemPoolContext):
r"""MemPoolContext holds the currently active pool and stashes the previous
pool. On deletion it makes the previous pool active.
Args:
pool(torch.cuda.MemPool): a MemPool object to be made active so that
allocations route to this pool.
"""
def __init__(self, pool: _MemPool):
super().__init__(pool)
@staticmethod
def active_pool() -> Optional[_MemPool]:
r"""Returns the active MemPool"""
return _MemPoolContext.active_pool()
class MemPool(_MemPool): class MemPool(_MemPool):
r"""MemPool represents a pool of memory in a caching allocator. Currently, r"""MemPool represents a pool of memory in a caching allocator. Currently,
it's just the ID of the pool object maintained in the CUDACachingAllocator. it's just the ID of the pool object maintained in the CUDACachingAllocator.
@ -1177,11 +1155,7 @@ class MemPool(_MemPool):
See :ref:`cuda-memory-management` for more details about GPU memory See :ref:`cuda-memory-management` for more details about GPU memory
management. management.
""" """
try: snapshot = torch.cuda.memory_snapshot(self.id)
ctx = MemPoolContext(self)
snapshot = torch.cuda.memory_snapshot()
finally:
del ctx
return snapshot return snapshot
@ -1202,7 +1176,6 @@ def use_mem_pool(pool: MemPool, device: "Device" = None):
(e.g. by calling backward) the allocations in that thread will not (e.g. by calling backward) the allocations in that thread will not
route to the given pool. route to the given pool.
""" """
ctx = MemPoolContext(pool)
device_index = ( device_index = (
torch.cuda.current_device() if device is None else _get_device_index(device) torch.cuda.current_device() if device is None else _get_device_index(device)
) )
@ -1212,4 +1185,3 @@ def use_mem_pool(pool: MemPool, device: "Device" = None):
finally: finally:
_cuda_endAllocateToPool(device_index, pool.id) _cuda_endAllocateToPool(device_index, pool.id)
_cuda_releasePool(device_index, pool.id) _cuda_releasePool(device_index, pool.id)
del ctx