Compare commits

...

1 Commits

Author SHA1 Message Date
ad7bc7ba0c Next token 2025-05-28 16:03:09 +02:00

View File

@ -27,6 +27,8 @@ from typing import Deque, Dict, List, Optional, Set, Tuple, Union
import torch
import torch.nn as nn
from tokenizers import Tokenizer
from tokenizers.decoders import DecodeStream
from torch.profiler import profile, schedule, tensorboard_trace_handler
from tqdm import tqdm
@ -72,6 +74,7 @@ class GenerationOutput:
error: Optional[str] = None
status: RequestStatus = RequestStatus.PENDING
created_time: float = field(default_factory=time.time)
next_token: Optional[int] = field(default_factory=int)
@dataclass
@ -96,6 +99,7 @@ class RequestState:
eos_token_id: int = -1
created_time: float = field(default_factory=time.time)
error: Optional[str] = None
next_token: Optional[str] = None
def current_len(self) -> int:
"""Get the current length of the sequence (prompt + generated tokens)."""
@ -139,6 +143,7 @@ class RequestState:
generated_tokens=self.static_outputs,
logprobs=[],
error=self.error,
next_token=self.next_token,
)
@ -751,6 +756,9 @@ class ContinuousBatchProcessor:
self.setup_static_tensors()
self.tokenizer = Tokenizer.from_pretrained(self.config._name_or_path)
self.decode_stream = DecodeStream(skip_special_tokens=True)
@traced(standalone=True)
def setup_static_tensors(self):
T = self.max_batch_tokens
@ -982,7 +990,7 @@ class ContinuousBatchProcessor:
def _maybe_send_output(self, state: RequestState, token: int):
"""Send output to the queue based on streaming mode and request state."""
if self.streaming:
state.next_token = token
state.next_token = self.decode_stream.step(self.tokenizer, state.static_outputs[-1])
self.output_queue.put(state.to_generation_output())
elif state.status == RequestStatus.FINISHED:
self.output_queue.put(state.to_generation_output())
@ -1080,6 +1088,7 @@ class ContinuousBatchingManager:
self.logit_processor = self.model._get_logits_processor(self.model.generation_config)
self.use_cuda_graph = getattr(generation_config, "use_cuda_graph", True)
self.profile = getattr(generation_config, "profile", False)
self.decode_stream = DecodeStream(skip_special_tokens=True)
@traced
def start(self):