Add DeviceAllocator as the base device allocator (#138222)

# Motivation
In line with [RFC] [A device-agnostic Python device memory related API design for stream-based accelerators](https://github.com/pytorch/pytorch/issues/134978), some memory-related APIs are widely used in popular repositories, such as HuggingFace [so many if-else conditional code](https://github.com/search?q=repo%3Ahuggingface%2Faccelerate%20torch.cuda.empty_cache&type=code). We would like to introduce a generic API set under torch.accelerator namespace to generalize these user cases.

<div align="center">
<table>
<tr>
<td> Device-specific memory APIs torch.xxx.foo</td> <td> Device-agnostic memory APIs torch.accelerator.foo</td>
</tr>
<tr>
<td>

```python
torch.xxx.empty_cache
```

</td>
<td>

```python
torch.accelerator.empty_cache
```

</td>
</tr>

<tr>
<td>

```python
torch.xxx.reset_peak_memory_stats
```

</td>
<td>

```python
torch.accelerator.reset_peak_memory_stats
```

</td>
</tr>

<tr>
<td>

```python
torch.xxx.reset_accumulated_memory_stats
```

</td>
<td>

```python
torch.accelerator.reset_accumulated_memory_stats
```

</td>
</tr>

<tr>
<td>

```python
torch.xxx.memory_stats
```

</td>
<td>

```python
torch.accelerator.memory_stats
```

</td>
</tr>

<tr>
<td>

```python
torch.xxx.memory_allocated
```

</td>
<td>

```python
torch.accelerator.memory_allocated
```

</td>
</tr>

<tr>
<td>

```python
torch.xxx.max_memory_allocated
```

</td>
<td>

```python
torch.accelerator.max_memory_allocated
```

</td>
</tr>

<tr>
<td>

```python
torch.xxx.memory_reserved
```

</td>
<td>

```python
torch.accelerator.memory_reserved
```

</td>
</tr>

<tr>
<td>

```python
torch.xxx.max_memory_reserved
```

</td>
<td>

```python
torch.accelerator.max_memory_reserved
```

</td>
</tr>

</table>
</div>

# Solution
This design follows a similar pattern to `HostAllocator`. We're introducing a base class `DeviceAllocator`, from which `CUDAAllocator` and `XPUAllocator` will inherit. This allows us to provide a unified call path like: `torch.accelerator.empty_cache()` -> `GetDeviceAllocator(allocator)->empty_cache()`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138222
Approved by: https://github.com/albanD, https://github.com/Camyll
This commit is contained in:
Yu, Guangye
2025-07-11 00:15:04 +00:00
committed by PyTorch MergeBot
parent f6d138807f
commit 1179e33323
10 changed files with 116 additions and 27 deletions

View File

@ -540,7 +540,7 @@ class DeviceCachingAllocator {
static void local_raw_delete(void* ptr);
class XPUAllocator : public Allocator {
class XPUAllocator : public DeviceAllocator {
private:
std::mutex mutex;
ska::flat_hash_map<void*, Block*> allocated_blocks;
@ -576,6 +576,10 @@ class XPUAllocator : public Allocator {
}
}
bool initialized() override {
return !device_allocators.empty();
}
void malloc(
void** devPtr,
DeviceIndex device,
@ -610,13 +614,13 @@ class XPUAllocator : public Allocator {
}
}
void emptyCache() {
void emptyCache(MempoolId_t mempool_id [[maybe_unused]] = {0, 0}) override {
for (auto& da : device_allocators) {
da->emptyCache();
}
}
void recordStream(const DataPtr& ptr, XPUStream stream) {
void recordStream(const DataPtr& ptr, c10::Stream stream) override {
if (!ptr.get()) {
return;
}
@ -626,7 +630,8 @@ class XPUAllocator : public Allocator {
Block* block = get_allocated_block(ptr.get());
TORCH_CHECK(block, "No allocated block can be found.");
device_allocators[block->device]->recordStream(block, stream);
c10::xpu::XPUStream xpu_stream{stream};
device_allocators[block->device]->recordStream(block, xpu_stream);
}
DataPtr allocate(size_t size) override {
@ -679,17 +684,17 @@ class XPUAllocator : public Allocator {
": did you call init?");
}
DeviceStats getDeviceStats(DeviceIndex device) {
DeviceStats getDeviceStats(DeviceIndex device) override {
assertValidDevice(device);
return device_allocators[device]->getStats();
}
void resetPeakStats(DeviceIndex device) {
void resetPeakStats(DeviceIndex device) override {
assertValidDevice(device);
device_allocators[device]->resetPeakStats();
}
void resetAccumulatedStats(DeviceIndex device) {
void resetAccumulatedStats(DeviceIndex device) override {
assertValidDevice(device);
device_allocators[device]->resetAccumulatedStats();
}