mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Hardware][TPU] Raise errors for unsupported sampling params (#5850)
This commit is contained in:
@ -20,6 +20,8 @@ from vllm.utils import make_tensor_with_pad
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_PAD_SLOT_ID = 0 # FIXME(woosuk)
|
||||
# FIXME(woosuk): Temporarily disabled top-p sampling since it's too slow.
|
||||
_ENABLE_TOP_P = False
|
||||
|
||||
|
||||
class TPUModelRunner:
|
||||
@ -339,9 +341,34 @@ class TPUModelRunner:
|
||||
assert seq_group_metadata.sampling_params is not None
|
||||
sampling_params = seq_group_metadata.sampling_params
|
||||
|
||||
# NOTE(woosuk): Here we mimic argmax sampling by applying a very
|
||||
# low temperature. This is not accurate.
|
||||
t.append(sampling_params.temperature
|
||||
if sampling_params.temperature >= 1e-5 else 1e-5)
|
||||
if sampling_params.top_p != 1 and not _ENABLE_TOP_P:
|
||||
raise NotImplementedError(
|
||||
"Top-p sampling is currently disabled for the TPU backend "
|
||||
"due to performance issues.")
|
||||
p.append(sampling_params.top_p)
|
||||
if sampling_params.top_k != -1:
|
||||
raise NotImplementedError(
|
||||
"Top-k sampling is currently disabled for the TPU backend "
|
||||
"due to performance issues.")
|
||||
if sampling_params.best_of > 1:
|
||||
raise NotImplementedError(
|
||||
"best_of > 1 is not currently supported by the TPU "
|
||||
"backend.")
|
||||
if sampling_params.use_beam_search:
|
||||
raise NotImplementedError(
|
||||
"Beam search is not supported by the TPU backend.")
|
||||
if sampling_params.logprobs is not None:
|
||||
raise NotImplementedError(
|
||||
"logprobs is not currently supported by the TPU backend.")
|
||||
if sampling_params.prompt_logprobs is not None:
|
||||
raise NotImplementedError(
|
||||
"prompt_logprobs is not currently supported by the TPU "
|
||||
"backend.")
|
||||
|
||||
num_paddings = padded_batch_size - len(seq_group_metadata_list)
|
||||
t += [1.0] * num_paddings
|
||||
p += [1.0] * num_paddings
|
||||
@ -350,35 +377,32 @@ class TPUModelRunner:
|
||||
p = torch.tensor(p, dtype=torch.float32, device=self.device)
|
||||
return t, p
|
||||
|
||||
def prepare_inputs(
|
||||
self,
|
||||
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
|
||||
):
|
||||
assert seq_group_metadata_list is not None
|
||||
assert len(seq_group_metadata_list) > 0
|
||||
# NOTE: We assume that all sequences in the group are all prompts or
|
||||
# all decodes.
|
||||
if seq_group_metadata_list[0].is_prompt:
|
||||
inputs = self._prepare_prompt(seq_group_metadata_list)
|
||||
else:
|
||||
inputs = self._prepare_decode(seq_group_metadata_list)
|
||||
padded_batch_size = inputs[0].shape[0]
|
||||
sample_inputs = self._prepare_sample(seq_group_metadata_list,
|
||||
padded_batch_size)
|
||||
return inputs + sample_inputs
|
||||
|
||||
def _execute_model(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
) -> List[CompletionSequenceGroupOutput]:
|
||||
inputs = self.prepare_inputs(seq_group_metadata_list)
|
||||
# Prepare inputs.
|
||||
assert len(seq_group_metadata_list) > 0
|
||||
# NOTE: We assume that all sequences in the group are all prompts or
|
||||
# all decodes.
|
||||
is_prompt = seq_group_metadata_list[0].is_prompt
|
||||
if is_prompt:
|
||||
inputs = self._prepare_prompt(seq_group_metadata_list)
|
||||
else:
|
||||
inputs = self._prepare_decode(seq_group_metadata_list)
|
||||
padded_batch_size = inputs[0].shape[0]
|
||||
t, p = self._prepare_sample(seq_group_metadata_list, padded_batch_size)
|
||||
|
||||
# Execute the model.
|
||||
next_token_ids = self.model(inputs[0], inputs[1], kv_caches,
|
||||
*inputs[2:])
|
||||
if not self.is_driver_worker:
|
||||
return []
|
||||
*inputs[2:], t, p)
|
||||
# Retrieve the outputs to CPU.
|
||||
next_token_ids = next_token_ids.cpu().tolist()
|
||||
|
||||
# NOTE(woosuk): Minimal code to construct the sampler outputs.
|
||||
# The TPU backend does not reuse the sampler, since the TPU backend
|
||||
# does not support the advanced sampling parameters such as logprobs.
|
||||
i = 0
|
||||
sampler_outputs = []
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
@ -400,6 +424,7 @@ class TPUModelRunner:
|
||||
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
) -> SamplerOutput:
|
||||
assert seq_group_metadata_list is not None
|
||||
assert len(seq_group_metadata_list) > 0
|
||||
if seq_group_metadata_list[0].is_prompt:
|
||||
# NOTE(woosuk): To reduce the compilation time, we only compile the
|
||||
# prefill inputs with batch size 1. Because the scheduler is not
|
||||
@ -492,8 +517,8 @@ class ModelWrapper(nn.Module):
|
||||
logits = self.model.compute_logits(hidden_states, sampling_metadata)
|
||||
|
||||
logits = logits / t.unsqueeze(dim=1)
|
||||
# FIXME(woosuk): Disabled top-p sampling since it's too slow.
|
||||
# logits = _apply_top_p(logits, p.unsqueeze(dim=1))
|
||||
if _ENABLE_TOP_P:
|
||||
logits = _apply_top_p(logits, p.unsqueeze(dim=1))
|
||||
probs = torch.softmax(logits, dim=-1, dtype=torch.float32)
|
||||
# FIXME(woosuk): best_of > 1 is not supported.
|
||||
next_token_ids = torch.multinomial(probs, num_samples=1).squeeze(dim=1)
|
||||
|
Reference in New Issue
Block a user