mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
Compare commits
3 Commits
v0.11.1rc1
...
woosuk-jf
Author | SHA1 | Date | |
---|---|---|---|
bcf3c8230d | |||
a01af39aa8 | |||
eeb5761cf1 |
@ -721,15 +721,29 @@ class Scheduler(SchedulerInterface):
|
||||
# the outer lists can be of length > 1.
|
||||
new_logprobs = logprobs.slice(req_index, req_index + 1)
|
||||
|
||||
jump_tokens = []
|
||||
if new_token_ids and request.use_structured_output:
|
||||
# NOTE: structured_output_request
|
||||
# should not be None if use_structured_output, we have
|
||||
# check above, so safe to ignore type warning
|
||||
request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr]
|
||||
assert request.structured_output_request is not None
|
||||
assert request.structured_output_request.grammar is not None
|
||||
request.structured_output_request.grammar.accept_tokens(
|
||||
req_id, new_token_ids)
|
||||
|
||||
if not stopped:
|
||||
jump_tokens = request.structured_output_request.grammar.jump_forward(
|
||||
req_id)
|
||||
for token_id in jump_tokens:
|
||||
request.append_output_token_ids(token_id)
|
||||
new_token_ids.append(token_id)
|
||||
stopped = check_stop(request, self.max_model_len)
|
||||
if stopped:
|
||||
break
|
||||
if jump_tokens:
|
||||
print(f"jump_tokens: {jump_tokens}")
|
||||
|
||||
# Add newly generated spec token ids to the request.
|
||||
if spec_token_ids is not None:
|
||||
if jump_tokens:
|
||||
request.spec_token_ids.clear()
|
||||
elif spec_token_ids is not None:
|
||||
if request.use_structured_output:
|
||||
metadata = request.structured_output_request
|
||||
assert metadata is not None and metadata.grammar is not None
|
||||
|
@ -170,6 +170,26 @@ class XgrammarGrammar(StructuredOutputGrammar):
|
||||
self.num_processed_tokens += 1
|
||||
return True
|
||||
|
||||
def jump_forward(
|
||||
self,
|
||||
request_id: str,
|
||||
) -> list[int]:
|
||||
bitmask = xgr.allocate_token_bitmask(1, self.vocab_size)
|
||||
jump_forward_tokens: list[int] = []
|
||||
while not self.is_terminated():
|
||||
self.fill_bitmask(bitmask, 0)
|
||||
is_single, unique_token_id = xgr.testing._is_single_token_bitmask(
|
||||
bitmask,
|
||||
vocab_size=self.vocab_size,
|
||||
index=0,
|
||||
)
|
||||
if not is_single:
|
||||
break
|
||||
|
||||
self.accept_tokens(request_id, [unique_token_id])
|
||||
jump_forward_tokens.append(unique_token_id)
|
||||
return jump_forward_tokens
|
||||
|
||||
def validate_tokens(self, tokens: list[int]) -> list[int]:
|
||||
"""Checks if the list of tokens are accepted by the FSM in sequence.
|
||||
Will not advance the FSM.
|
||||
|
Reference in New Issue
Block a user