mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Misc] Misc code simplifications (#26450)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
@ -1474,7 +1474,7 @@ class Scheduler(SchedulerInterface):
|
||||
|
||||
affected_req_ids.add(request.request_id)
|
||||
|
||||
return (affected_req_ids, total_affected_tokens)
|
||||
return affected_req_ids, total_affected_tokens
|
||||
|
||||
def _handle_invalid_blocks(self, invalid_block_ids: set[int]) -> set[str]:
|
||||
total_requests_to_reschedule = 0
|
||||
|
@ -59,8 +59,7 @@ def check_stop(
|
||||
sampling_params = request.sampling_params
|
||||
assert sampling_params is not None
|
||||
|
||||
min_tokens = sampling_params.min_tokens
|
||||
if request.num_output_tokens < min_tokens:
|
||||
if request.num_output_tokens < sampling_params.min_tokens:
|
||||
return False
|
||||
|
||||
last_token_id = request.output_token_ids[-1]
|
||||
|
@ -147,22 +147,20 @@ class RejectionSampler(nn.Module):
|
||||
sampling_metadata: SamplingMetadata,
|
||||
metadata: SpecDecodeMetadata,
|
||||
) -> torch.Tensor:
|
||||
has_penalties = not sampling_metadata.no_penalties
|
||||
any_penalties_or_bad_words = (
|
||||
sampling_metadata.bad_words_token_ids or not sampling_metadata.no_penalties
|
||||
sampling_metadata.bad_words_token_ids or has_penalties
|
||||
)
|
||||
|
||||
output_token_ids = sampling_metadata.output_token_ids
|
||||
if any_penalties_or_bad_words:
|
||||
output_token_ids = self._combine_outputs_with_spec_tokens(
|
||||
sampling_metadata.output_token_ids,
|
||||
output_token_ids,
|
||||
sampling_metadata.spec_token_ids,
|
||||
)
|
||||
|
||||
# Calculate indices of target logits.
|
||||
if (
|
||||
sampling_metadata.allowed_token_ids_mask is not None
|
||||
or not sampling_metadata.no_penalties
|
||||
):
|
||||
if sampling_metadata.allowed_token_ids_mask is not None or has_penalties:
|
||||
num_requests = len(sampling_metadata.output_token_ids)
|
||||
num_draft_tokens = torch.tensor(metadata.num_draft_tokens, device="cpu")
|
||||
original_indices = torch.arange(num_requests, device="cpu")
|
||||
@ -180,18 +178,15 @@ class RejectionSampler(nn.Module):
|
||||
logits.masked_fill_(token_mask, float("-inf"))
|
||||
|
||||
# Apply bad words exclusion.
|
||||
if sampling_metadata.bad_words_token_ids:
|
||||
if bad_words_token_ids := sampling_metadata.bad_words_token_ids:
|
||||
apply_bad_words_with_drafts(
|
||||
logits,
|
||||
sampling_metadata.bad_words_token_ids,
|
||||
output_token_ids,
|
||||
metadata.num_draft_tokens,
|
||||
logits, bad_words_token_ids, output_token_ids, metadata.num_draft_tokens
|
||||
)
|
||||
|
||||
return logits
|
||||
|
||||
@staticmethod
|
||||
def apply_penalties(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
metadata: SpecDecodeMetadata,
|
||||
@ -218,8 +213,8 @@ class RejectionSampler(nn.Module):
|
||||
)
|
||||
return logits
|
||||
|
||||
@staticmethod
|
||||
def _combine_outputs_with_spec_tokens(
|
||||
self,
|
||||
output_token_ids: list[list[int]],
|
||||
spec_token_ids: Optional[list[list[int]]] = None,
|
||||
) -> list[list[int]]:
|
||||
|
@ -120,8 +120,8 @@ class Sampler(nn.Module):
|
||||
)
|
||||
return sampler_output
|
||||
|
||||
@staticmethod
|
||||
def apply_temperature(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
temp: torch.Tensor,
|
||||
all_random: bool,
|
||||
@ -132,7 +132,8 @@ class Sampler(nn.Module):
|
||||
temp = torch.where(temp < _SAMPLING_EPS, 1.0, temp)
|
||||
return logits.div_(temp.unsqueeze(dim=1))
|
||||
|
||||
def greedy_sample(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
@staticmethod
|
||||
def greedy_sample(logits: torch.Tensor) -> torch.Tensor:
|
||||
return logits.argmax(dim=-1).view(-1)
|
||||
|
||||
def sample(
|
||||
@ -191,11 +192,12 @@ class Sampler(nn.Module):
|
||||
)
|
||||
return sampled, processed_logprobs
|
||||
|
||||
def compute_logprobs(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
@staticmethod
|
||||
def compute_logprobs(logits: torch.Tensor) -> torch.Tensor:
|
||||
return logits.log_softmax(dim=-1, dtype=torch.float32)
|
||||
|
||||
@staticmethod
|
||||
def gather_logprobs(
|
||||
self,
|
||||
logprobs: torch.Tensor,
|
||||
num_logprobs: int,
|
||||
token_ids: torch.Tensor,
|
||||
@ -238,8 +240,8 @@ class Sampler(nn.Module):
|
||||
|
||||
return LogprobsTensors(indices, logprobs, token_ranks)
|
||||
|
||||
@staticmethod
|
||||
def _combine_outputs_with_spec_tokens(
|
||||
self,
|
||||
output_token_ids: list[list[int]],
|
||||
spec_token_ids: Optional[list[list[int]]] = None,
|
||||
) -> list[list[int]]:
|
||||
@ -257,8 +259,9 @@ class Sampler(nn.Module):
|
||||
sampling_metadata: SamplingMetadata,
|
||||
predict_bonus_token: bool,
|
||||
) -> torch.Tensor:
|
||||
bad_words_token_ids = sampling_metadata.bad_words_token_ids
|
||||
any_penalties_or_bad_words = (
|
||||
sampling_metadata.bad_words_token_ids or not sampling_metadata.no_penalties
|
||||
bool(bad_words_token_ids) or not sampling_metadata.no_penalties
|
||||
)
|
||||
|
||||
output_token_ids = sampling_metadata.output_token_ids
|
||||
@ -266,7 +269,7 @@ class Sampler(nn.Module):
|
||||
# Combine base outputs with spec tokens when speculative decoding
|
||||
# is enabled.
|
||||
output_token_ids = self._combine_outputs_with_spec_tokens(
|
||||
sampling_metadata.output_token_ids,
|
||||
output_token_ids,
|
||||
sampling_metadata.spec_token_ids,
|
||||
)
|
||||
|
||||
@ -275,14 +278,8 @@ class Sampler(nn.Module):
|
||||
logits.masked_fill_(sampling_metadata.allowed_token_ids_mask, float("-inf"))
|
||||
|
||||
# Apply bad words exclusion.
|
||||
if sampling_metadata.bad_words_token_ids:
|
||||
apply_bad_words(
|
||||
logits,
|
||||
sampling_metadata.bad_words_token_ids,
|
||||
output_token_ids
|
||||
if output_token_ids is not None
|
||||
else sampling_metadata.output_token_ids,
|
||||
)
|
||||
if bad_words_token_ids:
|
||||
apply_bad_words(logits, bad_words_token_ids, output_token_ids)
|
||||
|
||||
# Apply logits processors which can impact greedy sampling.
|
||||
for processor in sampling_metadata.logitsprocs.non_argmax_invariant:
|
||||
@ -292,22 +289,21 @@ class Sampler(nn.Module):
|
||||
logits = self.apply_penalties(logits, sampling_metadata, output_token_ids)
|
||||
return logits
|
||||
|
||||
@staticmethod
|
||||
def apply_penalties(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
output_token_ids: Optional[list[list[int]]] = None,
|
||||
output_token_ids: list[list[int]],
|
||||
) -> torch.Tensor:
|
||||
if not sampling_metadata.no_penalties:
|
||||
assert sampling_metadata.prompt_token_ids is not None
|
||||
logits = apply_all_penalties(
|
||||
logits,
|
||||
sampling_metadata.prompt_token_ids,
|
||||
sampling_metadata.presence_penalties,
|
||||
sampling_metadata.frequency_penalties,
|
||||
sampling_metadata.repetition_penalties,
|
||||
output_token_ids
|
||||
if output_token_ids is not None
|
||||
else sampling_metadata.output_token_ids,
|
||||
)
|
||||
return logits
|
||||
if sampling_metadata.no_penalties:
|
||||
return logits
|
||||
|
||||
assert sampling_metadata.prompt_token_ids is not None
|
||||
return apply_all_penalties(
|
||||
logits,
|
||||
sampling_metadata.prompt_token_ids,
|
||||
sampling_metadata.presence_penalties,
|
||||
sampling_metadata.frequency_penalties,
|
||||
sampling_metadata.repetition_penalties,
|
||||
output_token_ids,
|
||||
)
|
||||
|
@ -62,10 +62,9 @@ class CachedRequestState:
|
||||
"provided via prompt_embeds, and its ID is unknown."
|
||||
)
|
||||
return self.prompt_token_ids[idx]
|
||||
elif idx - self.num_prompt_tokens < len(self.output_token_ids):
|
||||
if idx - self.num_prompt_tokens < len(self.output_token_ids):
|
||||
return self.output_token_ids[idx - self.num_prompt_tokens]
|
||||
else:
|
||||
return -1
|
||||
return -1
|
||||
|
||||
|
||||
class InputBatch:
|
||||
@ -770,14 +769,13 @@ class InputBatch:
|
||||
not self.no_penalties
|
||||
or self.logits_processing_needs_token_ids[:num_reqs].any()
|
||||
)
|
||||
if needs_prompt_token_ids:
|
||||
# The prompt tokens are used only for applying penalties or
|
||||
# step pooling during the sampling/pooling process.
|
||||
# Hence copy these tensors only when there are requests which
|
||||
# need penalties/step_pooler to be applied.
|
||||
prompt_token_ids = self._make_prompt_token_ids_tensor()
|
||||
else:
|
||||
prompt_token_ids = None
|
||||
# The prompt tokens are used only for applying penalties or
|
||||
# step pooling during the sampling/pooling process.
|
||||
# Hence copy these tensors only when there are requests which
|
||||
# need penalties/step_pooler to be applied.
|
||||
prompt_token_ids = (
|
||||
self._make_prompt_token_ids_tensor() if needs_prompt_token_ids else None
|
||||
)
|
||||
|
||||
allowed_token_ids_mask: Optional[torch.Tensor] = None
|
||||
if not self.no_allowed_token_ids:
|
||||
|
@ -1996,7 +1996,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# Should be called after attention metadata creation. This just pads
|
||||
# the second ubatch slice out to the total number of tokens
|
||||
# (num_tokens + padding)
|
||||
def pad_out_ubatch_slice(self, ubatch_slices: UBatchSlices, num_total_tokens: int):
|
||||
@staticmethod
|
||||
def pad_out_ubatch_slice(ubatch_slices: UBatchSlices, num_total_tokens: int):
|
||||
padded_second_ubatch_slice = slice(
|
||||
ubatch_slices[1].token_slice.start, num_total_tokens
|
||||
)
|
||||
@ -2085,12 +2086,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
dict[str, Any],
|
||||
]:
|
||||
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
is_first_rank = get_pp_group().is_first_rank
|
||||
|
||||
# _prepare_inputs may reorder the batch, so we must gather multi
|
||||
# modal outputs after that to ensure the correct order
|
||||
if (
|
||||
self.supports_mm_inputs
|
||||
and get_pp_group().is_first_rank
|
||||
and is_first_rank
|
||||
and not self.model_config.is_encoder_decoder
|
||||
):
|
||||
# Run the multimodal encoder if any.
|
||||
@ -2115,7 +2117,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
**self._init_model_kwargs(num_scheduled_tokens),
|
||||
**self._extract_mm_kwargs(scheduler_output),
|
||||
}
|
||||
elif self.enable_prompt_embeds and get_pp_group().is_first_rank:
|
||||
elif self.enable_prompt_embeds and is_first_rank:
|
||||
# Get the input embeddings for the tokens that are not input embeds,
|
||||
# then put them into the appropriate positions.
|
||||
# TODO(qthequartermasterman): Since even when prompt embeds are
|
||||
@ -2155,7 +2157,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
else:
|
||||
positions = self.positions.gpu[:num_input_tokens]
|
||||
|
||||
if get_pp_group().is_first_rank:
|
||||
if is_first_rank:
|
||||
intermediate_tensors = None
|
||||
else:
|
||||
intermediate_tensors = self.sync_and_slice_intermediate_tensors(
|
||||
@ -2186,38 +2188,37 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# Sample the next token and get logprobs if needed.
|
||||
sampling_metadata = self.input_batch.sampling_metadata
|
||||
if spec_decode_metadata is None:
|
||||
sampler_output = self.sampler(
|
||||
return self.sampler(
|
||||
logits=logits,
|
||||
sampling_metadata=sampling_metadata,
|
||||
)
|
||||
else:
|
||||
# When indexing with a tensor (bonus_logits_indices), PyTorch
|
||||
# creates a new tensor with separate storage from the original
|
||||
# logits tensor. This means any in-place operations on bonus_logits
|
||||
# won't affect the original logits tensor.
|
||||
assert logits is not None
|
||||
bonus_logits = logits[spec_decode_metadata.bonus_logits_indices]
|
||||
sampler_output = self.sampler(
|
||||
logits=bonus_logits,
|
||||
sampling_metadata=sampling_metadata,
|
||||
predict_bonus_token=True,
|
||||
)
|
||||
bonus_token_ids = sampler_output.sampled_token_ids
|
||||
|
||||
# Just like `bonus_logits`, `target_logits` is a new tensor with
|
||||
# separate storage from the original `logits` tensor. Therefore,
|
||||
# it is safe to update `target_logits` in place.
|
||||
target_logits = logits[spec_decode_metadata.target_logits_indices]
|
||||
output_token_ids = self.rejection_sampler(
|
||||
spec_decode_metadata,
|
||||
None, # draft_probs
|
||||
target_logits,
|
||||
bonus_token_ids,
|
||||
sampling_metadata,
|
||||
)
|
||||
sampler_output.sampled_token_ids = output_token_ids
|
||||
self._update_states_after_model_execute(output_token_ids)
|
||||
# When indexing with a tensor (bonus_logits_indices), PyTorch
|
||||
# creates a new tensor with separate storage from the original
|
||||
# logits tensor. This means any in-place operations on bonus_logits
|
||||
# won't affect the original logits tensor.
|
||||
assert logits is not None
|
||||
bonus_logits = logits[spec_decode_metadata.bonus_logits_indices]
|
||||
sampler_output = self.sampler(
|
||||
logits=bonus_logits,
|
||||
sampling_metadata=sampling_metadata,
|
||||
predict_bonus_token=True,
|
||||
)
|
||||
bonus_token_ids = sampler_output.sampled_token_ids
|
||||
|
||||
# Just like `bonus_logits`, `target_logits` is a new tensor with
|
||||
# separate storage from the original `logits` tensor. Therefore,
|
||||
# it is safe to update `target_logits` in place.
|
||||
target_logits = logits[spec_decode_metadata.target_logits_indices]
|
||||
output_token_ids = self.rejection_sampler(
|
||||
spec_decode_metadata,
|
||||
None, # draft_probs
|
||||
target_logits,
|
||||
bonus_token_ids,
|
||||
sampling_metadata,
|
||||
)
|
||||
sampler_output.sampled_token_ids = output_token_ids
|
||||
self._update_states_after_model_execute(output_token_ids)
|
||||
return sampler_output
|
||||
|
||||
def _bookkeeping_sync(
|
||||
@ -3741,7 +3742,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
decode_cudagraph_batch_sizes = [
|
||||
x
|
||||
for x in self.cudagraph_batch_sizes
|
||||
if x <= max_num_tokens and x >= self.uniform_decode_query_len
|
||||
if max_num_tokens >= x >= self.uniform_decode_query_len
|
||||
]
|
||||
compilation_cases_decode = list(reversed(decode_cudagraph_batch_sizes))
|
||||
self._capture_cudagraphs(
|
||||
|
Reference in New Issue
Block a user