[Core] [Bugfix]: tensor parallel with prompt embeds (#18171)

Signed-off-by: Nan2018 <nan@protopia.ai>
Co-authored-by: Andrew Sansom <andrew@protopia.ai>
This commit is contained in:
Nan Qin
2025-05-19 22:21:27 -05:00
committed by GitHub
parent f07a673eb2
commit 9609327fa4
4 changed files with 138 additions and 64 deletions

View File

@ -8,12 +8,13 @@ import weakref
from unittest.mock import Mock
import pytest
import torch
from vllm import LLM
from vllm import LLM, envs
from vllm.platforms import current_platform
from vllm.v1.engine.llm_engine import LLMEngine as LLMEngineV1
from ..conftest import VllmRunner
from ..conftest import HfRunner, VllmRunner
from ..models.utils import check_outputs_equal
from ..utils import multi_gpu_test
@ -43,11 +44,26 @@ def test_vllm_gc_ed():
assert weak_llm() is None
def _fix_prompt_embed_outputs(
vllm_outputs: list[tuple[list[int], str]], hf_model: HfRunner,
example_prompts: list[str]) -> list[tuple[list[int], str]]:
fixed_vllm_outputs = []
for vllm_output, hf_input, prompt in zip(
vllm_outputs, hf_model.get_inputs(example_prompts),
example_prompts):
hf_input_ids = hf_input["input_ids"].tolist()[0]
fixed_vllm_outputs.append(
(hf_input_ids + vllm_output[0][len(hf_input_ids):],
prompt + vllm_output[1]))
return fixed_vllm_outputs
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("backend", ["FLASH_ATTN"])
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [5])
@pytest.mark.parametrize("enforce_eager", [False])
@pytest.mark.parametrize("enable_prompt_embeds", [True, False])
def test_models(
monkeypatch: pytest.MonkeyPatch,
hf_runner,
@ -56,8 +72,13 @@ def test_models(
dtype: str,
max_tokens: int,
enforce_eager: bool,
enable_prompt_embeds: bool,
) -> None:
if enable_prompt_embeds and envs.is_set(
"VLLM_USE_V1") and envs.VLLM_USE_V1:
pytest.skip("enable_prompt_embeds is not supported in v1.")
if backend == "FLASHINFER" and current_platform.is_rocm():
pytest.skip("Flashinfer does not support ROCm/HIP.")
@ -78,14 +99,25 @@ def test_models(
with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
if enable_prompt_embeds:
with torch.no_grad():
prompt_embeds = hf_model.get_prompt_embeddings(
example_prompts)
with VllmRunner(model,
max_model_len=8192,
dtype=dtype,
enforce_eager=enforce_eager,
enable_prompt_embeds=enable_prompt_embeds,
gpu_memory_utilization=0.7) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts,
max_tokens)
if enable_prompt_embeds:
vllm_outputs = vllm_model.generate_greedy(
prompt_embeds, max_tokens)
vllm_outputs = _fix_prompt_embed_outputs(
vllm_outputs, hf_model, example_prompts)
else:
vllm_outputs = vllm_model.generate_greedy(
example_prompts, max_tokens)
check_outputs_equal(
outputs_0_lst=hf_outputs,
@ -108,6 +140,7 @@ def test_models(
("distilbert/distilgpt2", "mp", "FLASHINFER", "A100"),
("meta-llama/Meta-Llama-3-8B", "ray", "FLASHINFER", "A100"),
])
@pytest.mark.parametrize("enable_prompt_embeds", [True, False])
def test_models_distributed(
monkeypatch: pytest.MonkeyPatch,
hf_runner,
@ -117,14 +150,22 @@ def test_models_distributed(
distributed_executor_backend: str,
attention_backend: str,
test_suite: str,
enable_prompt_embeds: bool,
) -> None:
if enable_prompt_embeds and envs.is_set(
"VLLM_USE_V1") and envs.VLLM_USE_V1:
pytest.skip("enable_prompt_embeds is not supported in v1.")
if test_suite != TARGET_TEST_SUITE:
pytest.skip(f"Skip test for {test_suite}")
with monkeypatch.context() as monkeypatch_context:
if model == "meta-llama/Llama-3.2-1B-Instruct" and distributed_executor_backend == "ray" and attention_backend == "" and test_suite == "L4": # noqa
# test Ray Compiled Graph
if enable_prompt_embeds:
pytest.skip(
"enable_prompt_embeds does not work with ray compiled dag."
)
monkeypatch_context.setenv("VLLM_USE_RAY_SPMD_WORKER", "1")
monkeypatch_context.setenv("VLLM_USE_RAY_COMPILED_DAG", "1")
@ -147,12 +188,26 @@ def test_models_distributed(
dtype=dtype,
tensor_parallel_size=2,
distributed_executor_backend=distributed_executor_backend,
enable_prompt_embeds=enable_prompt_embeds,
gpu_memory_utilization=0.7,
) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts,
max_tokens)
with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
if enable_prompt_embeds:
with hf_runner(model, dtype=dtype) as hf_model:
with torch.no_grad():
prompt_embeds = hf_model.get_prompt_embeddings(
example_prompts)
vllm_outputs = vllm_model.generate_greedy(
prompt_embeds, max_tokens)
vllm_outputs = _fix_prompt_embed_outputs(
vllm_outputs, hf_model, example_prompts)
hf_outputs = hf_model.generate_greedy(
example_prompts, max_tokens)
else:
vllm_outputs = vllm_model.generate_greedy(
example_prompts, max_tokens)
with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy(
example_prompts, max_tokens)
check_outputs_equal(
outputs_0_lst=hf_outputs,

View File

@ -430,6 +430,15 @@ class HfRunner:
return all_inputs
def get_prompt_embeddings(self, prompts: list[str]) -> list[torch.Tensor]:
all_inputs = self.get_inputs(prompts)
embeddings = []
for inputs in all_inputs:
input_ids = self.wrap_device(inputs)["input_ids"]
embedding = self.model.get_input_embeddings()(input_ids).squeeze(0)
embeddings.append(embedding)
return embeddings
def classify(self, prompts: list[str]) -> list[str]:
# output is final logits
all_inputs = self.get_inputs(prompts)

View File

@ -112,12 +112,12 @@ class RequestMetrics:
will include model forward, block/sync across
workers, cpu-gpu sync time and sampling time.
spec_token_acceptance_counts: number of accepted speculative tokens at
each position; the first token is from
each position; the first token is from
the target model and is always accepted;
e.g., when it's [10, 8, 4, 2] for a req,
e.g., when it's [10, 8, 4, 2] for a req,
it means there were 10 forward passes in
total, and there were 8, 4, 2 accepted
tokens at 1st, 2nd, 3rd speculation step.
total, and there were 8, 4, 2 accepted
tokens at 1st, 2nd, 3rd speculation step.
"""
arrival_time: float
last_token_time: float
@ -714,9 +714,9 @@ class SequenceGroup:
trace_headers: OpenTelemetry trace headers.
prompt_adapter_request: Prompt Adapter request.
priority: User-defined priority of the request.
draft_size: The number of speculative tokens plus one from the target
draft_size: The number of speculative tokens plus one from the target
model; equal to max number of tokens a step can generate
for single-draft speculative decoding but larger than
for single-draft speculative decoding but larger than
that for multi-draft SD (currently not supported).
"""
@ -1123,7 +1123,7 @@ class SequenceOutput(
self.output_embed.shape if self.output_embed is not None else None
return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, "
f"output_token={self.output_token}, "
f"output_embed.shape={output_embed_shape}"
f"output_embed.shape={output_embed_shape}, "
f"logprobs={self.logprobs})")
def __eq__(self, other: object) -> bool:

View File

@ -23,7 +23,7 @@ from vllm.attention.backends.abstract import AttentionState
from vllm.attention.backends.utils import CommonAttentionState
from vllm.config import CompilationLevel, VllmConfig
from vllm.core.scheduler import SchedulerOutputs
from vllm.distributed import get_pp_group
from vllm.distributed import broadcast_tensor_dict, get_pp_group
from vllm.distributed.kv_transfer import get_kv_transfer_group
from vllm.distributed.parallel_state import (get_tensor_model_parallel_rank,
graph_capture)
@ -872,7 +872,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
"""
# Combine and flatten intermediate data.
input_tokens = list[int]()
inputs_embeds_lst = list[torch.Tensor]()
inputs_embeds_list = list[torch.Tensor]()
token_types = list[int]()
for inter_data in self.inter_data_list:
for cur_input_tokens in inter_data.input_tokens:
@ -880,15 +880,15 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
for cur_token_types in inter_data.token_types:
token_types.extend(cur_token_types)
if inter_data.inputs_embeds is not None:
inputs_embeds_lst.append(
inputs_embeds_list.append(
inter_data.inputs_embeds.to(
dtype=self.runner.model_config.dtype,
device=self.runner.device))
inputs_embeds: Optional[torch.Tensor]
if len(inputs_embeds_lst) == 0:
if len(inputs_embeds_list) == 0:
inputs_embeds = None
else:
inputs_embeds = torch.cat(inputs_embeds_lst, dim=0).to(
inputs_embeds = torch.cat(inputs_embeds_list, dim=0).to(
dtype=self.runner.model_config.dtype,
device=self.runner.device)
assert len(inputs_embeds) == len(input_tokens)
@ -1893,51 +1893,61 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
logits = self.model.compute_logits(hidden_or_intermediate_states,
model_input.sampling_metadata)
if self.is_driver_worker:
if model_input.async_callback is not None:
model_input.async_callback()
# Sample the next token.
assert isinstance(self.sampler, Sampler)
orig_include_gpu_probs = self.sampler.include_gpu_probs_tensor
if model_input.inputs_embeds is not None:
self.sampler.include_gpu_probs_tensor = True
output: SamplerOutput = self.sampler(
logits=logits,
sampling_metadata=model_input.sampling_metadata,
)
if (self.observability_config is not None
and self.observability_config.collect_model_forward_time
and output is not None):
model_forward_end.synchronize()
model_forward_time = model_forward_start.elapsed_time(
model_forward_end)
orig_model_forward_time = 0.0
if intermediate_tensors is not None:
orig_model_forward_time = intermediate_tensors.tensors.get(
"model_forward_time", torch.tensor(0.0)).item()
# If there are multiple workers, we are still tracking the
# latency from the start time of the driver worker to the end
# time of the driver worker. The model forward time will then
# end up covering the communication time as well.
output.model_forward_time = (orig_model_forward_time +
model_forward_time)
if model_input.inputs_embeds is not None:
if self.is_driver_worker:
sampled = broadcast_tensor_dict(
{"token_ids": output.sampled_token_ids})
else:
sampled = broadcast_tensor_dict()
if sampled["token_ids"] is not None:
sampled_token_embeds = self.model.get_input_embeddings(
sampled["token_ids"].squeeze(1))
if self.is_driver_worker:
self.sampler.include_gpu_probs_tensor = \
orig_include_gpu_probs
output.sampled_token_embeds = sampled_token_embeds
for token_embed, sequence_group_output in zip(
output.sampled_token_embeds, output.outputs):
assert len(sequence_group_output.samples) == 1
sequence_group_output.samples[
0].output_embed = token_embed
if not self.is_driver_worker:
return []
if model_input.async_callback is not None:
model_input.async_callback()
# Sample the next token.
assert isinstance(self.sampler, Sampler)
orig_include_gpu_probs_tensor = self.sampler.include_gpu_probs_tensor
if model_input.inputs_embeds is not None:
self.sampler.include_gpu_probs_tensor = True
output: SamplerOutput = self.sampler(
logits=logits,
sampling_metadata=model_input.sampling_metadata,
)
if (self.observability_config is not None
and self.observability_config.collect_model_forward_time
and output is not None):
model_forward_end.synchronize()
model_forward_time = model_forward_start.elapsed_time(
model_forward_end)
orig_model_forward_time = 0.0
if intermediate_tensors is not None:
orig_model_forward_time = intermediate_tensors.tensors.get(
"model_forward_time", torch.tensor(0.0)).item()
# If there are multiple workers, we are still tracking the latency
# from the start time of the driver worker to the end time of the
# driver worker. The model forward time will then end up covering
# the communication time as well.
output.model_forward_time = (orig_model_forward_time +
model_forward_time)
if model_input.inputs_embeds is not None:
self.sampler.include_gpu_probs_tensor = \
orig_include_gpu_probs_tensor
if output.sampled_token_ids is not None:
output.sampled_token_embeds = self.model.get_input_embeddings(
output.sampled_token_ids.squeeze(1))
for token_embed, sequence_group_output in zip(
output.sampled_token_embeds, output.outputs):
assert len(sequence_group_output.samples) == 1
sequence_group_output.samples[0].output_embed = token_embed
if self.return_hidden_states:
# we only need to pass hidden states of most recent token
assert model_input.sampling_metadata is not None