mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:07:10 +08:00
Add DeviceAllocator as the base device allocator (#138222)
# Motivation In line with [RFC] [A device-agnostic Python device memory related API design for stream-based accelerators](https://github.com/pytorch/pytorch/issues/134978), some memory-related APIs are widely used in popular repositories, such as HuggingFace [so many if-else conditional code](https://github.com/search?q=repo%3Ahuggingface%2Faccelerate%20torch.cuda.empty_cache&type=code). We would like to introduce a generic API set under torch.accelerator namespace to generalize these user cases. <div align="center"> <table> <tr> <td> Device-specific memory APIs torch.xxx.foo</td> <td> Device-agnostic memory APIs torch.accelerator.foo</td> </tr> <tr> <td> ```python torch.xxx.empty_cache ``` </td> <td> ```python torch.accelerator.empty_cache ``` </td> </tr> <tr> <td> ```python torch.xxx.reset_peak_memory_stats ``` </td> <td> ```python torch.accelerator.reset_peak_memory_stats ``` </td> </tr> <tr> <td> ```python torch.xxx.reset_accumulated_memory_stats ``` </td> <td> ```python torch.accelerator.reset_accumulated_memory_stats ``` </td> </tr> <tr> <td> ```python torch.xxx.memory_stats ``` </td> <td> ```python torch.accelerator.memory_stats ``` </td> </tr> <tr> <td> ```python torch.xxx.memory_allocated ``` </td> <td> ```python torch.accelerator.memory_allocated ``` </td> </tr> <tr> <td> ```python torch.xxx.max_memory_allocated ``` </td> <td> ```python torch.accelerator.max_memory_allocated ``` </td> </tr> <tr> <td> ```python torch.xxx.memory_reserved ``` </td> <td> ```python torch.accelerator.memory_reserved ``` </td> </tr> <tr> <td> ```python torch.xxx.max_memory_reserved ``` </td> <td> ```python torch.accelerator.max_memory_reserved ``` </td> </tr> </table> </div> # Solution This design follows a similar pattern to `HostAllocator`. We're introducing a base class `DeviceAllocator`, from which `CUDAAllocator` and `XPUAllocator` will inherit. This allows us to provide a unified call path like: `torch.accelerator.empty_cache()` -> `GetDeviceAllocator(allocator)->empty_cache()`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/138222 Approved by: https://github.com/albanD, https://github.com/Camyll
This commit is contained in:
committed by
PyTorch MergeBot
parent
f6d138807f
commit
1179e33323
@ -540,7 +540,7 @@ class DeviceCachingAllocator {
|
||||
|
||||
static void local_raw_delete(void* ptr);
|
||||
|
||||
class XPUAllocator : public Allocator {
|
||||
class XPUAllocator : public DeviceAllocator {
|
||||
private:
|
||||
std::mutex mutex;
|
||||
ska::flat_hash_map<void*, Block*> allocated_blocks;
|
||||
@ -576,6 +576,10 @@ class XPUAllocator : public Allocator {
|
||||
}
|
||||
}
|
||||
|
||||
bool initialized() override {
|
||||
return !device_allocators.empty();
|
||||
}
|
||||
|
||||
void malloc(
|
||||
void** devPtr,
|
||||
DeviceIndex device,
|
||||
@ -610,13 +614,13 @@ class XPUAllocator : public Allocator {
|
||||
}
|
||||
}
|
||||
|
||||
void emptyCache() {
|
||||
void emptyCache(MempoolId_t mempool_id [[maybe_unused]] = {0, 0}) override {
|
||||
for (auto& da : device_allocators) {
|
||||
da->emptyCache();
|
||||
}
|
||||
}
|
||||
|
||||
void recordStream(const DataPtr& ptr, XPUStream stream) {
|
||||
void recordStream(const DataPtr& ptr, c10::Stream stream) override {
|
||||
if (!ptr.get()) {
|
||||
return;
|
||||
}
|
||||
@ -626,7 +630,8 @@ class XPUAllocator : public Allocator {
|
||||
|
||||
Block* block = get_allocated_block(ptr.get());
|
||||
TORCH_CHECK(block, "No allocated block can be found.");
|
||||
device_allocators[block->device]->recordStream(block, stream);
|
||||
c10::xpu::XPUStream xpu_stream{stream};
|
||||
device_allocators[block->device]->recordStream(block, xpu_stream);
|
||||
}
|
||||
|
||||
DataPtr allocate(size_t size) override {
|
||||
@ -679,17 +684,17 @@ class XPUAllocator : public Allocator {
|
||||
": did you call init?");
|
||||
}
|
||||
|
||||
DeviceStats getDeviceStats(DeviceIndex device) {
|
||||
DeviceStats getDeviceStats(DeviceIndex device) override {
|
||||
assertValidDevice(device);
|
||||
return device_allocators[device]->getStats();
|
||||
}
|
||||
|
||||
void resetPeakStats(DeviceIndex device) {
|
||||
void resetPeakStats(DeviceIndex device) override {
|
||||
assertValidDevice(device);
|
||||
device_allocators[device]->resetPeakStats();
|
||||
}
|
||||
|
||||
void resetAccumulatedStats(DeviceIndex device) {
|
||||
void resetAccumulatedStats(DeviceIndex device) override {
|
||||
assertValidDevice(device);
|
||||
device_allocators[device]->resetAccumulatedStats();
|
||||
}
|
||||
|
Reference in New Issue
Block a user