mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: Per title Test Plan: Added tests + existing tests Differential Revision: D75695030 Pull Request resolved: https://github.com/pytorch/pytorch/pull/154746 Approved by: https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
932733e0e6
commit
f01e628e3b
@ -833,8 +833,9 @@ class EventPool {
|
|||||||
|
|
||||||
// CUDA graphs helper
|
// CUDA graphs helper
|
||||||
struct PrivatePool {
|
struct PrivatePool {
|
||||||
PrivatePool(MempoolId_t id)
|
PrivatePool(MempoolId_t id, CUDAAllocator* allocator = nullptr)
|
||||||
: id(std::move(id)),
|
: id(std::move(id)),
|
||||||
|
allocator_(allocator),
|
||||||
large_blocks(/*small=*/false, this),
|
large_blocks(/*small=*/false, this),
|
||||||
small_blocks(/*small=*/true, this) {}
|
small_blocks(/*small=*/true, this) {}
|
||||||
PrivatePool(const PrivatePool&) = delete;
|
PrivatePool(const PrivatePool&) = delete;
|
||||||
@ -855,8 +856,14 @@ struct PrivatePool {
|
|||||||
// distinguish private blocks by adding a "pool id" check above the stream
|
// distinguish private blocks by adding a "pool id" check above the stream
|
||||||
// check in BlockComparator. BlockComparator is performance- critical though,
|
// check in BlockComparator. BlockComparator is performance- critical though,
|
||||||
// I'd rather not add more logic to it.
|
// I'd rather not add more logic to it.
|
||||||
|
CUDAAllocator* allocator_;
|
||||||
BlockPool large_blocks;
|
BlockPool large_blocks;
|
||||||
BlockPool small_blocks;
|
BlockPool small_blocks;
|
||||||
|
|
||||||
|
public:
|
||||||
|
CUDAAllocator* allocator() {
|
||||||
|
return allocator_;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
MempoolId_t BlockPool::owner_MempoolId() const {
|
MempoolId_t BlockPool::owner_MempoolId() const {
|
||||||
@ -905,9 +912,8 @@ struct MempoolIdHash {
|
|||||||
};
|
};
|
||||||
|
|
||||||
cudaError_t allocPrimitive(void** ptr, size_t size, AllocParams& p) {
|
cudaError_t allocPrimitive(void** ptr, size_t size, AllocParams& p) {
|
||||||
auto active_pool = MemPoolContext::getActiveMemPool();
|
if (p.pool->owner_PrivatePool && p.pool->owner_PrivatePool->allocator()) {
|
||||||
if (active_pool && active_pool->allocator() && p.pool->owner_PrivatePool) {
|
*ptr = p.pool->owner_PrivatePool->allocator()->raw_alloc(size);
|
||||||
*ptr = active_pool->allocator()->raw_alloc(size);
|
|
||||||
return *ptr ? cudaSuccess : cudaErrorMemoryAllocation;
|
return *ptr ? cudaSuccess : cudaErrorMemoryAllocation;
|
||||||
} else {
|
} else {
|
||||||
return C10_CUDA_ERROR_HANDLED(cudaMalloc(ptr, size));
|
return C10_CUDA_ERROR_HANDLED(cudaMalloc(ptr, size));
|
||||||
@ -1277,14 +1283,14 @@ class DeviceCachingAllocator {
|
|||||||
alloc_block(params, false, context, lock))
|
alloc_block(params, false, context, lock))
|
||||||
// Free all non-split cached blocks and retry alloc.
|
// Free all non-split cached blocks and retry alloc.
|
||||||
|| (C10_LIKELY(captures_underway.empty()) &&
|
|| (C10_LIKELY(captures_underway.empty()) &&
|
||||||
release_cached_blocks(context) &&
|
release_cached_blocks(context, {0, 0}) &&
|
||||||
alloc_block(params, true, context, lock));
|
alloc_block(params, true, context, lock));
|
||||||
}
|
}
|
||||||
|
|
||||||
// we are about to oom, try to use existing mempools as a last resort
|
// we are about to oom, try to use existing mempools as a last resort
|
||||||
if (!block_found && params.err == cudaErrorMemoryAllocation) {
|
if (!block_found && params.err == cudaErrorMemoryAllocation) {
|
||||||
// if already trying to use a mempool, then just oom
|
// if already trying to use a mempool, then just oom
|
||||||
auto active_pool = MemPoolContext::getActiveMemPool();
|
bool active_pool = params.pool->owner_PrivatePool;
|
||||||
if (!active_pool) {
|
if (!active_pool) {
|
||||||
for (MempoolId_t mempool_id : use_on_oom_pools) {
|
for (MempoolId_t mempool_id : use_on_oom_pools) {
|
||||||
auto tid = std::this_thread::get_id();
|
auto tid = std::this_thread::get_id();
|
||||||
@ -1671,10 +1677,10 @@ class DeviceCachingAllocator {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/** returns cached blocks to the system allocator **/
|
/** returns cached blocks to the system allocator **/
|
||||||
void emptyCache() {
|
void emptyCache(MempoolId_t mempool_id) {
|
||||||
auto context = maybeGatherContext(RecordContext::ALL);
|
auto context = maybeGatherContext(RecordContext::ALL);
|
||||||
std::lock_guard<std::recursive_mutex> lock(mutex);
|
std::lock_guard<std::recursive_mutex> lock(mutex);
|
||||||
release_cached_blocks(context);
|
release_cached_blocks(context, mempool_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Retrieves size of largest unused block held by the memory cache **/
|
/** Retrieves size of largest unused block held by the memory cache **/
|
||||||
@ -1992,16 +1998,10 @@ class DeviceCachingAllocator {
|
|||||||
|
|
||||||
/** Dump a complete snapshot of the memory held by the allocator. Potentially
|
/** Dump a complete snapshot of the memory held by the allocator. Potentially
|
||||||
* VERY expensive. **/
|
* VERY expensive. **/
|
||||||
std::vector<SegmentInfo> snapshot() {
|
std::vector<SegmentInfo> snapshot(MempoolId_t mempool_id) {
|
||||||
std::lock_guard<std::recursive_mutex> lock(mutex);
|
std::lock_guard<std::recursive_mutex> lock(mutex);
|
||||||
|
|
||||||
std::vector<Block*> all_blocks;
|
std::vector<Block*> all_blocks;
|
||||||
MempoolId_t mempool_id = {0, 0};
|
|
||||||
|
|
||||||
auto active_mempool = MemPoolContext::getActiveMemPool();
|
|
||||||
if (active_mempool) {
|
|
||||||
mempool_id = active_mempool->id();
|
|
||||||
}
|
|
||||||
|
|
||||||
if (mempool_id.first != 0 || mempool_id.second != 0) {
|
if (mempool_id.first != 0 || mempool_id.second != 0) {
|
||||||
// If there is an active mempool, we find the corresponding PrivatePool
|
// If there is an active mempool, we find the corresponding PrivatePool
|
||||||
@ -2011,7 +2011,7 @@ class DeviceCachingAllocator {
|
|||||||
all_blocks = get_private_pool_head_blocks(pool->second.get());
|
all_blocks = get_private_pool_head_blocks(pool->second.get());
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// When snapshot is called outside a MemPoolContext, we return
|
// When snapshot is called with non-default mempool_id, we return
|
||||||
// all the blocks in the CUDACachingAllocator (as returned by
|
// all the blocks in the CUDACachingAllocator (as returned by
|
||||||
// get_all_blocks).
|
// get_all_blocks).
|
||||||
all_blocks = get_all_blocks();
|
all_blocks = get_all_blocks();
|
||||||
@ -2130,11 +2130,11 @@ class DeviceCachingAllocator {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void ensureExistsAndIncrefPool(MempoolId_t mempool_id) {
|
void createOrIncrefPool(MempoolId_t mempool_id, CUDAAllocator* allocator) {
|
||||||
// Create a PrivatePool object if it does not exist yet
|
// Create a PrivatePool object if it does not exist yet
|
||||||
// and increment its use_count
|
// and increment its use_count
|
||||||
std::lock_guard<std::recursive_mutex> lock(mutex);
|
std::lock_guard<std::recursive_mutex> lock(mutex);
|
||||||
ensure_exists_and_incref_pool(mempool_id);
|
create_or_incref_pool(mempool_id, allocator);
|
||||||
}
|
}
|
||||||
|
|
||||||
void setUseOnOOM(MempoolId_t mempool_id) {
|
void setUseOnOOM(MempoolId_t mempool_id) {
|
||||||
@ -2150,7 +2150,7 @@ class DeviceCachingAllocator {
|
|||||||
MempoolId_t mempool_id,
|
MempoolId_t mempool_id,
|
||||||
std::function<bool(cudaStream_t)> filter) {
|
std::function<bool(cudaStream_t)> filter) {
|
||||||
std::lock_guard<std::recursive_mutex> lock(mutex);
|
std::lock_guard<std::recursive_mutex> lock(mutex);
|
||||||
ensure_exists_and_incref_pool(mempool_id);
|
create_or_incref_pool(mempool_id);
|
||||||
for (auto it2 = captures_underway.begin(); it2 != captures_underway.end();
|
for (auto it2 = captures_underway.begin(); it2 != captures_underway.end();
|
||||||
++it2) {
|
++it2) {
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
@ -2272,21 +2272,24 @@ class DeviceCachingAllocator {
|
|||||||
return blocks;
|
return blocks;
|
||||||
}
|
}
|
||||||
|
|
||||||
void ensure_exists_and_incref_pool(MempoolId_t mempool_id) {
|
void create_or_incref_pool(
|
||||||
|
MempoolId_t mempool_id,
|
||||||
|
CUDAAllocator* allocator = nullptr) {
|
||||||
auto it = graph_pools.find(mempool_id);
|
auto it = graph_pools.find(mempool_id);
|
||||||
if (it == graph_pools.end()) {
|
if (it == graph_pools.end()) {
|
||||||
// mempool_id does not reference an existing pool.
|
// mempool_id does not reference an existing pool.
|
||||||
// Make a new pool for CUDAGraph capture or torch.cuda.use_mem_pool
|
// Make a new pool for CUDAGraph capture or torch.cuda.use_mem_pool
|
||||||
// usage. use_count is initially 1, which means the pool is
|
// usage. use_count is initially 1, which means the pool is
|
||||||
// being used since somebody called ensureExistsAndIncrefPool.
|
// being used since somebody called createOrIncrefPool.
|
||||||
graph_pools.emplace(
|
graph_pools.emplace(
|
||||||
mempool_id, std::make_unique<PrivatePool>(mempool_id));
|
mempool_id, std::make_unique<PrivatePool>(mempool_id, allocator));
|
||||||
} else {
|
} else {
|
||||||
// mempool_id references an existing pool, which the current CUDAGraph
|
// mempool_id references an existing pool, which the current CUDAGraph
|
||||||
// capture or torch.cuda.use_mem_pool will
|
// capture or torch.cuda.use_mem_pool will
|
||||||
// share. Check this pool is live (at least one other capture already
|
// share. Check this pool is live (at least one other capture already
|
||||||
// references it). Increment it to establish the usage.
|
// references it). Increment it to establish the usage.
|
||||||
TORCH_INTERNAL_ASSERT(it->second->use_count > 0);
|
TORCH_INTERNAL_ASSERT(it->second->use_count > 0);
|
||||||
|
TORCH_INTERNAL_ASSERT(allocator == nullptr);
|
||||||
it->second->use_count++;
|
it->second->use_count++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -2776,7 +2779,8 @@ class DeviceCachingAllocator {
|
|||||||
bool in_fbcode = false;
|
bool in_fbcode = false;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
auto active_pool = MemPoolContext::getActiveMemPool();
|
bool active_pool =
|
||||||
|
p.pool->owner_PrivatePool && p.pool->owner_PrivatePool->allocator();
|
||||||
if (set_fraction &&
|
if (set_fraction &&
|
||||||
total_allocated_memory + size > allowed_memory_maximum) {
|
total_allocated_memory + size > allowed_memory_maximum) {
|
||||||
p.err = cudaErrorMemoryAllocation;
|
p.err = cudaErrorMemoryAllocation;
|
||||||
@ -2801,12 +2805,6 @@ class DeviceCachingAllocator {
|
|||||||
}
|
}
|
||||||
return bool(p.block);
|
return bool(p.block);
|
||||||
} else {
|
} else {
|
||||||
if (active_pool && active_pool->allocator() &&
|
|
||||||
p.pool->owner_PrivatePool) {
|
|
||||||
// Ensure that active_pool and p.pool are the same
|
|
||||||
auto pp = get_private_pool(active_pool->id());
|
|
||||||
TORCH_INTERNAL_ASSERT(pp == p.pool->owner_PrivatePool);
|
|
||||||
}
|
|
||||||
if (CUDAAllocatorConfig::release_lock_on_cudamalloc()) {
|
if (CUDAAllocatorConfig::release_lock_on_cudamalloc()) {
|
||||||
// At scope exit, acquire the lock again. This provides safety against
|
// At scope exit, acquire the lock again. This provides safety against
|
||||||
// any potential exceptions in the cudaMallocMaybeCapturing function.
|
// any potential exceptions in the cudaMallocMaybeCapturing function.
|
||||||
@ -2926,13 +2924,9 @@ class DeviceCachingAllocator {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool release_cached_blocks(const std::shared_ptr<GatheredContext>& context) {
|
bool release_cached_blocks(
|
||||||
MempoolId_t mempool_id = {0, 0};
|
const std::shared_ptr<GatheredContext>& context,
|
||||||
auto active_mempool = MemPoolContext::getActiveMemPool();
|
MempoolId_t mempool_id) {
|
||||||
if (active_mempool) {
|
|
||||||
mempool_id = active_mempool->id();
|
|
||||||
}
|
|
||||||
|
|
||||||
if (mempool_id.first == 0 && mempool_id.second == 0) {
|
if (mempool_id.first == 0 && mempool_id.second == 0) {
|
||||||
// If there is no active mempool, we work on releasing *all* blocks.
|
// If there is no active mempool, we work on releasing *all* blocks.
|
||||||
|
|
||||||
@ -3005,15 +2999,10 @@ class DeviceCachingAllocator {
|
|||||||
context ? context : block->context_when_segment_allocated);
|
context ? context : block->context_when_segment_allocated);
|
||||||
|
|
||||||
auto* pool = block->pool;
|
auto* pool = block->pool;
|
||||||
auto active_pool = MemPoolContext::getActiveMemPool();
|
if (pool->owner_PrivatePool && pool->owner_PrivatePool->allocator()) {
|
||||||
if (active_pool && active_pool->allocator() && pool->owner_PrivatePool) {
|
|
||||||
// Ensure that active_pool and pool are the same
|
|
||||||
auto pp = get_private_pool(active_pool->id());
|
|
||||||
TORCH_INTERNAL_ASSERT(pp == pool->owner_PrivatePool);
|
|
||||||
|
|
||||||
// If there is an active mempool with a given allocator,
|
// If there is an active mempool with a given allocator,
|
||||||
// we use the given allocator's delete function.
|
// we use the given allocator's delete function.
|
||||||
active_pool->allocator()->raw_delete((void*)block->ptr);
|
pool->owner_PrivatePool->allocator()->raw_delete((void*)block->ptr);
|
||||||
} else {
|
} else {
|
||||||
C10_CUDA_CHECK(cudaFree((void*)block->ptr));
|
C10_CUDA_CHECK(cudaFree((void*)block->ptr));
|
||||||
}
|
}
|
||||||
@ -3589,9 +3578,9 @@ class NativeCachingAllocator : public CUDAAllocator {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void emptyCache() override {
|
void emptyCache(MempoolId_t mempool_id) override {
|
||||||
for (auto& da : device_allocator)
|
for (auto& da : device_allocator)
|
||||||
da->emptyCache();
|
da->emptyCache(mempool_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
void enable(bool value) override {
|
void enable(bool value) override {
|
||||||
@ -3639,7 +3628,7 @@ class NativeCachingAllocator : public CUDAAllocator {
|
|||||||
device_allocator[block->device]->recordStream(block, stream);
|
device_allocator[block->device]->recordStream(block, stream);
|
||||||
}
|
}
|
||||||
|
|
||||||
SnapshotInfo snapshot() override {
|
SnapshotInfo snapshot(MempoolId_t mempool_id) override {
|
||||||
// Set-up converter to convert timestamps from tsc to microseconds.
|
// Set-up converter to convert timestamps from tsc to microseconds.
|
||||||
auto tsc_to_ns = clock_converter.makeConverter();
|
auto tsc_to_ns = clock_converter.makeConverter();
|
||||||
auto tsc_to_us = [=](approx_time_t t_approx) {
|
auto tsc_to_us = [=](approx_time_t t_approx) {
|
||||||
@ -3657,7 +3646,7 @@ class NativeCachingAllocator : public CUDAAllocator {
|
|||||||
// Get the device_traces' TraceEntry lists.
|
// Get the device_traces' TraceEntry lists.
|
||||||
for (auto& da : device_allocator) {
|
for (auto& da : device_allocator) {
|
||||||
result.device_traces.emplace_back(da->trace(tsc_to_us));
|
result.device_traces.emplace_back(da->trace(tsc_to_us));
|
||||||
auto snap = da->snapshot();
|
auto snap = da->snapshot(mempool_id);
|
||||||
result.segments.insert(result.segments.end(), snap.begin(), snap.end());
|
result.segments.insert(result.segments.end(), snap.begin(), snap.end());
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -3785,11 +3774,13 @@ class NativeCachingAllocator : public CUDAAllocator {
|
|||||||
device_allocator[device]->resetPeakStats();
|
device_allocator[device]->resetPeakStats();
|
||||||
}
|
}
|
||||||
|
|
||||||
void ensureExistsAndIncrefPool(
|
void createOrIncrefPool(
|
||||||
c10::DeviceIndex device,
|
c10::DeviceIndex device,
|
||||||
MempoolId_t mempool_id) override {
|
MempoolId_t mempool_id,
|
||||||
|
CUDAAllocator* allocator) override {
|
||||||
assertValidDevice(device);
|
assertValidDevice(device);
|
||||||
device_allocator[device]->ensureExistsAndIncrefPool(std::move(mempool_id));
|
device_allocator[device]->createOrIncrefPool(
|
||||||
|
std::move(mempool_id), allocator);
|
||||||
}
|
}
|
||||||
|
|
||||||
void setUseOnOOM(c10::DeviceIndex device, MempoolId_t mempool_id) override {
|
void setUseOnOOM(c10::DeviceIndex device, MempoolId_t mempool_id) override {
|
||||||
@ -4134,7 +4125,7 @@ MemPool::MemPool(
|
|||||||
id_ = {uuid_++, 0};
|
id_ = {uuid_++, 0};
|
||||||
}
|
}
|
||||||
device_ = c10::cuda::current_device();
|
device_ = c10::cuda::current_device();
|
||||||
CUDACachingAllocator::ensureExistsAndIncrefPool(device_, id_);
|
CUDACachingAllocator::createOrIncrefPool(device_, id_, allocator);
|
||||||
if (use_on_oom) {
|
if (use_on_oom) {
|
||||||
CUDACachingAllocator::setUseOnOOM(device_, id_);
|
CUDACachingAllocator::setUseOnOOM(device_, id_);
|
||||||
}
|
}
|
||||||
@ -4143,8 +4134,7 @@ MemPool::MemPool(
|
|||||||
MemPool::~MemPool() {
|
MemPool::~MemPool() {
|
||||||
TORCH_INTERNAL_ASSERT(use_count() == 1);
|
TORCH_INTERNAL_ASSERT(use_count() == 1);
|
||||||
CUDACachingAllocator::releasePool(device_, id_);
|
CUDACachingAllocator::releasePool(device_, id_);
|
||||||
auto ctx = MemPoolContext(this);
|
c10::cuda::CUDACachingAllocator::emptyCache(id_);
|
||||||
c10::cuda::CUDACachingAllocator::emptyCache();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
MempoolId_t MemPool::id() {
|
MempoolId_t MemPool::id() {
|
||||||
@ -4170,23 +4160,4 @@ MempoolId_t MemPool::graph_pool_handle(bool is_user_created) {
|
|||||||
return {uuid_++, 0};
|
return {uuid_++, 0};
|
||||||
}
|
}
|
||||||
|
|
||||||
// Note that active_mempool_ is a global variable here
|
|
||||||
// and not inside MemPoolContext class, because in windows we
|
|
||||||
// can't use __declspec(dllexport) and __declspec(thread)
|
|
||||||
// together: https://stackoverflow.com/a/50967977
|
|
||||||
static thread_local MemPool* active_mempool_ = nullptr;
|
|
||||||
|
|
||||||
MemPoolContext::MemPoolContext(MemPool* mempool)
|
|
||||||
: prev_mempool_(active_mempool_) {
|
|
||||||
active_mempool_ = mempool;
|
|
||||||
}
|
|
||||||
|
|
||||||
MemPoolContext::~MemPoolContext() {
|
|
||||||
active_mempool_ = prev_mempool_;
|
|
||||||
}
|
|
||||||
|
|
||||||
MemPool* MemPoolContext::getActiveMemPool() {
|
|
||||||
return active_mempool_;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace c10::cuda
|
} // namespace c10::cuda
|
||||||
|
@ -211,7 +211,7 @@ class CUDAAllocator : public Allocator {
|
|||||||
virtual bool initialized() = 0;
|
virtual bool initialized() = 0;
|
||||||
virtual double getMemoryFraction(c10::DeviceIndex device) = 0;
|
virtual double getMemoryFraction(c10::DeviceIndex device) = 0;
|
||||||
virtual void setMemoryFraction(double fraction, c10::DeviceIndex device) = 0;
|
virtual void setMemoryFraction(double fraction, c10::DeviceIndex device) = 0;
|
||||||
virtual void emptyCache() = 0;
|
virtual void emptyCache(MempoolId_t mempool_id = {0, 0}) = 0;
|
||||||
virtual void enable(bool value) = 0;
|
virtual void enable(bool value) = 0;
|
||||||
virtual bool isEnabled() const = 0;
|
virtual bool isEnabled() const = 0;
|
||||||
virtual void cacheInfo(c10::DeviceIndex device, size_t* largestBlock) = 0;
|
virtual void cacheInfo(c10::DeviceIndex device, size_t* largestBlock) = 0;
|
||||||
@ -221,7 +221,7 @@ class CUDAAllocator : public Allocator {
|
|||||||
c10::DeviceIndex device) = 0;
|
c10::DeviceIndex device) = 0;
|
||||||
virtual void resetAccumulatedStats(c10::DeviceIndex device) = 0;
|
virtual void resetAccumulatedStats(c10::DeviceIndex device) = 0;
|
||||||
virtual void resetPeakStats(c10::DeviceIndex device) = 0;
|
virtual void resetPeakStats(c10::DeviceIndex device) = 0;
|
||||||
virtual SnapshotInfo snapshot() = 0;
|
virtual SnapshotInfo snapshot(MempoolId_t mempool_id = {0, 0}) = 0;
|
||||||
virtual void beginAllocateToPool(
|
virtual void beginAllocateToPool(
|
||||||
c10::DeviceIndex device,
|
c10::DeviceIndex device,
|
||||||
MempoolId_t mempool_id,
|
MempoolId_t mempool_id,
|
||||||
@ -239,13 +239,14 @@ class CUDAAllocator : public Allocator {
|
|||||||
" does not yet support getPoolUseCount. "
|
" does not yet support getPoolUseCount. "
|
||||||
"If you need it, please file an issue describing your use case.");
|
"If you need it, please file an issue describing your use case.");
|
||||||
}
|
}
|
||||||
virtual void ensureExistsAndIncrefPool(
|
virtual void createOrIncrefPool(
|
||||||
c10::DeviceIndex /*device*/,
|
c10::DeviceIndex /*device*/,
|
||||||
MempoolId_t /*mempool_id*/) {
|
MempoolId_t /*mempool_id*/,
|
||||||
|
CUDAAllocator* allocator = nullptr) {
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
false,
|
false,
|
||||||
name(),
|
name(),
|
||||||
" does not yet support ensureExistsAndIncrefPool. "
|
" does not yet support createOrIncrefPool. "
|
||||||
"If you need it, please file an issue describing your use case.");
|
"If you need it, please file an issue describing your use case.");
|
||||||
}
|
}
|
||||||
virtual void setUseOnOOM(c10::DeviceIndex device, MempoolId_t mempool_id) {
|
virtual void setUseOnOOM(c10::DeviceIndex device, MempoolId_t mempool_id) {
|
||||||
@ -364,7 +365,7 @@ inline void setMemoryFraction(double fraction, c10::DeviceIndex device) {
|
|||||||
return get()->setMemoryFraction(fraction, device);
|
return get()->setMemoryFraction(fraction, device);
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void emptyCache() {
|
inline void emptyCache(MempoolId_t mempool_id = {0, 0}) {
|
||||||
return get()->emptyCache();
|
return get()->emptyCache();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -401,8 +402,8 @@ inline void resetPeakStats(c10::DeviceIndex device) {
|
|||||||
return get()->resetPeakStats(device);
|
return get()->resetPeakStats(device);
|
||||||
}
|
}
|
||||||
|
|
||||||
inline SnapshotInfo snapshot() {
|
inline SnapshotInfo snapshot(MempoolId_t mempool_id = {0, 0}) {
|
||||||
return get()->snapshot();
|
return get()->snapshot(mempool_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
inline std::shared_ptr<AllocatorState> getCheckpointState(
|
inline std::shared_ptr<AllocatorState> getCheckpointState(
|
||||||
@ -475,10 +476,11 @@ inline void attachAllocatorTraceTracker(AllocatorTraceTracker tracker) {
|
|||||||
inline void releasePool(c10::DeviceIndex device, MempoolId_t mempool_id) {
|
inline void releasePool(c10::DeviceIndex device, MempoolId_t mempool_id) {
|
||||||
return get()->releasePool(device, mempool_id);
|
return get()->releasePool(device, mempool_id);
|
||||||
}
|
}
|
||||||
inline void ensureExistsAndIncrefPool(
|
inline void createOrIncrefPool(
|
||||||
c10::DeviceIndex device,
|
c10::DeviceIndex device,
|
||||||
MempoolId_t mempool_id) {
|
MempoolId_t mempool_id,
|
||||||
get()->ensureExistsAndIncrefPool(device, mempool_id);
|
CUDAAllocator* allocator_ptr = nullptr) {
|
||||||
|
get()->createOrIncrefPool(device, mempool_id, allocator_ptr);
|
||||||
}
|
}
|
||||||
inline void setUseOnOOM(c10::DeviceIndex device, MempoolId_t mempool_id) {
|
inline void setUseOnOOM(c10::DeviceIndex device, MempoolId_t mempool_id) {
|
||||||
get()->setUseOnOOM(device, mempool_id);
|
get()->setUseOnOOM(device, mempool_id);
|
||||||
@ -555,26 +557,4 @@ struct C10_CUDA_API MemPool {
|
|||||||
c10::DeviceIndex device_;
|
c10::DeviceIndex device_;
|
||||||
};
|
};
|
||||||
|
|
||||||
// MemPoolContext holds the currently active pool and stashes the previous
|
|
||||||
// pool. On deletion it makes the previous pool active.
|
|
||||||
struct C10_CUDA_API MemPoolContext {
|
|
||||||
MemPoolContext(MemPool* mempool);
|
|
||||||
|
|
||||||
~MemPoolContext();
|
|
||||||
|
|
||||||
// getActiveMemPool() can be used to get the currently active pool.
|
|
||||||
// For instance: in CUDACachingAllocator, we can route allocations
|
|
||||||
// to a user provided allocator, by doing:
|
|
||||||
//
|
|
||||||
// auto active_pool = MemPoolContext::getActiveMemPool();
|
|
||||||
// if (active_pool && active_pool->allocator()) {
|
|
||||||
// ptr = active_pool->allocator()->raw_alloc(size);
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
static MemPool* getActiveMemPool();
|
|
||||||
|
|
||||||
private:
|
|
||||||
MemPool* prev_mempool_;
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace c10::cuda
|
} // namespace c10::cuda
|
||||||
|
@ -496,7 +496,7 @@ struct CudaMallocAsyncAllocator : public CUDAAllocator {
|
|||||||
// introduces performance nondeterminism.
|
// introduces performance nondeterminism.
|
||||||
}
|
}
|
||||||
|
|
||||||
void emptyCache() override {
|
void emptyCache(/*unused*/ MempoolId_t mempool_id) override {
|
||||||
std::lock_guard<std::mutex> lk(general_mutex);
|
std::lock_guard<std::mutex> lk(general_mutex);
|
||||||
|
|
||||||
for (int dev = 0; dev < device_count; dev++) {
|
for (int dev = 0; dev < device_count; dev++) {
|
||||||
@ -778,7 +778,7 @@ struct CudaMallocAsyncAllocator : public CUDAAllocator {
|
|||||||
cudaMemPoolSetAttribute(mempool, cudaMemPoolAttrUsedMemHigh, &zero));
|
cudaMemPoolSetAttribute(mempool, cudaMemPoolAttrUsedMemHigh, &zero));
|
||||||
}
|
}
|
||||||
|
|
||||||
SnapshotInfo snapshot() override {
|
SnapshotInfo snapshot(MempoolId_t mempool_id) override {
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
false,
|
false,
|
||||||
"Calling snapshot with backend:cudaMallocAsync is not meaningful. "
|
"Calling snapshot with backend:cudaMallocAsync is not meaningful. "
|
||||||
|
@ -2282,7 +2282,6 @@ coverage_ignore_classes = [
|
|||||||
"UnsynchronizedAccessError",
|
"UnsynchronizedAccessError",
|
||||||
# torch.cuda.memory
|
# torch.cuda.memory
|
||||||
"MemPool",
|
"MemPool",
|
||||||
"MemPoolContext",
|
|
||||||
# torch.distributed.elastic.multiprocessing.errors
|
# torch.distributed.elastic.multiprocessing.errors
|
||||||
"ChildFailedError",
|
"ChildFailedError",
|
||||||
"ProcessFailure",
|
"ProcessFailure",
|
||||||
|
@ -128,7 +128,6 @@ Memory management
|
|||||||
CUDAPluggableAllocator
|
CUDAPluggableAllocator
|
||||||
change_current_allocator
|
change_current_allocator
|
||||||
MemPool
|
MemPool
|
||||||
MemPoolContext
|
|
||||||
|
|
||||||
.. currentmodule:: torch.cuda.memory
|
.. currentmodule:: torch.cuda.memory
|
||||||
|
|
||||||
|
@ -5049,41 +5049,57 @@ class TestMemPool(TestCase):
|
|||||||
# increments the id
|
# increments the id
|
||||||
self.assertTrue(abs(pool2[1] - pool1[1]) > 0)
|
self.assertTrue(abs(pool2[1] - pool1[1]) > 0)
|
||||||
|
|
||||||
def test_mempool_with_allocator(self):
|
def get_dummy_allocator(self, check_vars):
|
||||||
pool = torch.cuda.MemPool()
|
dummy_allocator_source_vars = """
|
||||||
|
|
||||||
# MemPool doesn't have an allocator by default
|
|
||||||
self.assertEqual(pool.allocator, None)
|
|
||||||
|
|
||||||
from torch.utils.cpp_extension import load_inline
|
|
||||||
|
|
||||||
dummy_allocator_source = """
|
|
||||||
#include <torch/extension.h>
|
#include <torch/extension.h>
|
||||||
#include <ATen/cuda/Exceptions.h>
|
#include <ATen/cuda/Exceptions.h>
|
||||||
#include <cuda_runtime_api.h>
|
#include <cuda_runtime_api.h>
|
||||||
|
|
||||||
extern "C" {
|
extern "C" {
|
||||||
C10_EXPORT int called_dummy_alloc = 0;
|
C10_EXPORT int called_dummy_alloc = 0;
|
||||||
C10_EXPORT int called_dummy_free = 0;
|
C10_EXPORT int called_dummy_free = 0;
|
||||||
|
|
||||||
|
// Note that windows needs __declspec(dllexport): https://stackoverflow.com/a/24575865
|
||||||
|
C10_EXPORT void* dummy_alloc(size_t size, int device, void* stream) {
|
||||||
|
called_dummy_alloc = 123;
|
||||||
|
void* ptr;
|
||||||
|
C10_CUDA_CHECK(cudaMallocManaged(&ptr, size));
|
||||||
|
return ptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
C10_EXPORT void dummy_free(void* ptr, size_t size, int device, void* stream) {
|
||||||
|
called_dummy_free = 321;
|
||||||
|
C10_CUDA_CHECK(cudaFree(ptr));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
dummy_allocator_source_no_vars = """
|
||||||
|
#include <torch/extension.h>
|
||||||
|
#include <ATen/cuda/Exceptions.h>
|
||||||
|
#include <cuda_runtime_api.h>
|
||||||
|
|
||||||
|
extern "C" {
|
||||||
// Note that windows needs __declspec(dllexport): https://stackoverflow.com/a/24575865
|
// Note that windows needs __declspec(dllexport): https://stackoverflow.com/a/24575865
|
||||||
C10_EXPORT void* dummy_alloc(size_t size, int device, void* stream) {
|
C10_EXPORT void* dummy_alloc(size_t size, int device, void* stream) {
|
||||||
called_dummy_alloc = 123;
|
|
||||||
void* ptr;
|
void* ptr;
|
||||||
C10_CUDA_CHECK(cudaMallocManaged(&ptr, size));
|
C10_CUDA_CHECK(cudaMallocManaged(&ptr, size));
|
||||||
return ptr;
|
return ptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
C10_EXPORT void dummy_free(void* ptr, size_t size, int device, void* stream) {
|
C10_EXPORT void dummy_free(void* ptr, size_t size, int device, void* stream) {
|
||||||
called_dummy_free = 321;
|
|
||||||
C10_CUDA_CHECK(cudaFree(ptr));
|
C10_CUDA_CHECK(cudaFree(ptr));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from torch.utils.cpp_extension import load_inline
|
||||||
|
|
||||||
dummy_allocator_libname = "dummy_allocator"
|
dummy_allocator_libname = "dummy_allocator"
|
||||||
dummy_allocator = load_inline(
|
dummy_allocator = load_inline(
|
||||||
name=dummy_allocator_libname,
|
name=dummy_allocator_libname,
|
||||||
cpp_sources=dummy_allocator_source,
|
cpp_sources=dummy_allocator_source_vars
|
||||||
|
if check_vars
|
||||||
|
else dummy_allocator_source_no_vars,
|
||||||
is_python_module=False,
|
is_python_module=False,
|
||||||
keep_intermediates=False,
|
keep_intermediates=False,
|
||||||
verbose=True,
|
verbose=True,
|
||||||
@ -5094,6 +5110,15 @@ class TestMemPool(TestCase):
|
|||||||
"dummy_alloc",
|
"dummy_alloc",
|
||||||
"dummy_free",
|
"dummy_free",
|
||||||
)
|
)
|
||||||
|
return allocator, dummy_allocator
|
||||||
|
|
||||||
|
def test_mempool_with_allocator(self):
|
||||||
|
pool = torch.cuda.MemPool()
|
||||||
|
|
||||||
|
# MemPool doesn't have an allocator by default
|
||||||
|
self.assertEqual(pool.allocator, None)
|
||||||
|
allocator, dummy_allocator = self.get_dummy_allocator(check_vars=True)
|
||||||
|
|
||||||
pool = torch.cuda.MemPool(allocator.allocator())
|
pool = torch.cuda.MemPool(allocator.allocator())
|
||||||
|
|
||||||
# pool should point to the same allocator as the one passed into it
|
# pool should point to the same allocator as the one passed into it
|
||||||
@ -5128,6 +5153,8 @@ class TestMemPool(TestCase):
|
|||||||
# out tensor
|
# out tensor
|
||||||
self.assertEqual(called_dummy_alloc.value, 123)
|
self.assertEqual(called_dummy_alloc.value, 123)
|
||||||
|
|
||||||
|
out_non_pool = torch.empty(nelem_1mb, device="cuda")
|
||||||
|
|
||||||
with torch.cuda.use_mem_pool(pool):
|
with torch.cuda.use_mem_pool(pool):
|
||||||
# pool should have 1 segment since we made a small allocation (1 MB)
|
# pool should have 1 segment since we made a small allocation (1 MB)
|
||||||
# above and so the CUDACachingAllocator packed it into a 2 MB buffer
|
# above and so the CUDACachingAllocator packed it into a 2 MB buffer
|
||||||
@ -5145,6 +5172,8 @@ class TestMemPool(TestCase):
|
|||||||
# to make a new 2 MB buffer to accomodate out_2
|
# to make a new 2 MB buffer to accomodate out_2
|
||||||
self.assertEqual(len(pool.snapshot()), 2)
|
self.assertEqual(len(pool.snapshot()), 2)
|
||||||
|
|
||||||
|
self.assertEqual(len(pool.snapshot()), 2)
|
||||||
|
|
||||||
del out_0, out_1, out_2
|
del out_0, out_1, out_2
|
||||||
|
|
||||||
# pool's destructor calls emptyCache()
|
# pool's destructor calls emptyCache()
|
||||||
@ -5156,40 +5185,7 @@ class TestMemPool(TestCase):
|
|||||||
|
|
||||||
@serialTest()
|
@serialTest()
|
||||||
def test_mempool_limited_memory_with_allocator(self):
|
def test_mempool_limited_memory_with_allocator(self):
|
||||||
from torch.utils.cpp_extension import load_inline
|
allocator, _ = self.get_dummy_allocator(check_vars=False)
|
||||||
|
|
||||||
dummy_allocator_source = """
|
|
||||||
#include <torch/extension.h>
|
|
||||||
#include <ATen/cuda/Exceptions.h>
|
|
||||||
#include <cuda_runtime_api.h>
|
|
||||||
|
|
||||||
extern "C" {
|
|
||||||
// Note that windows needs __declspec(dllexport): https://stackoverflow.com/a/24575865
|
|
||||||
C10_EXPORT void* dummy_alloc(size_t size, int device, void* stream) {
|
|
||||||
void* ptr;
|
|
||||||
C10_CUDA_CHECK(cudaMallocManaged(&ptr, size));
|
|
||||||
return ptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
C10_EXPORT void dummy_free(void* ptr, size_t size, int device, void* stream) {
|
|
||||||
C10_CUDA_CHECK(cudaFree(ptr));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
dummy_allocator_libname = "dummy_allocator"
|
|
||||||
dummy_allocator = load_inline(
|
|
||||||
name=dummy_allocator_libname,
|
|
||||||
cpp_sources=dummy_allocator_source,
|
|
||||||
is_python_module=False,
|
|
||||||
keep_intermediates=False,
|
|
||||||
verbose=True,
|
|
||||||
with_cuda=True,
|
|
||||||
)
|
|
||||||
allocator = torch.cuda.memory.CUDAPluggableAllocator(
|
|
||||||
dummy_allocator,
|
|
||||||
"dummy_alloc",
|
|
||||||
"dummy_free",
|
|
||||||
)
|
|
||||||
pool_do_not_use = torch.cuda.MemPool(allocator.allocator())
|
pool_do_not_use = torch.cuda.MemPool(allocator.allocator())
|
||||||
pool_use = torch.cuda.MemPool(allocator.allocator(), use_on_oom=True)
|
pool_use = torch.cuda.MemPool(allocator.allocator(), use_on_oom=True)
|
||||||
|
|
||||||
@ -5258,38 +5254,13 @@ class TestMemPool(TestCase):
|
|||||||
|
|
||||||
self._teardown_mempool_limited_memory_test()
|
self._teardown_mempool_limited_memory_test()
|
||||||
|
|
||||||
def test_mempool_context(self):
|
|
||||||
active_pool = torch.cuda.MemPoolContext.active_pool()
|
|
||||||
|
|
||||||
# there is no active pool if none was made active
|
|
||||||
self.assertEqual(active_pool, None)
|
|
||||||
|
|
||||||
pool = torch.cuda.MemPool()
|
|
||||||
ctx = torch.cuda.MemPoolContext(pool)
|
|
||||||
active_pool = torch.cuda.MemPoolContext.active_pool()
|
|
||||||
|
|
||||||
# pool was made active
|
|
||||||
self.assertEqual(active_pool, pool)
|
|
||||||
|
|
||||||
del ctx
|
|
||||||
active_pool = torch.cuda.MemPoolContext.active_pool()
|
|
||||||
|
|
||||||
# ctx was deleted, so active pool is the previous one
|
|
||||||
self.assertEqual(active_pool, None)
|
|
||||||
|
|
||||||
def test_mempool_multithread(self):
|
def test_mempool_multithread(self):
|
||||||
pool_ids = []
|
pool_ids = []
|
||||||
active_pool_ids = []
|
|
||||||
|
|
||||||
def create_mempool_and_make_active():
|
def create_mempool_and_make_active():
|
||||||
pool = torch.cuda.MemPool()
|
pool = torch.cuda.MemPool()
|
||||||
pool_ids.extend([pool.id])
|
pool_ids.extend([pool.id])
|
||||||
|
|
||||||
ctx = torch.cuda.MemPoolContext(pool)
|
|
||||||
active_pool = torch.cuda.MemPoolContext.active_pool()
|
|
||||||
active_pool_ids.extend([active_pool.id])
|
|
||||||
del ctx
|
|
||||||
|
|
||||||
num_threads = 4
|
num_threads = 4
|
||||||
threads = [
|
threads = [
|
||||||
threading.Thread(target=create_mempool_and_make_active)
|
threading.Thread(target=create_mempool_and_make_active)
|
||||||
@ -5304,14 +5275,12 @@ class TestMemPool(TestCase):
|
|||||||
# mempool id creation is atomic
|
# mempool id creation is atomic
|
||||||
self.assertEqual(len(set(pool_ids)), 4)
|
self.assertEqual(len(set(pool_ids)), 4)
|
||||||
|
|
||||||
# each thread should have different active mempool, since
|
|
||||||
# the pointer to the mempool is thread local
|
|
||||||
self.assertEqual(len(set(active_pool_ids)), 4)
|
|
||||||
|
|
||||||
@skipIfRocm(msg="expandable_segments mode is not supported on ROCm")
|
@skipIfRocm(msg="expandable_segments mode is not supported on ROCm")
|
||||||
|
@unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Load_inline doesn't work in fbcode")
|
||||||
def test_mempool_expandable(self):
|
def test_mempool_expandable(self):
|
||||||
torch.cuda.memory._set_allocator_settings("expandable_segments:True")
|
torch.cuda.memory._set_allocator_settings("expandable_segments:True")
|
||||||
pool = torch.cuda.MemPool()
|
allocator, _ = self.get_dummy_allocator(check_vars=False)
|
||||||
|
pool = torch.cuda.MemPool(allocator.allocator())
|
||||||
|
|
||||||
# torch.cuda.MemPool doesn't work with expandable segments
|
# torch.cuda.MemPool doesn't work with expandable segments
|
||||||
with self.assertRaises(RuntimeError):
|
with self.assertRaises(RuntimeError):
|
||||||
|
@ -2027,7 +2027,7 @@ def _cuda_resetPeakMemoryStats(device: _int) -> None: ...
|
|||||||
def _cuda_hostMemoryStats() -> dict[str, Any]: ...
|
def _cuda_hostMemoryStats() -> dict[str, Any]: ...
|
||||||
def _cuda_resetAccumulatedHostMemoryStats() -> None: ...
|
def _cuda_resetAccumulatedHostMemoryStats() -> None: ...
|
||||||
def _cuda_resetPeakHostMemoryStats() -> None: ...
|
def _cuda_resetPeakHostMemoryStats() -> None: ...
|
||||||
def _cuda_memorySnapshot() -> dict[str, Any]: ...
|
def _cuda_memorySnapshot(mempool_id: tuple[_int, _int] | None) -> dict[str, Any]: ...
|
||||||
def _cuda_record_memory_history_legacy(
|
def _cuda_record_memory_history_legacy(
|
||||||
enabled: _bool,
|
enabled: _bool,
|
||||||
record_context: _bool,
|
record_context: _bool,
|
||||||
@ -2304,11 +2304,6 @@ class _MemPool:
|
|||||||
def allocator(self) -> _cuda_CUDAAllocator | None: ...
|
def allocator(self) -> _cuda_CUDAAllocator | None: ...
|
||||||
def use_count(self) -> _int: ...
|
def use_count(self) -> _int: ...
|
||||||
|
|
||||||
class _MemPoolContext:
|
|
||||||
def __init__(self, pool: _MemPool) -> None: ...
|
|
||||||
@staticmethod
|
|
||||||
def active_pool() -> _MemPool | None: ...
|
|
||||||
|
|
||||||
def _cuda_isCurrentStreamCapturing() -> _bool: ...
|
def _cuda_isCurrentStreamCapturing() -> _bool: ...
|
||||||
def _graph_pool_handle() -> tuple[_int, _int]: ...
|
def _graph_pool_handle() -> tuple[_int, _int]: ...
|
||||||
|
|
||||||
|
@ -184,7 +184,8 @@ void CUDAPluggableAllocator::setMemoryFraction(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void CUDAPluggableAllocator::emptyCache() {
|
void CUDAPluggableAllocator::emptyCache(
|
||||||
|
/*unused*/ c10::cuda::MempoolId_t mempool_id) {
|
||||||
if (reset_fn_) {
|
if (reset_fn_) {
|
||||||
return reset_fn_();
|
return reset_fn_();
|
||||||
}
|
}
|
||||||
@ -237,8 +238,8 @@ void CUDAPluggableAllocator::resetPeakStats(c10::DeviceIndex device) {
|
|||||||
"If you need it, please file an issue describing your use case.");
|
"If you need it, please file an issue describing your use case.");
|
||||||
}
|
}
|
||||||
|
|
||||||
c10::cuda::CUDACachingAllocator::SnapshotInfo CUDAPluggableAllocator::
|
c10::cuda::CUDACachingAllocator::SnapshotInfo CUDAPluggableAllocator::snapshot(
|
||||||
snapshot() {
|
c10::cuda::MempoolId_t mempool_id) {
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
false,
|
false,
|
||||||
"CUDAPluggableAllocator does not yet support snapshot. "
|
"CUDAPluggableAllocator does not yet support snapshot. "
|
||||||
|
@ -114,7 +114,7 @@ struct TORCH_CUDA_CPP_API CUDAPluggableAllocator
|
|||||||
bool initialized() override;
|
bool initialized() override;
|
||||||
double getMemoryFraction(c10::DeviceIndex device) override;
|
double getMemoryFraction(c10::DeviceIndex device) override;
|
||||||
void setMemoryFraction(double fraction, c10::DeviceIndex device) override;
|
void setMemoryFraction(double fraction, c10::DeviceIndex device) override;
|
||||||
void emptyCache() override;
|
void emptyCache(c10::cuda::MempoolId_t mempool_id = {0, 0}) override;
|
||||||
void enable(bool) override {}
|
void enable(bool) override {}
|
||||||
bool isEnabled() const override {
|
bool isEnabled() const override {
|
||||||
return true;
|
return true;
|
||||||
@ -128,7 +128,8 @@ struct TORCH_CUDA_CPP_API CUDAPluggableAllocator
|
|||||||
c10::DeviceIndex device) override;
|
c10::DeviceIndex device) override;
|
||||||
void resetAccumulatedStats(c10::DeviceIndex device) override;
|
void resetAccumulatedStats(c10::DeviceIndex device) override;
|
||||||
void resetPeakStats(c10::DeviceIndex device) override;
|
void resetPeakStats(c10::DeviceIndex device) override;
|
||||||
c10::cuda::CUDACachingAllocator::SnapshotInfo snapshot() override;
|
c10::cuda::CUDACachingAllocator::SnapshotInfo snapshot(
|
||||||
|
c10::cuda::MempoolId_t mempool) override;
|
||||||
void beginAllocateToPool(
|
void beginAllocateToPool(
|
||||||
c10::DeviceIndex device,
|
c10::DeviceIndex device,
|
||||||
c10::cuda::MempoolId_t mempool_id,
|
c10::cuda::MempoolId_t mempool_id,
|
||||||
|
@ -24,8 +24,4 @@ void THCPMemPool_init(PyObject* module) {
|
|||||||
.def_property_readonly("id", &::c10::cuda::MemPool::id)
|
.def_property_readonly("id", &::c10::cuda::MemPool::id)
|
||||||
.def_property_readonly("allocator", &::c10::cuda::MemPool::allocator)
|
.def_property_readonly("allocator", &::c10::cuda::MemPool::allocator)
|
||||||
.def("use_count", &::c10::cuda::MemPool::use_count);
|
.def("use_count", &::c10::cuda::MemPool::use_count);
|
||||||
shared_ptr_class_<::c10::cuda::MemPoolContext>(torch_C_m, "_MemPoolContext")
|
|
||||||
.def(py::init<c10::cuda::MemPool*>())
|
|
||||||
.def_static(
|
|
||||||
"active_pool", &::c10::cuda::MemPoolContext::getActiveMemPool);
|
|
||||||
}
|
}
|
||||||
|
@ -721,8 +721,24 @@ CapturedTraceback* getFromContext(
|
|||||||
"attempting to gather stack context from the wrong StackContext type.");
|
"attempting to gather stack context from the wrong StackContext type.");
|
||||||
}
|
}
|
||||||
|
|
||||||
PyObject* THCPModule_memorySnapshot(PyObject* _unused, PyObject* noargs) {
|
PyObject* THCPModule_memorySnapshot(PyObject* _unused, PyObject* arg) {
|
||||||
HANDLE_TH_ERRORS
|
HANDLE_TH_ERRORS
|
||||||
|
c10::cuda::MempoolId_t mempool_id = {0, 0};
|
||||||
|
if (arg && arg != Py_None) {
|
||||||
|
TORCH_CHECK(PyTuple_Check(arg), "mempool_id must be a tuple");
|
||||||
|
Py_ssize_t size = PyTuple_Size(arg);
|
||||||
|
TORCH_CHECK(size == 2, "mempool_id must be a tuple of 2 integers");
|
||||||
|
|
||||||
|
auto id1 = THPObjectPtr(PyTuple_GetItem(arg, 0));
|
||||||
|
auto id2 = THPObjectPtr(PyTuple_GetItem(arg, 1));
|
||||||
|
TORCH_CHECK(
|
||||||
|
THPUtils_checkLong(id1) && THPUtils_checkLong(id2),
|
||||||
|
"mempool_id elements must be integers");
|
||||||
|
|
||||||
|
mempool_id = c10::cuda::MempoolId_t(
|
||||||
|
static_cast<int64_t>(THPUtils_unpackLong(id1)),
|
||||||
|
static_cast<int64_t>(THPUtils_unpackLong(id2)));
|
||||||
|
}
|
||||||
|
|
||||||
using c10::cuda::CUDACachingAllocator::BlockInfo;
|
using c10::cuda::CUDACachingAllocator::BlockInfo;
|
||||||
using c10::cuda::CUDACachingAllocator::SegmentInfo;
|
using c10::cuda::CUDACachingAllocator::SegmentInfo;
|
||||||
@ -802,7 +818,7 @@ PyObject* THCPModule_memorySnapshot(PyObject* _unused, PyObject* noargs) {
|
|||||||
return segmentDict;
|
return segmentDict;
|
||||||
};
|
};
|
||||||
|
|
||||||
auto snapshot = c10::cuda::CUDACachingAllocator::snapshot();
|
auto snapshot = c10::cuda::CUDACachingAllocator::snapshot(mempool_id);
|
||||||
|
|
||||||
py::list segments;
|
py::list segments;
|
||||||
|
|
||||||
@ -2011,7 +2027,7 @@ static struct PyMethodDef _THCPModule_methods[] = {
|
|||||||
THCPModule_resetPeakMemoryStats,
|
THCPModule_resetPeakMemoryStats,
|
||||||
METH_O,
|
METH_O,
|
||||||
nullptr},
|
nullptr},
|
||||||
{"_cuda_memorySnapshot", THCPModule_memorySnapshot, METH_NOARGS, nullptr},
|
{"_cuda_memorySnapshot", THCPModule_memorySnapshot, METH_O, nullptr},
|
||||||
{"_cuda_attach_out_of_memory_observer",
|
{"_cuda_attach_out_of_memory_observer",
|
||||||
THCPModule_attachOutOfMemoryObserver,
|
THCPModule_attachOutOfMemoryObserver,
|
||||||
METH_O,
|
METH_O,
|
||||||
|
@ -1186,8 +1186,7 @@ void ProcessGroupNCCL::registerMemPool(c10::cuda::MemPool* pool) {
|
|||||||
// We must ensure we're listening for allocator trace events in order to
|
// We must ensure we're listening for allocator trace events in order to
|
||||||
// register future segments allocated in this pool (this call is idempotent).
|
// register future segments allocated in this pool (this call is idempotent).
|
||||||
attachAllocatorHooks();
|
attachAllocatorHooks();
|
||||||
auto ctx = c10::cuda::MemPoolContext(pool);
|
auto snapshot = c10::cuda::CUDACachingAllocator::snapshot(pool->id());
|
||||||
auto snapshot = c10::cuda::CUDACachingAllocator::snapshot();
|
|
||||||
for (const auto& segmentInfo : snapshot.segments) {
|
for (const auto& segmentInfo : snapshot.segments) {
|
||||||
TORCH_INTERNAL_ASSERT(
|
TORCH_INTERNAL_ASSERT(
|
||||||
segmentInfo.device == pool->device(),
|
segmentInfo.device == pool->device(),
|
||||||
@ -1221,8 +1220,7 @@ void ProcessGroupNCCL::deregisterMemPool(c10::cuda::MemPool* pool) {
|
|||||||
auto iter = ncclCommMemPoolMap.find(ncclComm);
|
auto iter = ncclCommMemPoolMap.find(ncclComm);
|
||||||
iter->second.erase(pool->id());
|
iter->second.erase(pool->id());
|
||||||
}
|
}
|
||||||
auto ctx = c10::cuda::MemPoolContext(pool);
|
auto snapshot = c10::cuda::CUDACachingAllocator::snapshot(pool->id());
|
||||||
auto snapshot = c10::cuda::CUDACachingAllocator::snapshot();
|
|
||||||
for (const auto& segmentInfo : snapshot.segments) {
|
for (const auto& segmentInfo : snapshot.segments) {
|
||||||
TORCH_INTERNAL_ASSERT(
|
TORCH_INTERNAL_ASSERT(
|
||||||
segmentInfo.device == pool->device(),
|
segmentInfo.device == pool->device(),
|
||||||
@ -5572,7 +5570,6 @@ at::Tensor ProcessGroupNCCL::allocateTensor(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Allocate tensor under this MemPool's context
|
// Allocate tensor under this MemPool's context
|
||||||
auto ctx = c10::cuda::MemPoolContext(memPool_.get());
|
|
||||||
auto tid = std::this_thread::get_id();
|
auto tid = std::this_thread::get_id();
|
||||||
c10::cuda::CUDACachingAllocator::beginAllocateToPool(
|
c10::cuda::CUDACachingAllocator::beginAllocateToPool(
|
||||||
memPool_->device(), memPool_->id(), [=](cudaStream_t) {
|
memPool_->device(), memPool_->id(), [=](cudaStream_t) {
|
||||||
|
@ -1872,7 +1872,6 @@ __all__ = [
|
|||||||
"memory_summary",
|
"memory_summary",
|
||||||
"memory_usage",
|
"memory_usage",
|
||||||
"MemPool",
|
"MemPool",
|
||||||
"MemPoolContext",
|
|
||||||
"use_mem_pool",
|
"use_mem_pool",
|
||||||
"temperature",
|
"temperature",
|
||||||
"power_draw",
|
"power_draw",
|
||||||
|
@ -60,7 +60,6 @@ __all__ = [
|
|||||||
"CUDAPluggableAllocator",
|
"CUDAPluggableAllocator",
|
||||||
"change_current_allocator",
|
"change_current_allocator",
|
||||||
"MemPool",
|
"MemPool",
|
||||||
"MemPoolContext",
|
|
||||||
"use_mem_pool",
|
"use_mem_pool",
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -73,7 +72,6 @@ if not hasattr(torch._C, "_cuda_CUDAAllocator"):
|
|||||||
if not hasattr(torch._C, "_MemPool"):
|
if not hasattr(torch._C, "_MemPool"):
|
||||||
# Define dummy base classes
|
# Define dummy base classes
|
||||||
torch._C.__dict__["_MemPool"] = _dummy_type("_MemPool")
|
torch._C.__dict__["_MemPool"] = _dummy_type("_MemPool")
|
||||||
torch._C.__dict__["_MemPoolContext"] = _dummy_type("_MemPoolContext")
|
|
||||||
torch._C.__dict__["_cuda_beginAllocateToPool"] = _dummy_type(
|
torch._C.__dict__["_cuda_beginAllocateToPool"] = _dummy_type(
|
||||||
"_cuda_beginAllocateToPool"
|
"_cuda_beginAllocateToPool"
|
||||||
)
|
)
|
||||||
@ -92,7 +90,6 @@ from torch._C import ( # noqa: F401
|
|||||||
_cuda_endAllocateToPool,
|
_cuda_endAllocateToPool,
|
||||||
_cuda_releasePool,
|
_cuda_releasePool,
|
||||||
_MemPool,
|
_MemPool,
|
||||||
_MemPoolContext,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -617,7 +614,7 @@ def max_memory_cached(device: "Device" = None) -> int:
|
|||||||
return max_memory_reserved(device=device)
|
return max_memory_reserved(device=device)
|
||||||
|
|
||||||
|
|
||||||
def memory_snapshot():
|
def memory_snapshot(mempool_id=None):
|
||||||
r"""Return a snapshot of the CUDA memory allocator state across all devices.
|
r"""Return a snapshot of the CUDA memory allocator state across all devices.
|
||||||
|
|
||||||
Interpreting the output of this function requires familiarity with the
|
Interpreting the output of this function requires familiarity with the
|
||||||
@ -627,7 +624,7 @@ def memory_snapshot():
|
|||||||
See :ref:`cuda-memory-management` for more details about GPU memory
|
See :ref:`cuda-memory-management` for more details about GPU memory
|
||||||
management.
|
management.
|
||||||
"""
|
"""
|
||||||
return torch._C._cuda_memorySnapshot()["segments"]
|
return torch._C._cuda_memorySnapshot(mempool_id)["segments"]
|
||||||
|
|
||||||
|
|
||||||
def memory_summary(device: "Device" = None, abbreviated: bool = False) -> str:
|
def memory_summary(device: "Device" = None, abbreviated: bool = False) -> str:
|
||||||
@ -998,7 +995,7 @@ def _snapshot(device: "Device" = None):
|
|||||||
Returns:
|
Returns:
|
||||||
The Snapshot dictionary object
|
The Snapshot dictionary object
|
||||||
"""
|
"""
|
||||||
return _C._cuda_memorySnapshot()
|
return _C._cuda_memorySnapshot(None)
|
||||||
|
|
||||||
|
|
||||||
def _dump_snapshot(filename="dump_snapshot.pickle"):
|
def _dump_snapshot(filename="dump_snapshot.pickle"):
|
||||||
@ -1110,25 +1107,6 @@ def _get_current_allocator() -> _CUDAAllocator:
|
|||||||
return _CUDAAllocator(torch._C._cuda_getAllocator())
|
return _CUDAAllocator(torch._C._cuda_getAllocator())
|
||||||
|
|
||||||
|
|
||||||
class MemPoolContext(_MemPoolContext):
|
|
||||||
r"""MemPoolContext holds the currently active pool and stashes the previous
|
|
||||||
pool. On deletion it makes the previous pool active.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
pool(torch.cuda.MemPool): a MemPool object to be made active so that
|
|
||||||
allocations route to this pool.
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, pool: _MemPool):
|
|
||||||
super().__init__(pool)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def active_pool() -> Optional[_MemPool]:
|
|
||||||
r"""Returns the active MemPool"""
|
|
||||||
return _MemPoolContext.active_pool()
|
|
||||||
|
|
||||||
|
|
||||||
class MemPool(_MemPool):
|
class MemPool(_MemPool):
|
||||||
r"""MemPool represents a pool of memory in a caching allocator. Currently,
|
r"""MemPool represents a pool of memory in a caching allocator. Currently,
|
||||||
it's just the ID of the pool object maintained in the CUDACachingAllocator.
|
it's just the ID of the pool object maintained in the CUDACachingAllocator.
|
||||||
@ -1177,11 +1155,7 @@ class MemPool(_MemPool):
|
|||||||
See :ref:`cuda-memory-management` for more details about GPU memory
|
See :ref:`cuda-memory-management` for more details about GPU memory
|
||||||
management.
|
management.
|
||||||
"""
|
"""
|
||||||
try:
|
snapshot = torch.cuda.memory_snapshot(self.id)
|
||||||
ctx = MemPoolContext(self)
|
|
||||||
snapshot = torch.cuda.memory_snapshot()
|
|
||||||
finally:
|
|
||||||
del ctx
|
|
||||||
return snapshot
|
return snapshot
|
||||||
|
|
||||||
|
|
||||||
@ -1202,7 +1176,6 @@ def use_mem_pool(pool: MemPool, device: "Device" = None):
|
|||||||
(e.g. by calling backward) the allocations in that thread will not
|
(e.g. by calling backward) the allocations in that thread will not
|
||||||
route to the given pool.
|
route to the given pool.
|
||||||
"""
|
"""
|
||||||
ctx = MemPoolContext(pool)
|
|
||||||
device_index = (
|
device_index = (
|
||||||
torch.cuda.current_device() if device is None else _get_device_index(device)
|
torch.cuda.current_device() if device is None else _get_device_index(device)
|
||||||
)
|
)
|
||||||
@ -1212,4 +1185,3 @@ def use_mem_pool(pool: MemPool, device: "Device" = None):
|
|||||||
finally:
|
finally:
|
||||||
_cuda_endAllocateToPool(device_index, pool.id)
|
_cuda_endAllocateToPool(device_index, pool.id)
|
||||||
_cuda_releasePool(device_index, pool.id)
|
_cuda_releasePool(device_index, pool.id)
|
||||||
del ctx
|
|
||||||
|
Reference in New Issue
Block a user