[Misc] Misc code simplifications (#26450)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill
2025-10-09 02:10:06 -07:00
committed by GitHub
parent a83ff278d6
commit ddcbc2f334
6 changed files with 78 additions and 89 deletions

View File

@ -1474,7 +1474,7 @@ class Scheduler(SchedulerInterface):
affected_req_ids.add(request.request_id)
return (affected_req_ids, total_affected_tokens)
return affected_req_ids, total_affected_tokens
def _handle_invalid_blocks(self, invalid_block_ids: set[int]) -> set[str]:
total_requests_to_reschedule = 0

View File

@ -59,8 +59,7 @@ def check_stop(
sampling_params = request.sampling_params
assert sampling_params is not None
min_tokens = sampling_params.min_tokens
if request.num_output_tokens < min_tokens:
if request.num_output_tokens < sampling_params.min_tokens:
return False
last_token_id = request.output_token_ids[-1]

View File

@ -147,22 +147,20 @@ class RejectionSampler(nn.Module):
sampling_metadata: SamplingMetadata,
metadata: SpecDecodeMetadata,
) -> torch.Tensor:
has_penalties = not sampling_metadata.no_penalties
any_penalties_or_bad_words = (
sampling_metadata.bad_words_token_ids or not sampling_metadata.no_penalties
sampling_metadata.bad_words_token_ids or has_penalties
)
output_token_ids = sampling_metadata.output_token_ids
if any_penalties_or_bad_words:
output_token_ids = self._combine_outputs_with_spec_tokens(
sampling_metadata.output_token_ids,
output_token_ids,
sampling_metadata.spec_token_ids,
)
# Calculate indices of target logits.
if (
sampling_metadata.allowed_token_ids_mask is not None
or not sampling_metadata.no_penalties
):
if sampling_metadata.allowed_token_ids_mask is not None or has_penalties:
num_requests = len(sampling_metadata.output_token_ids)
num_draft_tokens = torch.tensor(metadata.num_draft_tokens, device="cpu")
original_indices = torch.arange(num_requests, device="cpu")
@ -180,18 +178,15 @@ class RejectionSampler(nn.Module):
logits.masked_fill_(token_mask, float("-inf"))
# Apply bad words exclusion.
if sampling_metadata.bad_words_token_ids:
if bad_words_token_ids := sampling_metadata.bad_words_token_ids:
apply_bad_words_with_drafts(
logits,
sampling_metadata.bad_words_token_ids,
output_token_ids,
metadata.num_draft_tokens,
logits, bad_words_token_ids, output_token_ids, metadata.num_draft_tokens
)
return logits
@staticmethod
def apply_penalties(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
metadata: SpecDecodeMetadata,
@ -218,8 +213,8 @@ class RejectionSampler(nn.Module):
)
return logits
@staticmethod
def _combine_outputs_with_spec_tokens(
self,
output_token_ids: list[list[int]],
spec_token_ids: Optional[list[list[int]]] = None,
) -> list[list[int]]:

View File

@ -120,8 +120,8 @@ class Sampler(nn.Module):
)
return sampler_output
@staticmethod
def apply_temperature(
self,
logits: torch.Tensor,
temp: torch.Tensor,
all_random: bool,
@ -132,7 +132,8 @@ class Sampler(nn.Module):
temp = torch.where(temp < _SAMPLING_EPS, 1.0, temp)
return logits.div_(temp.unsqueeze(dim=1))
def greedy_sample(self, logits: torch.Tensor) -> torch.Tensor:
@staticmethod
def greedy_sample(logits: torch.Tensor) -> torch.Tensor:
return logits.argmax(dim=-1).view(-1)
def sample(
@ -191,11 +192,12 @@ class Sampler(nn.Module):
)
return sampled, processed_logprobs
def compute_logprobs(self, logits: torch.Tensor) -> torch.Tensor:
@staticmethod
def compute_logprobs(logits: torch.Tensor) -> torch.Tensor:
return logits.log_softmax(dim=-1, dtype=torch.float32)
@staticmethod
def gather_logprobs(
self,
logprobs: torch.Tensor,
num_logprobs: int,
token_ids: torch.Tensor,
@ -238,8 +240,8 @@ class Sampler(nn.Module):
return LogprobsTensors(indices, logprobs, token_ranks)
@staticmethod
def _combine_outputs_with_spec_tokens(
self,
output_token_ids: list[list[int]],
spec_token_ids: Optional[list[list[int]]] = None,
) -> list[list[int]]:
@ -257,8 +259,9 @@ class Sampler(nn.Module):
sampling_metadata: SamplingMetadata,
predict_bonus_token: bool,
) -> torch.Tensor:
bad_words_token_ids = sampling_metadata.bad_words_token_ids
any_penalties_or_bad_words = (
sampling_metadata.bad_words_token_ids or not sampling_metadata.no_penalties
bool(bad_words_token_ids) or not sampling_metadata.no_penalties
)
output_token_ids = sampling_metadata.output_token_ids
@ -266,7 +269,7 @@ class Sampler(nn.Module):
# Combine base outputs with spec tokens when speculative decoding
# is enabled.
output_token_ids = self._combine_outputs_with_spec_tokens(
sampling_metadata.output_token_ids,
output_token_ids,
sampling_metadata.spec_token_ids,
)
@ -275,14 +278,8 @@ class Sampler(nn.Module):
logits.masked_fill_(sampling_metadata.allowed_token_ids_mask, float("-inf"))
# Apply bad words exclusion.
if sampling_metadata.bad_words_token_ids:
apply_bad_words(
logits,
sampling_metadata.bad_words_token_ids,
output_token_ids
if output_token_ids is not None
else sampling_metadata.output_token_ids,
)
if bad_words_token_ids:
apply_bad_words(logits, bad_words_token_ids, output_token_ids)
# Apply logits processors which can impact greedy sampling.
for processor in sampling_metadata.logitsprocs.non_argmax_invariant:
@ -292,22 +289,21 @@ class Sampler(nn.Module):
logits = self.apply_penalties(logits, sampling_metadata, output_token_ids)
return logits
@staticmethod
def apply_penalties(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
output_token_ids: Optional[list[list[int]]] = None,
output_token_ids: list[list[int]],
) -> torch.Tensor:
if not sampling_metadata.no_penalties:
assert sampling_metadata.prompt_token_ids is not None
logits = apply_all_penalties(
logits,
sampling_metadata.prompt_token_ids,
sampling_metadata.presence_penalties,
sampling_metadata.frequency_penalties,
sampling_metadata.repetition_penalties,
output_token_ids
if output_token_ids is not None
else sampling_metadata.output_token_ids,
)
return logits
if sampling_metadata.no_penalties:
return logits
assert sampling_metadata.prompt_token_ids is not None
return apply_all_penalties(
logits,
sampling_metadata.prompt_token_ids,
sampling_metadata.presence_penalties,
sampling_metadata.frequency_penalties,
sampling_metadata.repetition_penalties,
output_token_ids,
)

View File

@ -62,10 +62,9 @@ class CachedRequestState:
"provided via prompt_embeds, and its ID is unknown."
)
return self.prompt_token_ids[idx]
elif idx - self.num_prompt_tokens < len(self.output_token_ids):
if idx - self.num_prompt_tokens < len(self.output_token_ids):
return self.output_token_ids[idx - self.num_prompt_tokens]
else:
return -1
return -1
class InputBatch:
@ -770,14 +769,13 @@ class InputBatch:
not self.no_penalties
or self.logits_processing_needs_token_ids[:num_reqs].any()
)
if needs_prompt_token_ids:
# The prompt tokens are used only for applying penalties or
# step pooling during the sampling/pooling process.
# Hence copy these tensors only when there are requests which
# need penalties/step_pooler to be applied.
prompt_token_ids = self._make_prompt_token_ids_tensor()
else:
prompt_token_ids = None
# The prompt tokens are used only for applying penalties or
# step pooling during the sampling/pooling process.
# Hence copy these tensors only when there are requests which
# need penalties/step_pooler to be applied.
prompt_token_ids = (
self._make_prompt_token_ids_tensor() if needs_prompt_token_ids else None
)
allowed_token_ids_mask: Optional[torch.Tensor] = None
if not self.no_allowed_token_ids:

View File

@ -1996,7 +1996,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Should be called after attention metadata creation. This just pads
# the second ubatch slice out to the total number of tokens
# (num_tokens + padding)
def pad_out_ubatch_slice(self, ubatch_slices: UBatchSlices, num_total_tokens: int):
@staticmethod
def pad_out_ubatch_slice(ubatch_slices: UBatchSlices, num_total_tokens: int):
padded_second_ubatch_slice = slice(
ubatch_slices[1].token_slice.start, num_total_tokens
)
@ -2085,12 +2086,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
dict[str, Any],
]:
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
is_first_rank = get_pp_group().is_first_rank
# _prepare_inputs may reorder the batch, so we must gather multi
# modal outputs after that to ensure the correct order
if (
self.supports_mm_inputs
and get_pp_group().is_first_rank
and is_first_rank
and not self.model_config.is_encoder_decoder
):
# Run the multimodal encoder if any.
@ -2115,7 +2117,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
**self._init_model_kwargs(num_scheduled_tokens),
**self._extract_mm_kwargs(scheduler_output),
}
elif self.enable_prompt_embeds and get_pp_group().is_first_rank:
elif self.enable_prompt_embeds and is_first_rank:
# Get the input embeddings for the tokens that are not input embeds,
# then put them into the appropriate positions.
# TODO(qthequartermasterman): Since even when prompt embeds are
@ -2155,7 +2157,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
else:
positions = self.positions.gpu[:num_input_tokens]
if get_pp_group().is_first_rank:
if is_first_rank:
intermediate_tensors = None
else:
intermediate_tensors = self.sync_and_slice_intermediate_tensors(
@ -2186,38 +2188,37 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Sample the next token and get logprobs if needed.
sampling_metadata = self.input_batch.sampling_metadata
if spec_decode_metadata is None:
sampler_output = self.sampler(
return self.sampler(
logits=logits,
sampling_metadata=sampling_metadata,
)
else:
# When indexing with a tensor (bonus_logits_indices), PyTorch
# creates a new tensor with separate storage from the original
# logits tensor. This means any in-place operations on bonus_logits
# won't affect the original logits tensor.
assert logits is not None
bonus_logits = logits[spec_decode_metadata.bonus_logits_indices]
sampler_output = self.sampler(
logits=bonus_logits,
sampling_metadata=sampling_metadata,
predict_bonus_token=True,
)
bonus_token_ids = sampler_output.sampled_token_ids
# Just like `bonus_logits`, `target_logits` is a new tensor with
# separate storage from the original `logits` tensor. Therefore,
# it is safe to update `target_logits` in place.
target_logits = logits[spec_decode_metadata.target_logits_indices]
output_token_ids = self.rejection_sampler(
spec_decode_metadata,
None, # draft_probs
target_logits,
bonus_token_ids,
sampling_metadata,
)
sampler_output.sampled_token_ids = output_token_ids
self._update_states_after_model_execute(output_token_ids)
# When indexing with a tensor (bonus_logits_indices), PyTorch
# creates a new tensor with separate storage from the original
# logits tensor. This means any in-place operations on bonus_logits
# won't affect the original logits tensor.
assert logits is not None
bonus_logits = logits[spec_decode_metadata.bonus_logits_indices]
sampler_output = self.sampler(
logits=bonus_logits,
sampling_metadata=sampling_metadata,
predict_bonus_token=True,
)
bonus_token_ids = sampler_output.sampled_token_ids
# Just like `bonus_logits`, `target_logits` is a new tensor with
# separate storage from the original `logits` tensor. Therefore,
# it is safe to update `target_logits` in place.
target_logits = logits[spec_decode_metadata.target_logits_indices]
output_token_ids = self.rejection_sampler(
spec_decode_metadata,
None, # draft_probs
target_logits,
bonus_token_ids,
sampling_metadata,
)
sampler_output.sampled_token_ids = output_token_ids
self._update_states_after_model_execute(output_token_ids)
return sampler_output
def _bookkeeping_sync(
@ -3741,7 +3742,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
decode_cudagraph_batch_sizes = [
x
for x in self.cudagraph_batch_sizes
if x <= max_num_tokens and x >= self.uniform_decode_query_len
if max_num_tokens >= x >= self.uniform_decode_query_len
]
compilation_cases_decode = list(reversed(decode_cudagraph_batch_sizes))
self._capture_cudagraphs(