diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index c12f2fd594..24a51288cb 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -244,7 +244,9 @@ def test_schedule_partial_requests(): model_runner_output = ModelRunnerOutput( req_ids=[request.request_id for request in requests], req_id_to_index=req_to_index, - sampled_token_ids=[[0] for _ in range(len(requests))], + # Only the first request has a sampled token id because + # the rest requests are still being prefilled. + sampled_token_ids=[[0], [], []], spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, @@ -266,7 +268,7 @@ def test_schedule_partial_requests(): @pytest.mark.parametrize("enable_prefix_caching", [True, False]) -def test_schedule_concurrent_partial_requestse(enable_prefix_caching: bool): +def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool): """Test scheduling behavior with concurrent partial requests. This test verifies that: there are multiple long prefill requests in the @@ -304,7 +306,7 @@ def test_schedule_concurrent_partial_requestse(enable_prefix_caching: bool): model_runner_output = ModelRunnerOutput( req_ids=[request.request_id for request in requests], req_id_to_index=req_to_index, - sampled_token_ids=[[0] for _ in range(len(requests))], + sampled_token_ids=[[] for _ in range(len(requests))], spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, @@ -325,6 +327,14 @@ def test_schedule_concurrent_partial_requestse(enable_prefix_caching: bool): # Schedule the third step. All three requests are running. # First and second requests are in the decode stage. # All the remaining tokens in the third request are processed. + model_runner_output = ModelRunnerOutput( + req_ids=[request.request_id for request in requests], + req_id_to_index=req_to_index, + sampled_token_ids=[[0], [0]] + [[] for _ in range(len(requests) - 2)], + spec_token_ids=None, + logprobs=None, + prompt_logprobs_dict={}, + ) scheduler.update_from_output(output1, model_runner_output) output2 = scheduler.schedule() assert len(scheduler.running) == 3 diff --git a/tests/v1/engine/test_engine_core.py b/tests/v1/engine/test_engine_core.py index ca5ff8fa84..3f3109c148 100644 --- a/tests/v1/engine/test_engine_core.py +++ b/tests/v1/engine/test_engine_core.py @@ -231,8 +231,10 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch): Test that the engine can handle multiple concurrent batches. """ - def make_request_with_max_tokens(max_tokens: int) -> EngineCoreRequest: + def make_request_with_max_tokens(req_id: int, + max_tokens: int) -> EngineCoreRequest: request = make_request() + request.request_id = req_id request.sampling_params.max_tokens = max_tokens return request @@ -279,6 +281,8 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch): # Avoid all requests being scheduled once. enable_prefix_caching=False, max_num_batched_tokens=10, + # Reduce startup time. + enforce_eager=True, ) vllm_config = engine_args.create_engine_config() engine_core = EngineCore(vllm_config=vllm_config, @@ -286,13 +290,13 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch): executor_class=DummyExecutor) assert engine_core.batch_queue is not None - # Add two requests in a row. - req = make_request_with_max_tokens(5) - engine_core.add_request(req) - req = make_request_with_max_tokens(5) - engine_core.add_request(req) + # Add two requests in a row. Each request have 12 prompt tokens. + req0 = make_request_with_max_tokens(0, 5) + engine_core.add_request(req0) + req1 = make_request_with_max_tokens(1, 5) + engine_core.add_request(req1) - # First saturate the batch queue. + # Schedule Batch 1: (10, req0) assert engine_core.step_with_batch_queue() is None assert engine_core.batch_queue.qsize() == 1 assert engine_core.step_with_batch_queue() is None diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 850687423d..ba7c691306 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -153,9 +153,9 @@ class Scheduler(SchedulerInterface): num_new_tokens = (request.num_tokens_with_spec - request.num_computed_tokens) - if self.scheduler_config.long_prefill_token_threshold > 0: - num_new_tokens = min( - num_new_tokens, + if (0 < self.scheduler_config.long_prefill_token_threshold < + num_new_tokens): + num_new_tokens = ( self.scheduler_config.long_prefill_token_threshold) num_new_tokens = min(num_new_tokens, token_budget) assert num_new_tokens > 0 @@ -303,9 +303,9 @@ class Scheduler(SchedulerInterface): num_computed_tokens -= self.block_size num_new_tokens = self.block_size computed_blocks.pop() - if self.scheduler_config.long_prefill_token_threshold > 0: - num_new_tokens = min( - num_new_tokens, + if (0 < self.scheduler_config.long_prefill_token_threshold < + num_new_tokens): + num_new_tokens = ( self.scheduler_config.long_prefill_token_threshold) num_new_tokens = min(num_new_tokens, token_budget) assert num_new_tokens > 0 @@ -433,6 +433,18 @@ class Scheduler(SchedulerInterface): grammar_bitmask=grammar_bitmask, ) + # Advance the number of computed tokens for the request AFTER + # the request is scheduled. + # 1. The scheduler_output of the current step has to include the + # original number of scheduled tokens to determine input IDs. + # 2. Advance the number of computed tokens here allowing us to + # schedule the prefill request again immediately in the next + # scheduling step. + # 3. If some tokens (e.g. spec tokens) are rejected later, the number of + # computed tokens will be adjusted in update_from_output. + for req_id, num_scheduled_token in num_scheduled_tokens.items(): + self.requests[req_id].num_computed_tokens += num_scheduled_token + self.finished_req_ids = set() return scheduler_output @@ -561,28 +573,19 @@ class Scheduler(SchedulerInterface): req_index = model_runner_output.req_id_to_index[req_id] generated_token_ids = sampled_token_ids[req_index] - if req_id not in scheduler_output.scheduled_spec_decode_tokens: - # When the request's num_computed_tokens catches up - # its num_tokens, the request generates output tokens. - # Otherwise, we ignore the sampler output for the request. - request.num_computed_tokens += num_tokens_scheduled - assert request.num_computed_tokens <= request.num_tokens - else: - # num_computed_tokens_step represents the number of tokens - # processed in the current step, considering scheduled - # tokens and rejections. - # It is calculated as: - # num_computed_tokens_step = num_scheduled_tokens - - # num_tokens_rejected, - # where num_tokens_rejected is given by: - # len(scheduled_spec_token_ids) + 1 - len(generated_token_ids). - scheduled_spec_token_ids = ( - scheduler_output.scheduled_spec_decode_tokens[req_id]) - num_computed_tokens_step = num_scheduled_tokens[req_id] - ( - len(scheduled_spec_token_ids) + 1 - - len(generated_token_ids)) - request.num_computed_tokens += num_computed_tokens_step + scheduled_spec_token_ids = ( + scheduler_output.scheduled_spec_decode_tokens.get(req_id)) + if scheduled_spec_token_ids: + # num_computed_tokens represents the number of tokens + # processed in the current step, considering scheduled + # tokens and rejections. If some tokens are rejected, + # num_computed_tokens is decreased by the number of rejected + # tokens, where is given by: + # len(scheduled_spec_token_ids) + 1 - len(generated_token_ids). + num_tokens_rejected = (len(scheduled_spec_token_ids) + 1 - + len(generated_token_ids)) + request.num_computed_tokens -= num_tokens_rejected cached_encoder_input_ids = ( self.encoder_cache_manager.get_cached_input_ids(request)) @@ -605,24 +608,26 @@ class Scheduler(SchedulerInterface): new_logprobs = None new_token_ids: list[int] = [] - if request.num_computed_tokens >= request.num_tokens: - for output_token_id in generated_token_ids: - request.append_output_token_ids(output_token_id) - new_token_ids.append(output_token_id) + # Append generated tokens and check for stop. Note that if + # a request is still being prefilled, we expect the model runner + # to return empty token ids for the request. + for output_token_id in generated_token_ids: + request.append_output_token_ids(output_token_id) + new_token_ids.append(output_token_id) - # Check for stop and update request state. - # This must be called before we make the EngineCoreOutput. - stopped = check_stop(request, self.max_model_len) - if stopped: - self._free_request(request) - break + # Check for stop and update request state. + # This must be called before we make the EngineCoreOutput. + stopped = check_stop(request, self.max_model_len) + if stopped: + self._free_request(request) + break - # Extract sample logprobs if needed. - if request.sampling_params.logprobs is not None: - assert logprobs is not None - # NOTE: once we support N tokens per step (spec decode), - # the outer lists can be of length > 1. - new_logprobs = logprobs.slice(req_index, req_index + 1) + # Extract sample logprobs if needed. + if (request.sampling_params.logprobs is not None + and logprobs is not None): + # NOTE: once we support N tokens per step (spec decode), + # the outer lists can be of length > 1. + new_logprobs = logprobs.slice(req_index, req_index + 1) if new_token_ids and request.use_structured_output: # NOTE: structured_output_request diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index 69bc68174d..e5b8872a2a 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -107,14 +107,33 @@ class RejectionSampler(nn.Module): @staticmethod def parse_output( output_token_ids: torch.Tensor, + ignored_req_idxs: list[int], vocab_size: int, ) -> list[list[int]]: + """Parse the output of the rejection sampler. + + Args: + output_token_ids: The sampled token IDs in shape + [batch_size, max_spec_len + 1]. The rejected tokens are + replaced with `PLACEHOLDER_TOKEN_ID` by the rejection sampler + and will be filtered out in this function. + ignored_req_idxs: The indices of the requests that should not be + sampled. This is usually because the request is still in the + prefill phase. + vocab_size: The size of the vocabulary. + + Returns: + A list of lists of token IDs. + """ output_token_ids_np = output_token_ids.cpu().numpy() # Create mask for valid tokens. valid_mask = ((output_token_ids_np != PLACEHOLDER_TOKEN_ID) & (output_token_ids_np < vocab_size)) + + ignored_req_idx_set = set(ignored_req_idxs) outputs = [ row[valid_mask[i]].tolist() + if i not in ignored_req_idx_set else [] for i, row in enumerate(output_token_ids_np) ] return outputs diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index a85009f1a3..bcf7762b44 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1085,8 +1085,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): # TODO(woosuk): The following loop can be slow since it iterates over # the requests one by one. Optimize. - for i, generator in self.input_batch.generators.items(): - req_id = self.input_batch.req_ids[i] + discard_sampled_tokens_req_indices = [] + for i, req_id in enumerate(self.input_batch.req_ids): req_state = self.requests[req_id] seq_len = (req_state.num_computed_tokens + scheduler_output.num_scheduled_tokens[req_id]) @@ -1094,7 +1094,12 @@ class GPUModelRunner(LoRAModelRunnerMixin): # Ignore the sampled token for partial prefills. # Rewind the generator state as if the token was not sampled. # This relies on cuda-specific torch-internal impl details - generator.set_offset(generator.get_offset() - 4) + generator = self.input_batch.generators.get(i) + if generator is not None: + generator.set_offset(generator.get_offset() - 4) + # Record the index of the request that should not be sampled, + # so that we could clear the sampled tokens before returning. + discard_sampled_tokens_req_indices.append(i) # NOTE: GPU -> CPU Sync happens here. # Move as many CPU operations as possible before this sync point. @@ -1114,10 +1119,16 @@ class GPUModelRunner(LoRAModelRunnerMixin): if max_gen_len == 1: # No spec decode tokens. valid_sampled_token_ids = sampled_token_ids.tolist() + # Mask out the sampled tokens that should not be sampled. + for i in discard_sampled_tokens_req_indices: + valid_sampled_token_ids[i].clear() else: # Includes spec decode tokens. valid_sampled_token_ids = self.rejection_sampler.parse_output( - sampled_token_ids, self.input_batch.vocab_size) + sampled_token_ids, + discard_sampled_tokens_req_indices, + self.input_batch.vocab_size, + ) if not self.use_spec_decode: spec_token_ids = None