mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[AOTInductor] Use CudaCachingAllocator for memory allocation (#162893)
Summary: Use c10::CudaCachingAllocator for AOTInductor's initial constant buffer allocation. Test Plan: Activate test under test/cpp/aoti_inference/test.cpp Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/162893 Approved by: https://github.com/desertfire
This commit is contained in:
committed by
PyTorch MergeBot
parent
0e9f9c3a61
commit
2291199e9b
@ -879,12 +879,15 @@ void test_cuda_alloc_test() {
|
||||
if (cudaStatus != cudaSuccess || device_idx == -1) {
|
||||
throw std::runtime_error("cudaGetDevice failed!");
|
||||
}
|
||||
|
||||
c10::cuda::CUDACachingAllocator::emptyCache();
|
||||
c10::cuda::CUDACachingAllocator::DeviceStats stats =
|
||||
c10::cuda::CUDACachingAllocator::getDeviceStats(device_idx);
|
||||
size_t initTorchActive = stats.active_bytes[0].current;
|
||||
size_t initTorchActive = stats.allocated_bytes[0].current;
|
||||
auto runner = std::make_unique<torch::inductor::AOTIModelContainerRunnerCuda>(
|
||||
model_so_path);
|
||||
size_t torchActive = stats.active_bytes[0].current;
|
||||
stats = c10::cuda::CUDACachingAllocator::getDeviceStats(device_idx);
|
||||
size_t torchActive = stats.allocated_bytes[0].current;
|
||||
|
||||
ASSERT_EQ(initTorchActive + DATASIZE, torchActive);
|
||||
|
||||
@ -1113,8 +1116,7 @@ TEST(AotInductorTest, MultiStreamTestCuda) {
|
||||
test_multi_cuda_streams("cuda");
|
||||
}
|
||||
|
||||
// TODO: ENABLE CUDACachingAllocator Test
|
||||
TEST(DISABLED_AotInductorTest, CudaAllocTestCuda) {
|
||||
TEST(AotInductorTest, CudaAllocTestCuda) {
|
||||
test_cuda_alloc_test();
|
||||
}
|
||||
#endif
|
||||
|
@ -1345,6 +1345,15 @@ def get_mmap_self_macro(use_mmap_weights: bool) -> list[str]:
|
||||
return macros
|
||||
|
||||
|
||||
def get_caching_allocator_macro() -> list[str]:
|
||||
from torch._inductor import config
|
||||
|
||||
macros = []
|
||||
if config.aot_inductor.weight_use_caching_allocator:
|
||||
macros.append(" AOT_INDUCTOR_USE_CACHING_ALLOCATOR")
|
||||
return macros
|
||||
|
||||
|
||||
def get_cpp_torch_options(
|
||||
cpp_compiler: str,
|
||||
vec_isa: VecISA,
|
||||
@ -1401,6 +1410,7 @@ def get_cpp_torch_options(
|
||||
fb_macro_passthrough_args = _use_fb_internal_macros()
|
||||
|
||||
mmap_self_macros = get_mmap_self_macro(use_mmap_weights)
|
||||
caching_allocator_macros = get_caching_allocator_macro()
|
||||
|
||||
definitions = (
|
||||
torch_cpp_wrapper_definitions
|
||||
@ -1408,6 +1418,7 @@ def get_cpp_torch_options(
|
||||
+ isa_macros
|
||||
+ fb_macro_passthrough_args
|
||||
+ mmap_self_macros
|
||||
+ caching_allocator_macros
|
||||
)
|
||||
include_dirs = (
|
||||
sys_libs_include_dirs
|
||||
|
@ -325,10 +325,23 @@ using RAIIDataPtr = std::unique_ptr<void, std::function<void(void*)>>;
|
||||
|
||||
// NOLINTNEXTLINE(clang-diagnostic-unneeded-internal-declaration)
|
||||
RAIIDataPtr RAII_gpuMalloc(size_t num_bytes) {
|
||||
#ifdef AOT_INDUCTOR_USE_CACHING_ALLOCATOR
|
||||
// Use caching allocator for allocating GPU memory
|
||||
void* data_ptr = nullptr;
|
||||
AOTI_TORCH_ERROR_CODE_CHECK(
|
||||
aoti_torch_cuda_caching_allocator_raw_alloc(num_bytes, &data_ptr));
|
||||
auto deleter = [](void* ptr) {
|
||||
AOTI_TORCH_ERROR_CODE_CHECK(
|
||||
aoti_torch_cuda_caching_allocator_raw_delete(ptr));
|
||||
};
|
||||
return RAIIDataPtr(data_ptr, deleter);
|
||||
#else
|
||||
// Use cudaMalloc directly for allocating GPU memory
|
||||
void* data_ptr = nullptr;
|
||||
AOTI_RUNTIME_CUDA_CHECK(cudaMalloc((void**)&data_ptr, num_bytes));
|
||||
auto deleter = [](void* ptr) { AOTI_RUNTIME_CUDA_CHECK(cudaFree(ptr)); };
|
||||
return RAIIDataPtr(data_ptr, deleter);
|
||||
#endif
|
||||
}
|
||||
|
||||
#elif defined(USE_XPU)
|
||||
|
@ -579,6 +579,15 @@ aoti_torch_delete_cuda_stream_guard(CUDAStreamGuardHandle guard);
|
||||
AOTI_TORCH_EXPORT AOTITorchError
|
||||
aoti_torch_get_current_cuda_stream(int32_t device_index, void** ret_stream);
|
||||
|
||||
// CUDA memory allocation using CUDACachingAllocator
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_caching_allocator_raw_alloc(
|
||||
uint64_t nbytes,
|
||||
void** ret_ptr // returns raw GPU memory pointer
|
||||
);
|
||||
|
||||
AOTI_TORCH_EXPORT AOTITorchError
|
||||
aoti_torch_cuda_caching_allocator_raw_delete(void* ptr);
|
||||
|
||||
#endif // USE_CUDA
|
||||
|
||||
// See `ProxyExecutor Design Note` in ir.py for more details
|
||||
|
@ -2,6 +2,7 @@
|
||||
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
|
||||
#include <torch/csrc/inductor/aoti_torch/utils.h>
|
||||
|
||||
#include <c10/cuda/CUDACachingAllocator.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
|
||||
@ -53,3 +54,32 @@ AOTITorchError aoti_torch_get_current_cuda_stream(
|
||||
*(cudaStream_t*)(ret_stream) = at::cuda::getCurrentCUDAStream(device_index);
|
||||
});
|
||||
}
|
||||
|
||||
AOTITorchError aoti_torch_cuda_caching_allocator_raw_alloc(
|
||||
uint64_t nbytes,
|
||||
void** ret_ptr) {
|
||||
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
|
||||
if (nbytes == 0) {
|
||||
*ret_ptr = nullptr;
|
||||
return AOTI_TORCH_SUCCESS;
|
||||
}
|
||||
|
||||
*ret_ptr = c10::cuda::CUDACachingAllocator::raw_alloc(nbytes);
|
||||
|
||||
if (*ret_ptr == nullptr) {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Failed to allocate ",
|
||||
nbytes,
|
||||
" bytes from CUDA caching allocator");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
AOTITorchError aoti_torch_cuda_caching_allocator_raw_delete(void* ptr) {
|
||||
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
|
||||
if (ptr != nullptr) {
|
||||
c10::cuda::CUDACachingAllocator::raw_delete(ptr);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
Reference in New Issue
Block a user