Compare commits

...

11 Commits

Author SHA1 Message Date
f8699ce082 Update
[ghstack-poisoned]
2025-10-31 17:57:24 +00:00
7c6ae5c6dc Update
[ghstack-poisoned]
2025-10-31 14:00:01 +00:00
fcd2c20847 Update
[ghstack-poisoned]
2025-10-31 09:41:34 +00:00
a648209e23 Update (base update)
[ghstack-poisoned]
2025-10-31 09:41:34 +00:00
7dde3f269a Update
[ghstack-poisoned]
2025-10-31 01:13:24 +00:00
5f978247ff Update
[ghstack-poisoned]
2025-10-31 00:45:31 +00:00
01f34b3736 Update
[ghstack-poisoned]
2025-10-30 02:06:57 +00:00
c70e057b4c Update (base update)
[ghstack-poisoned]
2025-10-29 12:49:28 +00:00
01e522cdd7 Update
[ghstack-poisoned]
2025-10-29 12:49:28 +00:00
1f05519131 Update (base update)
[ghstack-poisoned]
2025-10-27 19:19:09 +00:00
831a02a237 Update
[ghstack-poisoned]
2025-10-27 19:19:09 +00:00

View File

@ -80,6 +80,212 @@ bool BlockComparatorSize(const Block* a, const Block* b) {
reinterpret_cast<uintptr_t>(b->ptr);
}
// Represents a contiguous virtual memory segment mapped for allocation.
struct SegmentRange {
SegmentRange(void* addr, size_t bytes)
: ptr(static_cast<char*>(addr)), size(bytes) {}
char* ptr; // Starting address of the mapped range.
size_t size; // Size in bytes of the mapped range.
};
struct ExpandableSegment {
ExpandableSegment(
c10::DeviceIndex device,
std::optional<sycl::queue*> queue,
size_t segment_size,
std::vector<c10::DeviceIndex> peers)
: device_(device),
queue_(queue),
// 2MB for small pool, 20MB for large pool
segment_size_(segment_size),
peers_(std::move(peers)) {
const auto device_total =
c10::xpu::get_raw_device(device)
.get_info<sycl::info::device::global_mem_size>();
// The extra 1/8 allows flexibility for remapping or moving pages within the
// segment when unmapping earlier regions.
constexpr float kVirtualMemOversubscriptFactor = 1.125f; // 1 + 1/8
max_handles_ = numSegments(device_total * kVirtualMemOversubscriptFactor);
ptr_ = sycl::ext::oneapi::experimental::reserve_virtual_mem(
segment_size_ * max_handles_, xpu::get_device_context());
}
C10_DISABLE_COPY_AND_ASSIGN(ExpandableSegment);
ExpandableSegment(ExpandableSegment&&) = delete;
ExpandableSegment& operator=(ExpandableSegment&&) = delete;
// Maps a virtual memory range to physical memory.
SegmentRange map(SegmentRange range) {
auto begin = segmentLeft(range.ptr);
auto end = segmentRight(range.ptr + range.size);
TORCH_INTERNAL_ASSERT(ptr() + begin * segment_size_ == range.ptr);
if (begin == end) {
return rangeFromHandles(begin, end);
}
// Ensure handles_ vector is large enough to hold all segments.
if (end > handles_.size()) {
handles_.resize(end, std::nullopt);
}
// Allocate and map physical memory for each segment.
for (const auto i : c10::irange(begin, end)) {
TORCH_INTERNAL_ASSERT(!handles_.at(i));
try {
// Allocate physical memory for each segment. Construct the physical_mem
// in-place to avoid copies.
handles_.at(i).emplace(
xpu::get_raw_device(device_),
xpu::get_device_context(),
segment_size_);
// Map the allocated physical memory into the virtual address space.
handles_.at(i).value().map(
ptr_ + i * segment_size_,
segment_size_,
sycl::ext::oneapi::experimental::address_access_mode::read_write);
} catch (const sycl::exception& e) {
// Allocation failure: typically sycl::errc::memory_allocation.
// Mapping failure: typically sycl::errc::runtime (e.g., OOM due to
// over-subscription).
// Note: constructing physical_mem may over-subscribe device memory but
// not immediately trigger OOM. The actual OOM can occur during map().
// Roll back all segments allocated or mapped in this operation.
handles_.at(i) = std::nullopt;
for (const auto j : c10::irange(begin, i)) {
sycl::ext::oneapi::experimental::unmap(
reinterpret_cast<void*>(ptr_ + segment_size_ * j),
segment_size_,
xpu::get_device_context());
handles_.at(j) = std::nullopt;
}
trimHandles();
return rangeFromHandles(begin, begin);
}
}
return rangeFromHandles(begin, end);
}
// Unmap a virtual memory range from physical memory.
SegmentRange unmap(SegmentRange range) {
auto begin = segmentRight(range.ptr);
auto end = segmentLeft(range.ptr + range.size);
if (begin >= end) {
return SegmentRange{range.ptr, 0};
}
unmapHandles(begin, end);
return rangeFromHandles(begin, end);
}
// Returns the base pointer of the virtual memory segment.
char* ptr() const {
// NOLINTNEXTLINE(performance-no-int-to-ptr)
return reinterpret_cast<char*>(ptr_);
}
// Returns the total size of the virtual memory segment.
size_t size() const {
return max_handles_ * segment_size_;
}
~ExpandableSegment() {
forEachAllocatedRange(
[&](size_t begin, size_t end) { unmapHandles(begin, end); });
sycl::ext::oneapi::experimental::free_virtual_mem(
ptr_, segment_size_ * max_handles_, xpu::get_device_context());
}
private:
// Unmaps the physical memory handles in the range [begin, end) from the
// segment.
void unmapHandles(size_t begin, size_t end) {
// Currently, we don't support IPC shared memory with expandable segments.
TORCH_INTERNAL_ASSERT(queue_);
// As explained in Note [Safe to Free Blocks on BlockPool], additional
// synchronization is unnecessary here because the memory is already safe to
// release.
for (const auto i : c10::irange(begin, end)) {
// Note: physical_mem's destructor does NOT automatically unmap any mapped
// ranges. Users must explicitly call unmap on all ranges before
// destroying the physical_mem object.
sycl::ext::oneapi::experimental::unmap(
reinterpret_cast<void*>(ptr_ + segment_size_ * i),
segment_size_,
xpu::get_device_context());
// Here physical_mem object is being destructed.
handles_.at(i) = std::nullopt;
}
trimHandles();
}
// Remove trailing unused handles from the end of handles_.
void trimHandles() {
while (!handles_.empty() && !handles_.back()) {
handles_.pop_back();
}
}
// Iterates over all contiguous ranges of allocated segments in `handles_`,
// and invokes the provided function `fn(start, end)` for each range.
// Each range is defined as a half-open interval [start, end).
void forEachAllocatedRange(const std::function<void(size_t, size_t)>& fn) {
size_t start = 0;
for (const auto i : c10::irange(handles_.size())) {
if (handles_.at(i) && (i == 0 || !handles_.at(i - 1))) {
start = i;
}
if (handles_.at(i) && (i + 1 == handles_.size() || !handles_.at(i + 1))) {
fn(start, i + 1);
}
}
}
// Returns the number of full segments required to cover `size` bytes.
// Rounds up to ensure partial segments are counted.
size_t numSegments(size_t size) const {
return (size + segment_size_ - 1) / segment_size_;
}
// Returns the index of the segment that contains the pointer `p`,
// relative to the base pointer `ptr_`. This is the *inclusive* lower bound
// of the segment that includes `p`.
size_t segmentLeft(char* p) const {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(p >= ptr() && p < ptr() + size());
size_t offset = p - ptr();
return offset / segment_size_;
}
// Returns the index of the segment just *past* the one containing pointer
// `p`, relative to the base pointer `ptr_`. This is the *exclusive* upper
// bound, useful for [begin, end) style ranges.
// If `p` lies exactly on a segment boundary, this is equal to segmentLeft(p).
// Otherwise, it rounds up and returns segmentLeft(p) + 1.
size_t segmentRight(char* p) const {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(p >= ptr() && p < ptr() + size());
size_t offset = p - ptr();
return numSegments(offset);
}
// Constructs a SegmentRange spanning indices [start, end).
SegmentRange rangeFromHandles(size_t begin, size_t end) {
return SegmentRange(
ptr() + segment_size_ * begin, segment_size_ * (end - begin));
}
c10::DeviceIndex device_{-1};
std::optional<sycl::queue*> queue_;
// Virtual memory address used for reservation.
uintptr_t ptr_{0};
// Size of each segment in bytes.
size_t segment_size_{0};
// Maximum number of segments that can be allocated in this segment.
size_t max_handles_{0};
// Physical memory handles for the segments.
std::vector<std::optional<sycl::ext::oneapi::experimental::physical_mem>>
handles_{};
// Peer devices on which this memory could be accessible, reserved.
std::vector<c10::DeviceIndex> peers_{};
};
struct AllocParams {
AllocParams(
DeviceIndex device,