diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index f5b075e83b..43ecdff382 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -15,7 +15,7 @@ from vllm.v1.spec_decode.metadata import SpecDecodeMetadata logger = init_logger(__name__) PLACEHOLDER_TOKEN_ID: tl.constexpr = -1 -GREEDY_TEMPERATURE: tl.constexpr = -1 +GREEDY_TEMPERATURE: tl.constexpr = 0 # Maximum number of speculative draft tokens allowed per request in a single # step. This value is chosen to be large enough to handle typical use cases. MAX_SPEC_LEN = 128 diff --git a/vllm/v1/sample/tpu/metadata.py b/vllm/v1/sample/tpu/metadata.py index c4bc88e615..0c1a22e84e 100644 --- a/vllm/v1/sample/tpu/metadata.py +++ b/vllm/v1/sample/tpu/metadata.py @@ -30,6 +30,7 @@ class TPUSupportedSamplingMetadata: top_p: torch.Tensor = None all_greedy: bool = True + all_random: bool = False # Whether logprobs are to be gathered in this batch of request. To balance # out compile time and runtime, a fixed `max_number_logprobs` value is used @@ -110,6 +111,7 @@ class TPUSupportedSamplingMetadata: xla_device ), all_greedy=input_batch.all_greedy, + all_random=input_batch.all_random, # TODO enable more and avoid returning None values top_p=input_batch.top_p_cpu_tensor[:padded_num_reqs].to(xla_device), top_k=input_batch.top_k_cpu_tensor[:padded_num_reqs].to(xla_device), diff --git a/vllm/v1/sample/tpu/sampler.py b/vllm/v1/sample/tpu/sampler.py index f81f3a0eef..8f0463c76c 100644 --- a/vllm/v1/sample/tpu/sampler.py +++ b/vllm/v1/sample/tpu/sampler.py @@ -40,7 +40,11 @@ class Sampler(nn.Module): self, logits: torch.Tensor, temp: torch.Tensor, + all_random: bool = False, ) -> torch.Tensor: + # Avoid division by zero for greedy sampling (temperature ~ 0.0). + if not all_random: + temp = torch.where(temp < _SAMPLING_EPS, 1.0, temp) return logits.div_(temp.unsqueeze(dim=1)) def greedy_sample(self, logits: torch.Tensor) -> torch.Tensor: @@ -56,7 +60,9 @@ class Sampler(nn.Module): assert sampling_metadata.temperature is not None # Apply temperature. - logits = self.apply_temperature(logits, sampling_metadata.temperature) + logits = self.apply_temperature( + logits, sampling_metadata.temperature, sampling_metadata.all_random + ) # Apply min_p. if sampling_metadata.min_p is not None: diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 6d5d0b2614..60418e53c1 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -37,6 +37,7 @@ from vllm.v1.attention.backends.utils import ( ) from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.sample.sampler import _SAMPLING_EPS from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.utils import CpuGpuBuffer from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch @@ -1140,8 +1141,15 @@ def compute_probs_and_sample_next_token( next_token_ids = logits.argmax(dim=-1) return next_token_ids, probs - is_greedy = sampling_metadata.temperature == -1 - temperature = torch.where(is_greedy, 1.0, sampling_metadata.temperature) + assert sampling_metadata.temperature is not None + + # Use epsilon comparison to detect greedy sampling (temperature ~ 0.0) + # consistent with sampler.py's _SAMPLING_EPS threshold + temperature = sampling_metadata.temperature + # Avoid division by zero if there are greedy requests. + if not sampling_metadata.all_random: + is_greedy = temperature < _SAMPLING_EPS + temperature = torch.where(is_greedy, 1.0, temperature) logits.div_(temperature.view(-1, 1)) probs = logits.softmax(dim=-1, dtype=torch.float32) diff --git a/vllm/v1/worker/tpu_input_batch.py b/vllm/v1/worker/tpu_input_batch.py index efd107b097..f52d92afab 100644 --- a/vllm/v1/worker/tpu_input_batch.py +++ b/vllm/v1/worker/tpu_input_batch.py @@ -215,8 +215,8 @@ class InputBatch: sampling_params = request.sampling_params assert sampling_params is not None, "pooling requests not supported yet" if sampling_params.sampling_type == SamplingType.GREEDY: - # Avoid later division by zero. - self.temperature_cpu[req_index] = -1.0 + # Should avoid division by zero later when apply_temperature. + self.temperature_cpu[req_index] = 0.0 self.greedy_reqs.add(req_id) else: self.temperature_cpu[req_index] = sampling_params.temperature