mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Core] [Bugfix]: tensor parallel with prompt embeds (#18171)
Signed-off-by: Nan2018 <nan@protopia.ai> Co-authored-by: Andrew Sansom <andrew@protopia.ai>
This commit is contained in:
@ -8,12 +8,13 @@ import weakref
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm import LLM
|
||||
from vllm import LLM, envs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.engine.llm_engine import LLMEngine as LLMEngineV1
|
||||
|
||||
from ..conftest import VllmRunner
|
||||
from ..conftest import HfRunner, VllmRunner
|
||||
from ..models.utils import check_outputs_equal
|
||||
from ..utils import multi_gpu_test
|
||||
|
||||
@ -43,11 +44,26 @@ def test_vllm_gc_ed():
|
||||
assert weak_llm() is None
|
||||
|
||||
|
||||
def _fix_prompt_embed_outputs(
|
||||
vllm_outputs: list[tuple[list[int], str]], hf_model: HfRunner,
|
||||
example_prompts: list[str]) -> list[tuple[list[int], str]]:
|
||||
fixed_vllm_outputs = []
|
||||
for vllm_output, hf_input, prompt in zip(
|
||||
vllm_outputs, hf_model.get_inputs(example_prompts),
|
||||
example_prompts):
|
||||
hf_input_ids = hf_input["input_ids"].tolist()[0]
|
||||
fixed_vllm_outputs.append(
|
||||
(hf_input_ids + vllm_output[0][len(hf_input_ids):],
|
||||
prompt + vllm_output[1]))
|
||||
return fixed_vllm_outputs
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("backend", ["FLASH_ATTN"])
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("max_tokens", [5])
|
||||
@pytest.mark.parametrize("enforce_eager", [False])
|
||||
@pytest.mark.parametrize("enable_prompt_embeds", [True, False])
|
||||
def test_models(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
hf_runner,
|
||||
@ -56,8 +72,13 @@ def test_models(
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
enforce_eager: bool,
|
||||
enable_prompt_embeds: bool,
|
||||
) -> None:
|
||||
|
||||
if enable_prompt_embeds and envs.is_set(
|
||||
"VLLM_USE_V1") and envs.VLLM_USE_V1:
|
||||
pytest.skip("enable_prompt_embeds is not supported in v1.")
|
||||
|
||||
if backend == "FLASHINFER" and current_platform.is_rocm():
|
||||
pytest.skip("Flashinfer does not support ROCm/HIP.")
|
||||
|
||||
@ -78,14 +99,25 @@ def test_models(
|
||||
|
||||
with hf_runner(model, dtype=dtype) as hf_model:
|
||||
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
|
||||
if enable_prompt_embeds:
|
||||
with torch.no_grad():
|
||||
prompt_embeds = hf_model.get_prompt_embeddings(
|
||||
example_prompts)
|
||||
|
||||
with VllmRunner(model,
|
||||
max_model_len=8192,
|
||||
dtype=dtype,
|
||||
enforce_eager=enforce_eager,
|
||||
enable_prompt_embeds=enable_prompt_embeds,
|
||||
gpu_memory_utilization=0.7) as vllm_model:
|
||||
vllm_outputs = vllm_model.generate_greedy(example_prompts,
|
||||
max_tokens)
|
||||
if enable_prompt_embeds:
|
||||
vllm_outputs = vllm_model.generate_greedy(
|
||||
prompt_embeds, max_tokens)
|
||||
vllm_outputs = _fix_prompt_embed_outputs(
|
||||
vllm_outputs, hf_model, example_prompts)
|
||||
else:
|
||||
vllm_outputs = vllm_model.generate_greedy(
|
||||
example_prompts, max_tokens)
|
||||
|
||||
check_outputs_equal(
|
||||
outputs_0_lst=hf_outputs,
|
||||
@ -108,6 +140,7 @@ def test_models(
|
||||
("distilbert/distilgpt2", "mp", "FLASHINFER", "A100"),
|
||||
("meta-llama/Meta-Llama-3-8B", "ray", "FLASHINFER", "A100"),
|
||||
])
|
||||
@pytest.mark.parametrize("enable_prompt_embeds", [True, False])
|
||||
def test_models_distributed(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
hf_runner,
|
||||
@ -117,14 +150,22 @@ def test_models_distributed(
|
||||
distributed_executor_backend: str,
|
||||
attention_backend: str,
|
||||
test_suite: str,
|
||||
enable_prompt_embeds: bool,
|
||||
) -> None:
|
||||
|
||||
if enable_prompt_embeds and envs.is_set(
|
||||
"VLLM_USE_V1") and envs.VLLM_USE_V1:
|
||||
pytest.skip("enable_prompt_embeds is not supported in v1.")
|
||||
|
||||
if test_suite != TARGET_TEST_SUITE:
|
||||
pytest.skip(f"Skip test for {test_suite}")
|
||||
|
||||
with monkeypatch.context() as monkeypatch_context:
|
||||
if model == "meta-llama/Llama-3.2-1B-Instruct" and distributed_executor_backend == "ray" and attention_backend == "" and test_suite == "L4": # noqa
|
||||
# test Ray Compiled Graph
|
||||
if enable_prompt_embeds:
|
||||
pytest.skip(
|
||||
"enable_prompt_embeds does not work with ray compiled dag."
|
||||
)
|
||||
monkeypatch_context.setenv("VLLM_USE_RAY_SPMD_WORKER", "1")
|
||||
monkeypatch_context.setenv("VLLM_USE_RAY_COMPILED_DAG", "1")
|
||||
|
||||
@ -147,12 +188,26 @@ def test_models_distributed(
|
||||
dtype=dtype,
|
||||
tensor_parallel_size=2,
|
||||
distributed_executor_backend=distributed_executor_backend,
|
||||
enable_prompt_embeds=enable_prompt_embeds,
|
||||
gpu_memory_utilization=0.7,
|
||||
) as vllm_model:
|
||||
vllm_outputs = vllm_model.generate_greedy(example_prompts,
|
||||
max_tokens)
|
||||
|
||||
with hf_runner(model, dtype=dtype) as hf_model:
|
||||
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
|
||||
if enable_prompt_embeds:
|
||||
with hf_runner(model, dtype=dtype) as hf_model:
|
||||
with torch.no_grad():
|
||||
prompt_embeds = hf_model.get_prompt_embeddings(
|
||||
example_prompts)
|
||||
vllm_outputs = vllm_model.generate_greedy(
|
||||
prompt_embeds, max_tokens)
|
||||
vllm_outputs = _fix_prompt_embed_outputs(
|
||||
vllm_outputs, hf_model, example_prompts)
|
||||
hf_outputs = hf_model.generate_greedy(
|
||||
example_prompts, max_tokens)
|
||||
else:
|
||||
vllm_outputs = vllm_model.generate_greedy(
|
||||
example_prompts, max_tokens)
|
||||
with hf_runner(model, dtype=dtype) as hf_model:
|
||||
hf_outputs = hf_model.generate_greedy(
|
||||
example_prompts, max_tokens)
|
||||
|
||||
check_outputs_equal(
|
||||
outputs_0_lst=hf_outputs,
|
||||
|
@ -430,6 +430,15 @@ class HfRunner:
|
||||
|
||||
return all_inputs
|
||||
|
||||
def get_prompt_embeddings(self, prompts: list[str]) -> list[torch.Tensor]:
|
||||
all_inputs = self.get_inputs(prompts)
|
||||
embeddings = []
|
||||
for inputs in all_inputs:
|
||||
input_ids = self.wrap_device(inputs)["input_ids"]
|
||||
embedding = self.model.get_input_embeddings()(input_ids).squeeze(0)
|
||||
embeddings.append(embedding)
|
||||
return embeddings
|
||||
|
||||
def classify(self, prompts: list[str]) -> list[str]:
|
||||
# output is final logits
|
||||
all_inputs = self.get_inputs(prompts)
|
||||
|
@ -112,12 +112,12 @@ class RequestMetrics:
|
||||
will include model forward, block/sync across
|
||||
workers, cpu-gpu sync time and sampling time.
|
||||
spec_token_acceptance_counts: number of accepted speculative tokens at
|
||||
each position; the first token is from
|
||||
each position; the first token is from
|
||||
the target model and is always accepted;
|
||||
e.g., when it's [10, 8, 4, 2] for a req,
|
||||
e.g., when it's [10, 8, 4, 2] for a req,
|
||||
it means there were 10 forward passes in
|
||||
total, and there were 8, 4, 2 accepted
|
||||
tokens at 1st, 2nd, 3rd speculation step.
|
||||
total, and there were 8, 4, 2 accepted
|
||||
tokens at 1st, 2nd, 3rd speculation step.
|
||||
"""
|
||||
arrival_time: float
|
||||
last_token_time: float
|
||||
@ -714,9 +714,9 @@ class SequenceGroup:
|
||||
trace_headers: OpenTelemetry trace headers.
|
||||
prompt_adapter_request: Prompt Adapter request.
|
||||
priority: User-defined priority of the request.
|
||||
draft_size: The number of speculative tokens plus one from the target
|
||||
draft_size: The number of speculative tokens plus one from the target
|
||||
model; equal to max number of tokens a step can generate
|
||||
for single-draft speculative decoding but larger than
|
||||
for single-draft speculative decoding but larger than
|
||||
that for multi-draft SD (currently not supported).
|
||||
"""
|
||||
|
||||
@ -1123,7 +1123,7 @@ class SequenceOutput(
|
||||
self.output_embed.shape if self.output_embed is not None else None
|
||||
return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, "
|
||||
f"output_token={self.output_token}, "
|
||||
f"output_embed.shape={output_embed_shape}"
|
||||
f"output_embed.shape={output_embed_shape}, "
|
||||
f"logprobs={self.logprobs})")
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
|
@ -23,7 +23,7 @@ from vllm.attention.backends.abstract import AttentionState
|
||||
from vllm.attention.backends.utils import CommonAttentionState
|
||||
from vllm.config import CompilationLevel, VllmConfig
|
||||
from vllm.core.scheduler import SchedulerOutputs
|
||||
from vllm.distributed import get_pp_group
|
||||
from vllm.distributed import broadcast_tensor_dict, get_pp_group
|
||||
from vllm.distributed.kv_transfer import get_kv_transfer_group
|
||||
from vllm.distributed.parallel_state import (get_tensor_model_parallel_rank,
|
||||
graph_capture)
|
||||
@ -872,7 +872,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
"""
|
||||
# Combine and flatten intermediate data.
|
||||
input_tokens = list[int]()
|
||||
inputs_embeds_lst = list[torch.Tensor]()
|
||||
inputs_embeds_list = list[torch.Tensor]()
|
||||
token_types = list[int]()
|
||||
for inter_data in self.inter_data_list:
|
||||
for cur_input_tokens in inter_data.input_tokens:
|
||||
@ -880,15 +880,15 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
for cur_token_types in inter_data.token_types:
|
||||
token_types.extend(cur_token_types)
|
||||
if inter_data.inputs_embeds is not None:
|
||||
inputs_embeds_lst.append(
|
||||
inputs_embeds_list.append(
|
||||
inter_data.inputs_embeds.to(
|
||||
dtype=self.runner.model_config.dtype,
|
||||
device=self.runner.device))
|
||||
inputs_embeds: Optional[torch.Tensor]
|
||||
if len(inputs_embeds_lst) == 0:
|
||||
if len(inputs_embeds_list) == 0:
|
||||
inputs_embeds = None
|
||||
else:
|
||||
inputs_embeds = torch.cat(inputs_embeds_lst, dim=0).to(
|
||||
inputs_embeds = torch.cat(inputs_embeds_list, dim=0).to(
|
||||
dtype=self.runner.model_config.dtype,
|
||||
device=self.runner.device)
|
||||
assert len(inputs_embeds) == len(input_tokens)
|
||||
@ -1893,51 +1893,61 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
|
||||
logits = self.model.compute_logits(hidden_or_intermediate_states,
|
||||
model_input.sampling_metadata)
|
||||
|
||||
if self.is_driver_worker:
|
||||
if model_input.async_callback is not None:
|
||||
model_input.async_callback()
|
||||
|
||||
# Sample the next token.
|
||||
assert isinstance(self.sampler, Sampler)
|
||||
orig_include_gpu_probs = self.sampler.include_gpu_probs_tensor
|
||||
if model_input.inputs_embeds is not None:
|
||||
self.sampler.include_gpu_probs_tensor = True
|
||||
|
||||
output: SamplerOutput = self.sampler(
|
||||
logits=logits,
|
||||
sampling_metadata=model_input.sampling_metadata,
|
||||
)
|
||||
if (self.observability_config is not None
|
||||
and self.observability_config.collect_model_forward_time
|
||||
and output is not None):
|
||||
model_forward_end.synchronize()
|
||||
model_forward_time = model_forward_start.elapsed_time(
|
||||
model_forward_end)
|
||||
orig_model_forward_time = 0.0
|
||||
if intermediate_tensors is not None:
|
||||
orig_model_forward_time = intermediate_tensors.tensors.get(
|
||||
"model_forward_time", torch.tensor(0.0)).item()
|
||||
# If there are multiple workers, we are still tracking the
|
||||
# latency from the start time of the driver worker to the end
|
||||
# time of the driver worker. The model forward time will then
|
||||
# end up covering the communication time as well.
|
||||
output.model_forward_time = (orig_model_forward_time +
|
||||
model_forward_time)
|
||||
|
||||
if model_input.inputs_embeds is not None:
|
||||
if self.is_driver_worker:
|
||||
sampled = broadcast_tensor_dict(
|
||||
{"token_ids": output.sampled_token_ids})
|
||||
else:
|
||||
sampled = broadcast_tensor_dict()
|
||||
if sampled["token_ids"] is not None:
|
||||
sampled_token_embeds = self.model.get_input_embeddings(
|
||||
sampled["token_ids"].squeeze(1))
|
||||
if self.is_driver_worker:
|
||||
self.sampler.include_gpu_probs_tensor = \
|
||||
orig_include_gpu_probs
|
||||
|
||||
output.sampled_token_embeds = sampled_token_embeds
|
||||
|
||||
for token_embed, sequence_group_output in zip(
|
||||
output.sampled_token_embeds, output.outputs):
|
||||
assert len(sequence_group_output.samples) == 1
|
||||
sequence_group_output.samples[
|
||||
0].output_embed = token_embed
|
||||
|
||||
if not self.is_driver_worker:
|
||||
return []
|
||||
|
||||
if model_input.async_callback is not None:
|
||||
model_input.async_callback()
|
||||
|
||||
# Sample the next token.
|
||||
assert isinstance(self.sampler, Sampler)
|
||||
orig_include_gpu_probs_tensor = self.sampler.include_gpu_probs_tensor
|
||||
if model_input.inputs_embeds is not None:
|
||||
self.sampler.include_gpu_probs_tensor = True
|
||||
|
||||
output: SamplerOutput = self.sampler(
|
||||
logits=logits,
|
||||
sampling_metadata=model_input.sampling_metadata,
|
||||
)
|
||||
if (self.observability_config is not None
|
||||
and self.observability_config.collect_model_forward_time
|
||||
and output is not None):
|
||||
model_forward_end.synchronize()
|
||||
model_forward_time = model_forward_start.elapsed_time(
|
||||
model_forward_end)
|
||||
orig_model_forward_time = 0.0
|
||||
if intermediate_tensors is not None:
|
||||
orig_model_forward_time = intermediate_tensors.tensors.get(
|
||||
"model_forward_time", torch.tensor(0.0)).item()
|
||||
# If there are multiple workers, we are still tracking the latency
|
||||
# from the start time of the driver worker to the end time of the
|
||||
# driver worker. The model forward time will then end up covering
|
||||
# the communication time as well.
|
||||
output.model_forward_time = (orig_model_forward_time +
|
||||
model_forward_time)
|
||||
|
||||
if model_input.inputs_embeds is not None:
|
||||
self.sampler.include_gpu_probs_tensor = \
|
||||
orig_include_gpu_probs_tensor
|
||||
if output.sampled_token_ids is not None:
|
||||
output.sampled_token_embeds = self.model.get_input_embeddings(
|
||||
output.sampled_token_ids.squeeze(1))
|
||||
|
||||
for token_embed, sequence_group_output in zip(
|
||||
output.sampled_token_embeds, output.outputs):
|
||||
assert len(sequence_group_output.samples) == 1
|
||||
sequence_group_output.samples[0].output_embed = token_embed
|
||||
|
||||
if self.return_hidden_states:
|
||||
# we only need to pass hidden states of most recent token
|
||||
assert model_input.sampling_metadata is not None
|
||||
|
Reference in New Issue
Block a user