mirror of
				https://github.com/vllm-project/vllm.git
				synced 2025-10-20 23:03:52 +08:00 
			
		
		
		
	[Bugfix][SpecDecode] apply sampling parameters to target probabilities for consistency in rejection sampling. (#10198)
Signed-off-by: jeongin601 <0200angela@gmail.com> Signed-off-by: jeong_in.bae <jeong_in.bae@navercorp.com>
This commit is contained in:
		| @ -203,7 +203,7 @@ def test_mlp_e2e_acceptance_rate(vllm_runner, common_llm_kwargs, | |||||||
| @pytest.mark.parametrize("test_llm_kwargs", [{"seed": 5}]) | @pytest.mark.parametrize("test_llm_kwargs", [{"seed": 5}]) | ||||||
| @pytest.mark.parametrize("output_len", [64]) | @pytest.mark.parametrize("output_len", [64]) | ||||||
| @pytest.mark.parametrize("batch_size", [1, 32]) | @pytest.mark.parametrize("batch_size", [1, 32]) | ||||||
| @pytest.mark.parametrize("temperature", [0.1, 1.0]) | @pytest.mark.parametrize("temperature", [1.0]) | ||||||
| @pytest.mark.parametrize("seed", [1]) | @pytest.mark.parametrize("seed", [1]) | ||||||
| def test_mlp_e2e_seeded_correctness(vllm_runner, common_llm_kwargs, | def test_mlp_e2e_seeded_correctness(vllm_runner, common_llm_kwargs, | ||||||
|                                     per_test_common_llm_kwargs, |                                     per_test_common_llm_kwargs, | ||||||
|  | |||||||
| @ -90,6 +90,14 @@ def test_create_single_target_seq_group_metadata(k: int): | |||||||
|     ) |     ) | ||||||
|  |  | ||||||
|     assert output.request_id == input_seq_group_metadata.request_id |     assert output.request_id == input_seq_group_metadata.request_id | ||||||
|  |     assert output.sampling_params.repetition_penalty == \ | ||||||
|  |         input_seq_group_metadata.sampling_params.repetition_penalty | ||||||
|  |     assert output.sampling_params.temperature == \ | ||||||
|  |         input_seq_group_metadata.sampling_params.temperature | ||||||
|  |     assert output.sampling_params.top_p == \ | ||||||
|  |         input_seq_group_metadata.sampling_params.top_p | ||||||
|  |     assert output.sampling_params.top_k == \ | ||||||
|  |         input_seq_group_metadata.sampling_params.top_k | ||||||
|     assert len(output.seq_data) == 1 |     assert len(output.seq_data) == 1 | ||||||
|     assert output.seq_data[target_seq_id].get_prompt_token_ids() == tuple( |     assert output.seq_data[target_seq_id].get_prompt_token_ids() == tuple( | ||||||
|         prompt_tokens) |         prompt_tokens) | ||||||
|  | |||||||
| @ -307,28 +307,16 @@ class BatchExpansionTop1Scorer(SpeculativeScorer): | |||||||
|         token_ids_to_score = self._get_token_ids_to_score( |         token_ids_to_score = self._get_token_ids_to_score( | ||||||
|             proposal_token_ids[batch_index]) |             proposal_token_ids[batch_index]) | ||||||
|  |  | ||||||
|         # Use simpler sampling parameters apart from for final token |  | ||||||
|         # (in particular don't do seeded sampling) since those sampled tokens |  | ||||||
|         # aren't used. |  | ||||||
|         # We don't replace the sampling_params in the greedy case because |  | ||||||
|         # this also controls whether the probs get modified in the sampler |  | ||||||
|         # (see use of _modify_greedy_probs_inplace there). |  | ||||||
|         sampling_params = input_seq_group_metadata.sampling_params |         sampling_params = input_seq_group_metadata.sampling_params | ||||||
|         non_bonus_sampling_params = DEFAULT_SIMPLE_SAMPLING_PARAMS \ |  | ||||||
|             if sampling_params.temperature else sampling_params |  | ||||||
|  |  | ||||||
|         target_seq_group_metadata_list: List[SequenceGroupMetadata] = [] |         target_seq_group_metadata_list: List[SequenceGroupMetadata] = [] | ||||||
|         last_index = len(token_ids_to_score) - 1 |  | ||||||
|         for i, token_ids in enumerate(token_ids_to_score): |         for i, token_ids in enumerate(token_ids_to_score): | ||||||
|             target_sampling_params = sampling_params if i == last_index \ |  | ||||||
|                 else non_bonus_sampling_params |  | ||||||
|             target_seq_group_metadata_list.append( |             target_seq_group_metadata_list.append( | ||||||
|                 self._create_single_target_seq_group_metadata( |                 self._create_single_target_seq_group_metadata( | ||||||
|                     input_seq_group_metadata, |                     input_seq_group_metadata, | ||||||
|                     input_seq_id, |                     input_seq_id, | ||||||
|                     next(target_seq_ids_iter), |                     next(target_seq_ids_iter), | ||||||
|                     token_ids, |                     token_ids, | ||||||
|                     sampling_params=target_sampling_params, |                     sampling_params=sampling_params, | ||||||
|                 )) |                 )) | ||||||
|  |  | ||||||
|         return target_seq_group_metadata_list |         return target_seq_group_metadata_list | ||||||
|  | |||||||
		Reference in New Issue
	
	Block a user