mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
compile_kernel: Add DLPack test (#163166)
Note to self: i should probably. start using gh stack
This is rebased on top of https://github.com/pytorch/pytorch/pull/163165 so you only need to review this commit 7387c1becf
This test doesn't add any new functionality it just ensures DLPack conversion is working well
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163166
Approved by: https://github.com/janeyx99, https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
0661ecdb38
commit
28c42cc280
@ -7265,6 +7265,42 @@ class TestCompileKernel(TestCase):
|
||||
expected = a + b
|
||||
self.assertEqual(c, expected)
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "No CUDA")
|
||||
def test_compile_kernel_dlpack(self):
|
||||
"""Test that compile_kernel works with tensors created via DLPack."""
|
||||
kernel_source = """
|
||||
__global__ void add_tensors(const float* a, const float* b, float* c, int n) {
|
||||
int i = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
if (i < n)
|
||||
c[i] = a[i] + b[i];
|
||||
}
|
||||
"""
|
||||
|
||||
from torch.cuda import _compile_kernel
|
||||
|
||||
add_kernel = _compile_kernel(kernel_source, "add_tensors")
|
||||
|
||||
N = 512
|
||||
a = torch.rand(N, device="cuda", dtype=torch.float32)
|
||||
b = torch.rand(N, device="cuda", dtype=torch.float32)
|
||||
|
||||
a_dlpack = torch.utils.dlpack.from_dlpack(torch.utils.dlpack.to_dlpack(a))
|
||||
b_dlpack = torch.utils.dlpack.from_dlpack(torch.utils.dlpack.to_dlpack(b))
|
||||
c = torch.empty_like(a)
|
||||
|
||||
threads_per_block = 256
|
||||
blocks_per_grid = (N + threads_per_block - 1) // threads_per_block
|
||||
|
||||
add_kernel(
|
||||
grid=(blocks_per_grid, 1, 1),
|
||||
block=(threads_per_block, 1, 1),
|
||||
args=[a_dlpack, b_dlpack, c, N],
|
||||
)
|
||||
|
||||
self.assertEqual(c, a + b)
|
||||
a_dlpack[0] = 42.0
|
||||
self.assertEqual(a[0].item(), 42.0, "DLPack tensors should share memory")
|
||||
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "CUDA not available, skipping tests")
|
||||
class TestCudaDeviceParametrized(TestCase):
|
||||
|
Reference in New Issue
Block a user