Revert "Initial implementation of host memory stats (#147660)"

This reverts commit 945e359fc1afe6c0bb6129ed9607b237fa19cd98.

Reverted https://github.com/pytorch/pytorch/pull/147660 on behalf of https://github.com/mradmila due to There is an issue with ambiguous definition of Stat structure when different C++ tools are used. Backing out for now. ([comment](https://github.com/pytorch/pytorch/pull/147660#issuecomment-2692346379))
This commit is contained in:
PyTorch MergeBot
2025-03-01 18:05:45 +00:00
parent d23051f29b
commit a983b2b11a
18 changed files with 86 additions and 794 deletions

View File

@ -10,8 +10,6 @@
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-parameter")
namespace at {
using namespace c10::CachingAllocator;
/**
* HostBlock is typically a fundamental memory block used in pinned memory. It
* is likely related to Event and Stream of device runtime. It is probably a
@ -44,60 +42,6 @@ namespace {
constexpr size_t MAX_SIZE_INDEX = 64;
}
// Struct containing memory allocator summary statistics for host.
struct HostStats {
// COUNT: allocations requested by client code. Note that active
// count can be extracted by looking at current allocations
Stat allocation;
// COUNT: number of allocated segments from host memory allocation.
Stat segment;
// SUM: bytes allocated by this memory alocator. Note that active bytes
// can be extracted by looking at current bytes allocated
Stat allocated_bytes;
// SUM: bytes reserved by this memory allocator (both free and used)
Stat reserved_bytes;
// SUM: time spent in cudaHostAlloc/cudaHostRegister in microseconds
DurationStat host_alloc_time;
// SUM: time spent in cudaHostFree/cudaHostUnregister in microseconds
DurationStat host_free_time;
// COUNT: number of times cudaHostAlloc/cudaHostRegister was called because
// the request could not be satisfied from existing free blocks.
int64_t num_host_alloc = 0; // This is derived from segment or timing
// COUNT: number of times cudaHostFree/cudaHostUnregister was called.
int64_t num_host_free = 0; // This is derived from segment or timing
};
// Struct containing memory allocator summary statistics for host, as they
// are staged for reporting. This is a temporary struct that is used to
// avoid locking the allocator while collecting stats.
struct alignas(64) HostStatsStaged {
std::mutex timing_mutex_;
// COUNT: allocations requested by client code resulting in a new segment/block allocation
// LOCK: access to this stat is protected by the allocator's blocks_mutex_
Stat allocation;
// SUM: bytes within active memory blocks, including blocks that are
// currently in the free list.
// LOCK: access to this stat is protected by the allocator's blocks_mutex_
Stat allocated_bytes;
// COUNT: number of allocations per bucket
// LOCK: access to this stat is protected by the per bucket free_list_[index].mutex_
std::vector<Stat> allocation_bucket_stats = std::vector<Stat>(MAX_SIZE_INDEX);
// SUM: bytes of allocation per bucket
// LOCK: access to this stat is protected by the per bucket free_list_[index].mutex_
std::vector<Stat> allocated_bytes_bucket_stats = std::vector<Stat>(MAX_SIZE_INDEX);
// SUM: time spent in cudaHostAlloc/cudaHostRegister
// LOCK: access to this stat is protected by the timing_mutex_
DurationStat host_alloc_time;
// SUM: time spent in cudaHostFree/cudaHostUnregister
// LOCK: access to this stat is protected by the timing_mutex_
DurationStat host_free_time;
};
/**
* Note [HostAllocator design]
* ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@ -159,13 +103,6 @@ struct alignas(64) HostStatsStaged {
*
* Note that this caching host allocator does not split larger allocations into
* smaller blocks, unlike the caching device allocator.
*
* In order to gather statistics about caching host allocator while minimally
* impacting performance, we use a HostStatsStaged struct to stage the stats
* before reporting them. This is done to avoid adding new locks to the allocator.
* Collecting stats is carefully done under existing locks, and then the staged
* stats are converted to the final stats when getStats is called. At that time
* we hold the same locks as empty_cache, to ensure the fidelity of the stats.
*/
template <
@ -262,8 +199,6 @@ struct CachingHostAllocatorImpl {
auto index = size_index(block->size_);
std::lock_guard<std::mutex> g(free_list_[index].mutex_);
free_list_[index].list_.push_back(block);
stats_.allocation_bucket_stats[index].decrease(1);
stats_.allocated_bytes_bucket_stats[index].decrease(block->size_);
} else {
// restore these events that record by used streams.
std::lock_guard<std::mutex> g(events_mutex_);
@ -318,12 +253,9 @@ struct CachingHostAllocatorImpl {
std::vector<B*> blocks_to_remove(free_list_[i].list_.begin(), free_list_[i].list_.end());
free_list_[i].list_.clear();
for (auto* block : blocks_to_remove) {
blocks_.erase(block);
ptr_to_block_.erase(block->ptr_);
stats_.allocation.decrease(1);
stats_.allocated_bytes.decrease(block->size_);
free_block(block);
delete block;
}
@ -342,125 +274,11 @@ struct CachingHostAllocatorImpl {
TORCH_CHECK_NOT_IMPLEMENTED(false, "Not implemented for copy_data");
}
HostStats getStats() {
HostStats stats;
// To keep getStats lightweight we do *not* flush any available blocks
// into the free_list. This may skew the stats a bit.
auto add_bucket_stats = [](Stat& accumulator, const Stat& other) {
accumulator.allocated += other.allocated;
accumulator.current += other.current;
accumulator.freed += other.freed;
// Since peaks are measured per bucket independently, we add them up
// to estimate the total peak. This is not strictly correct, but it is
// the best approximation we can get after the fact.
accumulator.peak += other.peak;
};
// Accurate reading of memory stats requires concurrently holding both the
// free list mutexes and the blocks mutex. Previously, this was only done in
// empty_cache function.
for (size_t i = 0; i < free_list_.size(); ++i) {
std::lock(free_list_[i].mutex_, blocks_mutex_);
std::lock_guard<std::mutex> gf(free_list_[i].mutex_, std::adopt_lock);
std::lock_guard<std::mutex> gb(blocks_mutex_, std::adopt_lock);
// We collect the slow-path stats only once, since they are not collected
// per bucket (we pick index 0 arbitrarily). These are also all the host
// allocations, not taking into account caching and free lists.
if (i == 0) {
stats.segment = stats_.allocation;
stats.reserved_bytes = stats_.allocated_bytes;
stats.num_host_alloc = stats.segment.allocated;
stats.num_host_free = stats.segment.freed;
}
// Bucket stats need to be merged with the slow-path stats. We do this in
// a best effort manner, since we can't really replay the cached events per bucket.
add_bucket_stats(stats.allocation, stats_.allocation_bucket_stats[i]);
add_bucket_stats(stats.allocated_bytes, stats_.allocated_bytes_bucket_stats[i]);
}
// Get the timing stats
{
std::lock_guard<std::mutex> g(stats_.timing_mutex_);
stats.host_alloc_time = stats_.host_alloc_time;
stats.host_free_time = stats_.host_free_time;
}
return stats;
}
void resetAccumulatedStats() {
// Reseting accumulated memory stats requires concurrently holding both the
// free list mutexes and the blocks mutex. Previously, this was only done in
// empty_cache function.
for (size_t i = 0; i < free_list_.size(); ++i) {
std::lock(free_list_[i].mutex_, blocks_mutex_);
std::lock_guard<std::mutex> gf(free_list_[i].mutex_, std::adopt_lock);
std::lock_guard<std::mutex> gb(blocks_mutex_, std::adopt_lock);
if (i == 0) {
stats_.allocation.reset_accumulated();
stats_.allocated_bytes.reset_accumulated();
}
stats_.allocation_bucket_stats[i].reset_accumulated();
stats_.allocated_bytes_bucket_stats[i].reset_accumulated();
}
// Also reset timing stats
{
std::lock_guard<std::mutex> g(stats_.timing_mutex_);
stats_.host_alloc_time.reset_accumulated();
stats_.host_free_time.reset_accumulated();
}
}
void resetPeakStats() {
// Reseting peak memory stats requires concurrently holding both the
// free list mutexes and the blocks mutex. Previously, this was only done in
// empty_cache function.
for (size_t i = 0; i < free_list_.size(); ++i) {
std::lock(free_list_[i].mutex_, blocks_mutex_);
std::lock_guard<std::mutex> gf(free_list_[i].mutex_, std::adopt_lock);
std::lock_guard<std::mutex> gb(blocks_mutex_, std::adopt_lock);
if (i == 0) {
stats_.allocation.reset_peak();
stats_.allocated_bytes.reset_peak();
}
stats_.allocation_bucket_stats[i].reset_peak();
stats_.allocated_bytes_bucket_stats[i].reset_peak();
}
// Also reset timing stats
{
std::lock_guard<std::mutex> g(stats_.timing_mutex_);
stats_.host_alloc_time.reset_peak();
stats_.host_free_time.reset_peak();
}
}
private:
virtual void add_allocated_block(B* block) {
std::lock_guard<std::mutex> g(blocks_mutex_);
blocks_.insert(block);
stats_.allocation.increase(1);
stats_.allocated_bytes.increase(block->size_);
ptr_to_block_.insert({block->ptr_, block});
// Unfortunately, we have to, on the slow path, quickly
// lock the bucket to record the allocation. This should
// be a rare event once the cache is warmed up.
auto size = block->size_;
auto index = size_index(size);
{
std::lock_guard<std::mutex> g(free_list_[index].mutex_);
stats_.allocation_bucket_stats[index].increase(1);
stats_.allocated_bytes_bucket_stats[index].increase(size);
}
}
virtual B* get_free_block(size_t size) {
@ -470,8 +288,6 @@ struct CachingHostAllocatorImpl {
B* block = free_list_[index].list_.back();
free_list_[index].list_.pop_back();
block->allocated_ = true;
stats_.allocation_bucket_stats[index].increase(1);
stats_.allocated_bytes_bucket_stats[index].increase(size);
return block;
}
return nullptr;
@ -565,8 +381,6 @@ struct CachingHostAllocatorImpl {
auto index = size_index(block->size_);
std::lock_guard<std::mutex> g(free_list_[index].mutex_);
free_list_[index].list_.push_back(block);
stats_.allocation_bucket_stats[index].decrease(1);
stats_.allocated_bytes_bucket_stats[index].decrease(size);
if (size != -1) {
return;
}
@ -579,45 +393,42 @@ struct CachingHostAllocatorImpl {
return pool;
}
/* These following functions are runtime-related. */
/* These following functions are runtime-related. */
// Allocate page-locked memory on the host.
virtual void allocate_host_memory(size_t size, void** ptr) {
TORCH_CHECK_NOT_IMPLEMENTED(
false, "Not implemented for allocate_host_memory");
}
// Allocate page-locked memory on the host.
virtual void allocate_host_memory(size_t size, void** ptr) {
TORCH_CHECK_NOT_IMPLEMENTED(
false, "Not implemented for allocate_host_memory");
}
// Free block and release the pointer contained in block.
virtual void free_block(B* block) {
TORCH_CHECK_NOT_IMPLEMENTED(false, "Not implemented for free_block");
}
// Free block and release the pointer contained in block.
virtual void free_block(B* block) {
TORCH_CHECK_NOT_IMPLEMENTED(false, "Not implemented for free_block");
}
// Record an event on stream and store event into events.
virtual void record_stream(std::optional<std::vector<E>>& events, S stream) {
TORCH_CHECK_NOT_IMPLEMENTED(false, "Not implemented for record_stream");
}
// Record an event on stream and store event into events.
virtual void record_stream(std::optional<std::vector<E>>& events, S stream) {
TORCH_CHECK_NOT_IMPLEMENTED(false, "Not implemented for record_stream");
}
// Query event if it is completed.
virtual bool query_event(E& event) {
TORCH_CHECK_NOT_IMPLEMENTED(false, "Not implemented for query_event");
}
// Query event if it is completed.
virtual bool query_event(E& event) {
TORCH_CHECK_NOT_IMPLEMENTED(false, "Not implemented for query_event");
}
alignas(64) std::mutex blocks_mutex_;
ska::flat_hash_set<B*> blocks_; // block list
ska::flat_hash_map<void*, B*> ptr_to_block_;
alignas(64) std::mutex blocks_mutex_;
ska::flat_hash_set<B*> blocks_; // block list
ska::flat_hash_map<void*, B*> ptr_to_block_;
// We keep free list as a vector of free lists, one for each power of two
// size. This allows us to quickly find a free block of the right size.
// We use deque to store per size free list and guard the list with its own
// mutex.
alignas(64) std::vector<FreeBlockList<B>> free_list_ =
std::vector<FreeBlockList<B>>(MAX_SIZE_INDEX);
// We keep free list as a vector of free lists, one for each power of two
// size. This allows us to quickly find a free block of the right size.
// We use deque to store per size free list and guard the list with its own
// mutex.
alignas(64) std::vector<FreeBlockList<B>> free_list_ = std::vector<FreeBlockList<B>>(MAX_SIZE_INDEX);
alignas(64) std::mutex events_mutex_;
std::deque<std::pair<E, B*>> events_; // event queue paired with block
protected:
alignas(64) HostStatsStaged stats_;
};
alignas(64) std::mutex events_mutex_;
std::deque<std::pair<E, B*>> events_; // event queue paired with block
};
template <typename T>
struct CachingHostAllocatorInterface : public at::Allocator {
@ -645,18 +456,6 @@ struct CachingHostAllocatorInterface : public at::Allocator {
impl_->copy_data(dest, src, count);
}
HostStats getStats() {
return impl_->getStats();
}
void resetAccumulatedStats() {
impl_->resetAccumulatedStats();
}
void resetPeakStats() {
impl_->resetPeakStats();
}
std::unique_ptr<T> impl_;
};

View File

@ -76,10 +76,6 @@ struct CUDACachingHostAllocatorImpl
// any other device, regardless of the current device at the time of
// allocation, since we assume unified addressing. So we grab any existing
// primary context, if available. See pytorch/pytorch#21081.
// This can be a large performance hit if we cross NUMA nodes by allocating
// and pinning memory on one side of the NUMA node and then using it on the
// other side. Thankfully, we use one process per GPU, so we don't run into
// this issue.
at::OptionalDeviceGuard device_guard;
auto primary_ctx_device_index =
c10::cuda::getDeviceIndexWithPrimaryContext();
@ -88,7 +84,6 @@ struct CUDACachingHostAllocatorImpl
at::Device(at::DeviceType::CUDA, *primary_ctx_device_index));
}
auto start = std::chrono::system_clock::now();
if (c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig::
pinned_use_cuda_host_register()) {
allocWithCudaHostRegister(ptr, size);
@ -96,18 +91,9 @@ struct CUDACachingHostAllocatorImpl
// Use cudaHostAlloc for allocating pinned memory (global lock in driver)
C10_CUDA_CHECK(cudaHostAlloc(ptr, size, cudaHostAllocDefault));
}
auto end = std::chrono::system_clock::now();
auto duration = std::chrono::duration_cast<std::chrono::microseconds>(end - start);
// Update the statistics on the time spent on cudaHostAlloc/hostRegister
{
std::lock_guard<std::mutex> g(stats_.timing_mutex_);
stats_.host_alloc_time.increase(duration.count());
}
}
void free_block(Block* block) override {
auto start = std::chrono::system_clock::now();
if (c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig::
pinned_use_cuda_host_register()) {
void* ptr = block->ptr_;
@ -117,14 +103,6 @@ struct CUDACachingHostAllocatorImpl
} else {
AT_CUDA_CHECK(cudaFreeHost(block->ptr_));
}
auto end = std::chrono::system_clock::now();
auto duration = std::chrono::duration_cast<std::chrono::microseconds>(end - start);
// Update the statistics on the time spent on cudaFreeHost/hostUnregister
{
std::lock_guard<std::mutex> g(stats_.timing_mutex_);
stats_.host_free_time.increase(duration.count());
}
}
void record_stream(
@ -295,16 +273,4 @@ at::Allocator* getCachingHostAllocator() {
return &getCUDACachingHostAllocator();
}
at::HostStats CachingHostAllocator_getStats() {
return getCUDACachingHostAllocator().getStats();
}
void CachingHostAllocator_resetAccumulatedStats() {
return getCUDACachingHostAllocator().resetAccumulatedStats();
}
void CachingHostAllocator_resetPeakStats() {
return getCUDACachingHostAllocator().resetPeakStats();
}
} // namespace at::cuda

