mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[V1][Spec Decode] Fix greedy temperature detection after sampler refactor (#27077)
Signed-off-by: Pradyun Ramadorai <pradyunr@amazon.com> Co-authored-by: Pradyun Ramadorai <pradyunr@amazon.com>
This commit is contained in:
@ -15,7 +15,7 @@ from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
logger = init_logger(__name__)
|
||||
|
||||
PLACEHOLDER_TOKEN_ID: tl.constexpr = -1
|
||||
GREEDY_TEMPERATURE: tl.constexpr = -1
|
||||
GREEDY_TEMPERATURE: tl.constexpr = 0
|
||||
# Maximum number of speculative draft tokens allowed per request in a single
|
||||
# step. This value is chosen to be large enough to handle typical use cases.
|
||||
MAX_SPEC_LEN = 128
|
||||
|
@ -30,6 +30,7 @@ class TPUSupportedSamplingMetadata:
|
||||
top_p: torch.Tensor = None
|
||||
|
||||
all_greedy: bool = True
|
||||
all_random: bool = False
|
||||
|
||||
# Whether logprobs are to be gathered in this batch of request. To balance
|
||||
# out compile time and runtime, a fixed `max_number_logprobs` value is used
|
||||
@ -110,6 +111,7 @@ class TPUSupportedSamplingMetadata:
|
||||
xla_device
|
||||
),
|
||||
all_greedy=input_batch.all_greedy,
|
||||
all_random=input_batch.all_random,
|
||||
# TODO enable more and avoid returning None values
|
||||
top_p=input_batch.top_p_cpu_tensor[:padded_num_reqs].to(xla_device),
|
||||
top_k=input_batch.top_k_cpu_tensor[:padded_num_reqs].to(xla_device),
|
||||
|
@ -40,7 +40,11 @@ class Sampler(nn.Module):
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
temp: torch.Tensor,
|
||||
all_random: bool = False,
|
||||
) -> torch.Tensor:
|
||||
# Avoid division by zero for greedy sampling (temperature ~ 0.0).
|
||||
if not all_random:
|
||||
temp = torch.where(temp < _SAMPLING_EPS, 1.0, temp)
|
||||
return logits.div_(temp.unsqueeze(dim=1))
|
||||
|
||||
def greedy_sample(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
@ -56,7 +60,9 @@ class Sampler(nn.Module):
|
||||
assert sampling_metadata.temperature is not None
|
||||
|
||||
# Apply temperature.
|
||||
logits = self.apply_temperature(logits, sampling_metadata.temperature)
|
||||
logits = self.apply_temperature(
|
||||
logits, sampling_metadata.temperature, sampling_metadata.all_random
|
||||
)
|
||||
|
||||
# Apply min_p.
|
||||
if sampling_metadata.min_p is not None:
|
||||
|
@ -37,6 +37,7 @@ from vllm.v1.attention.backends.utils import (
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.sample.sampler import _SAMPLING_EPS
|
||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
from vllm.v1.utils import CpuGpuBuffer
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
@ -1140,8 +1141,15 @@ def compute_probs_and_sample_next_token(
|
||||
next_token_ids = logits.argmax(dim=-1)
|
||||
return next_token_ids, probs
|
||||
|
||||
is_greedy = sampling_metadata.temperature == -1
|
||||
temperature = torch.where(is_greedy, 1.0, sampling_metadata.temperature)
|
||||
assert sampling_metadata.temperature is not None
|
||||
|
||||
# Use epsilon comparison to detect greedy sampling (temperature ~ 0.0)
|
||||
# consistent with sampler.py's _SAMPLING_EPS threshold
|
||||
temperature = sampling_metadata.temperature
|
||||
# Avoid division by zero if there are greedy requests.
|
||||
if not sampling_metadata.all_random:
|
||||
is_greedy = temperature < _SAMPLING_EPS
|
||||
temperature = torch.where(is_greedy, 1.0, temperature)
|
||||
logits.div_(temperature.view(-1, 1))
|
||||
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
||||
|
||||
|
@ -215,8 +215,8 @@ class InputBatch:
|
||||
sampling_params = request.sampling_params
|
||||
assert sampling_params is not None, "pooling requests not supported yet"
|
||||
if sampling_params.sampling_type == SamplingType.GREEDY:
|
||||
# Avoid later division by zero.
|
||||
self.temperature_cpu[req_index] = -1.0
|
||||
# Should avoid division by zero later when apply_temperature.
|
||||
self.temperature_cpu[req_index] = 0.0
|
||||
self.greedy_reqs.add(req_id)
|
||||
else:
|
||||
self.temperature_cpu[req_index] = sampling_params.temperature
|
||||
|
Reference in New Issue
Block a user