Add kernel test

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon
2024-12-26 11:23:52 -08:00
parent ca4f9e69a8
commit 27e8eb2e94

View File

@ -0,0 +1,47 @@
import random
import pytest
import torch
from vllm import _custom_ops as ops
from vllm.platforms import current_platform
SEEDS = [0]
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_copy_subranges(seed, device):
torch.set_default_device(device)
current_platform.seed_everything(seed)
num_rows = 1024
num_cols = 1024
src_matrix = torch.zeros(num_rows,
num_cols,
device=device,
dtype=torch.int32)
dst_matrix = torch.zeros(num_rows,
num_cols,
device=device,
dtype=torch.int32)
diff_matrix = torch.zeros(num_rows, 2, device=device, dtype=torch.int32)
for i in range(num_rows):
start_idx = random.randint(0, num_cols - 1)
end_idx = random.randint(start_idx, num_cols - 1)
num_diffs = end_idx - start_idx
src_matrix[i, start_idx:end_idx] = torch.randint(0,
100, (num_diffs, ),
device=device,
dtype=torch.int32)
diff_matrix[i, 0] = start_idx
diff_matrix[i, 1] = num_diffs
ops.copy_subranges(src_matrix, diff_matrix, dst_matrix, num_rows)
assert torch.allclose(src_matrix, dst_matrix, rtol=0, atol=0)