View File

@ -34,9 +34,4 @@ inline TORCH_CUDA_CPP_API at::DataPtr HostAlloc(size_t size) {
return getCachingHostAllocator()->allocate(size);
}
TORCH_CUDA_CPP_API at::HostStats CachingHostAllocator_getStats();
TORCH_CUDA_CPP_API void CachingHostAllocator_resetAccumulatedStats();
TORCH_CUDA_CPP_API void CachingHostAllocator_resetPeakStats();
} // namespace at::cuda

View File

@ -9,109 +9,6 @@
constexpr int64_t N = 100;
// NOTE: please leave this as the first test to ensure that
// the allocator is not used and stats are zero.
TEST(CachingHostAllocatorTest, check_stats) {
if (!at::cuda::is_available()) {
return;
}
// Clear the stats and ensure they are zero.
size_t round_size = c10::llvm::PowerOf2Ceil(N);
auto stats = at::cuda::CachingHostAllocator_getStats();
ASSERT_EQ(stats.allocation.current, 0);
ASSERT_EQ(stats.allocation.peak, 0);
ASSERT_EQ(stats.allocation.allocated, 0);
ASSERT_EQ(stats.allocation.freed, 0);
void* ptr{nullptr};
void* ctx{nullptr};
{
auto pinned_tensor = at::empty(
{N}, at::TensorOptions().dtype(at::kByte).pinned_memory(true));
ptr = pinned_tensor.data_ptr();
ctx = pinned_tensor.storage().data_ptr().get_context();
auto stats = at::cuda::CachingHostAllocator_getStats();
ASSERT_EQ(stats.allocation.current, 1);
ASSERT_EQ(stats.allocation.peak, 1);
ASSERT_EQ(stats.allocation.allocated, 1);
ASSERT_EQ(stats.allocation.freed, 0);
ASSERT_EQ(stats.segment.allocated, 1);
ASSERT_EQ(stats.segment.freed, 0);
ASSERT_EQ(stats.reserved_bytes.current, round_size);
ASSERT_EQ(stats.allocated_bytes.current, round_size);
ASSERT_EQ(stats.host_alloc_time.max, stats.host_alloc_time.min);
ASSERT_EQ(stats.host_free_time.total, 0);
}
// Ensure we reuse the allocation.
{
auto pinned_tensor = at::empty(
{N}, at::TensorOptions().dtype(at::kByte).pinned_memory(true));
auto stats = at::cuda::CachingHostAllocator_getStats();
ASSERT_EQ(ptr, pinned_tensor.data_ptr());
ASSERT_EQ(ctx, pinned_tensor.storage().data_ptr().get_context());
ASSERT_EQ(stats.allocation.current, 1);
ASSERT_EQ(stats.allocation.peak, 1);
ASSERT_EQ(stats.allocation.allocated, 2);
ASSERT_EQ(stats.allocation.freed, 1);
ASSERT_EQ(stats.segment.allocated, 1);
ASSERT_EQ(stats.segment.freed, 0);
ASSERT_EQ(stats.reserved_bytes.current, round_size);
ASSERT_EQ(stats.allocated_bytes.current, round_size);
}
// Ensure we don't reuse the allocation, due to size mismatch.
{
int64_t new_size = N*2;
size_t new_round_size = c10::llvm::PowerOf2Ceil(new_size);
auto pinned_tensor = at::empty(
{new_size}, at::TensorOptions().dtype(at::kByte).pinned_memory(true));
auto stats = at::cuda::CachingHostAllocator_getStats();
ASSERT_NE(ptr, pinned_tensor.data_ptr());
ASSERT_NE(ctx, pinned_tensor.storage().data_ptr().get_context());
ASSERT_EQ(stats.allocation.current, 1);
ASSERT_EQ(stats.allocation.peak, 2);
ASSERT_EQ(stats.allocation.allocated, 3);
ASSERT_EQ(stats.allocation.freed, 2);
ASSERT_EQ(stats.segment.allocated, 2);
ASSERT_EQ(stats.segment.freed, 0);
ASSERT_EQ(stats.reserved_bytes.current, round_size + new_round_size);
ASSERT_EQ(stats.allocated_bytes.current, new_round_size);
ASSERT_NE(stats.host_alloc_time.total, stats.host_alloc_time.min);
}
// Test the empty cache.
{
at::cuda::CachingHostAllocator_emptyCache();
auto stats = at::cuda::CachingHostAllocator_getStats();
ASSERT_EQ(stats.allocation.current, 0);
ASSERT_EQ(stats.allocated_bytes.current, 0);
ASSERT_EQ(stats.allocation.peak, 2);
ASSERT_EQ(stats.allocation.allocated, 3);
ASSERT_EQ(stats.allocation.freed, 3);
ASSERT_EQ(stats.segment.allocated, 2);
ASSERT_EQ(stats.segment.freed, 2);
ASSERT_EQ(stats.num_host_alloc, 2);
ASSERT_EQ(stats.num_host_free, 2);
ASSERT_NE(stats.host_free_time.total, stats.host_free_time.min);
}
// Test the reset stats.
{
at::cuda::CachingHostAllocator_resetAccumulatedStats();
at::cuda::CachingHostAllocator_resetPeakStats();
auto stats = at::cuda::CachingHostAllocator_getStats();
ASSERT_EQ(stats.allocation.peak, 0);
ASSERT_EQ(stats.allocation.allocated, 0);
ASSERT_EQ(stats.allocation.freed, 0);
ASSERT_EQ(stats.allocated_bytes.peak, 0);
ASSERT_EQ(stats.num_host_alloc, 0);
ASSERT_EQ(stats.num_host_free, 0);
}
// At this point, the allocator should be empty, and stats should be zero,
// leaving the test harness in a clean state for the next test.
}
TEST(CachingHostAllocatorTest, pinned_alias_slice) {
if (!at::cuda::is_available()) {
return;

View File

@ -1,6 +1,5 @@
#pragma once
#include <array>
#include <cstddef>
#include <cstdint>
#include <functional>
@ -14,7 +13,6 @@
#include <c10/util/Exception.h>
#include <c10/util/ThreadLocalDebugInfo.h>
#include <c10/util/UniqueVoidPtr.h>
#include <c10/util/irange.h>
namespace c10 {
@ -331,83 +329,4 @@ struct GatheredContext {
virtual ~GatheredContext() = default;
};
namespace CachingAllocator {
struct Stat {
void increase(size_t amount) {
current += static_cast<int64_t>(amount);
peak = std::max(current, peak);
allocated += static_cast<int64_t>(amount);
}
void decrease(size_t amount) {
current -= static_cast<int64_t>(amount);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
current >= 0,
"Negative tracked stat in device allocator (likely logic error).");
freed += static_cast<int64_t>(amount);
}
void reset_accumulated() {
allocated = 0;
freed = 0;
}
void reset_peak() {
peak = current;
}
int64_t current = 0;
int64_t peak = 0;
int64_t allocated = 0;
int64_t freed = 0;
};
enum struct StatType : uint64_t {
AGGREGATE = 0,
SMALL_POOL = 1,
LARGE_POOL = 2,
NUM_TYPES = 3 // remember to update this whenever a new stat type is added
};
using StatArray = std::array<Stat, static_cast<size_t>(StatType::NUM_TYPES)>;
using StatTypes = std::array<bool, static_cast<size_t>(StatType::NUM_TYPES)>;
template <typename Func>
void for_each_selected_stat_type(const StatTypes& stat_types, Func f) {
for (const auto stat_type : c10::irange(stat_types.size())) {
if (stat_types[stat_type]) {
f(stat_type);
}
}
}
// Structure for keeping timing information
struct DurationStat {
void increase(int64_t amount) {
total += amount;
count += 1;
max = std::max(amount, max);
if (min == 0) {
min = amount;
} else {
min = std::min(amount, min);
}
}
void reset_accumulated() {
total = 0;
count = 0;
}
void reset_peak() {
min = 0;
max = 0;
}
int64_t total = 0;
int64_t max = 0;
int64_t min = 0;
int64_t count = 0;
};
} // namespace CachingAllocator
} // namespace c10

View File

@ -1,10 +1,60 @@
#pragma once
#include <c10/core/Allocator.h>
#include <c10/util/irange.h>
#include <array>
namespace c10::CachingDeviceAllocator {
using namespace c10::CachingAllocator;
struct Stat {
void increase(size_t amount) {
current += static_cast<int64_t>(amount);
peak = std::max(current, peak);
allocated += static_cast<int64_t>(amount);
}
void decrease(size_t amount) {
current -= static_cast<int64_t>(amount);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
current >= 0,
"Negative tracked stat in device allocator (likely logic error).");
freed += static_cast<int64_t>(amount);
}
void reset_accumulated() {
allocated = 0;
freed = 0;
}
void reset_peak() {
peak = current;
}
int64_t current = 0;
int64_t peak = 0;
int64_t allocated = 0;
int64_t freed = 0;
};
enum struct StatType : uint64_t {
AGGREGATE = 0,
SMALL_POOL = 1,
LARGE_POOL = 2,
NUM_TYPES = 3 // remember to update this whenever a new stat type is added
};
using StatArray = std::array<Stat, static_cast<size_t>(StatType::NUM_TYPES)>;
using StatTypes = std::array<bool, static_cast<size_t>(StatType::NUM_TYPES)>;
template <typename Func>
void for_each_selected_stat_type(const StatTypes& stat_types, Func f) {
for (const auto stat_type : c10::irange(stat_types.size())) {
if (stat_types[stat_type]) {
f(stat_type);
}
}
}
// Struct containing memory allocator summary statistics for a device.
struct DeviceStats {

View File

@ -45,7 +45,6 @@ C10_DEFINE_REGISTRY(FreeCudaMemoryCallbacksRegistry, FreeMemoryCallback);
namespace cuda::CUDACachingAllocator {
using namespace c10::CachingAllocator;
using namespace c10::CachingDeviceAllocator;
// Included here as this is externally used in CUDAAllocatorConfig

View File

@ -11,7 +11,6 @@
namespace c10::cuda::CUDACachingAllocator::CudaMallocAsync {
using namespace c10::CachingAllocator;
using namespace c10::CachingDeviceAllocator;
#if CUDA_VERSION >= 11040

View File

@ -9,7 +9,6 @@
namespace c10::xpu::XPUCachingAllocator {
using namespace c10::CachingAllocator;
using namespace c10::CachingDeviceAllocator;
// newly allocated memory with 512-byte alignment.

View File

@ -424,15 +424,11 @@ coverage_ignore_functions = [
"memory_snapshot",
"memory_stats",
"memory_stats_as_nested_dict",
"host_memory_stats",
"host_memory_stats_as_nested_dict",
"memory_summary",
"reset_accumulated_memory_stats",
"reset_accumulated_host_memory_stats",
"reset_max_memory_allocated",
"reset_max_memory_cached",
"reset_peak_memory_stats",
"reset_peak_host_memory_stats",
"set_per_process_memory_fraction",
# torch.cuda.nccl
"all_gather",

View File

@ -107,7 +107,6 @@ Memory management
list_gpu_processes
mem_get_info
memory_stats
host_memory_stats
memory_summary
memory_snapshot
memory_allocated
@ -120,7 +119,6 @@ Memory management
max_memory_cached
reset_max_memory_cached
reset_peak_memory_stats
reset_peak_host_memory_stats
caching_allocator_alloc
caching_allocator_delete
get_allocator_backend

View File

@ -156,146 +156,6 @@ class TestCuda(TestCase):
for thread in threads:
thread.join()
def test_host_memory_stats(self):
# Helper functions
def empty_stats():
return {
"allocated_bytes.allocated": 0,
"allocated_bytes.current": 0,
"allocated_bytes.freed": 0,
"allocated_bytes.peak": 0,
"allocation.allocated": 0,
"allocation.current": 0,
"allocation.freed": 0,
"allocation.peak": 0,
"host_alloc_time.count": 0,
"host_free_time.count": 0,
"num_host_alloc": 0,
"num_host_free": 0,
"reserved_bytes.allocated": 0,
"reserved_bytes.current": 0,
"reserved_bytes.freed": 0,
"reserved_bytes.peak": 0,
"segment.allocated": 0,
"segment.current": 0,
"segment.freed": 0,
"segment.peak": 0,
}
def check_stats(expected):
stats = torch.cuda.host_memory_stats()
for k, v in expected.items():
self.assertEqual(v, stats[k])
# Setup the test cleanly
alloc1 = 10
alloc1_aligned = 16
alloc2 = 20
alloc2_aligned = 32
expected = empty_stats()
# Reset any lingering state
gc.collect()
torch._C._host_emptyCache()
# Check that stats are empty
check_stats(expected)
# Make first allocation and check stats
t1 = torch.ones(alloc1 * 1024, pin_memory=True)
self.assertTrue(t1.is_pinned())
for prefix in ["segment", "allocation"]:
for suffix in ["allocated", "current", "peak"]:
expected[prefix + "." + suffix] += 1
allocation_size1 = alloc1_aligned * 1024 * 4
for prefix in ["allocated_bytes", "reserved_bytes"]:
for suffix in ["allocated", "current", "peak"]:
expected[prefix + "." + suffix] += allocation_size1
expected["num_host_alloc"] += 1
expected["host_alloc_time.count"] += 1
check_stats(expected)
# Remove first allocation and check stats
del t1
expected["allocation.current"] -= 1
expected["allocation.freed"] += 1
expected["allocated_bytes.current"] -= allocation_size1
expected["allocated_bytes.freed"] += allocation_size1
check_stats(expected)
# Make first allocation again and check reuse
t1 = torch.ones(alloc1 * 1024, pin_memory=True)
self.assertTrue(t1.is_pinned())
for suffix in ["allocated", "current"]:
expected["allocation" + "." + suffix] += 1
allocation_size1 = alloc1_aligned * 1024 * 4
for suffix in ["allocated", "current"]:
expected["allocated_bytes" + "." + suffix] += allocation_size1
check_stats(expected)
# Make second allocation and check stats
t2 = torch.ones(alloc2 * 1024, pin_memory=True)
self.assertTrue(t2.is_pinned())
for prefix in ["segment", "allocation"]:
for suffix in ["allocated", "current", "peak"]:
expected[prefix + "." + suffix] += 1
allocation_size2 = alloc2_aligned * 1024 * 4
for prefix in ["allocated_bytes", "reserved_bytes"]:
for suffix in ["allocated", "current", "peak"]:
expected[prefix + "." + suffix] += allocation_size2
expected["num_host_alloc"] += 1
expected["host_alloc_time.count"] += 1
check_stats(expected)
# Remove first allocation and check stats
del t1
expected["allocation.current"] -= 1
expected["allocation.freed"] += 1
expected["allocated_bytes.current"] -= allocation_size1
expected["allocated_bytes.freed"] += allocation_size1
check_stats(expected)
# Remove second allocation and check stats
del t2
expected["allocation.current"] -= 1
expected["allocation.freed"] += 1
expected["allocated_bytes.current"] -= allocation_size2
expected["allocated_bytes.freed"] += allocation_size2
check_stats(expected)
# Empty cache and check stats
torch._C._host_emptyCache()
expected["segment.freed"] += expected["segment.current"]
expected["segment.current"] = 0
expected["reserved_bytes.freed"] += expected["reserved_bytes.current"]
expected["reserved_bytes.current"] = 0
expected["num_host_free"] = expected["num_host_alloc"]
expected["host_free_time.count"] += expected["host_alloc_time.count"]
check_stats(expected)
# Finally, check the reset of peak and accumulated stats
torch.cuda.reset_peak_host_memory_stats()
torch.cuda.reset_accumulated_host_memory_stats()
expected = empty_stats()
check_stats(expected)
def test_pinned_memory_empty_cache(self):
try:
for alloc_settings in (True, False):

View File

@ -1889,9 +1889,6 @@ def _cuda_emptyCache() -> None: ...
def _cuda_memoryStats(device: _int) -> Dict[str, Any]: ...
def _cuda_resetAccumulatedMemoryStats(device: _int) -> None: ...
def _cuda_resetPeakMemoryStats(device: _int) -> None: ...
def _cuda_hostMemoryStats() -> Dict[str, Any]: ...
def _cuda_resetAccumulatedHostMemoryStats() -> None: ...
def _cuda_resetPeakHostMemoryStats() -> None: ...
def _cuda_memorySnapshot() -> Dict[str, Any]: ...
def _cuda_record_memory_history_legacy(
enabled: _bool,

View File

@ -468,7 +468,6 @@ torch_c_binding_in_graph_functions = dict.fromkeys(
"torch._C._cuda_getDevice",
"torch._C._cuda_getDeviceCount",
"torch._C._cuda_hasPrimaryContext",
"torch._C._cuda_hostMemoryStats",
"torch._C._cuda_init",
"torch._C._cuda_ipc_collect",
"torch._C._cuda_isCurrentStreamCapturing",
@ -482,9 +481,7 @@ torch_c_binding_in_graph_functions = dict.fromkeys(
"torch._C._cuda_record_memory_history_legacy",
"torch._C._cuda_record_memory_history",
"torch._C._cuda_releasePool",
"torch._C._cuda_resetAccumulatedHostMemoryStats",
"torch._C._cuda_resetAccumulatedMemoryStats",
"torch._C._cuda_resetPeakHostMemoryStats",
"torch._C._cuda_resetPeakMemoryStats",
"torch._C._cuda_set_cudnn_benchmark_limit",
"torch._C._cuda_set_sync_debug_mode",
@ -2549,8 +2546,6 @@ torch_non_c_binding_in_graph_functions = dict.fromkeys(
"torch.cuda.memory.empty_cache",
"torch.cuda.memory.get_allocator_backend",
"torch.cuda.memory.get_per_process_memory_fraction",
"torch.cuda.memory.host_memory_stats_as_nested_dict",
"torch.cuda.memory.host_memory_stats",
"torch.cuda.memory.list_gpu_processes",
"torch.cuda.memory.max_memory_allocated",
"torch.cuda.memory.max_memory_cached",
@ -2563,11 +2558,9 @@ torch_non_c_binding_in_graph_functions = dict.fromkeys(
"torch.cuda.memory.memory_stats_as_nested_dict",
"torch.cuda.memory.memory_stats",
"torch.cuda.memory.memory_summary",
"torch.cuda.memory.reset_accumulated_host_memory_stats",
"torch.cuda.memory.reset_accumulated_memory_stats",
"torch.cuda.memory.reset_max_memory_allocated",
"torch.cuda.memory.reset_max_memory_cached",
"torch.cuda.memory.reset_peak_host_memory_stats",
"torch.cuda.memory.reset_peak_memory_stats",
"torch.cuda.memory.set_per_process_memory_fraction",
"torch.cuda.nccl._check_sequence_type",

View File

@ -593,10 +593,10 @@ PyObject* THCPModule_memoryStats(PyObject* _unused, PyObject* arg) {
TORCH_CHECK(THPUtils_checkLong(arg), "invalid argument to memory_allocated");
const auto device_index = THPUtils_unpackDeviceIndex(arg);
using c10::CachingAllocator::Stat;
using c10::CachingAllocator::StatArray;
using c10::CachingAllocator::StatType;
using c10::CachingDeviceAllocator::DeviceStats;
using c10::CachingDeviceAllocator::Stat;
using c10::CachingDeviceAllocator::StatArray;
using c10::CachingDeviceAllocator::StatType;
const auto statToDict = [](const Stat& stat) {
py::dict dict;
@ -667,70 +667,6 @@ PyObject* THCPModule_resetPeakMemoryStats(PyObject* _unused, PyObject* arg) {
Py_RETURN_NONE;
}
PyObject* THCPModule_hostMemoryStats(PyObject* _unused, PyObject* noargs) {
HANDLE_TH_ERRORS
using at::HostStats;
using c10::CachingAllocator::DurationStat;
using c10::CachingAllocator::Stat;
using c10::CachingAllocator::StatArray;
using c10::CachingAllocator::StatType;
const auto statToDict = [](const Stat& stat) {
py::dict dict;
dict["current"] = stat.current;
dict["peak"] = stat.peak;
dict["allocated"] = stat.allocated;
dict["freed"] = stat.freed;
return dict;
};
const auto durationStatToDict = [](const DurationStat& stat) {
py::dict dict;
dict["total"] = stat.total;
dict["max"] = stat.max;
dict["min"] = stat.min;
dict["count"] = stat.count;
dict["avg"] = stat.count == 0 ? 0 : stat.total / stat.count;
return dict;
};
const HostStats stats = at::cuda::CachingHostAllocator_getStats();
py::dict result;
result["num_host_alloc"] = stats.num_host_alloc;
result["num_host_free"] = stats.num_host_free;
result["allocation"] = statToDict(stats.allocation);
result["segment"] = statToDict(stats.segment);
result["allocated_bytes"] = statToDict(stats.allocated_bytes);
result["reserved_bytes"] = statToDict(stats.reserved_bytes);
result["host_alloc_time"] = durationStatToDict(stats.host_alloc_time);
result["host_free_time"] = durationStatToDict(stats.host_free_time);
return result.release().ptr();
END_HANDLE_TH_ERRORS
}
PyObject* THCPModule_resetAccumulatedHostMemoryStats(
PyObject* _unused,
PyObject* noargs) {
HANDLE_TH_ERRORS
at::cuda::CachingHostAllocator_resetAccumulatedStats();
END_HANDLE_TH_ERRORS
Py_RETURN_NONE;
}
PyObject* THCPModule_resetPeakHostMemoryStats(
PyObject* _unused,
PyObject* noargs) {
HANDLE_TH_ERRORS
at::cuda::CachingHostAllocator_resetPeakStats();
END_HANDLE_TH_ERRORS
Py_RETURN_NONE;
}
CapturedTraceback* getFromContext(
const std::shared_ptr<c10::GatheredContext>& x) {
if (CapturedTraceback* sc = dynamic_cast<CapturedTraceback*>(x.get())) {
@ -2021,15 +1957,6 @@ static struct PyMethodDef _THCPModule_methods[] = {
THCPModule_attachOutOfMemoryObserver,
METH_O,
nullptr},
{"_cuda_hostMemoryStats", THCPModule_hostMemoryStats, METH_NOARGS, nullptr},
{"_cuda_resetAccumulatedHostMemoryStats",
THCPModule_resetAccumulatedHostMemoryStats,
METH_NOARGS,
nullptr},
{"_cuda_resetPeakHostMemoryStats",
THCPModule_resetPeakHostMemoryStats,
METH_NOARGS,
nullptr},
{"_cuda_cudaHostAllocator",
THCPModule_cudaHostAllocator,
METH_NOARGS,

View File

@ -212,10 +212,10 @@ PyObject* THXPModule_memoryStats(PyObject* self, PyObject* arg) {
TORCH_CHECK(THPUtils_checkLong(arg), "invalid argument to memory_stats");
const auto device_index = THPUtils_unpackDeviceIndex(arg);
using c10::CachingAllocator::Stat;
using c10::CachingAllocator::StatArray;
using c10::CachingAllocator::StatType;
using c10::CachingDeviceAllocator::DeviceStats;
using c10::CachingDeviceAllocator::Stat;
using c10::CachingDeviceAllocator::StatArray;
using c10::CachingDeviceAllocator::StatType;
const auto statToDict = [](const Stat& stat) {
py::dict dict;

View File

@ -1760,8 +1760,6 @@ __all__ = [
"graphs",
"has_half",
"has_magma",
"host_memory_stats",
"host_memory_stats_as_nested_dict",
"init",
"initial_seed",
"ipc_collect",
@ -1798,11 +1796,9 @@ __all__ = [
"nvtx",
"profiler",
"random",
"reset_accumulated_host_memory_stats",
"reset_accumulated_memory_stats",
"reset_max_memory_allocated",
"reset_max_memory_cached",
"reset_peak_host_memory_stats",
"reset_peak_memory_stats",
"seed",
"seed_all",

View File

@ -39,10 +39,6 @@ __all__ = [
"reset_peak_memory_stats",
"reset_max_memory_allocated",
"reset_max_memory_cached",
"host_memory_stats",
"host_memory_stats_as_nested_dict",
"reset_accumulated_host_memory_stats",
"reset_peak_host_memory_stats",
"memory_allocated",
"max_memory_allocated",
"memory_reserved",
@ -374,100 +370,6 @@ def reset_peak_memory_stats(device: Union[Device, int] = None) -> None:
return torch._C._cuda_resetPeakMemoryStats(device)
def host_memory_stats() -> dict[str, Any]:
r"""Return a dictionary of CUDA memory allocator statistics for a given device.
The return value of this function is a dictionary of statistics, each of
which is a non-negative integer.
Core statistics:
- ``"allocated.{current,peak,allocated,freed}"``:
number of allocation requests received by the memory allocator.
- ``"allocated_bytes.{current,peak,allocated,freed}"``:
amount of allocated memory.
- ``"segment.{current,peak,allocated,freed}"``:
number of reserved segments from ``cudaMalloc()``.
- ``"reserved_bytes.{current,peak,allocated,freed}"``:
amount of reserved memory.
For these core statistics, values are broken down as follows.
Metric type:
- ``current``: current value of this metric.
- ``peak``: maximum value of this metric.
- ``allocated``: historical total increase in this metric.
- ``freed``: historical total decrease in this metric.
In addition to the core statistics, we also provide some simple event
counters:
- ``"num_host_alloc"``: number of CUDA allocation calls. This includes both
cudaHostAlloc and cudaHostRegister.
- ``"num_host_free"``: number of CUDA free calls. This includes both cudaHostFree
and cudaHostUnregister.
Finally, we also provide some simple timing counters:
- ``"host_alloc_time.{total,max,min,count,avg}"``:
timing of allocation requests going through CUDA calls.
- ``"host_free_time.{total,max,min,count,avg}"``:
timing of free requests going through CUDA calls.
For these timing statistics, values are broken down as follows.
Metric type:
- ``total``: total time spent.
- ``max``: maximum value per call.
- ``min``: minimum value per call.
- ``count``: number of times it was called.
- ``avg``: average time per call.
"""
result = []
def _recurse_add_to_result(prefix, obj):
if isinstance(obj, dict):
if len(prefix) > 0:
prefix += "."
for k, v in obj.items():
_recurse_add_to_result(prefix + k, v)
else:
result.append((prefix, obj))
stats = host_memory_stats_as_nested_dict()
_recurse_add_to_result("", stats)
result.sort()
return collections.OrderedDict(result)
def host_memory_stats_as_nested_dict() -> dict[str, Any]:
r"""Return the result of :func:`~torch.cuda.host_memory_stats` as a nested dictionary."""
if not is_initialized():
return {}
return torch._C._cuda_hostMemoryStats()
def reset_accumulated_host_memory_stats() -> None:
r"""Reset the "accumulated" (historical) stats tracked by the host memory allocator.
See :func:`~torch.cuda.host_memory_stats` for details. Accumulated stats correspond to
the `"allocated"` and `"freed"` keys in each individual stat dict.
"""
return torch._C._cuda_resetAccumulatedHostMemoryStats()
def reset_peak_host_memory_stats() -> None:
r"""Reset the "peak" stats tracked by the host memory allocator.
See :func:`~torch.cuda.host_memory_stats` for details. Peak stats correspond to the
`"peak"` key in each individual stat dict.
"""
return torch._C._cuda_resetPeakHostMemoryStats()
def reset_max_memory_allocated(device: Union[Device, int] = None) -> None:
r"""Reset the starting point in tracking maximum GPU memory occupied by tensors for a given device.