Compare commits

...

6 Commits

Author SHA1 Message Date
a8c367127b Revert all changes to torch/cuda/_utils.py 2025-09-17 13:16:34 -07:00
fa839e440c testy test 2025-09-17 13:15:06 -07:00
4ae58a3dd4 simplify nvrtc discovery logic 2025-09-17 13:15:06 -07:00
8e8ec24374 Update _utils.py 2025-09-17 13:15:05 -07:00
4b74106204 lint 2025-09-17 13:15:05 -07:00
693880081c cub and compile_kernel 2025-09-17 13:15:04 -07:00

View File

@ -7030,6 +7030,51 @@ class TestCompileKernel(TestCase):
# Verify results
self.assertEqual(C_explicit, expected)
@unittest.skipIf(TEST_WITH_ROCM, "ROCM does not support nvrtc")
@unittest.skipIf(not TEST_CUDA, "No CUDA")
def test_compile_kernel_reduction(self):
# TODO: Not sure if I should be using ATen/cuda/cub.cuh or cub.cuh
# BlockReduce is not exposed in ATen/cuda/cub.cuh
kernel_source = """
#include <cub/block/block_reduce.cuh>
extern "C" __global__ void reduction_kernel(const float* input, float* output, int n) {
typedef cub::BlockReduce<float, 256> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
int tid = threadIdx.x + blockIdx.x * blockDim.x;
float thread_data = (tid < n) ? input[tid] : 0.0f;
float aggregate = BlockReduce(temp_storage).Sum(thread_data);
if (threadIdx.x == 0) {
atomicAdd(output, aggregate);
}
}
"""
from torch.cuda import _compile_kernel
reduction_kernel = _compile_kernel(kernel_source, "reduction_kernel")
N = 1024
input_data = torch.ones(N, device="cuda")
output_data = torch.zeros(1, device="cuda")
threads_per_block = 256
blocks_per_grid = (N + threads_per_block - 1) // threads_per_block
reduction_kernel(
grid=(blocks_per_grid, 1, 1),
block=(threads_per_block, 1, 1),
args=[input_data, output_data, N],
)
expected = float(N)
actual = output_data.item()
self.assertAlmostEqual(actual, expected, places=4)
@unittest.skipIf(not TEST_CUDA, "No CUDA")
def test_compile_kernel_as_custom_op(self):
# Define a simple vector addition kernel