From 507611f9aeaa352c9f784da0623a58ec1d6e8585 Mon Sep 17 00:00:00 2001 From: cyy Date: Tue, 5 Mar 2024 09:53:01 +0000 Subject: [PATCH] [CUDACachingAllocator] Turn Allocator::allocate into non-const (#120969) Ideally, the method should be non-const since it changes the allocator state. Some const_casts are also removed in the way. Pull Request resolved: https://github.com/pytorch/pytorch/pull/120969 Approved by: https://github.com/albanD --- aten/src/ATen/EmptyTensor.cpp | 2 +- aten/src/ATen/cuda/CachingHostAllocator.cpp | 2 +- aten/src/ATen/hip/impl/HIPAllocatorMasqueradingAsCUDA.h | 2 +- aten/src/ATen/mps/MPSAllocator.mm | 2 +- aten/src/ATen/native/TensorFactories.h | 2 +- aten/src/ATen/test/xla_tensor_test.cpp | 2 +- c10/core/Allocator.cpp | 2 +- c10/core/Allocator.h | 4 ++-- c10/core/CPUAllocator.cpp | 4 ++-- c10/core/TensorImpl.h | 2 +- c10/cuda/CUDACachingAllocator.cpp | 6 ++---- c10/cuda/CUDAMallocAsyncAllocator.cpp | 2 +- c10/xpu/XPUCachingAllocator.cpp | 6 ++---- caffe2/core/context_gpu.cu | 4 ++-- test/cpp_extensions/open_registration_extension.cpp | 2 +- test/inductor/extension_backends/extension_device.cpp | 2 +- torch/csrc/cuda/CUDAPluggableAllocator.cpp | 6 ++---- torch/csrc/cuda/CUDAPluggableAllocator.h | 2 +- 18 files changed, 24 insertions(+), 30 deletions(-) diff --git a/aten/src/ATen/EmptyTensor.cpp b/aten/src/ATen/EmptyTensor.cpp index c292a3b598b3..369dcdb70f22 100644 --- a/aten/src/ATen/EmptyTensor.cpp +++ b/aten/src/ATen/EmptyTensor.cpp @@ -316,7 +316,7 @@ struct MetaAllocator final : public at::Allocator { static void deleter(void* const pointer) { TORCH_INTERNAL_ASSERT(!pointer); } - DataPtr allocate(const size_t nbytes) const override { + DataPtr allocate(const size_t nbytes) override { return {nullptr, nullptr, &deleter, at::Device(DeviceType::Meta)}; } DeleterFnPtr raw_deleter() const override { diff --git a/aten/src/ATen/cuda/CachingHostAllocator.cpp b/aten/src/ATen/cuda/CachingHostAllocator.cpp index 3c63ed71c0ad..442f93cf9c84 100644 --- a/aten/src/ATen/cuda/CachingHostAllocator.cpp +++ b/aten/src/ATen/cuda/CachingHostAllocator.cpp @@ -492,7 +492,7 @@ void CachingHostAllocator_emptyCache() { } struct CUDAHostAllocatorWrapper final : public at::Allocator { - at::DataPtr allocate(size_t size) const override { + at::DataPtr allocate(size_t size) override { auto ptr_and_ctx = getCUDAHostAllocator().allocate(size); return { ptr_and_ctx.first, diff --git a/aten/src/ATen/hip/impl/HIPAllocatorMasqueradingAsCUDA.h b/aten/src/ATen/hip/impl/HIPAllocatorMasqueradingAsCUDA.h index e4c0cec950a6..8e2654bafe90 100644 --- a/aten/src/ATen/hip/impl/HIPAllocatorMasqueradingAsCUDA.h +++ b/aten/src/ATen/hip/impl/HIPAllocatorMasqueradingAsCUDA.h @@ -15,7 +15,7 @@ class HIPAllocatorMasqueradingAsCUDA final : public Allocator { public: explicit HIPAllocatorMasqueradingAsCUDA(Allocator* allocator) : allocator_(allocator) {} - DataPtr allocate(size_t size) const override { + DataPtr allocate(size_t size) override { DataPtr r = allocator_->allocate(size); r.unsafe_set_device(Device(c10::DeviceType::CUDA, r.device().index())); return r; diff --git a/aten/src/ATen/mps/MPSAllocator.mm b/aten/src/ATen/mps/MPSAllocator.mm index ac9bf85f4cdd..76280fb469e5 100644 --- a/aten/src/ATen/mps/MPSAllocator.mm +++ b/aten/src/ATen/mps/MPSAllocator.mm @@ -748,7 +748,7 @@ struct TORCH_API MPSAllocator final : public IMPSAllocator { return &Delete; } - DataPtr allocate(const size_t nbytes) const override { + DataPtr allocate(const size_t nbytes) override { __block id buf = nbytes > 0 ? _getAllocImpl().malloc(nbytes, m_usage) : nullptr; return {buf, buf, &Delete, at::Device(at::DeviceType::MPS, 0)}; } diff --git a/aten/src/ATen/native/TensorFactories.h b/aten/src/ATen/native/TensorFactories.h index 09bbde077e87..f9b2893d768a 100644 --- a/aten/src/ATen/native/TensorFactories.h +++ b/aten/src/ATen/native/TensorFactories.h @@ -124,7 +124,7 @@ struct ZeroTensorAllocator final : public at::Allocator { static void deleter(void* const pointer) { TORCH_INTERNAL_ASSERT(!pointer); } - DataPtr allocate(const size_t /*nbytes*/) const override { + DataPtr allocate(const size_t /*nbytes*/) override { return {nullptr, nullptr, &deleter, device_}; } DeleterFnPtr raw_deleter() const override { diff --git a/aten/src/ATen/test/xla_tensor_test.cpp b/aten/src/ATen/test/xla_tensor_test.cpp index 1c9e392f9fe0..bca63697ba9f 100644 --- a/aten/src/ATen/test/xla_tensor_test.cpp +++ b/aten/src/ATen/test/xla_tensor_test.cpp @@ -17,7 +17,7 @@ void* XLAMalloc(ptrdiff_t size) { } struct XLAAllocator final : public at::Allocator { - at::DataPtr allocate(size_t size) const override { + at::DataPtr allocate(size_t size) override { auto* ptr = XLAMalloc(size); return {ptr, ptr, &XLAFree, at::DeviceType::XLA}; } diff --git a/c10/core/Allocator.cpp b/c10/core/Allocator.cpp index 8b6674d078e7..491c85b081e8 100644 --- a/c10/core/Allocator.cpp +++ b/c10/core/Allocator.cpp @@ -4,7 +4,7 @@ namespace c10 { -DataPtr Allocator::clone(const void* data, std::size_t n) const { +DataPtr Allocator::clone(const void* data, std::size_t n) { DataPtr new_data = allocate(n); copy_data(new_data.mutable_get(), data, n); return new_data; diff --git a/c10/core/Allocator.h b/c10/core/Allocator.h index b0dd5a8a6831..412412557a0d 100644 --- a/c10/core/Allocator.h +++ b/c10/core/Allocator.h @@ -160,7 +160,7 @@ inline bool operator!=(std::nullptr_t, const DataPtr& dp) noexcept { struct C10_API Allocator { virtual ~Allocator() = default; - virtual DataPtr allocate(size_t n) const = 0; + virtual DataPtr allocate(size_t n) = 0; // Clones an allocation that came from this allocator. // @@ -171,7 +171,7 @@ struct C10_API Allocator { // attached to the input data. // // Requires: input data was allocated by the same allocator. - DataPtr clone(const void* data, std::size_t n) const; + DataPtr clone(const void* data, std::size_t n); // Checks if DataPtr has a simple context, not wrapped with any out of the // ordinary contexts. diff --git a/c10/core/CPUAllocator.cpp b/c10/core/CPUAllocator.cpp index 04759047b403..144e1b27b6de 100644 --- a/c10/core/CPUAllocator.cpp +++ b/c10/core/CPUAllocator.cpp @@ -17,7 +17,7 @@ namespace c10 { struct C10_API DefaultCPUAllocator final : at::Allocator { DefaultCPUAllocator() = default; - at::DataPtr allocate(size_t nbytes) const override { + at::DataPtr allocate(size_t nbytes) override { void* data = nullptr; try { data = c10::alloc_cpu(nbytes); @@ -103,7 +103,7 @@ class DefaultMobileCPUAllocator final : public at::Allocator { } } - DataPtr allocate(const size_t nbytes) const override { + DataPtr allocate(const size_t nbytes) override { if (C10_UNLIKELY(0u == nbytes)) { return { nullptr, diff --git a/c10/core/TensorImpl.h b/c10/core/TensorImpl.h index abea631dd45c..228d8b29f701 100644 --- a/c10/core/TensorImpl.h +++ b/c10/core/TensorImpl.h @@ -2262,7 +2262,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { storage_offset_ == 0); // because we just reallocated return storage_.mutable_data(); } - const Allocator* allocator = storage_.allocator(); + Allocator* allocator = storage_.allocator(); // Storage might have nullptr allocator in rare cases, for example, if // an external memory segment has been wrapped with Tensor and we don't // know how to reallocate it. However, in order to preserve legacy C2 diff --git a/c10/cuda/CUDACachingAllocator.cpp b/c10/cuda/CUDACachingAllocator.cpp index 094720340082..c44af3710501 100644 --- a/c10/cuda/CUDACachingAllocator.cpp +++ b/c10/cuda/CUDACachingAllocator.cpp @@ -3106,7 +3106,7 @@ class NativeCachingAllocator : public CUDAAllocator { return cpd; } - DataPtr allocate(size_t size) const override { + DataPtr allocate(size_t size) override { constexpr size_t one_exa_bytes = 1152921504606846976ULL; TORCH_CHECK_WITH( OutOfMemoryError, @@ -3131,9 +3131,7 @@ class NativeCachingAllocator : public CUDAAllocator { } } else { if (size != 0) { - // Allocator declars allocate const!? - const_cast(this)->malloc( - &devPtr, device, size, stream); + this->malloc(&devPtr, device, size, stream); } } diff --git a/c10/cuda/CUDAMallocAsyncAllocator.cpp b/c10/cuda/CUDAMallocAsyncAllocator.cpp index 7319d21b306c..cb5bc94a66ce 100644 --- a/c10/cuda/CUDAMallocAsyncAllocator.cpp +++ b/c10/cuda/CUDAMallocAsyncAllocator.cpp @@ -405,7 +405,7 @@ void local_raw_delete(void* ptr); // Same pattern as CUDACachingAllocator.cpp. struct CudaMallocAsyncAllocator : public CUDAAllocator { - DataPtr allocate(size_t size) const override { + DataPtr allocate(size_t size) override { constexpr size_t one_exa_bytes = 1152921504606846976ULL; TORCH_CHECK_WITH( OutOfMemoryError, diff --git a/c10/xpu/XPUCachingAllocator.cpp b/c10/xpu/XPUCachingAllocator.cpp index eeb38b1c213a..3b5c4b58593e 100644 --- a/c10/xpu/XPUCachingAllocator.cpp +++ b/c10/xpu/XPUCachingAllocator.cpp @@ -497,13 +497,11 @@ class XPUAllocator : public Allocator { device_allocators[block->device]->recordStream(block, stream); } - DataPtr allocate(size_t size) const override { + DataPtr allocate(size_t size) override { auto device = c10::xpu::current_device(); void* r = nullptr; if (size != 0) { - // Allocator declares allocate const! - const_cast(this)->malloc( - &r, device, size, xpu::getCurrentXPUStream(device)); + this->malloc(&r, device, size, xpu::getCurrentXPUStream(device)); } return {r, r, &local_raw_delete, Device(DeviceType::XPU, device)}; } diff --git a/caffe2/core/context_gpu.cu b/caffe2/core/context_gpu.cu index 6555b9732c9a..ecc933ac7fad 100644 --- a/caffe2/core/context_gpu.cu +++ b/caffe2/core/context_gpu.cu @@ -306,7 +306,7 @@ struct CAFFE2_CUDA_API PinnedCPUAllocator final : public at::Allocator { baseAllocator_ = GetDefaultCPUAllocator(); } ~PinnedCPUAllocator() override {} - at::DataPtr allocate(size_t nbytes) const override { + at::DataPtr allocate(size_t nbytes) override { if (nbytes == 0) { // replicate c10::alloc_cpu behavior - return nullptr return {nullptr, nullptr, &Delete, at::Device(CPU)}; @@ -513,7 +513,7 @@ void TrackMemoryAlloc(size_t nbytes) { struct DefaultCUDAAllocator final : public at::Allocator { DefaultCUDAAllocator() {} ~DefaultCUDAAllocator() override {} - at::DataPtr allocate(size_t nbytes) const override { + at::DataPtr allocate(size_t nbytes) override { // Lock the mutex std::lock_guard lock(CUDAContext::mutex()); // A one-time caffe2 cuda initializer. diff --git a/test/cpp_extensions/open_registration_extension.cpp b/test/cpp_extensions/open_registration_extension.cpp index 5cf6ea1df902..5f2eac14aeb8 100644 --- a/test/cpp_extensions/open_registration_extension.cpp +++ b/test/cpp_extensions/open_registration_extension.cpp @@ -173,7 +173,7 @@ at::Tensor& custom_abs_out(const at::Tensor& self, at::Tensor& out) { // A dummy allocator for our custom device, that secretly uses the CPU struct DummyCustomAllocator final : at::Allocator { DummyCustomAllocator() = default; - at::DataPtr allocate(size_t nbytes) const override { + at::DataPtr allocate(size_t nbytes) override { void* data = c10::alloc_cpu(nbytes); return {data, data, &ReportAndDelete, at::Device(at::DeviceType::PrivateUse1, custom_device_index)}; } diff --git a/test/inductor/extension_backends/extension_device.cpp b/test/inductor/extension_backends/extension_device.cpp index 2e86cde2dd1c..71f3f5919a9b 100644 --- a/test/inductor/extension_backends/extension_device.cpp +++ b/test/inductor/extension_backends/extension_device.cpp @@ -66,7 +66,7 @@ at::Tensor custom_to_device( // A dummy allocator for our custom device, that secretly uses the CPU struct DummyCustomAllocator final : at::Allocator { DummyCustomAllocator() = default; - at::DataPtr allocate(size_t nbytes) const override { + at::DataPtr allocate(size_t nbytes) override { void* data = c10::alloc_cpu(nbytes); return {data, data, &ReportAndDelete, at::Device(at::DeviceType::PrivateUse1, 0)}; } diff --git a/torch/csrc/cuda/CUDAPluggableAllocator.cpp b/torch/csrc/cuda/CUDAPluggableAllocator.cpp index b9758f46a779..cb7b62387b68 100644 --- a/torch/csrc/cuda/CUDAPluggableAllocator.cpp +++ b/torch/csrc/cuda/CUDAPluggableAllocator.cpp @@ -94,13 +94,11 @@ void* CUDAPluggableAllocator::malloc( return r; } -c10::DataPtr CUDAPluggableAllocator::allocate(size_t size) const { +c10::DataPtr CUDAPluggableAllocator::allocate(size_t size) { c10::DeviceIndex device = -1; C10_CUDA_CHECK(c10::cuda::GetDevice(&device)); cudaStream_t stream = c10::cuda::getCurrentCUDAStream(device); - void* r = - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) - const_cast(this)->malloc(size, device, stream); + void* r = this->malloc(size, device, stream); c10::DataPtr data_ptr = { r, r, raw_deleter(), c10::Device(c10::DeviceType::CUDA, device)}; return data_ptr; diff --git a/torch/csrc/cuda/CUDAPluggableAllocator.h b/torch/csrc/cuda/CUDAPluggableAllocator.h index 6cfa663d3d00..22a61e48e4a2 100644 --- a/torch/csrc/cuda/CUDAPluggableAllocator.h +++ b/torch/csrc/cuda/CUDAPluggableAllocator.h @@ -71,7 +71,7 @@ struct CUDAPluggableAllocator void* malloc(size_t size, c10::DeviceIndex device, cudaStream_t stream); - c10::DataPtr allocate(size_t size) const override; + c10::DataPtr allocate(size_t size) override; c10::DeleterFnPtr raw_deleter() const override; void* raw_alloc(size_t nbytes) override;