mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Adds Issue#153109 as a test for CUDAPluggableAllocator (#163575)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163575 Approved by: https://github.com/ngimel
This commit is contained in:
committed by
PyTorch MergeBot
parent
9fd53a2bdc
commit
4dab208d97
@ -1,10 +1,83 @@
|
|||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
#include <ATen/ATen.h>
|
#include <ATen/ATen.h>
|
||||||
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
#include <c10/cuda/CUDACachingAllocator.h>
|
#include <c10/cuda/CUDACachingAllocator.h>
|
||||||
|
|
||||||
#include <ATen/test/allocator_clone_test.h>
|
#include <ATen/test/allocator_clone_test.h>
|
||||||
|
|
||||||
|
#include <torch/csrc/cuda/CUDAPluggableAllocator.h>
|
||||||
|
|
||||||
|
std::unordered_map<void*, size_t> allocation_sizes;
|
||||||
|
|
||||||
|
void* logging_malloc(size_t size, int device, cudaStream_t stream) {
|
||||||
|
void* ptr;
|
||||||
|
cudaMalloc(&ptr, size);
|
||||||
|
allocation_sizes[ptr] = size;
|
||||||
|
return ptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
void logging_free(void* ptr, size_t size, int device, cudaStream_t stream) {
|
||||||
|
if (allocation_sizes.find(ptr) != allocation_sizes.end()) {
|
||||||
|
if (allocation_sizes[ptr] != size) {
|
||||||
|
throw std::runtime_error("free mismatch");
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
throw std::runtime_error("free of unknown ptr");
|
||||||
|
}
|
||||||
|
cudaFree(ptr);
|
||||||
|
allocation_sizes.erase(ptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(TestTorchUnique, UniqueComparisonTest) {
|
||||||
|
if (!at::cuda::is_available()) return;
|
||||||
|
auto custom_allocator =
|
||||||
|
torch::cuda::CUDAPluggableAllocator::createCustomAllocator(logging_malloc, logging_free);
|
||||||
|
torch::cuda::CUDAPluggableAllocator::changeCurrentAllocator(custom_allocator);
|
||||||
|
// Run the command 3 times; the first 2 will pass and the third invocation will have
|
||||||
|
// different sizes in alloc and free if the test fails.
|
||||||
|
for (int i = 0; i < 3; ++i) {
|
||||||
|
// Initialize simple sorted tensor with repeats
|
||||||
|
at::Tensor sorted_tensor =
|
||||||
|
at::tensor({0, 0, 0, 1, 1, 2, 3, 3, 3, 3, 5},
|
||||||
|
at::TensorOptions().dtype(at::kFloat).device(at::kCUDA));
|
||||||
|
|
||||||
|
// This operation will call malloc/free with different sizes on the same pointer
|
||||||
|
auto unique_dim_result = at::unique_consecutive(sorted_tensor, false, true, 0);
|
||||||
|
|
||||||
|
// Everything below is only there to validate correct results
|
||||||
|
auto unique_dim_values = std::get<0>(unique_dim_result);
|
||||||
|
auto unique_dim_counts = std::get<2>(unique_dim_result);
|
||||||
|
|
||||||
|
// Check tensor sizes
|
||||||
|
EXPECT_EQ(unique_dim_values.size(0), 5);
|
||||||
|
EXPECT_EQ(unique_dim_counts.size(0), 5);
|
||||||
|
|
||||||
|
// Copy to CPU before accessing elements
|
||||||
|
at::Tensor cpu_values = unique_dim_values.cpu();
|
||||||
|
at::Tensor cpu_counts = unique_dim_counts.cpu();
|
||||||
|
|
||||||
|
// Use accessors on the CPU tensors
|
||||||
|
auto values_accessor = cpu_values.accessor<float, 1>();
|
||||||
|
auto counts_accessor = cpu_counts.accessor<int64_t, 1>();
|
||||||
|
|
||||||
|
// Check individual values using accessors
|
||||||
|
EXPECT_EQ(values_accessor[0], 0.0f);
|
||||||
|
EXPECT_EQ(values_accessor[1], 1.0f);
|
||||||
|
EXPECT_EQ(values_accessor[2], 2.0f);
|
||||||
|
EXPECT_EQ(values_accessor[3], 3.0f);
|
||||||
|
EXPECT_EQ(values_accessor[4], 5.0f);
|
||||||
|
|
||||||
|
// Check count values using accessors
|
||||||
|
EXPECT_EQ(counts_accessor[0], 3);
|
||||||
|
EXPECT_EQ(counts_accessor[1], 2);
|
||||||
|
EXPECT_EQ(counts_accessor[2], 1);
|
||||||
|
EXPECT_EQ(counts_accessor[3], 4);
|
||||||
|
EXPECT_EQ(counts_accessor[4], 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
TEST(AllocatorTestCUDA, test_clone) {
|
TEST(AllocatorTestCUDA, test_clone) {
|
||||||
|
if (!at::cuda::is_available()) return;
|
||||||
test_allocator_clone(c10::cuda::CUDACachingAllocator::get());
|
test_allocator_clone(c10::cuda::CUDACachingAllocator::get());
|
||||||
}
|
}
|
||||||
|
@ -50,6 +50,7 @@ run_if_exists cuda_complex_test
|
|||||||
run_if_exists cuda_complex_math_test
|
run_if_exists cuda_complex_math_test
|
||||||
run_if_exists cuda_cub_test
|
run_if_exists cuda_cub_test
|
||||||
run_if_exists cuda_atomic_ops_test
|
run_if_exists cuda_atomic_ops_test
|
||||||
|
run_if_exists cuda_allocator_test
|
||||||
|
|
||||||
if [ "$VALGRIND" == "ON" ]; then
|
if [ "$VALGRIND" == "ON" ]; then
|
||||||
# NB: As these tests are invoked by valgrind, let's leave them for now as it's
|
# NB: As these tests are invoked by valgrind, let's leave them for now as it's
|
||||||
|
Reference in New Issue
Block a user