Revert "Add DeviceAllocator as the base device allocator (#138222)"

This reverts commit 92409b6c89fbfbd3caa79c81b1e3d9e7917d3bc7.

Reverted https://github.com/pytorch/pytorch/pull/138222 on behalf of https://github.com/Camyll due to internal build failures ([comment](https://github.com/pytorch/pytorch/pull/138222#issuecomment-3002206756))
This commit is contained in:
PyTorch MergeBot
2025-06-25 00:11:35 +00:00
parent 6459a5c7a9
commit 3dd872e6d5
13 changed files with 33 additions and 119 deletions

View File

@ -2,6 +2,7 @@
#include <ATen/cuda/CUDAGraph.h>
#include <ATen/cuda/Exceptions.h>
#include <ATen/Functions.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include <c10/cuda/CUDAFunctions.h>
#include <cstddef>

View File

@ -2,7 +2,6 @@
#include <ATen/Tensor.h>
#include <c10/core/Device.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include <c10/cuda/CUDAGraphsC10Utils.h>
#include <c10/cuda/CUDAStream.h>
#include <c10/util/flat_hash_map.h>

View File

@ -1,6 +1,6 @@
#pragma once
#include <c10/core/CachingDeviceAllocator.h>
#include <c10/core/Allocator.h>
#include <c10/core/DeviceType.h>
// Use of c10::hip namespace here makes hipification easier, because
@ -10,10 +10,10 @@ namespace c10::hip {
// Takes a valid HIPAllocator (of any sort) and turns it into
// an allocator pretending to be a CUDA allocator. See
// Note [Masquerading as CUDA]
class HIPAllocatorMasqueradingAsCUDA final : public DeviceAllocator {
DeviceAllocator* allocator_;
class HIPAllocatorMasqueradingAsCUDA final : public Allocator {
Allocator* allocator_;
public:
explicit HIPAllocatorMasqueradingAsCUDA(DeviceAllocator* allocator)
explicit HIPAllocatorMasqueradingAsCUDA(Allocator* allocator)
: allocator_(allocator) {}
DataPtr allocate(size_t size) override {
DataPtr r = allocator_->allocate(size);
@ -26,24 +26,6 @@ public:
void copy_data(void* dest, const void* src, std::size_t count) const final {
allocator_->copy_data(dest, src, count);
}
bool initialized() override {
return allocator_->initialized();
}
void emptyCache(MempoolId_t mempool_id = {0, 0}) {
allocator_->emptyCache(mempool_id);
}
void recordStream(const DataPtr& ptr, c10::Stream stream) {
allocator_->recordStream(ptr, stream);
}
CachingDeviceAllocator::DeviceStats getDeviceStats(c10::DeviceIndex device) {
return allocator_->getDeviceStats(device);
}
void resetAccumulatedStats(c10::DeviceIndex device) {
allocator_->resetAccumulatedStats(device);
}
void resetPeakStats(c10::DeviceIndex device) {
allocator_->resetPeakStats(device);
}
};
} // namespace c10::hip

View File

@ -4,9 +4,8 @@
namespace c10 { namespace hip {
namespace HIPCachingAllocatorMasqueradingAsCUDA {
static HIPAllocatorMasqueradingAsCUDA allocator(HIPCachingAllocator::get());
Allocator* get() {
static HIPAllocatorMasqueradingAsCUDA allocator(HIPCachingAllocator::get());
return &allocator;
}
@ -14,9 +13,5 @@ void recordStreamMasqueradingAsCUDA(const DataPtr& ptr, HIPStreamMasqueradingAsC
HIPCachingAllocator::recordStream(ptr, stream.hip_stream());
}
// Register this HIP allocator as CUDA allocator to enable access through both
// c10::GetAllocator(kCUDA) and c10::getDeviceAllocator(kCUDA) APIs
REGISTER_ALLOCATOR(kCUDA, &allocator)
} // namespace HIPCachingAllocatorMasqueradingAsCUDA
}} // namespace c10::hip

View File

@ -1,10 +0,0 @@
#include <c10/core/CachingDeviceAllocator.h>
namespace c10 {
// Ensures proper DLL export of this pure virtual base class on Windows,
// since it's mainly used in other DLLs outside c10.dll.
DeviceAllocator::DeviceAllocator() = default;
DeviceAllocator::~DeviceAllocator() = default;
} // namespace c10

View File

@ -1,7 +1,6 @@
#pragma once
#include <c10/core/Allocator.h>
#include <c10/core/Stream.h>
namespace c10::CachingDeviceAllocator {
@ -60,55 +59,3 @@ struct DeviceStats {
};
} // namespace c10::CachingDeviceAllocator
namespace c10 {
using CaptureId_t = unsigned long long;
// first is set if the instance is created by Graph mode capture_begin.
// second is set if the instance is created by Graph mode graph_pool_handle.
using MempoolId_t = std::pair<CaptureId_t, CaptureId_t>;
struct C10_API DeviceAllocator : public c10::Allocator {
DeviceAllocator();
~DeviceAllocator() override;
// Returns true if the allocator has been properly initialized and is ready
// for use
virtual bool initialized() = 0;
// Releases all cached device memory from the specified memory pool back to
// the system
virtual void emptyCache(MempoolId_t mempool_id = {0, 0}) = 0;
// Associates a memory allocation with a stream to establish dependency
// tracking. Prevents memory reuse until all operations on the specified
// stream complete
virtual void recordStream(const DataPtr& ptr, c10::Stream stream) = 0;
// Retrieves comprehensive memory statistics for the specified device,
// including allocation patterns, usage metrics
virtual CachingDeviceAllocator::DeviceStats getDeviceStats(
c10::DeviceIndex device) = 0;
// Resets cumulative allocation statistics for the specified device to zero
virtual void resetAccumulatedStats(c10::DeviceIndex device) = 0;
// Resets peak memory usage statistics for the specified device
virtual void resetPeakStats(c10::DeviceIndex device) = 0;
};
// This function is used to get the DeviceAllocator for a specific device type
// and keep backward compatibility with c10::GetAllocator.
C10_API inline DeviceAllocator* getDeviceAllocator(const DeviceType& t) {
TORCH_CHECK(
t != DeviceType::CPU,
"getDeviceAllocator is not supported for CPU device type.");
auto* allocator = c10::GetAllocator(t);
auto* device_allocator = dynamic_cast<DeviceAllocator*>(allocator);
TORCH_INTERNAL_ASSERT(
device_allocator, "Allocator for ", t, " is not a DeviceAllocator.");
return device_allocator;
}
} // namespace c10

View File

@ -3695,7 +3695,7 @@ class NativeCachingAllocator : public CUDAAllocator {
return device_allocator[block->device]->shareIpcHandle(block);
}
void recordStream(const DataPtr& ptr, c10::Stream stream) override {
void recordStream(const DataPtr& ptr, cuda::CUDAStream stream) override {
// Empty tensor's storage().data() might be a null ptr. As there is no
// blocks associated with those tensors, it is fine to do nothing here.
if (!ptr.get()) {
@ -3713,8 +3713,7 @@ class NativeCachingAllocator : public CUDAAllocator {
Block* block = get_allocated_block(ptr.get());
// block must not be null reaching here
TORCH_INTERNAL_ASSERT(block != nullptr, "No allocated block can be found");
c10::cuda::CUDAStream cuda_stream{stream};
device_allocator[block->device]->recordStream(block, cuda_stream);
device_allocator[block->device]->recordStream(block, stream);
}
SnapshotInfo snapshot(MempoolId_t mempool_id) override {
@ -4179,7 +4178,6 @@ struct BackendStaticInitializer {
BackendStaticInitializer() {
auto r = parseEnvForBackend();
at::SetAllocator(kCUDA, r, 0);
allocator.store(r);
}
};

View File

@ -202,18 +202,25 @@ struct ShareableHandle {
std::string handle;
};
class CUDAAllocator : public DeviceAllocator {
class CUDAAllocator : public Allocator {
public:
virtual void* raw_alloc(size_t nbytes) = 0;
virtual void* raw_alloc_with_stream(size_t nbytes, cudaStream_t stream) = 0;
virtual void raw_delete(void* ptr) = 0;
virtual void init(int device_count) = 0;
virtual bool initialized() = 0;
virtual double getMemoryFraction(c10::DeviceIndex device) = 0;
virtual void setMemoryFraction(double fraction, c10::DeviceIndex device) = 0;
virtual void emptyCache(MempoolId_t mempool_id = {0, 0}) = 0;
virtual void enable(bool value) = 0;
virtual bool isEnabled() const = 0;
virtual void cacheInfo(c10::DeviceIndex device, size_t* largestBlock) = 0;
virtual void* getBaseAllocation(void* ptr, size_t* size) = 0;
virtual void recordStream(const DataPtr&, CUDAStream stream) = 0;
virtual c10::CachingDeviceAllocator::DeviceStats getDeviceStats(
c10::DeviceIndex device) = 0;
virtual void resetAccumulatedStats(c10::DeviceIndex device) = 0;
virtual void resetPeakStats(c10::DeviceIndex device) = 0;
virtual SnapshotInfo snapshot(MempoolId_t mempool_id = {0, 0}) = 0;
virtual void beginAllocateToPool(
c10::DeviceIndex device,
@ -518,10 +525,6 @@ inline void enablePeerAccess(
namespace c10::cuda {
// Keep BC only
using c10::CaptureId_t;
using c10::MempoolId_t;
// MemPool represents a pool of memory in a caching allocator. Currently,
// it's just the ID of the pool object maintained in the CUDACachingAllocator.
//

View File

@ -9,6 +9,12 @@
namespace c10::cuda {
using CaptureId_t = unsigned long long;
// first is set if the instance is created by CUDAGraph::capture_begin.
// second is set if the instance is created by at::cuda::graph_pool_handle.
using MempoolId_t = std::pair<CaptureId_t, CaptureId_t>;
// RAII guard for "cudaStreamCaptureMode", a thread-local value
// that controls the error-checking strictness of a capture.
struct C10_CUDA_API CUDAStreamCaptureModeGuard {

View File

@ -607,7 +607,7 @@ struct CudaMallocAsyncAllocator : public CUDAAllocator {
return ptr;
}
void recordStream(const DataPtr& ptr, c10::Stream stream) override {
void recordStream(const DataPtr& ptr, cuda::CUDAStream stream) override {
std::lock_guard<std::mutex> lk(general_mutex);
auto ptr_val = ptr.get();
// Empty tensor's storage().data() might be a null ptr. As there is no
@ -620,8 +620,7 @@ struct CudaMallocAsyncAllocator : public CUDAAllocator {
auto it = ptr_info.find(ptr_val);
TORCH_INTERNAL_ASSERT(it != ptr_info.end(), "ptr not found in ptr_info");
c10::cuda::CUDAStream cuda_stream{stream};
UsageStream to_record{cuda_stream.stream(), stream.device_index()};
UsageStream to_record{stream.stream(), stream.device_index()};
if (to_record == it->second.creation_stream) {
TORCH_WARN_ONCE(
"Called record_stream on tensor whose original creation stream "

View File

@ -540,7 +540,7 @@ class DeviceCachingAllocator {
static void local_raw_delete(void* ptr);
class XPUAllocator : public DeviceAllocator {
class XPUAllocator : public Allocator {
private:
std::mutex mutex;
ska::flat_hash_map<void*, Block*> allocated_blocks;
@ -576,10 +576,6 @@ class XPUAllocator : public DeviceAllocator {
}
}
bool initialized() override {
return !device_allocators.empty();
}
void malloc(
void** devPtr,
DeviceIndex device,
@ -614,13 +610,13 @@ class XPUAllocator : public DeviceAllocator {
}
}
void emptyCache(MempoolId_t mempool_id [[maybe_unused]] = {0, 0}) override {
void emptyCache() {
for (auto& da : device_allocators) {
da->emptyCache();
}
}
void recordStream(const DataPtr& ptr, c10::Stream stream) override {
void recordStream(const DataPtr& ptr, XPUStream stream) {
if (!ptr.get()) {
return;
}
@ -630,8 +626,7 @@ class XPUAllocator : public DeviceAllocator {
Block* block = get_allocated_block(ptr.get());
TORCH_CHECK(block, "No allocated block can be found.");
c10::xpu::XPUStream xpu_stream{stream};
device_allocators[block->device]->recordStream(block, xpu_stream);
device_allocators[block->device]->recordStream(block, stream);
}
DataPtr allocate(size_t size) override {
@ -684,17 +679,17 @@ class XPUAllocator : public DeviceAllocator {
": did you call init?");
}
DeviceStats getDeviceStats(DeviceIndex device) override {
DeviceStats getDeviceStats(DeviceIndex device) {
assertValidDevice(device);
return device_allocators[device]->getStats();
}
void resetPeakStats(DeviceIndex device) override {
void resetPeakStats(DeviceIndex device) {
assertValidDevice(device);
device_allocators[device]->resetPeakStats();
}
void resetAccumulatedStats(DeviceIndex device) override {
void resetAccumulatedStats(DeviceIndex device) {
assertValidDevice(device);
device_allocators[device]->resetAccumulatedStats();
}

View File

@ -210,8 +210,7 @@ void* CUDAPluggableAllocator::getBaseAllocation(void* ptr, size_t* size) {
void CUDAPluggableAllocator::recordStream(
const c10::DataPtr& ptr,
c10::Stream c10_stream) {
streamType stream{c10_stream};
streamType stream) {
if (record_stream_fn_) {
record_stream_fn_(ptr.get(), stream);
}

View File

@ -122,7 +122,7 @@ struct TORCH_CUDA_CPP_API CUDAPluggableAllocator
void cacheInfo(c10::DeviceIndex device, size_t* largestBlock) override;
void* getBaseAllocation(void* ptr, size_t* size) override;
void recordStream(const c10::DataPtr&, c10::Stream stream) override;
void recordStream(const c10::DataPtr&, streamType stream) override;
c10::CachingDeviceAllocator::DeviceStats getDeviceStats(
c10::DeviceIndex device) override;