Fix broken sampler tests (#1896)

Co-authored-by: Antoni Baum <antoni.baum@protonmail.com>
This commit is contained in:
Woosuk Kwon
2023-12-02 16:06:17 -08:00
committed by GitHub
parent 4cefa9b49b
commit 5f09cbdb63
2 changed files with 41 additions and 21 deletions

View File

@ -8,7 +8,7 @@ import torch
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.utils import set_random_seed
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
from vllm.worker.worker import Worker
from vllm.worker.model_runner import ModelRunner
class MockLogitsSampler(Sampler):
@ -27,7 +27,7 @@ class MockLogitsSampler(Sampler):
def _prepare_test(
batch_size: int
) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsSampler, Worker]:
) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsSampler, ModelRunner]:
vocab_size = 32000
input_tensor = torch.rand((batch_size, 1024),
device="cuda",
@ -37,9 +37,8 @@ def _prepare_test(
device=input_tensor.device,
dtype=input_tensor.dtype)
sampler = MockLogitsSampler(32000, fake_logits)
worker = Worker(None, None, None)
worker.block_size = 16
return input_tensor, fake_logits, sampler, worker
model_runner = ModelRunner(None, None, None)
return input_tensor, fake_logits, sampler, model_runner
RANDOM_SEEDS = list(range(128))
@ -49,9 +48,11 @@ RANDOM_SEEDS = list(range(128))
def test_sampler_all_greedy(seed: int):
set_random_seed(seed)
batch_size = random.randint(1, 256)
input_tensor, fake_logits, sampler, worker = _prepare_test(batch_size)
input_tensor, fake_logits, sampler, model_runner = _prepare_test(
batch_size)
seq_group_metadata_list = []
prompt_lens = []
for i in range(batch_size):
seq_group_metadata_list.append(
SequenceGroupMetadata(
@ -61,11 +62,13 @@ def test_sampler_all_greedy(seed: int):
sampling_params=SamplingParams(temperature=0, ),
block_tables={0: [1]},
))
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
_, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list)
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
prompt_lens)
sampler_output = sampler(embedding=None,
hidden_states=input_tensor,
input_metadata=input_metadata)
sampling_metadata=sampling_metadata)
expected = torch.argmax(fake_logits, dim=-1)
for i, sequence_output in enumerate(sampler_output):
for nth_output in sequence_output.samples:
@ -76,12 +79,14 @@ def test_sampler_all_greedy(seed: int):
def test_sampler_all_random(seed: int):
set_random_seed(seed)
batch_size = random.randint(1, 256)
input_tensor, fake_logits, sampler, worker = _prepare_test(batch_size)
input_tensor, fake_logits, sampler, model_runner = _prepare_test(
batch_size)
for i in range(batch_size):
fake_logits[i, i] = 1e2
seq_group_metadata_list = []
prompt_lens = []
for i in range(batch_size):
seq_group_metadata_list.append(
SequenceGroupMetadata(
@ -94,11 +99,13 @@ def test_sampler_all_random(seed: int):
),
block_tables={0: [1]},
))
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
_, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list)
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
prompt_lens)
sampler_output = sampler(embedding=None,
hidden_states=input_tensor,
input_metadata=input_metadata)
sampling_metadata=sampling_metadata)
for i, sequence_output in enumerate(sampler_output):
for nth_output in sequence_output.samples:
assert nth_output.output_token == i
@ -108,9 +115,10 @@ def test_sampler_all_random(seed: int):
def test_sampler_all_beam(seed: int):
set_random_seed(seed)
batch_size = random.randint(1, 256)
input_tensor, _, sampler, worker = _prepare_test(batch_size)
input_tensor, _, sampler, model_runner = _prepare_test(batch_size)
seq_group_metadata_list = []
prompt_lens = []
for i in range(batch_size):
seq_group_metadata_list.append(
SequenceGroupMetadata(
@ -124,11 +132,13 @@ def test_sampler_all_beam(seed: int):
),
block_tables={0: [1]},
))
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
_, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list)
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
prompt_lens)
sampler(embedding=None,
hidden_states=input_tensor,
input_metadata=input_metadata)
sampling_metadata=sampling_metadata)
# no assertion here as I am not sure how to determine whether
# the outputs are expected - in other words, this just tests
# whether there are no exceptions in the sampler
@ -139,10 +149,12 @@ def test_sampler_all_beam(seed: int):
def test_sampler_mixed(seed: int):
set_random_seed(seed)
batch_size = random.randint(1, 256)
input_tensor, fake_logits, sampler, worker = _prepare_test(batch_size)
input_tensor, fake_logits, sampler, model_runner = _prepare_test(
batch_size)
seq_group_metadata_list = []
expected_tokens = []
prompt_lens = []
for i in range(batch_size):
n = 1
sampling_type = random.randint(0, 2)
@ -172,11 +184,13 @@ def test_sampler_mixed(seed: int):
sampling_params=sampling_params,
block_tables={0: [1]},
))
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
_, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list)
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
prompt_lens)
sampler_output = sampler(embedding=None,
hidden_states=input_tensor,
input_metadata=input_metadata)
sampling_metadata=sampling_metadata)
for i, sequence_output in enumerate(sampler_output):
if seq_group_metadata_list[i].sampling_params.use_beam_search:
continue
@ -188,7 +202,7 @@ def test_sampler_mixed(seed: int):
def test_sampler_logits_processors(seed: int):
set_random_seed(seed)
batch_size = random.randint(1, 256)
input_tensor, _, sampler, worker = _prepare_test(batch_size)
input_tensor, _, sampler, model_runner = _prepare_test(batch_size)
# This sample logits processor gives infinite score to the i-th token,
# where i is the length of the input sequence.
@ -198,6 +212,7 @@ def test_sampler_logits_processors(seed: int):
return logits
seq_group_metadata_list = []
prompt_lens = []
for i in range(batch_size):
seq_group_metadata_list.append(
SequenceGroupMetadata(
@ -208,11 +223,13 @@ def test_sampler_logits_processors(seed: int):
logits_processors=[pick_ith]),
block_tables={0: [1]},
))
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
_, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list)
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
prompt_lens)
sampler_output = sampler(embedding=None,
hidden_states=input_tensor,
input_metadata=input_metadata)
sampling_metadata=sampling_metadata)
for _, sequence_output in enumerate(sampler_output):
for idx, nth_output in enumerate(sequence_output.samples):
assert nth_output.output_token == idx

View File

@ -25,7 +25,10 @@ class ModelRunner:
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.sliding_window = model_config.get_sliding_window()
# model_config can be None in tests/samplers/test_sampler.py.
# FIXME(woosuk): This is a hack to make the tests work. Refactor this.
self.sliding_window = (model_config.get_sliding_window()
if model_config is not None else None)
self.model = None
self.block_size = None # Set after initial profiling.