Files
vllm-dev/tests/test_sequence.py
2025-08-18 02:58:11 -07:00

137 lines
5.0 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors,
SequenceData, SequenceOutput)
from .core.utils import create_dummy_prompt
@pytest.fixture
def sample_outputs():
return [
CompletionSequenceGroupOutput(samples=[
SequenceOutput(parent_seq_id=0, output_token=i, logprobs={})
],
prompt_logprobs=None) for i in range(5)
]
@pytest.fixture
def sampler_output(sample_outputs):
return SamplerOutput(outputs=sample_outputs)
def test_sampler_output_initialization(sampler_output, sample_outputs):
assert len(sampler_output) == len(sample_outputs)
assert sampler_output.sampled_token_probs is None
assert sampler_output.sampled_token_ids is None
def test_sampler_output_getitem(sampler_output, sample_outputs):
assert sampler_output[2] == sample_outputs[2]
def test_sampler_output_setitem(sampler_output):
new_output = CompletionSequenceGroupOutput(samples=[
SequenceOutput(parent_seq_id=0, output_token=99, logprobs={})
],
prompt_logprobs=None)
sampler_output[2] = new_output
assert sampler_output[2] == new_output
def test_sampler_output_len(sampler_output, sample_outputs):
assert len(sampler_output) == len(sample_outputs)
def test_sampler_output_eq(sample_outputs):
sampler_output1 = SamplerOutput(outputs=sample_outputs)
sampler_output2 = SamplerOutput(outputs=sample_outputs.copy())
sampler_output3 = SamplerOutput(outputs=sample_outputs[:-1])
assert sampler_output1 == sampler_output2
assert sampler_output1 != sampler_output3
def test_sequence_data_prefill():
seq_data = SequenceData.from_seqs([1, 2, 3, 4])
assert seq_data.get_num_uncomputed_tokens() == 4
assert seq_data.get_num_computed_tokens() == 0
# advance by 2
seq_data.update_num_computed_tokens(2)
assert seq_data.get_num_uncomputed_tokens() == 2
assert seq_data.get_num_computed_tokens() == 2
# advance by 1
seq_data.update_num_computed_tokens(1)
assert seq_data.get_num_uncomputed_tokens() == 1
assert seq_data.get_num_computed_tokens() == 3
# append tokens and reset, simulating recompute
seq_data.append_token_id(1, logprob=0.0)
seq_data.reset_state_for_recompute()
assert seq_data.get_num_uncomputed_tokens() == 5
assert seq_data.get_num_computed_tokens() == 0
def test_sequence_group_stage():
_, seq_group = create_dummy_prompt("1", 12)
assert seq_group.is_prefill() is True
seq_group.update_num_computed_tokens(6)
assert seq_group.is_prefill() is True
seq_group.update_num_computed_tokens(5)
assert seq_group.is_prefill() is True
seq_group.update_num_computed_tokens(1)
assert seq_group.is_prefill() is False
seqs = seq_group.get_seqs()
assert len(seqs) == 1
seqs[0].data.append_token_id(1, logprob=0.0)
for seq in seq_group.get_seqs():
seq.reset_state_for_recompute()
assert seq_group.is_prefill() is True
seq_group.update_num_computed_tokens(5)
assert seq_group.is_prefill() is True
seq_group.update_num_computed_tokens(7)
assert seq_group.is_prefill() is True
seq_group.update_num_computed_tokens(1)
assert seq_group.is_prefill() is False
def test_sequence_intermediate_tensors_equal():
class AnotherIntermediateTensors(IntermediateTensors):
pass
intermediate_tensors = IntermediateTensors({})
another_intermediate_tensors = AnotherIntermediateTensors({})
assert intermediate_tensors != another_intermediate_tensors
empty_intermediate_tensors_1 = IntermediateTensors({})
empty_intermediate_tensors_2 = IntermediateTensors({})
assert empty_intermediate_tensors_1 == empty_intermediate_tensors_2
different_key_intermediate_tensors_1 = IntermediateTensors(
{"1": torch.zeros([2, 4], dtype=torch.int32)})
difference_key_intermediate_tensors_2 = IntermediateTensors(
{"2": torch.zeros([2, 4], dtype=torch.int32)})
assert (different_key_intermediate_tensors_1
!= difference_key_intermediate_tensors_2)
same_key_different_value_intermediate_tensors_1 = IntermediateTensors(
{"1": torch.zeros([2, 4], dtype=torch.int32)})
same_key_different_value_intermediate_tensors_2 = IntermediateTensors(
{"1": torch.zeros([2, 5], dtype=torch.int32)})
assert (same_key_different_value_intermediate_tensors_1
!= same_key_different_value_intermediate_tensors_2)
same_key_same_value_intermediate_tensors_1 = IntermediateTensors(
{"1": torch.zeros([2, 4], dtype=torch.int32)})
same_key_same_value_intermediate_tensors_2 = IntermediateTensors(
{"1": torch.zeros([2, 4], dtype=torch.int32)})
assert (same_key_same_value_intermediate_tensors_1 ==
same_key_same_value_intermediate_tensors_2)