[Misc][V1] Misc code streamlining (#15723)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill
2025-03-28 20:59:47 -07:00
committed by GitHub
parent 762b424a52
commit 6d531ad7b8
5 changed files with 32 additions and 38 deletions

View File

@ -207,10 +207,7 @@ class StatelessProcessGroup:
def barrier(self):
"""A barrier to synchronize all ranks."""
for i in range(self.world_size):
if i == self.rank:
self.broadcast_obj(None, src=self.rank)
else:
self.broadcast_obj(None, src=i)
self.broadcast_obj(None, src=i)
@staticmethod
def create(

View File

@ -269,29 +269,26 @@ class Scheduler(SchedulerInterface):
request = self.waiting[0]
# Waiting request skipping logic
is_skipped = False
# Skip request if the structured output request is still waiting
# for FSM.
if (not is_skipped
and request.status == RequestStatus.WAITING_FOR_FSM):
# for FSM compilation.
if request.status == RequestStatus.WAITING_FOR_FSM:
structured_output_req = request.structured_output_request
is_skipped = (not structured_output_req
or not structured_output_req.grammar)
if not is_skipped:
if structured_output_req and structured_output_req.grammar:
request.status = RequestStatus.WAITING
else:
self.waiting.popleft()
skipped_waiting_requests.appendleft(request)
continue
# Skip request if max_loras can't be honored.
if (not is_skipped and self.lora_config
and request.lora_request):
req_lora_id = request.lora_request.lora_int_id
is_skipped = (len(scheduled_loras)
== self.lora_config.max_loras
and (req_lora_id not in scheduled_loras))
if is_skipped:
skipped_waiting_requests.appendleft(request)
# Check that adding the request still respects the max_loras
# constraint.
if self.lora_config and request.lora_request and (
len(scheduled_loras) == self.lora_config.max_loras
and request.lora_request.lora_int_id
not in scheduled_loras):
# Scheduling would exceed max_loras, skip.
self.waiting.popleft()
skipped_waiting_requests.appendleft(request)
continue
# Get already-cached tokens.
@ -602,8 +599,9 @@ class Scheduler(SchedulerInterface):
# OPTIMIZATION: Avoid list(set) if the set is empty.
if cached_encoder_input_ids:
for input_id in list(cached_encoder_input_ids):
start_pos = request.mm_positions[input_id]["offset"]
num_tokens = request.mm_positions[input_id]["length"]
mm_positions = request.mm_positions[input_id]
start_pos = mm_positions["offset"]
num_tokens = mm_positions["length"]
if start_pos + num_tokens <= request.num_computed_tokens:
# The encoder output is already processed and stored
# in the decoder's KV cache.
@ -616,25 +614,24 @@ class Scheduler(SchedulerInterface):
stopped = False
new_logprobs = None
new_token_ids: list[int] = []
new_token_ids = generated_token_ids
# Append generated tokens and check for stop. Note that if
# a request is still being prefilled, we expect the model runner
# to return empty token ids for the request.
for output_token_id in generated_token_ids:
for num_new, output_token_id in enumerate(new_token_ids, 1):
request.append_output_token_ids(output_token_id)
new_token_ids.append(output_token_id)
# Check for stop and update request state.
# This must be called before we make the EngineCoreOutput.
stopped = check_stop(request, self.max_model_len)
if stopped:
self._free_request(request)
del new_token_ids[num_new:] # Trim new tokens if needed.
break
# Extract sample logprobs if needed.
if (request.sampling_params.logprobs is not None
and logprobs is not None):
if request.sampling_params.logprobs is not None and logprobs:
# NOTE: once we support N tokens per step (spec decode),
# the outer lists can be of length > 1.
new_logprobs = logprobs.slice(req_index, req_index + 1)
@ -644,9 +641,7 @@ class Scheduler(SchedulerInterface):
# should not be None if use_structured_output, we have
# check above, so safe to ignore type warning
request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr]
request.request_id,
new_token_ids,
)
req_id, new_token_ids)
# Get prompt logprobs for this request.
prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id)
@ -665,7 +660,7 @@ class Scheduler(SchedulerInterface):
# Invariant: EngineCore returns no partial prefill outputs.
assert not prompt_logprobs_tensors
self.scheduled_req_ids.remove(request.request_id)
self.scheduled_req_ids.remove(req_id)
if not stopped:
new_running.append(request)

View File

@ -416,9 +416,9 @@ class SyncMPClient(MPClient):
def process_outputs_socket():
shutdown_socket = ctx.socket(zmq.PAIR)
shutdown_socket.bind(shutdown_path)
out_socket = make_zmq_socket(ctx, output_path, zmq.constants.PULL)
try:
shutdown_socket.bind(shutdown_path)
poller = zmq.Poller()
poller.register(shutdown_socket)
poller.register(out_socket)

View File

@ -328,7 +328,7 @@ class OutputProcessor:
# 2) Detokenize the token ids into text and perform stop checks.
stop_string = req_state.detokenizer.update(
new_token_ids, finish_reason == FinishReason.STOP)
if stop_string and finish_reason != FinishReason.STOP:
if stop_string:
finish_reason = FinishReason.STOP
stop_reason = stop_string

View File

@ -93,9 +93,11 @@ class Request:
token_ids: Union[int, list[int]],
) -> None:
if isinstance(token_ids, int):
token_ids = [token_ids]
self._output_token_ids.extend(token_ids)
self._all_token_ids.extend(token_ids)
self._output_token_ids.append(token_ids)
self._all_token_ids.append(token_ids)
else:
self._output_token_ids.extend(token_ids)
self._all_token_ids.extend(token_ids)
@property
def num_tokens(self) -> int: