Compare commits

...

10 Commits

Author SHA1 Message Date
7516fc7aa2 Merge remote-tracking branch 'upstream/main' into majing/mempool-1 2025-11-14 06:22:42 +00:00
f9164a33d7 Merge remote-tracking branch 'upstream/main' into majing/mempool-1 2025-11-13 07:57:33 +00:00
c8322740c1 revert emptyCache API change
Signed-off-by: Ma, Jing1 <jing1.ma@intel.com>
2025-11-13 07:56:23 +00:00
ddf5fc67c7 Fix compile warning
Signed-off-by: Ma, Jing1 <jing1.ma@intel.com>
2025-11-12 08:03:09 +00:00
854f07f123 Fixed bugs in rebase
Signed-off-by: Ma, Jing1 <jing1.ma@intel.com>
2025-11-12 05:14:27 +00:00
bcffd08f20 rebase
Signed-off-by: Ma, Jing1 <jing1.ma@intel.com>
2025-11-12 02:50:09 +00:00
83029ff207 Fix review comments
Signed-off-by: Ma, Jing1 <jing1.ma@intel.com>
2025-11-07 08:20:41 +00:00
02611f69a5 Apply suggestion from @EikanWang
Co-authored-by: Eikan Wang <eikan.wang@intel.com>
2025-11-07 14:28:41 +08:00
67367e2cd4 Fixed format
Signed-off-by: Ma, Jing1 <jing1.ma@intel.com>
2025-11-03 04:10:08 +00:00
3c35b0938f [1/N][XPU] Add MemPool for XPU
Signed-off-by: Ma, Jing1 <jing1.ma@intel.com>
2025-11-03 02:58:48 +00:00
2 changed files with 154 additions and 20 deletions

View File

