Compare commits

...

1 Commits

Author SHA1 Message Date
4d49410a02 AOTI use CudaCachingAllocator
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
2025-08-06 14:09:46 -07:00
6 changed files with 91 additions and 3 deletions

View File

@ -879,15 +879,36 @@ void test_cuda_alloc_test() {
if (cudaStatus != cudaSuccess || device_idx == -1) {
throw std::runtime_error("cudaGetDevice failed!");
}
// Clear any existing cached memory to get a clean baseline
c10::cuda::CUDACachingAllocator::emptyCache();
// Set the environment variable to enable CUDACachingAllocator for weights
setenv("AOT_INDUCTOR_WEIGHT_USE_CACHING_ALLOCATOR", "1", 1);
c10::cuda::CUDACachingAllocator::DeviceStats stats =
c10::cuda::CUDACachingAllocator::getDeviceStats(device_idx);
size_t initTorchActive = stats.active_bytes[0].current;
auto runner = std::make_unique<torch::inductor::AOTIModelContainerRunnerCuda>(
model_so_path);
// Force CUDA synchronization to ensure all allocations are complete
cudaDeviceSynchronize();
// Get fresh stats after creating the runner
stats = c10::cuda::CUDACachingAllocator::getDeviceStats(device_idx);
size_t torchActive = stats.active_bytes[0].current;
// The test expects that when weight_use_caching_allocator is enabled,
// AOTInductor will allocate weights through PyTorch's CUDACachingAllocator
// instead of direct CUDA memory allocation. This should show up as an
// increase in the active_bytes stats.
ASSERT_EQ(initTorchActive + DATASIZE, torchActive);
// Clean up the environment variable
unsetenv("AOT_INDUCTOR_WEIGHT_USE_CACHING_ALLOCATOR");
auto actual_output_tensors =
runner->run(data_loader.attr(inputs_attr.c_str()).toTensorList().vec());
ASSERT_TRUE(torch::allclose(ref_output_tensors[0], actual_output_tensors[0]));
@ -1114,7 +1135,7 @@ TEST(AotInductorTest, MultiStreamTestCuda) {
}
// TODO: ENABLE CUDACachingAllocator Test
TEST(DISABLED_AotInductorTest, CudaAllocTestCuda) {
TEST(AotInductorTest, CudaAllocTestCuda) {
test_cuda_alloc_test();
}
#endif

View File

@ -1454,6 +1454,15 @@ class aot_inductor:
os.environ.get("AOT_INDUCTOR_WEIGHT_USE_CACHING_ALLOCATOR", "0") == "1"
)
def __setattr__(self, name: str, value: Any) -> None:
# When weight_use_caching_allocator is set to True, also set the environment variable
# so that the runtime code can access it
if name == "weight_use_caching_allocator" and value:
os.environ["AOT_INDUCTOR_WEIGHT_USE_CACHING_ALLOCATOR"] = "1"
elif name == "weight_use_caching_allocator" and not value:
os.environ["AOT_INDUCTOR_WEIGHT_USE_CACHING_ALLOCATOR"] = "0"
super().__setattr__(name, value)
# Experimental. Flag to control whether to include weight in .so
package_constants_in_so: bool = True

View File

@ -10,9 +10,11 @@
#endif
#include <fcntl.h>
#include <cstdlib>
#include <optional>
#include <regex>
#include <stdexcept>
#include <string>
#include <unordered_map>
#include <utility>
@ -74,6 +76,18 @@ RAIIDataPtr RAII_gpuMalloc(size_t num_bytes) {
return RAIIDataPtr(data_ptr, deleter);
}
// NOLINTNEXTLINE(clang-diagnostic-unneeded-internal-declaration)
RAIIDataPtr RAII_gpuMallocCaching(size_t num_bytes) {
void* data_ptr = nullptr;
AOTI_TORCH_ERROR_CODE_CHECK(
aoti_torch_cuda_caching_allocator_raw_alloc(&data_ptr, num_bytes));
auto deleter = [](void* ptr) {
AOTI_TORCH_ERROR_CODE_CHECK(
aoti_torch_cuda_caching_allocator_raw_delete(ptr));
};
return RAIIDataPtr(data_ptr, deleter);
}
#elif defined(USE_XPU)
RAIIDataPtr RAII_gpuMalloc(size_t num_bytes) {
@ -324,7 +338,19 @@ class AOTInductorModelBase {
return;
}
#if defined(USE_CUDA) || defined(USE_XPU) || defined(USE_MPS)
constant_blob_ = RAII_gpuMalloc(blob_size);
// Check if we should use PyTorch's CUDACachingAllocator for weight
// management
const char* use_caching_allocator =
std::getenv("AOT_INDUCTOR_WEIGHT_USE_CACHING_ALLOCATOR");
if (use_caching_allocator && std::string(use_caching_allocator) == "1") {
#ifdef USE_CUDA
constant_blob_ = RAII_gpuMallocCaching(blob_size);
#else
constant_blob_ = RAII_gpuMalloc(blob_size);
#endif
} else {
constant_blob_ = RAII_gpuMalloc(blob_size);
}
#else
constant_blob_ = RAII_cpuMalloc(blob_size);
#endif

View File

@ -2,6 +2,7 @@
#include <algorithm>
#include <condition_variable>
#include <cstdlib>
#include <deque>
#include <mutex>
#include <shared_mutex>
@ -682,7 +683,19 @@ class AOTInductorModelContainer {
RAIIDataPtr allocate_constant_blob() {
#if defined(USE_CUDA) || defined(USE_XPU) || defined(USE_MPS)
return RAII_gpuMalloc(blob_size_);
// Check if we should use PyTorch's CUDACachingAllocator for weight
// management
const char* use_caching_allocator =
std::getenv("AOT_INDUCTOR_WEIGHT_USE_CACHING_ALLOCATOR");
if (use_caching_allocator && std::string(use_caching_allocator) == "1") {
#ifdef USE_CUDA
return RAII_gpuMallocCaching(blob_size_);
#else
return RAII_gpuMalloc(blob_size_);
#endif
} else {
return RAII_gpuMalloc(blob_size_);
}
#else
return RAII_cpuMalloc(blob_size_);
#endif // USE_CUDA

View File

@ -514,6 +514,12 @@ aoti_torch_delete_cuda_stream_guard(CUDAStreamGuardHandle guard);
AOTI_TORCH_EXPORT AOTITorchError
aoti_torch_get_current_cuda_stream(int32_t device_index, void** ret_stream);
AOTI_TORCH_EXPORT AOTITorchError
aoti_torch_cuda_caching_allocator_raw_alloc(void** ptr, size_t size);
AOTI_TORCH_EXPORT AOTITorchError
aoti_torch_cuda_caching_allocator_raw_delete(void* ptr);
#endif // USE_CUDA
// See `ProxyExecutor Design Note` in ir.py for more details

View File

@ -2,6 +2,7 @@
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
#include <torch/csrc/inductor/aoti_torch/utils.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h>
@ -53,3 +54,15 @@ AOTITorchError aoti_torch_get_current_cuda_stream(
*(cudaStream_t*)(ret_stream) = at::cuda::getCurrentCUDAStream(device_index);
});
}
AOTITorchError aoti_torch_cuda_caching_allocator_raw_alloc(
void** ptr,
size_t size) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE(
{ *ptr = c10::cuda::CUDACachingAllocator::raw_alloc(size); });
}
AOTITorchError aoti_torch_cuda_caching_allocator_raw_delete(void* ptr) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE(
{ c10::cuda::CUDACachingAllocator::raw_delete(ptr); });
}