mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
47
tests/kernels/test_copy_subranges.py
Normal file
47
tests/kernels/test_copy_subranges.py
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
import random
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
SEEDS = [0]
|
||||||
|
CUDA_DEVICES = [
|
||||||
|
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("seed", SEEDS)
|
||||||
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||||
|
def test_copy_subranges(seed, device):
|
||||||
|
torch.set_default_device(device)
|
||||||
|
current_platform.seed_everything(seed)
|
||||||
|
|
||||||
|
num_rows = 1024
|
||||||
|
num_cols = 1024
|
||||||
|
src_matrix = torch.zeros(num_rows,
|
||||||
|
num_cols,
|
||||||
|
device=device,
|
||||||
|
dtype=torch.int32)
|
||||||
|
dst_matrix = torch.zeros(num_rows,
|
||||||
|
num_cols,
|
||||||
|
device=device,
|
||||||
|
dtype=torch.int32)
|
||||||
|
diff_matrix = torch.zeros(num_rows, 2, device=device, dtype=torch.int32)
|
||||||
|
|
||||||
|
for i in range(num_rows):
|
||||||
|
start_idx = random.randint(0, num_cols - 1)
|
||||||
|
end_idx = random.randint(start_idx, num_cols - 1)
|
||||||
|
num_diffs = end_idx - start_idx
|
||||||
|
|
||||||
|
src_matrix[i, start_idx:end_idx] = torch.randint(0,
|
||||||
|
100, (num_diffs, ),
|
||||||
|
device=device,
|
||||||
|
dtype=torch.int32)
|
||||||
|
|
||||||
|
diff_matrix[i, 0] = start_idx
|
||||||
|
diff_matrix[i, 1] = num_diffs
|
||||||
|
|
||||||
|
ops.copy_subranges(src_matrix, diff_matrix, dst_matrix, num_rows)
|
||||||
|
assert torch.allclose(src_matrix, dst_matrix, rtol=0, atol=0)
|
||||||
Reference in New Issue
Block a user