@ -15,6 +15,8 @@ using namespace c10::CachingDeviceAllocator;
// newly allocated memory with 512-byte alignment.
constexpr size_t kDeviceAlignment = 512;
class XPUAllocator;
namespace {
using stream_set = ska::flat_hash_set<xpu::XPUStream>;
@ -23,14 +25,19 @@ typedef bool (*Comparison)(const Block*, const Block*);
bool BlockComparatorSize(const Block* a, const Block* b);
bool BlockComparatorAddress(const Block* a, const Block* b);
struct PrivatePool;
struct BlockPool {
BlockPool(bool small)
BlockPool(bool small, PrivatePool* private_pool = nullptr)
: blocks(BlockComparatorSize),
unmapped(BlockComparatorAddress),
is_small(small) {}
is_small(small),
owner_PrivatePool(private_pool) {}
std::set<Block*, Comparison> blocks;
std::set<Block*, Comparison> unmapped;
const bool is_small;
PrivatePool* owner_PrivatePool;
};
struct ExpandableSegment;
@ -349,6 +356,43 @@ struct AllocParams {
StatTypes stat_types = {};
};
// Internal implementation that manages actual memory blocks.
// high level MemPool interface wraps PrivatePool via MempoolId.
struct PrivatePool {
PrivatePool(MempoolId_t id, XPUAllocator* allocator = nullptr)
: id(std::move(id)),
allocator_(allocator),
large_blocks(/*small=*/false, this),
small_blocks(/*small=*/true, this) {}
PrivatePool(const PrivatePool&) = delete;
PrivatePool(PrivatePool&&) = delete;
PrivatePool& operator=(const PrivatePool&) = delete;
PrivatePool& operator=(PrivatePool&&) = delete;
~PrivatePool() = default;
// default Mempool when no Mempool is specified
MempoolId_t id{0, 0};
// Number of live graphs using this pool
int use_count{1};
// Number of unfreed allocations made for this pool. When use_count and
// allocation_count drop to zero, we can delete this PrivatePool from
// graph_pools.
int allocation_count{0};
XPUAllocator* allocator_;
BlockPool large_blocks;
BlockPool small_blocks;
public:
XPUAllocator* allocator() {
return allocator_;
}
};
struct MempoolIdHash {
std::size_t operator()(const MempoolId_t& mempool_id) const noexcept {
return mempool_id.first != 0 ? mempool_id.first : mempool_id.second;
}
};
} // anonymous namespace
class DeviceCachingAllocator {
@ -365,6 +409,13 @@ class DeviceCachingAllocator {
bool set_fraction = false;
std::vector<ExpandableSegment*> expandable_segments;
std::vector<c10::DeviceIndex> devices_with_peer_access; // reserved
std::vector<std::pair<MempoolId_t, std::function<bool(sycl::queue*)>>>
captures_underway;
ska::flat_hash_map<MempoolId_t, std::unique_ptr<PrivatePool>, MempoolIdHash>
graph_pools;
// Pools no longer referenced by any graph.
ska::flat_hash_map<MempoolId_t, PrivatePool*, MempoolIdHash>
graph_pools_freeable;
size_t try_merge_blocks(Block* dst, Block* src, BlockPool& pool) {
if (!src || src->allocated || src->event_count > 0 ||
@ -463,7 +514,22 @@ class DeviceCachingAllocator {
}
}
BlockPool& get_pool(size_t size) {
BlockPool& get_pool(size_t size, sycl::queue* queue) {
if (C10_UNLIKELY(!captures_underway.empty())) {
for (auto& entry : captures_underway) {
// lookup for mempool id matching current capture graph
if (entry.second(queue)) {
auto it1 = graph_pools.find(entry.first);
// lookup mempool
TORCH_INTERNAL_ASSERT(it1 != graph_pools.end());
if (size <= kSmallSize) {
return it1->second->small_blocks;
} else {
return it1->second->large_blocks;
}
}
}
}
if (size < kSmallSize) {
return small_blocks;
} else {
@ -669,6 +735,10 @@ class DeviceCachingAllocator {
if (!ptr) {
return false;
}
if (p.pool->owner_PrivatePool) {
p.pool->owner_PrivatePool->allocation_count++;
}
p.block = new Block(device, p.queue(), size, p.pool, ptr);
for_each_selected_stat_type(p.stat_types, [&](size_t stat_type) {
stats.reserved_bytes[stat_type].increase(size);
@ -677,11 +747,14 @@ class DeviceCachingAllocator {
return true;
}
void synchronize_and_free_events() {
void synchronize_and_free_events(PrivatePool* pool = nullptr) {
for (auto& xe : xpu_events) {
for (auto& e : xe.second) {
auto event = e.first;
auto* block = e.second;
if (pool && block->pool->owner_PrivatePool != pool) {
continue;
}
event.wait();
block->event_count--;
if (block->event_count == 0) {
@ -785,6 +858,13 @@ class DeviceCachingAllocator {
for_each_selected_stat_type(stat_types, [&](size_t stat_type) {
stats.reserved_bytes[stat_type].decrease(unmapped.size);
});
if (block->pool->owner_PrivatePool) {
// The Freed block belonged to a XPU graph's PrivatePool.
TORCH_INTERNAL_ASSERT(
block->pool->owner_PrivatePool->allocation_count > 0);
block->pool->owner_PrivatePool->allocation_count--;
}
}
void release_blocks(BlockPool& pool) {
@ -812,13 +892,41 @@ class DeviceCachingAllocator {
}
}
bool release_cached_blocks() {
synchronize_and_free_events();
// See Note [Safe to Free Blocks on BlockPool]
c10::xpu::syncStreamsOnDevice(device_index);
bool release_cached_blocks(MempoolId_t mempool_id) {
if (mempool_id.first == 0 && mempool_id.second == 0 &&
captures_underway.empty()) {
synchronize_and_free_events();
// See Note [Safe to Free Blocks on BlockPool]
c10::xpu::syncStreamsOnDevice(device_index);
release_blocks(large_blocks);
release_blocks(small_blocks);
release_blocks(large_blocks);
release_blocks(small_blocks);
}
for (auto it = graph_pools_freeable.begin();
it != graph_pools_freeable.end();) {
if (mempool_id.first != 0 || mempool_id.second != 0) {
if (it->first == mempool_id) {
// If there is an active mempool, we sync only the events
// associated with the pool
synchronize_and_free_events(it->second);
} else {
// otherwise we move on
++it;
continue;
}
}
TORCH_INTERNAL_ASSERT(it->second->use_count == 0);
release_blocks(it->second->small_blocks);
release_blocks(it->second->large_blocks);
if (it->second->allocation_count == 0) {
auto erase_count = graph_pools.erase(it->first);
TORCH_INTERNAL_ASSERT(erase_count == 1);
it = graph_pools_freeable.erase(it);
} else {
++it;
}
}
return true;
}
@ -903,6 +1011,30 @@ class DeviceCachingAllocator {
}
}
void create_or_incref_pool(
MempoolId_t mempool_id,
XPUAllocator* allocator = nullptr) {
auto it = graph_pools.find(mempool_id);
if (it == graph_pools.end()) {
// mempool_id does not reference an existing pool.
// Make a new pool for XPU graph capture or memory pool usage.
graph_pools.emplace(
mempool_id, std::make_unique<PrivatePool>(mempool_id, allocator));
} else {
// mempool_id references an existing pool, which the current XPU graph
// capture will share.
TORCH_INTERNAL_ASSERT(it->second->use_count > 0);
TORCH_INTERNAL_ASSERT(allocator == nullptr);
it->second->use_count++;
}
}
PrivatePool* get_private_pool(MempoolId_t mempool_id) {
auto it = graph_pools.find(mempool_id);
TORCH_INTERNAL_ASSERT(it != graph_pools.end());
return it->second.get();
}
public:
DeviceCachingAllocator(DeviceIndex device_index)
: large_blocks(/* small */ false),
@ -911,9 +1043,11 @@ class DeviceCachingAllocator {
Block* malloc(DeviceIndex device, size_t orig_size, sycl::queue& queue) {
std::scoped_lock<std::recursive_mutex> lock(mutex);
process_events();
if (C10_LIKELY(captures_underway.empty())) {
process_events();
}
size_t size = round_size(orig_size);
auto& pool = get_pool(size);
auto& pool = get_pool(size, &queue);
const size_t alloc_size = get_allocation_size(size);
AllocParams params(device, size, &queue, &pool, alloc_size);
params.stat_types = get_stat_types_for_pool(pool);
@ -923,7 +1057,7 @@ class DeviceCachingAllocator {
// Can't reuse an existing block, try to get a new one.
if (!block_found) {
block_found = alloc_block(params, false) ||
(release_cached_blocks() && alloc_block(params, true));
(release_cached_blocks({0, 0}) && alloc_block(params, true));
}
if (!block_found) {
const auto& raw_device = c10::xpu::get_raw_device(device);
@ -1016,9 +1150,9 @@ class DeviceCachingAllocator {
block->stream_uses.insert(stream);
}
void emptyCache() {
void emptyCache(MempoolId_t mempool_id) {
std::scoped_lock<std::recursive_mutex> lock(mutex);
release_cached_blocks();
release_cached_blocks(mempool_id);
}
DeviceStats getStats() {
@ -1172,9 +1306,9 @@ class XPUAllocator : public DeviceAllocator {
}
}
void emptyCache(MempoolId_t mempool_id [[maybe_unused]] = {0, 0}) override {
void emptyCache(MempoolId_t mempool_id) override {
for (auto& da : device_allocators) {
da->emptyCache();
da->emptyCache(mempool_id);
}
}
@ -1290,8 +1424,8 @@ void init(DeviceIndex device_count) {
return allocator.init(device_count);
}
void emptyCache() {
return allocator.emptyCache();
void emptyCache(MempoolId_t mempool_id) {
return allocator.emptyCache(mempool_id);
}
void resetPeakStats(DeviceIndex device) {

View File

@ -10,7 +10,7 @@ C10_XPU_API Allocator* get();
C10_XPU_API void init(DeviceIndex device_count);
C10_XPU_API void emptyCache();
C10_XPU_API void emptyCache(MempoolId_t mempool_id = {0, 0});
C10_XPU_API void resetPeakStats(DeviceIndex device);