mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[PyTorch CCA] Add an API to get expandable segment sizes (#163771)
Summary: This diffs add an API to query expandable segment size for each stream so that we can use this info to warmup the segment in advance, so we dont incur any performance penalty during steady state inference for new CUDA memory allocations. Differential Revision: D76447308 Pull Request resolved: https://github.com/pytorch/pytorch/pull/163771 Approved by: https://github.com/bbus
This commit is contained in:
committed by
PyTorch MergeBot
parent
ad7e3c93b1
commit
c4bbc6433e
@ -90,6 +90,10 @@ public:
|
||||
allocator_->setMemoryFraction(fraction, device);
|
||||
}
|
||||
|
||||
std::vector<HIPCachingAllocator::StreamSegmentSize> getExpandableSegmentSizes(c10::DeviceIndex device) override {
|
||||
return allocator_->getExpandableSegmentSizes(device);
|
||||
}
|
||||
|
||||
void enable(bool value) override {
|
||||
allocator_->enable(value);
|
||||
}
|
||||
|
||||
@ -382,6 +382,7 @@ struct ExpandableSegment {
|
||||
peers_(std::move(peers)) {
|
||||
cudaDeviceProp prop{};
|
||||
C10_CUDA_CHECK(cudaGetDeviceProperties(&prop, device_));
|
||||
mapped_size_ = 0;
|
||||
// we allocate enough address space for 1 1/8 the total memory on the GPU.
|
||||
// This allows for some cases where we have to unmap pages earlier in the
|
||||
// segment to put them at the end.
|
||||
@ -493,6 +494,7 @@ struct ExpandableSegment {
|
||||
return SegmentRange{range.ptr, 0};
|
||||
}
|
||||
unmapHandles(begin, end);
|
||||
mapped_size_ -= (end - begin) * segment_size_;
|
||||
return rangeFromHandles(begin, end);
|
||||
}
|
||||
|
||||
@ -632,6 +634,18 @@ struct ExpandableSegment {
|
||||
return max_handles_ * segment_size_;
|
||||
}
|
||||
|
||||
cudaStream_t getStream() {
|
||||
return *stream_;
|
||||
}
|
||||
|
||||
size_t getMappedSize() {
|
||||
return mapped_size_;
|
||||
}
|
||||
|
||||
size_t getSegmentSize() {
|
||||
return segment_size_;
|
||||
}
|
||||
|
||||
void addPeer(c10::DeviceIndex device) {
|
||||
peers_.push_back(device);
|
||||
forEachAllocatedRange(
|
||||
@ -666,6 +680,7 @@ struct ExpandableSegment {
|
||||
handles_.at(i).value().handle,
|
||||
0ULL));
|
||||
}
|
||||
mapped_size_ += (end - begin) * segment_size_;
|
||||
setAccess(device_, begin, end);
|
||||
for (auto p : peers_) {
|
||||
setAccess(p, begin, end);
|
||||
@ -734,6 +749,7 @@ struct ExpandableSegment {
|
||||
std::optional<cudaStream_t> stream_;
|
||||
CUdeviceptr ptr_{};
|
||||
size_t segment_size_;
|
||||
size_t mapped_size_;
|
||||
size_t max_handles_;
|
||||
struct Handle {
|
||||
CUmemGenericAllocationHandle handle;
|
||||
@ -779,6 +795,17 @@ struct ExpandableSegment {
|
||||
size_t size() const {
|
||||
return 0;
|
||||
}
|
||||
cudaStream_t getStream() {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
size_t getMappedSize() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
size_t getSegmentSize() {
|
||||
return 0;
|
||||
}
|
||||
void addPeer(c10::DeviceIndex device) {}
|
||||
};
|
||||
#endif
|
||||
@ -2011,6 +2038,22 @@ class DeviceCachingAllocator {
|
||||
set_fraction = true;
|
||||
}
|
||||
|
||||
/** get expandable segment size for all the streams on device **/
|
||||
std::vector<StreamSegmentSize> getExpandableSegmentSizes() {
|
||||
std::lock_guard<std::recursive_mutex> lock(mutex);
|
||||
std::vector<StreamSegmentSize> sizes;
|
||||
for (auto& segment : expandable_segments_) {
|
||||
if (!segment->getStream()) {
|
||||
continue;
|
||||
}
|
||||
sizes.emplace_back(
|
||||
segment->getStream(),
|
||||
segment->getSegmentSize() == kSmallBuffer,
|
||||
segment->getMappedSize());
|
||||
}
|
||||
return sizes;
|
||||
}
|
||||
|
||||
/** returns cached blocks to the system allocator **/
|
||||
void emptyCache(MempoolId_t mempool_id) {
|
||||
auto context = maybeGatherContext(RecordContext::ALL);
|
||||
@ -3838,6 +3881,16 @@ class NativeCachingAllocator : public CUDAAllocator {
|
||||
device_allocator[device]->setMemoryFraction(fraction);
|
||||
}
|
||||
|
||||
std::vector<StreamSegmentSize> getExpandableSegmentSizes(
|
||||
c10::DeviceIndex device) override {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
0 <= device && static_cast<size_t>(device) < device_allocator.size(),
|
||||
"Allocator not initialized for device ",
|
||||
device,
|
||||
": did you call init?");
|
||||
return device_allocator[device]->getExpandableSegmentSizes();
|
||||
}
|
||||
|
||||
void recordHistory(
|
||||
bool enabled,
|
||||
CreateContextFn context_recorder,
|
||||
|
||||
@ -203,6 +203,14 @@ struct ShareableHandle {
|
||||
std::string handle;
|
||||
};
|
||||
|
||||
struct StreamSegmentSize {
|
||||
StreamSegmentSize(cudaStream_t s, bool small, size_t sz)
|
||||
: stream(s), is_small_pool(small), total_size(sz) {}
|
||||
cudaStream_t stream;
|
||||
bool is_small_pool;
|
||||
size_t total_size;
|
||||
};
|
||||
|
||||
class CUDAAllocator : public DeviceAllocator {
|
||||
public:
|
||||
virtual void* raw_alloc(size_t nbytes) = 0;
|
||||
@ -211,6 +219,8 @@ class CUDAAllocator : public DeviceAllocator {
|
||||
virtual void init(int device_count) = 0;
|
||||
virtual double getMemoryFraction(c10::DeviceIndex device) = 0;
|
||||
virtual void setMemoryFraction(double fraction, c10::DeviceIndex device) = 0;
|
||||
virtual std::vector<StreamSegmentSize> getExpandableSegmentSizes(
|
||||
c10::DeviceIndex device) = 0;
|
||||
virtual void enable(bool value) = 0;
|
||||
virtual bool isEnabled() const = 0;
|
||||
virtual void cacheInfo(c10::DeviceIndex device, size_t* largestBlock) = 0;
|
||||
@ -365,6 +375,11 @@ inline void setMemoryFraction(double fraction, c10::DeviceIndex device) {
|
||||
return get()->setMemoryFraction(fraction, device);
|
||||
}
|
||||
|
||||
inline std::vector<StreamSegmentSize> getExpandableSegmentSizes(
|
||||
c10::DeviceIndex device) {
|
||||
return get()->getExpandableSegmentSizes(device);
|
||||
}
|
||||
|
||||
inline void emptyCache(MempoolId_t mempool_id = {0, 0}) {
|
||||
return get()->emptyCache(mempool_id);
|
||||
}
|
||||
|
||||
@ -495,6 +495,13 @@ struct CudaMallocAsyncAllocator : public CUDAAllocator {
|
||||
// introduces performance nondeterminism.
|
||||
}
|
||||
|
||||
std::vector<StreamSegmentSize> getExpandableSegmentSizes(
|
||||
c10::DeviceIndex device) override {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"CUDAMallocAsyncAllocator does not yet support getExpandableSegmentSizes.");
|
||||
}
|
||||
|
||||
void emptyCache(/*unused*/ MempoolId_t mempool_id) override {
|
||||
std::lock_guard<std::mutex> lk(general_mutex);
|
||||
|
||||
|
||||
@ -165,6 +165,13 @@ void CUDAPluggableAllocator::setMemoryFraction(
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<c10::cuda::CUDACachingAllocator::StreamSegmentSize>
|
||||
CUDAPluggableAllocator::getExpandableSegmentSizes(c10::DeviceIndex device) {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"CUDAMallocAsyncAllocator does not yet support getExpandableSegmentSizes.");
|
||||
}
|
||||
|
||||
void CUDAPluggableAllocator::emptyCache(
|
||||
/*unused*/ c10::cuda::MempoolId_t mempool_id) {
|
||||
if (reset_fn_) {
|
||||
|
||||
@ -88,6 +88,8 @@ struct TORCH_CUDA_CPP_API CUDAPluggableAllocator
|
||||
bool initialized() override;
|
||||
double getMemoryFraction(c10::DeviceIndex device) override;
|
||||
void setMemoryFraction(double fraction, c10::DeviceIndex device) override;
|
||||
std::vector<c10::cuda::CUDACachingAllocator::StreamSegmentSize>
|
||||
getExpandableSegmentSizes(c10::DeviceIndex device) override;
|
||||
void emptyCache(c10::cuda::MempoolId_t mempool_id = {0, 0}) override;
|
||||
void enable(bool) override {}
|
||||
bool isEnabled() const override {
|
||||
|
||||
Reference in New Issue
Block a user