[V1][Spec Decode] Fix greedy temperature detection after sampler refactor (#27077)

Signed-off-by: Pradyun Ramadorai <pradyunr@amazon.com>
Co-authored-by: Pradyun Ramadorai <pradyunr@amazon.com>
This commit is contained in:
Pradyun92
2025-10-17 16:27:47 -04:00
committed by GitHub
parent d29483b58a
commit acedc74b1a
5 changed files with 22 additions and 6 deletions

View File

@ -15,7 +15,7 @@ from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
logger = init_logger(__name__)
PLACEHOLDER_TOKEN_ID: tl.constexpr = -1
GREEDY_TEMPERATURE: tl.constexpr = -1
GREEDY_TEMPERATURE: tl.constexpr = 0
# Maximum number of speculative draft tokens allowed per request in a single
# step. This value is chosen to be large enough to handle typical use cases.
MAX_SPEC_LEN = 128

View File

@ -30,6 +30,7 @@ class TPUSupportedSamplingMetadata:
top_p: torch.Tensor = None
all_greedy: bool = True
all_random: bool = False
# Whether logprobs are to be gathered in this batch of request. To balance
# out compile time and runtime, a fixed `max_number_logprobs` value is used
@ -110,6 +111,7 @@ class TPUSupportedSamplingMetadata:
xla_device
),
all_greedy=input_batch.all_greedy,
all_random=input_batch.all_random,
# TODO enable more and avoid returning None values
top_p=input_batch.top_p_cpu_tensor[:padded_num_reqs].to(xla_device),
top_k=input_batch.top_k_cpu_tensor[:padded_num_reqs].to(xla_device),

View File

@ -40,7 +40,11 @@ class Sampler(nn.Module):
self,
logits: torch.Tensor,
temp: torch.Tensor,
all_random: bool = False,
) -> torch.Tensor:
# Avoid division by zero for greedy sampling (temperature ~ 0.0).
if not all_random:
temp = torch.where(temp < _SAMPLING_EPS, 1.0, temp)
return logits.div_(temp.unsqueeze(dim=1))
def greedy_sample(self, logits: torch.Tensor) -> torch.Tensor:
@ -56,7 +60,9 @@ class Sampler(nn.Module):
assert sampling_metadata.temperature is not None
# Apply temperature.
logits = self.apply_temperature(logits, sampling_metadata.temperature)
logits = self.apply_temperature(
logits, sampling_metadata.temperature, sampling_metadata.all_random
)
# Apply min_p.
if sampling_metadata.min_p is not None:

View File

@ -37,6 +37,7 @@ from vllm.v1.attention.backends.utils import (
)
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.sampler import _SAMPLING_EPS
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.utils import CpuGpuBuffer
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
@ -1140,8 +1141,15 @@ def compute_probs_and_sample_next_token(
next_token_ids = logits.argmax(dim=-1)
return next_token_ids, probs
is_greedy = sampling_metadata.temperature == -1
temperature = torch.where(is_greedy, 1.0, sampling_metadata.temperature)
assert sampling_metadata.temperature is not None
# Use epsilon comparison to detect greedy sampling (temperature ~ 0.0)
# consistent with sampler.py's _SAMPLING_EPS threshold
temperature = sampling_metadata.temperature
# Avoid division by zero if there are greedy requests.
if not sampling_metadata.all_random:
is_greedy = temperature < _SAMPLING_EPS
temperature = torch.where(is_greedy, 1.0, temperature)
logits.div_(temperature.view(-1, 1))
probs = logits.softmax(dim=-1, dtype=torch.float32)

View File

@ -215,8 +215,8 @@ class InputBatch:
sampling_params = request.sampling_params
assert sampling_params is not None, "pooling requests not supported yet"
if sampling_params.sampling_type == SamplingType.GREEDY:
# Avoid later division by zero.
self.temperature_cpu[req_index] = -1.0
# Should avoid division by zero later when apply_temperature.
self.temperature_cpu[req_index] = 0.0
self.greedy_reqs.add(req_id)
else:
self.temperature_cpu[req_index] = sampling_params.temperature