mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Small expandable segments refactor. (#130889)
Makes next PRs that will export/import segment handles easier to write. Pull Request resolved: https://github.com/pytorch/pytorch/pull/130889 Approved by: https://github.com/dsjohns2 ghstack dependencies: #130888
This commit is contained in:
committed by
PyTorch MergeBot
parent
d8fed480ef
commit
4d9f2a6d56
@ -383,12 +383,14 @@ struct ExpandableSegment {
|
||||
ExpandableSegment(
|
||||
c10::DeviceIndex device,
|
||||
cudaStream_t stream,
|
||||
size_t size,
|
||||
size_t address_space_size,
|
||||
size_t segment_size,
|
||||
std::vector<c10::DeviceIndex> peers)
|
||||
: device_(device),
|
||||
stream_(stream),
|
||||
// 2MB for small pool, 20MB for large pool
|
||||
segment_size_(size),
|
||||
segment_size_(segment_size),
|
||||
max_handles_(numSegments(address_space_size)),
|
||||
peers_(std::move(peers)) {
|
||||
cudaDeviceProp prop{};
|
||||
C10_CUDA_CHECK(cudaGetDeviceProperties(&prop, device_));
|
||||
@ -435,19 +437,7 @@ struct ExpandableSegment {
|
||||
C10_CUDA_DRIVER_CHECK(status);
|
||||
handles_.at(i) = handle;
|
||||
}
|
||||
for (auto i : c10::irange(begin, end)) {
|
||||
C10_CUDA_DRIVER_CHECK(DriverAPI::get()->cuMemMap_(
|
||||
ptr_ + i * segment_size_,
|
||||
segment_size_,
|
||||
0,
|
||||
handles_.at(i).value(),
|
||||
0ULL));
|
||||
}
|
||||
|
||||
setAccess(device_, begin, end);
|
||||
for (auto p : peers_) {
|
||||
setAccess(p, begin, end);
|
||||
}
|
||||
mapAndSetAccess(begin, end);
|
||||
return rangeFromHandles(begin, end);
|
||||
}
|
||||
|
||||
@ -496,6 +486,21 @@ struct ExpandableSegment {
|
||||
ptr_ + begin * segment_size_, (end - begin) * segment_size_, &desc, 1));
|
||||
}
|
||||
|
||||
void mapAndSetAccess(size_t begin, size_t end) {
|
||||
for (auto i : c10::irange(begin, end)) {
|
||||
C10_CUDA_DRIVER_CHECK(DriverAPI::get()->cuMemMap_(
|
||||
ptr_ + i * segment_size_,
|
||||
segment_size_,
|
||||
0,
|
||||
handles_.at(i).value(),
|
||||
0ULL));
|
||||
}
|
||||
setAccess(device_, begin, end);
|
||||
for (auto p : peers_) {
|
||||
setAccess(p, begin, end);
|
||||
}
|
||||
}
|
||||
|
||||
void unmapHandles(size_t begin, size_t end) {
|
||||
// note: unlike cudaFree, MemUnmap and MemRelease do
|
||||
// not appear to synchronize in all cases, so we have to wait for the
|
||||
@ -548,8 +553,8 @@ struct ExpandableSegment {
|
||||
c10::DeviceIndex device_;
|
||||
cudaStream_t stream_;
|
||||
CUdeviceptr ptr_{};
|
||||
size_t max_handles_{0};
|
||||
size_t segment_size_;
|
||||
size_t max_handles_;
|
||||
std::vector<std::optional<CUmemGenericAllocationHandle>> handles_;
|
||||
// devices on which this memory should be mapped in addition
|
||||
// to the device where the physical memory lives (device_).
|
||||
@ -560,8 +565,9 @@ struct ExpandableSegment {
|
||||
ExpandableSegment(
|
||||
c10::DeviceIndex device,
|
||||
cudaStream_t stream,
|
||||
size_t size,
|
||||
const std::vector<c10::DeviceIndex>& peers) {
|
||||
size_t address_space_size,
|
||||
size_t segment_size,
|
||||
std::vector<c10::DeviceIndex> peers) {
|
||||
TORCH_INTERNAL_ASSERT(false, "expandable segment not supported");
|
||||
}
|
||||
SegmentRange map(SegmentRange range) {
|
||||
@ -2069,8 +2075,19 @@ class DeviceCachingAllocator {
|
||||
}
|
||||
}
|
||||
auto segment_size = pool->is_small ? kSmallBuffer : kLargeBuffer;
|
||||
cudaDeviceProp prop{};
|
||||
C10_CUDA_CHECK(cudaGetDeviceProperties(&prop, device));
|
||||
// we allocate enough address space for 1 1/8 the total memory on the GPU.
|
||||
// This allows for some cases where we have to unmap pages earlier in the
|
||||
// segment to put them at the end.
|
||||
size_t address_space_size = prop.totalGlobalMem + prop.totalGlobalMem / 8;
|
||||
|
||||
expandable_segments_.emplace_back(new ExpandableSegment(
|
||||
device, stream, segment_size, devices_with_peer_access_));
|
||||
device,
|
||||
stream,
|
||||
address_space_size,
|
||||
segment_size,
|
||||
devices_with_peer_access_));
|
||||
|
||||
ExpandableSegment* es = expandable_segments_.back();
|
||||
Block* candidate = new Block(device, stream, es->size(), pool, es->ptr());
|
||||
|
Reference in New Issue
Block a user