Compare commits

...

3 Commits

Author SHA1 Message Date
bcf3c8230d Merge branch 'main' into woosuk-jf 2025-05-04 11:16:07 -07:00
a01af39aa8 Merge branch 'main' into woosuk-jf 2025-05-03 10:42:43 -07:00
eeb5761cf1 Implement Jump-Forward (Fast-Forwrd) Decoding
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-05-01 18:08:52 -07:00
2 changed files with 39 additions and 5 deletions

View File

@ -721,15 +721,29 @@ class Scheduler(SchedulerInterface):
# the outer lists can be of length > 1.
new_logprobs = logprobs.slice(req_index, req_index + 1)
jump_tokens = []
if new_token_ids and request.use_structured_output:
# NOTE: structured_output_request
# should not be None if use_structured_output, we have
# check above, so safe to ignore type warning
request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr]
assert request.structured_output_request is not None
assert request.structured_output_request.grammar is not None
request.structured_output_request.grammar.accept_tokens(
req_id, new_token_ids)
if not stopped:
jump_tokens = request.structured_output_request.grammar.jump_forward(
req_id)
for token_id in jump_tokens:
request.append_output_token_ids(token_id)
new_token_ids.append(token_id)
stopped = check_stop(request, self.max_model_len)
if stopped:
break
if jump_tokens:
print(f"jump_tokens: {jump_tokens}")
# Add newly generated spec token ids to the request.
if spec_token_ids is not None:
if jump_tokens:
request.spec_token_ids.clear()
elif spec_token_ids is not None:
if request.use_structured_output:
metadata = request.structured_output_request
assert metadata is not None and metadata.grammar is not None

View File

@ -170,6 +170,26 @@ class XgrammarGrammar(StructuredOutputGrammar):
self.num_processed_tokens += 1
return True
def jump_forward(
self,
request_id: str,
) -> list[int]:
bitmask = xgr.allocate_token_bitmask(1, self.vocab_size)
jump_forward_tokens: list[int] = []
while not self.is_terminated():
self.fill_bitmask(bitmask, 0)
is_single, unique_token_id = xgr.testing._is_single_token_bitmask(
bitmask,
vocab_size=self.vocab_size,
index=0,
)
if not is_single:
break
self.accept_tokens(request_id, [unique_token_id])
jump_forward_tokens.append(unique_token_id)
return jump_forward_tokens
def validate_tokens(self, tokens: list[int]) -> list[int]:
"""Checks if the list of tokens are accepted by the FSM in sequence.
Will not advance the FSM.