mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[core] Multi Step Scheduling (#7000)
Co-authored-by: afeldman-nm <156691304+afeldman-nm@users.noreply.github.com>
This commit is contained in:
@ -311,6 +311,15 @@ steps:
|
||||
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
|
||||
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s distributed/test_utils.py
|
||||
|
||||
- label: Multi-step Tests (4 GPUs) # 10min
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
num_gpus: 4
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/multi_step/test_correctness.py
|
||||
commands:
|
||||
- pytest -v -s multi_step/test_correctness.py
|
||||
|
||||
- label: Pipeline Parallelism Test # 23min
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
num_gpus: 4
|
||||
|
0
tests/multi_step/__init__.py
Normal file
0
tests/multi_step/__init__.py
Normal file
85
tests/multi_step/test_correctness.py
Normal file
85
tests/multi_step/test_correctness.py
Normal file
@ -0,0 +1,85 @@
|
||||
# Test the AsyncLLMEngine with multi-step-decoding
|
||||
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
|
||||
from ..utils import RemoteOpenAIServer
|
||||
|
||||
MODELS = [
|
||||
"JackFram/llama-160m",
|
||||
]
|
||||
NUM_SCHEDULER_STEPS = [8] # Multi-step decoding steps
|
||||
NUM_PROMPTS = [10]
|
||||
|
||||
DEFAULT_SERVER_ARGS: List[str] = [
|
||||
"--disable-log-requests",
|
||||
"--use-v2-block-manager",
|
||||
"--worker-use-ray",
|
||||
"--gpu-memory-utilization",
|
||||
"0.85",
|
||||
"--swap-space",
|
||||
"16",
|
||||
]
|
||||
|
||||
|
||||
async def completions_with_server_args(prompts: List[str], model_name: str,
|
||||
server_cli_args: List[str]):
|
||||
|
||||
outputs = None
|
||||
with RemoteOpenAIServer(model_name, server_cli_args) as server:
|
||||
client = server.get_async_client()
|
||||
outputs = await client.completions.create(model=model_name,
|
||||
prompt=prompts,
|
||||
temperature=0,
|
||||
stream=False,
|
||||
max_tokens=5)
|
||||
assert outputs is not None
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize(("tp_size, pp_size"), [
|
||||
(1, 1),
|
||||
(2, 2),
|
||||
])
|
||||
@pytest.mark.parametrize("eager_mode", [False, True])
|
||||
@pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS)
|
||||
@pytest.mark.parametrize("num_prompts", NUM_PROMPTS)
|
||||
@pytest.mark.asyncio
|
||||
async def test_multi_step(example_prompts, model: str, tp_size: int,
|
||||
pp_size: int, eager_mode: int,
|
||||
num_scheduler_steps: int, num_prompts: int):
|
||||
|
||||
prompts = example_prompts
|
||||
if len(prompts) < num_prompts:
|
||||
prompts = prompts * ((num_prompts // len(prompts)) + 1)
|
||||
prompts = prompts[:num_prompts]
|
||||
assert len(prompts) == num_prompts
|
||||
|
||||
server_args = DEFAULT_SERVER_ARGS + ["--enforce-eager"]
|
||||
ms_server_args = DEFAULT_SERVER_ARGS + \
|
||||
["--num-scheduler-steps", f"{num_scheduler_steps}"]
|
||||
|
||||
if eager_mode:
|
||||
ms_server_args.append("--enforce-eager")
|
||||
|
||||
distributed_args = [
|
||||
"--tensor-parallel-size",
|
||||
str(tp_size),
|
||||
"--pipeline-parallel-size",
|
||||
str(pp_size),
|
||||
]
|
||||
|
||||
ref_completions = await completions_with_server_args(
|
||||
prompts, model, server_args + distributed_args)
|
||||
test_completions = await completions_with_server_args(
|
||||
prompts, model, ms_server_args + distributed_args)
|
||||
|
||||
def get_text_generations(completions):
|
||||
return [x.text for x in completions.choices]
|
||||
|
||||
ref_generations = get_text_generations(ref_completions)
|
||||
test_generations = get_text_generations(test_completions)
|
||||
assert ref_generations == test_generations
|
@ -10,6 +10,7 @@ from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.worker.embedding_model_runner import (
|
||||
ModelInputForGPUWithPoolingMetadata)
|
||||
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
|
||||
from vllm.worker.multi_step_model_runner import StatefulModelInput
|
||||
|
||||
|
||||
class MockAttentionBackend(AttentionBackend):
|
||||
@ -154,3 +155,79 @@ def test_embedding_model_runner_input():
|
||||
None) == getattr(attn_metadata, field.name, None)
|
||||
# Pooling metadata is not broadcast.
|
||||
assert received_model_input.pooling_metadata is None
|
||||
|
||||
|
||||
def test_multi_step_model_runner_input():
|
||||
sampling_metadata = SamplingMetadata(
|
||||
["seq_group"],
|
||||
"selected_token_indices",
|
||||
"categorized_sample_indices",
|
||||
"num_prompts",
|
||||
)
|
||||
attn_metadata = AttentionMetadata(
|
||||
num_prefills=1,
|
||||
num_prefill_tokens=2,
|
||||
num_decode_tokens=3,
|
||||
slot_mapping=torch.zeros(1),
|
||||
)
|
||||
frozen_model_input = ModelInputForGPUWithSamplingMetadata(
|
||||
input_tokens=torch.ones(10),
|
||||
input_positions=torch.ones(10),
|
||||
sampling_metadata=sampling_metadata,
|
||||
attn_metadata=attn_metadata)
|
||||
|
||||
model_input = StatefulModelInput(
|
||||
frozen_model_input=frozen_model_input,
|
||||
is_last_step=True,
|
||||
is_first_multi_step=False,
|
||||
current_step=4,
|
||||
last_sampled_token_ids=torch.ones((10, 1)),
|
||||
is_multi_step=True,
|
||||
num_queries=8,
|
||||
num_seqs=5,
|
||||
cached_outputs=[],
|
||||
)
|
||||
|
||||
assert isinstance(model_input, StatefulModelInput)
|
||||
|
||||
# Test round trip serialization.
|
||||
tensor_dict = model_input.as_broadcastable_tensor_dict()
|
||||
attn_backend = MockAttentionBackend()
|
||||
received_model_input = (StatefulModelInput.from_broadcasted_tensor_dict(
|
||||
tensor_dict, attn_backend=attn_backend))
|
||||
|
||||
receieved_frozen_input = received_model_input.frozen_model_input
|
||||
|
||||
# Check that received copy has correct values.
|
||||
assert isinstance(received_model_input, StatefulModelInput)
|
||||
assert receieved_frozen_input.input_tokens is not None
|
||||
assert (receieved_frozen_input.input_tokens ==
|
||||
frozen_model_input.input_tokens).all()
|
||||
assert receieved_frozen_input.input_positions is not None
|
||||
assert (receieved_frozen_input.input_positions ==
|
||||
frozen_model_input.input_positions).all()
|
||||
assert receieved_frozen_input.multi_modal_kwargs is None
|
||||
assert (frozen_model_input.multi_modal_kwargs ==
|
||||
frozen_model_input.multi_modal_kwargs)
|
||||
assert receieved_frozen_input.lora_requests is None
|
||||
assert (receieved_frozen_input.lora_requests ==
|
||||
frozen_model_input.lora_requests)
|
||||
assert receieved_frozen_input.lora_mapping is None
|
||||
assert (
|
||||
receieved_frozen_input.lora_mapping == frozen_model_input.lora_mapping)
|
||||
for field in dataclasses.fields(AttentionMetadata):
|
||||
assert getattr(receieved_frozen_input.attn_metadata, field.name,
|
||||
None) == getattr(attn_metadata, field.name, None)
|
||||
# For sampling metadata, only selected_token_indices is copied.
|
||||
assert (receieved_frozen_input.sampling_metadata.selected_token_indices ==
|
||||
sampling_metadata.selected_token_indices)
|
||||
assert receieved_frozen_input.sampling_metadata.seq_groups is None
|
||||
|
||||
# check non frozen fields
|
||||
assert received_model_input.is_last_step == model_input.is_last_step
|
||||
assert (received_model_input.is_first_multi_step ==
|
||||
model_input.is_first_multi_step)
|
||||
assert received_model_input.current_step == model_input.current_step
|
||||
assert (received_model_input.last_sampled_token_ids ==
|
||||
model_input.last_sampled_token_ids).all()
|
||||
assert received_model_input.is_multi_step == model_input.is_multi_step
|
||||
|
@ -853,6 +853,12 @@ class EngineArgs:
|
||||
"in low performance due to small KV cache space. Consider "
|
||||
"setting --max-model-len to a smaller value.", max_model_len)
|
||||
|
||||
if self.num_scheduler_steps > 1 and not self.use_v2_block_manager:
|
||||
self.use_v2_block_manager = True
|
||||
logger.warning(
|
||||
"Enabled BlockSpaceManagerV2 because it is "
|
||||
"required for multi-step (--num-scheduler-steps > 1)")
|
||||
|
||||
speculative_config = SpeculativeConfig.maybe_create_spec_config(
|
||||
target_model_config=model_config,
|
||||
target_parallel_config=parallel_config,
|
||||
@ -881,7 +887,6 @@ class EngineArgs:
|
||||
)
|
||||
|
||||
if self.num_scheduler_steps > 1:
|
||||
raise NotImplementedError("Multi-step is not yet supported.")
|
||||
if speculative_config is not None:
|
||||
raise ValueError("Speculative decoding is not supported with "
|
||||
"multi-step (--num-scheduler-steps > 1)")
|
||||
|
@ -1,9 +1,11 @@
|
||||
import asyncio
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import (AsyncGenerator, Callable, Dict, Iterable, List, Mapping,
|
||||
Optional, Set, Tuple, Type, Union)
|
||||
|
||||
import torch
|
||||
from transformers import PreTrainedTokenizer
|
||||
from typing_extensions import assert_never
|
||||
|
||||
@ -27,7 +29,8 @@ from vllm.outputs import EmbeddingRequestOutput, RequestOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import ExecuteModelRequest, SamplerOutput
|
||||
from vllm.sequence import (ExecuteModelRequest, SamplerOutput,
|
||||
SequenceGroupMetadata)
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import print_warning_once
|
||||
|
||||
@ -249,9 +252,25 @@ class RequestTracker:
|
||||
return not self._new_requests.empty()
|
||||
|
||||
|
||||
@dataclass
|
||||
class SchedulerOutputState:
|
||||
"""Caches the scheduler outputs for a virtual engine. Used for Multi-Step"""
|
||||
last_output: Optional[SamplerOutput] = None
|
||||
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None
|
||||
scheduler_outputs: Optional[SchedulerOutputs] = None
|
||||
|
||||
|
||||
class _AsyncLLMEngine(LLMEngine):
|
||||
"""Extension of LLMEngine to add async methods."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
pipeline_parallel_size = \
|
||||
self.parallel_config.pipeline_parallel_size
|
||||
self.cached_scheduler_outputs = [
|
||||
SchedulerOutputState() for _ in range(pipeline_parallel_size)
|
||||
]
|
||||
|
||||
async def step_async(
|
||||
self, virtual_engine: int
|
||||
) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
|
||||
@ -264,13 +283,39 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
and updates the scheduler with the model outputs. Finally, it decodes
|
||||
the sequences and returns the newly generated results.
|
||||
"""
|
||||
seq_group_metadata_list, scheduler_outputs = self.scheduler[
|
||||
virtual_engine].schedule()
|
||||
# these are cached outputs from previous iterations. None if on first
|
||||
# iteration
|
||||
cached_outputs = self.cached_scheduler_outputs[virtual_engine]
|
||||
seq_group_metadata_list = cached_outputs.seq_group_metadata_list
|
||||
scheduler_outputs = cached_outputs.scheduler_outputs
|
||||
# skip the scheduler if there are any remaining steps in the seq groups.
|
||||
# This ensures that the scheduler is only called again when the current
|
||||
# batch has completed.
|
||||
if not self._has_remaining_steps(seq_group_metadata_list):
|
||||
seq_group_metadata_list, scheduler_outputs = self.scheduler[
|
||||
virtual_engine].schedule()
|
||||
|
||||
if (self.scheduler_config.is_multi_step
|
||||
and scheduler_outputs.num_lookahead_slots > 0):
|
||||
# cache the scheduler outputs for the next iteration if we have
|
||||
# lookahead slots
|
||||
self._cache_scheduler_outputs_for_multi_step(
|
||||
virtual_engine, seq_group_metadata_list, scheduler_outputs)
|
||||
|
||||
assert seq_group_metadata_list is not None
|
||||
assert scheduler_outputs is not None
|
||||
|
||||
if not scheduler_outputs.is_empty():
|
||||
# Execute the model.
|
||||
finished_requests_ids = self.scheduler[
|
||||
virtual_engine].get_and_reset_finished_requests_ids()
|
||||
|
||||
# Check if we have a cached last_output from the previous iteration.
|
||||
# For supporting PP this is probably the best way to pass the
|
||||
# sampled_token_ids, as a separate broadcast over all the PP stages
|
||||
# will cause one virtual engine's microbatch to block the pipeline.
|
||||
last_sampled_token_ids = \
|
||||
self._get_last_sampled_token_ids(virtual_engine)
|
||||
|
||||
execute_model_req = ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
|
||||
@ -279,15 +324,35 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
virtual_engine=virtual_engine,
|
||||
num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
|
||||
running_queue_size=scheduler_outputs.running_queue_size,
|
||||
finished_requests_ids=finished_requests_ids)
|
||||
finished_requests_ids=finished_requests_ids,
|
||||
# We use ExecuteModelRequest to pass the last sampled_token_ids
|
||||
# to each of the non-last PP stages for in-place prepare_input.
|
||||
last_sampled_token_ids=last_sampled_token_ids)
|
||||
# Execute the model.
|
||||
output = await self.model_executor.execute_model_async(
|
||||
execute_model_req)
|
||||
# we need to do this here so that last step's sampled_token_ids can
|
||||
# be passed to the next iteration for PP.
|
||||
if self.scheduler_config.is_multi_step:
|
||||
self._update_cached_scheduler_output(virtual_engine, output)
|
||||
else:
|
||||
output = []
|
||||
|
||||
request_outputs = self._process_model_outputs(
|
||||
output, scheduler_outputs.scheduled_seq_groups,
|
||||
scheduler_outputs.ignored_seq_groups, seq_group_metadata_list)
|
||||
# Finish the current step for all the sequence groups.
|
||||
if self.scheduler_config.is_multi_step:
|
||||
for seq_group in seq_group_metadata_list:
|
||||
seq_group.finish_step()
|
||||
|
||||
if not self._has_remaining_steps(seq_group_metadata_list):
|
||||
# clear the cache if we have finished all the steps
|
||||
if self.scheduler_config.is_multi_step:
|
||||
self.cached_scheduler_outputs[
|
||||
virtual_engine] = SchedulerOutputState()
|
||||
request_outputs = self._process_model_outputs(
|
||||
output, scheduler_outputs.scheduled_seq_groups,
|
||||
scheduler_outputs.ignored_seq_groups, seq_group_metadata_list)
|
||||
else:
|
||||
request_outputs = []
|
||||
|
||||
# Log stats.
|
||||
self.do_log_stats(scheduler_outputs, output)
|
||||
@ -297,6 +362,60 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
|
||||
return request_outputs
|
||||
|
||||
def _has_remaining_steps(
|
||||
self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]]
|
||||
) -> bool:
|
||||
if (not self.scheduler_config.is_multi_step
|
||||
or not seq_group_metadata_list):
|
||||
return False
|
||||
|
||||
# TODO(will) this is a sanity check for nowto make sure that all the
|
||||
# seqs are on the same steps. Eventually we will want to do some sort of
|
||||
# dynamic scheduling when doing multi-step decoding.
|
||||
ref_remaining_steps = seq_group_metadata_list[0].state.remaining_steps
|
||||
if any([
|
||||
seq_group.state.remaining_steps != ref_remaining_steps
|
||||
for seq_group in seq_group_metadata_list[1:]
|
||||
]):
|
||||
raise AssertionError(("All running sequence groups should "
|
||||
"have the same remaining steps."))
|
||||
|
||||
return ref_remaining_steps > 0
|
||||
|
||||
def _cache_scheduler_outputs_for_multi_step(
|
||||
self, virtual_engine: int,
|
||||
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
|
||||
scheduler_outputs: SchedulerOutputs) -> None:
|
||||
self.cached_scheduler_outputs[
|
||||
virtual_engine].seq_group_metadata_list = seq_group_metadata_list
|
||||
self.cached_scheduler_outputs[virtual_engine].scheduler_outputs = \
|
||||
scheduler_outputs
|
||||
self.cached_scheduler_outputs[virtual_engine].last_output = None
|
||||
|
||||
def _get_last_sampled_token_ids(
|
||||
self, virtual_engine: int) -> Optional[torch.Tensor]:
|
||||
cached_last_output = self.cached_scheduler_outputs[
|
||||
virtual_engine].last_output
|
||||
if (self.scheduler_config.is_multi_step
|
||||
and self.parallel_config.pipeline_parallel_size > 1
|
||||
and cached_last_output is not None
|
||||
and cached_last_output.sampled_token_ids_cpu is not None):
|
||||
return cached_last_output.sampled_token_ids_cpu
|
||||
return None
|
||||
|
||||
def _update_cached_scheduler_output(
|
||||
self, virtual_engine: int,
|
||||
output: List[Optional[SamplerOutput]]) -> None:
|
||||
if (self.parallel_config.pipeline_parallel_size > 1 and len(output) > 0
|
||||
and output[0] is not None):
|
||||
last_output = output[-1]
|
||||
assert last_output is not None
|
||||
assert last_output.sampled_token_ids_cpu is not None
|
||||
assert last_output.sampled_token_ids is None
|
||||
assert last_output.sampled_token_probs is None
|
||||
self.cached_scheduler_outputs[
|
||||
virtual_engine].last_output = last_output
|
||||
|
||||
async def stop_remote_worker_execution_loop_async(self) -> None:
|
||||
"""Stop the remote worker execution loop."""
|
||||
await self.model_executor.stop_remote_worker_execution_loop_async()
|
||||
|
@ -69,13 +69,19 @@ class GPUExecutor(ExecutorBase):
|
||||
distributed_init_method: Optional[str] = None) -> Dict:
|
||||
worker_kwargs = self._get_worker_kwargs(local_rank, rank,
|
||||
distributed_init_method)
|
||||
if self.speculative_config is None:
|
||||
worker_kwargs.update(worker_module_name="vllm.worker.worker",
|
||||
worker_class_name="Worker")
|
||||
else:
|
||||
|
||||
if self.scheduler_config.is_multi_step:
|
||||
worker_kwargs.update(
|
||||
worker_module_name="vllm.worker.multi_step_worker",
|
||||
worker_class_name="MultiStepWorker")
|
||||
elif self.speculative_config:
|
||||
worker_kwargs.update(
|
||||
worker_module_name="vllm.spec_decode.spec_decode_worker",
|
||||
worker_class_name="create_spec_worker")
|
||||
else:
|
||||
worker_kwargs.update(worker_module_name="vllm.worker.worker",
|
||||
worker_class_name="Worker")
|
||||
|
||||
return worker_kwargs
|
||||
|
||||
def _create_worker(self,
|
||||
|
@ -94,6 +94,9 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
||||
if self.speculative_config is not None:
|
||||
worker_module_name = "vllm.spec_decode.spec_decode_worker"
|
||||
worker_class_name = "create_spec_worker"
|
||||
elif self.scheduler_config.is_multi_step:
|
||||
worker_module_name = "vllm.worker.multi_step_worker"
|
||||
worker_class_name = "MultiStepWorker"
|
||||
else:
|
||||
worker_module_name = "vllm.worker.worker"
|
||||
worker_class_name = "Worker"
|
||||
|
@ -9,7 +9,6 @@ from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Set,
|
||||
Tuple, Union, cast)
|
||||
|
||||
import msgspec
|
||||
import numpy
|
||||
import torch
|
||||
|
||||
from vllm.inputs.parse import is_valid_encoder_decoder_llm_inputs
|
||||
@ -1082,7 +1081,10 @@ class SamplerOutput(
|
||||
|
||||
# On-device tensor containing the sampled token ids.
|
||||
sampled_token_ids: Optional[torch.Tensor] = None
|
||||
sampled_token_ids_numpy: Optional[numpy.ndarray] = None
|
||||
# CPU tensor containing the sampled token ids. Used during multi-step to
|
||||
# return the sampled token ids from last rank to AsyncLLMEngine to be
|
||||
# 'broadcasted' to all other PP ranks for next step.
|
||||
sampled_token_ids_cpu: Optional[torch.Tensor] = None
|
||||
|
||||
# Spec decode metrics populated by workers.
|
||||
spec_decode_worker_metrics: Optional[SpecDecodeWorkerMetrics] = None
|
||||
@ -1257,9 +1259,7 @@ class ExecuteModelRequest(
|
||||
assert len(self.seq_group_metadata_list) > 0
|
||||
first_seq_group = self.seq_group_metadata_list[0]
|
||||
assert first_seq_group.state is not None
|
||||
num_steps = first_seq_group.state.num_steps
|
||||
current_step = first_seq_group.state.current_step
|
||||
return num_steps - current_step == 1
|
||||
return first_seq_group.state.remaining_steps == 1
|
||||
|
||||
@property
|
||||
def current_step(self) -> int:
|
||||
|
@ -14,7 +14,7 @@ if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.model_executor import SamplingMetadata
|
||||
|
||||
T = TypeVar('T', bound="ModelRunnerInputBase")
|
||||
T = TypeVar('T', bound="BroadcastableModelInput")
|
||||
|
||||
|
||||
def _add_attn_metadata_broadcastable_dict(
|
||||
@ -81,18 +81,26 @@ def _add_sampling_metadata_broadcastable_dict(
|
||||
sampling_metadata.selected_token_indices)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class ModelRunnerInputBase(ABC):
|
||||
"""Local inputs to each worker's model runner. May contain
|
||||
device-specific data. Different worker backends may have different methods
|
||||
of converting from the global ExecuteModelRequest produced by the LLM
|
||||
engine to the worker-local ModelRunnerInputBase objects.
|
||||
|
||||
Model runners that support multi-GPU execution should define a
|
||||
ModelRunnerInputBase subclass, add their required fields, and specify how to
|
||||
serialize/deserialize a ModelInput for broadcast between workers.
|
||||
def _init_frozen_model_input_from_tensor_dict(
|
||||
frozen_model_input_cls: Type["ModelRunnerInputBase"],
|
||||
tensor_dict: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Helper method to initialize a frozen ModelInput based on broadcastable
|
||||
"""
|
||||
valid_tensor_kwargs = {}
|
||||
for field in dataclasses.fields(frozen_model_input_cls):
|
||||
val = tensor_dict.pop(field.name, None)
|
||||
if val is not None:
|
||||
valid_tensor_kwargs[field.name] = val
|
||||
|
||||
frozen_model_input = frozen_model_input_cls(**valid_tensor_kwargs)
|
||||
tensor_dict["frozen_model_input"] = frozen_model_input
|
||||
return tensor_dict
|
||||
|
||||
|
||||
class BroadcastableModelInput(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Extract broadcastable fields. Override for fields that require some
|
||||
@ -109,11 +117,25 @@ class ModelRunnerInputBase(ABC):
|
||||
) -> T:
|
||||
"""
|
||||
Pop fields from the given tensor_dict and populate a new instance of
|
||||
ModelRunnerInputBase.
|
||||
BroadcastableModelInput.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class ModelRunnerInputBase(BroadcastableModelInput):
|
||||
"""Local inputs to each worker's model runner. May contain
|
||||
device-specific data. Different worker backends may have different methods
|
||||
of converting from the global ExecuteModelRequest produced by the LLM
|
||||
engine to the worker-local ModelRunnerInputBase objects.
|
||||
|
||||
Model runners that support multi-GPU execution should define a
|
||||
ModelRunnerInputBase subclass, add their required fields, and specify how to
|
||||
serialize/deserialize a ModelInput for broadcast between workers.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class ModelRunnerInputBuilderBase(ABC, Generic[T]):
|
||||
"""A builder to create ModelRunnerInputBase objects.
|
||||
"""
|
||||
|
453
vllm/worker/multi_step_model_runner.py
Normal file
453
vllm/worker/multi_step_model_runner.py
Normal file
@ -0,0 +1,453 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||
|
||||
try:
|
||||
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
|
||||
except ModuleNotFoundError:
|
||||
# vllm_flash_attn is not installed, use the identical ROCm FA metadata
|
||||
from vllm.attention.backends.rocm_flash_attn import (
|
||||
ROCmFlashAttentionMetadata as FlashAttentionMetadata)
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.distributed import get_pp_group
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors,
|
||||
Logprob, SamplerOutput, SequenceGroupMetadata,
|
||||
SequenceOutput)
|
||||
from vllm.worker.model_runner import (GPUModelRunnerBase,
|
||||
ModelInputForGPUWithSamplingMetadata)
|
||||
from vllm.worker.model_runner_base import (
|
||||
BroadcastableModelInput, _init_attn_metadata_from_tensor_dict,
|
||||
_init_frozen_model_input_from_tensor_dict,
|
||||
_init_sampling_metadata_from_tensor_dict)
|
||||
|
||||
from ..model_executor.model_loader.tensorizer import TensorizerConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelOutput:
|
||||
"""The output of a single model forward pass.
|
||||
|
||||
The sampler_output_ready_event is set when the tensors in
|
||||
sampler_output are ready (the model+sampler forward pass has
|
||||
completed). We use the event to synchronize the GPU->CPU transfer,
|
||||
which we want to only run when the data has been written to the
|
||||
GPU tensors. Until the event is ready, the tensors in sampler_output
|
||||
will have garbage data.
|
||||
|
||||
There are two scenarios:
|
||||
1. The output tensors are ready and we can pythonize them immediately.
|
||||
2. The output tensors are not ready and we need to wait for the event to be
|
||||
ready.
|
||||
"""
|
||||
sampler_output: SamplerOutput
|
||||
sampler_output_ready_event: torch.cuda.Event
|
||||
sampled_token_ids: Optional[torch.Tensor] = None
|
||||
pythonized: bool = False
|
||||
|
||||
def pythonize(self, input_metadata: "StatefulModelInput",
|
||||
copy_stream: torch.cuda.Stream,
|
||||
pinned_sampled_token_buffer: torch.Tensor) -> None:
|
||||
"""Pythonize the output. Blocking."""
|
||||
if not self.pythonized:
|
||||
self._pythonize_sampler_output(input_metadata, copy_stream,
|
||||
pinned_sampled_token_buffer, True)
|
||||
self.pythonized = True
|
||||
|
||||
def maybe_pythonize(self, input_metadata: "StatefulModelInput",
|
||||
copy_stream: torch.cuda.Stream,
|
||||
pinned_sampled_token_buffer: torch.Tensor) -> None:
|
||||
"""Pythonize the output if ready, else return None. Non-blocking."""
|
||||
if not self.pythonized:
|
||||
self.pythonized = self._pythonize_sampler_output(
|
||||
input_metadata, copy_stream, pinned_sampled_token_buffer,
|
||||
False)
|
||||
|
||||
def _pythonize_sampler_output(self, input_metadata: "StatefulModelInput",
|
||||
copy_stream: torch.cuda.Stream,
|
||||
pinned_sampled_token_buffer: torch.Tensor,
|
||||
blocking: bool) -> bool:
|
||||
"""
|
||||
If blocking is set, will block until the forward pass for the output is
|
||||
ready and pythonize the output.
|
||||
"""
|
||||
assert self.sampled_token_ids is not None
|
||||
if not blocking and not self.sampler_output_ready_event.query():
|
||||
return False
|
||||
|
||||
if blocking:
|
||||
self.sampler_output_ready_event.synchronize()
|
||||
with torch.cuda.stream(copy_stream):
|
||||
_pythonize_sampler_output(input_metadata, self.sampler_output,
|
||||
pinned_sampled_token_buffer,
|
||||
self.sampled_token_ids)
|
||||
return True
|
||||
|
||||
|
||||
@dataclass(frozen=False)
|
||||
class StatefulModelInput(BroadcastableModelInput):
|
||||
# actual frozen model input dataclass passed to _base_model_runner
|
||||
frozen_model_input: Optional[ModelInputForGPUWithSamplingMetadata] = None
|
||||
|
||||
# list of model outputs for each step, may not be all pythonized
|
||||
cached_outputs: List[ModelOutput] = field(default_factory=list)
|
||||
|
||||
# used to pass sampled token ids from the last step to the current step for
|
||||
# TP workers. Used to append to end of outputs and used by advance_step
|
||||
last_sampled_token_ids: Optional[torch.Tensor] = None
|
||||
current_step: int = 0
|
||||
is_multi_step: bool = True
|
||||
is_last_step: bool = False
|
||||
is_first_multi_step: bool = False
|
||||
# ping-pong data structures for multi-step to wait on the previous step
|
||||
step_cuda_events: List[torch.cuda.Event] = field(
|
||||
default_factory=lambda: [torch.cuda.Event(blocking=True)] * 2)
|
||||
num_seqs: int = -1
|
||||
num_queries: int = -1
|
||||
|
||||
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
|
||||
assert self.frozen_model_input is not None
|
||||
tensor_dict = self.frozen_model_input.as_broadcastable_tensor_dict()
|
||||
new_tensor_dict = {
|
||||
'last_sampled_token_ids': self.last_sampled_token_ids,
|
||||
'current_step': self.current_step,
|
||||
'is_multi_step': self.is_multi_step,
|
||||
'is_last_step': self.is_last_step,
|
||||
'is_first_multi_step': self.is_first_multi_step,
|
||||
'num_seqs': self.num_seqs,
|
||||
'num_queries': self.num_queries,
|
||||
}
|
||||
tensor_dict.update(new_tensor_dict)
|
||||
return tensor_dict
|
||||
|
||||
@classmethod
|
||||
def from_broadcasted_tensor_dict(
|
||||
cls,
|
||||
tensor_dict: Dict[str, Any],
|
||||
attn_backend: Optional["AttentionBackend"] = None,
|
||||
) -> "StatefulModelInput":
|
||||
tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict)
|
||||
if attn_backend is not None:
|
||||
tensor_dict = _init_attn_metadata_from_tensor_dict(
|
||||
attn_backend, tensor_dict)
|
||||
tensor_dict = _init_frozen_model_input_from_tensor_dict(
|
||||
ModelInputForGPUWithSamplingMetadata, tensor_dict)
|
||||
|
||||
return cls(**tensor_dict)
|
||||
|
||||
def record_step_event(self, current_stream: torch.cuda.Stream):
|
||||
# record the event for the current step so that the next step can sync
|
||||
# on it. We modulo by 2 to keep the events in a circular buffer and
|
||||
# support any attn backends that may be supported in the future. ie
|
||||
# Flashinfer would want two DecodeWrappers to overlap the CPU and GPU.
|
||||
self.step_cuda_events[self.current_step & 1] = \
|
||||
torch.cuda.Event(blocking=True)
|
||||
self.step_cuda_events[self.current_step & 1].record(current_stream)
|
||||
|
||||
def wait_previous_step(self):
|
||||
# These cuda events are an explicit synchronization to ensure that
|
||||
# advance_step() (for other attn backends that may be supported in the
|
||||
# future) do not clobber any data structures that is also used by any
|
||||
# enqueued forwards steps. For distributed case, only a single event is
|
||||
# needed, but for single GPU case, since we can let the CPU run much
|
||||
# further ahead, two events allow us to overlap the advance_step with
|
||||
# the previous forward (ie using two DecodeWrappers for flashinfer
|
||||
# backend)
|
||||
self.step_cuda_events[(self.current_step + 1) & 1].wait()
|
||||
|
||||
def add_sampler_output(self,
|
||||
sampler_output: SamplerOutput,
|
||||
sampled_token_ids: Optional[torch.Tensor] = None):
|
||||
self.cached_outputs.append(
|
||||
ModelOutput(sampler_output=sampler_output,
|
||||
sampler_output_ready_event=None,
|
||||
sampled_token_ids=sampled_token_ids,
|
||||
pythonized=False))
|
||||
|
||||
|
||||
# MutableModelInputForGPUWithMultiStepMetadata is not subclass of
|
||||
# ModelInputForGPU but it wraps the actual input dataclass and adds multi-step
|
||||
# metadata
|
||||
# mypy: disable-error-code=type-var
|
||||
class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
|
||||
# mypy: enable-error-code=type-var
|
||||
|
||||
def __init__(self, base_model_runner: GPUModelRunnerBase, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# uses the base model runner to execute the model and wraps it with
|
||||
# multi-step logic
|
||||
self._base_model_runner: GPUModelRunnerBase = base_model_runner
|
||||
|
||||
self.is_multi_step = self.scheduler_config.is_multi_step
|
||||
# used to copy tensors from GPU to CPU asynchronously
|
||||
self._copy_stream = torch.cuda.Stream()
|
||||
self.pinned_sampled_token_ids: Optional[torch.Tensor] = None
|
||||
|
||||
def make_model_input_from_broadcasted_tensor_dict(
|
||||
self, tensor_dict: Dict[str, Any]) -> StatefulModelInput:
|
||||
model_input = (StatefulModelInput.from_broadcasted_tensor_dict(
|
||||
tensor_dict,
|
||||
attn_backend=self.attn_backend,
|
||||
))
|
||||
return model_input
|
||||
|
||||
def prepare_model_input(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
virtual_engine: int = 0,
|
||||
finished_requests_ids: Optional[List[str]] = None
|
||||
) -> StatefulModelInput:
|
||||
frozen_model_input = self._base_model_runner.prepare_model_input(
|
||||
seq_group_metadata_list, virtual_engine, finished_requests_ids)
|
||||
|
||||
model_input = StatefulModelInput(
|
||||
frozen_model_input=frozen_model_input,
|
||||
num_seqs=len(frozen_model_input.seq_lens),
|
||||
num_queries=len(frozen_model_input.query_lens),
|
||||
)
|
||||
return model_input
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
self,
|
||||
model_input: StatefulModelInput,
|
||||
kv_caches: List[torch.Tensor],
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
num_steps: int = 1,
|
||||
) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]:
|
||||
"""
|
||||
Execute the model for a single step and update multi-step
|
||||
metadata
|
||||
"""
|
||||
assert num_steps == 1, "MultiStepModelRunner only supports num_steps=1"
|
||||
frozen_model_input = model_input.frozen_model_input
|
||||
assert frozen_model_input is not None
|
||||
|
||||
# path for warm up runs
|
||||
if not model_input.is_multi_step:
|
||||
return self._base_model_runner.execute_model(
|
||||
frozen_model_input, kv_caches, intermediate_tensors, num_steps)
|
||||
|
||||
# make sure we skip the sampler on the lask rank and only pythonize
|
||||
# if CPU is ahead.
|
||||
if self.is_driver_worker and get_pp_group().is_last_rank:
|
||||
if self.pinned_sampled_token_ids is None:
|
||||
self.pinned_sampled_token_ids = torch.zeros(
|
||||
(self.scheduler_config.max_num_seqs, 1),
|
||||
dtype=torch.long,
|
||||
device="cpu",
|
||||
pin_memory=True)
|
||||
|
||||
self._base_model_runner.model.sampler.include_gpu_probs_tensor = (
|
||||
True)
|
||||
if frozen_model_input.sampling_metadata:
|
||||
frozen_model_input.sampling_metadata.skip_sampler_cpu_output = (
|
||||
True)
|
||||
|
||||
# some pre-execute model logic for multi-step:
|
||||
# - if it's the first step, we need to reset the sampling tensors
|
||||
# - if it's not the first step, we need to advance the step using the
|
||||
# appended sampler output from last iteration
|
||||
# - also maybe pythonize if CPU is ahead of GPU
|
||||
|
||||
current_stream = torch.cuda.current_stream()
|
||||
if not model_input.is_first_multi_step:
|
||||
# Explicitly block on the previous step's forward to make sure we
|
||||
# don't clobber any GPU tensors still in use.
|
||||
# This is not needed for flashattn backend, but for other attn
|
||||
# backends such as flashinfer that performs extra CPU operations on
|
||||
# input metadata we may need to synchronize any CPU operations that
|
||||
# might clobber enqueued forwards. (prevents CPU from running too
|
||||
# far ahead if needed)
|
||||
model_input.wait_previous_step()
|
||||
model_input = self._advance_step(
|
||||
model_input, model_input.cached_outputs[-1].sampler_output)
|
||||
|
||||
# Execute the model
|
||||
output = self._base_model_runner.execute_model(frozen_model_input,
|
||||
kv_caches,
|
||||
intermediate_tensors,
|
||||
num_steps=1)
|
||||
|
||||
# record the event for the current step so that the next step can sync
|
||||
model_input.record_step_event(current_stream)
|
||||
|
||||
if get_pp_group().is_last_rank and self.is_driver_worker:
|
||||
assert len(
|
||||
output
|
||||
) == 1, "MultiStepModelRunner requires single-step base_models"
|
||||
|
||||
# event for the pythonization so that we only pythonize if the
|
||||
# tensors are ready. May be able to be combined with the step event
|
||||
output_ready_event = torch.cuda.Event()
|
||||
output_ready_event.record(current_stream)
|
||||
if self.parallel_config.pipeline_parallel_size > 1:
|
||||
output[0].sampled_token_ids_cpu = output[
|
||||
0].sampled_token_ids.cpu()
|
||||
model_input.cached_outputs.append(
|
||||
ModelOutput(output[0], output_ready_event,
|
||||
output[0].sampled_token_ids, False))
|
||||
# make sure we dont try to serialize any GPU tensors
|
||||
output[0].sampled_token_ids = None
|
||||
output[0].sampled_token_probs = None
|
||||
output[0].logprobs = None
|
||||
# Pythonize the output if CPU is ahead and the previous step is
|
||||
# ready.
|
||||
for model_output in model_input.cached_outputs:
|
||||
model_output.maybe_pythonize(model_input, self._copy_stream,
|
||||
self.pinned_sampled_token_ids)
|
||||
|
||||
model_input.current_step += 1
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
# Should be IntermediateTensors
|
||||
assert isinstance(output, IntermediateTensors)
|
||||
return output
|
||||
if not self.is_driver_worker:
|
||||
return []
|
||||
|
||||
# Pythonize the output and block if needed since it is the last step
|
||||
if model_input.is_last_step:
|
||||
outputs = []
|
||||
for output in model_input.cached_outputs:
|
||||
output.pythonize(model_input, self._copy_stream,
|
||||
self.pinned_sampled_token_ids)
|
||||
outputs.append(output.sampler_output)
|
||||
return outputs
|
||||
|
||||
# should be [SamplerOutput]
|
||||
return output
|
||||
|
||||
def _update_sampling_metadata(self, sampling_metadata, num_seqs,
|
||||
num_queries):
|
||||
|
||||
assert sampling_metadata.num_prompts == 0
|
||||
assert len(sampling_metadata.seq_groups) == num_queries
|
||||
assert sampling_metadata.selected_token_indices.shape == (
|
||||
num_queries, )
|
||||
# assert sampling_metadata.categorized_sample_indices == TODO: Add if needed # noqa: E501
|
||||
|
||||
# Verify that all sequences are decodes
|
||||
for i in range(num_queries):
|
||||
seq_group = sampling_metadata.seq_groups[i]
|
||||
|
||||
assert seq_group.is_prompt is False # No prompt
|
||||
assert seq_group.prompt_logprob_indices == [] # No prompt
|
||||
assert seq_group.sample_indices == [i] # Simple
|
||||
assert seq_group.seq_len is None # Decode
|
||||
assert seq_group.query_len is None # Decode
|
||||
|
||||
def _advance_step(self, model_input: StatefulModelInput,
|
||||
out: SamplerOutput) -> StatefulModelInput:
|
||||
frozen_model_input = model_input.frozen_model_input
|
||||
assert frozen_model_input is not None
|
||||
assert frozen_model_input.attn_metadata is not None
|
||||
|
||||
num_seqs = model_input.num_seqs
|
||||
num_queries = model_input.num_queries
|
||||
assert num_seqs > 0
|
||||
assert num_queries > 0
|
||||
assert num_seqs >= num_queries
|
||||
|
||||
attn_metadata = frozen_model_input.attn_metadata
|
||||
assert isinstance(attn_metadata, FlashAttentionMetadata)
|
||||
attn_metadata.advance_step(num_seqs, num_queries)
|
||||
|
||||
# Update GPU tensors
|
||||
ops.advance_step(
|
||||
num_seqs=num_seqs,
|
||||
num_queries=num_queries,
|
||||
block_size=self.block_size,
|
||||
input_tokens=frozen_model_input.input_tokens,
|
||||
sampled_token_ids=model_input.cached_outputs[-1].sampled_token_ids,
|
||||
input_positions=frozen_model_input.input_positions,
|
||||
seq_lens=attn_metadata.seq_lens_tensor,
|
||||
slot_mapping=attn_metadata.slot_mapping,
|
||||
block_tables=attn_metadata.block_tables)
|
||||
|
||||
if frozen_model_input.seq_lens is not None:
|
||||
for i in range(num_queries):
|
||||
frozen_model_input.seq_lens[i] = attn_metadata.seq_lens[i]
|
||||
|
||||
return model_input
|
||||
|
||||
def load_model(self) -> None:
|
||||
return self._base_model_runner.load_model()
|
||||
|
||||
def save_sharded_state(
|
||||
self,
|
||||
path: str,
|
||||
pattern: Optional[str] = None,
|
||||
max_size: Optional[int] = None,
|
||||
) -> None:
|
||||
return self._base_model_runner.save_sharded_state(
|
||||
path, pattern, max_size)
|
||||
|
||||
def save_tensorized_model(self,
|
||||
tensorizer_config: TensorizerConfig) -> None:
|
||||
return self._base_model_runner.save_tensorized_model(tensorizer_config)
|
||||
|
||||
def profile_run(self) -> None:
|
||||
return self._base_model_runner.profile_run()
|
||||
|
||||
def remove_all_loras(self):
|
||||
return self._base_model_runner.remove_all_loras()
|
||||
|
||||
def capture_model(self, kv_caches: List[List]) -> None:
|
||||
return self._base_model_runner.capture_model(kv_caches)
|
||||
|
||||
@property
|
||||
def vocab_size(self) -> int:
|
||||
return self._base_model_runner.vocab_size
|
||||
|
||||
|
||||
def _pythonize_sampler_output(model_input: StatefulModelInput,
|
||||
output: SamplerOutput,
|
||||
pinned_sampled_token_buffer: torch.Tensor,
|
||||
sampled_token_ids: torch.Tensor) -> None:
|
||||
""" This function is only called when the output tensors are ready.
|
||||
See ModelOutput
|
||||
"""
|
||||
|
||||
assert model_input.frozen_model_input is not None
|
||||
|
||||
frozen_model_input = model_input.frozen_model_input
|
||||
assert frozen_model_input.sampling_metadata is not None
|
||||
# samples generation should have been skipped
|
||||
assert not output.outputs
|
||||
|
||||
pinned_buffer = pinned_sampled_token_buffer[:model_input.num_queries]
|
||||
|
||||
# CPU GPU sync
|
||||
pinned_buffer = pinned_buffer.copy_(sampled_token_ids, non_blocking=False)
|
||||
|
||||
# this will not block as the tensors are already on CPU
|
||||
samples_list = pinned_buffer.tolist()
|
||||
|
||||
sampling_metadata = frozen_model_input.sampling_metadata
|
||||
|
||||
for (seq_group, sample_result) in zip(sampling_metadata.seq_groups,
|
||||
samples_list):
|
||||
seq_ids = seq_group.seq_ids
|
||||
next_token_ids = sample_result
|
||||
parent_ids = [0]
|
||||
seq_outputs: List[SequenceOutput] = []
|
||||
if seq_group.sampling_params.logits_processors:
|
||||
assert len(seq_group.sampling_params.logits_processors) == 0, (
|
||||
"Logits Processors are not supported in multi-step decoding")
|
||||
for parent_id, next_token_id in zip(parent_ids, next_token_ids):
|
||||
# TODO(will): support logprobs
|
||||
# Hard coded logprob
|
||||
seq_outputs.append(
|
||||
SequenceOutput(seq_ids[parent_id], next_token_id,
|
||||
{next_token_id: Logprob(logprob=-1)}))
|
||||
output.outputs.append(CompletionSequenceGroupOutput(seq_outputs, None))
|
||||
assert len(output.outputs) > 0
|
189
vllm/worker/multi_step_worker.py
Normal file
189
vllm/worker/multi_step_worker.py
Normal file
@ -0,0 +1,189 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from vllm.distributed import broadcast_tensor_dict, get_pp_group
|
||||
from vllm.sequence import ExecuteModelRequest, SamplerOutput
|
||||
from vllm.worker.model_runner_base import BroadcastableModelInput
|
||||
from vllm.worker.multi_step_model_runner import (MultiStepModelRunner,
|
||||
StatefulModelInput)
|
||||
from vllm.worker.worker import Worker, WorkerInput
|
||||
|
||||
|
||||
@dataclass
|
||||
class MultiStepState:
|
||||
worker_input: WorkerInput
|
||||
model_input: StatefulModelInput
|
||||
|
||||
|
||||
class MultiStepWorker(Worker):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
base_model_runner = self.model_runner
|
||||
# for multi-step model, wrap the model runner with MultiStepModelRunner
|
||||
self.model_runner = MultiStepModelRunner(
|
||||
base_model_runner,
|
||||
base_model_runner.model_config,
|
||||
base_model_runner.parallel_config,
|
||||
base_model_runner.scheduler_config,
|
||||
base_model_runner.device_config,
|
||||
base_model_runner.cache_config,
|
||||
load_config=base_model_runner.load_config,
|
||||
lora_config=self.lora_config,
|
||||
kv_cache_dtype=self.cache_config.cache_dtype,
|
||||
is_driver_worker=base_model_runner.is_driver_worker,
|
||||
prompt_adapter_config=base_model_runner.prompt_adapter_config,
|
||||
observability_config=base_model_runner.observability_config,
|
||||
)
|
||||
|
||||
pipeline_parallel_size = self.parallel_config.pipeline_parallel_size
|
||||
self.multi_step_states: List[
|
||||
Optional[MultiStepState]] = [None] * pipeline_parallel_size
|
||||
self.temp_output = None
|
||||
|
||||
def _get_driver_input_and_broadcast(
|
||||
self, execute_model_req: ExecuteModelRequest
|
||||
) -> Tuple[BroadcastableModelInput, WorkerInput]:
|
||||
"""
|
||||
Get the driver input and broadcast it to other workers.
|
||||
"""
|
||||
assert self.is_driver_worker
|
||||
virtual_engine = execute_model_req.virtual_engine
|
||||
is_first_multi_step = execute_model_req.is_first_multi_step
|
||||
if is_first_multi_step:
|
||||
# on first step we prepare the worker input and model input normally
|
||||
worker_input: WorkerInput = self.prepare_worker_input(
|
||||
execute_model_req=execute_model_req)
|
||||
model_input: StatefulModelInput = (
|
||||
self.model_runner.prepare_model_input(
|
||||
execute_model_req.seq_group_metadata_list,
|
||||
execute_model_req.virtual_engine,
|
||||
execute_model_req.finished_requests_ids))
|
||||
else:
|
||||
# on subsequent steps we reuse the worker input and model input
|
||||
multi_step_state = self.multi_step_states[virtual_engine]
|
||||
worker_input = multi_step_state.worker_input
|
||||
model_input = multi_step_state.model_input
|
||||
frozen_model_input = model_input.frozen_model_input
|
||||
assert frozen_model_input is not None
|
||||
assert frozen_model_input.attn_metadata is not None
|
||||
# clear the cached decode metadata so that it can be recomputed on
|
||||
# the workers
|
||||
frozen_model_input.attn_metadata._cached_decode_metadata = None
|
||||
|
||||
model_input.is_first_multi_step = is_first_multi_step
|
||||
model_input.is_last_step = execute_model_req.is_last_step
|
||||
|
||||
if not is_first_multi_step:
|
||||
# we broadcast the last sampled token ids to all TP workers so they
|
||||
# can update their model input metadata in-place.
|
||||
self._prepare_last_sampled_token_ids_for_tp_workers(
|
||||
execute_model_req=execute_model_req, model_input=model_input)
|
||||
|
||||
if self.do_metadata_broadcast:
|
||||
broadcast_data = worker_input.as_broadcastable_tensor_dict()
|
||||
broadcast_data.update(model_input.as_broadcastable_tensor_dict())
|
||||
broadcast_tensor_dict(broadcast_data, src=0)
|
||||
|
||||
return model_input, worker_input
|
||||
|
||||
def _prepare_last_sampled_token_ids_for_tp_workers(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
model_input: StatefulModelInput,
|
||||
) -> None:
|
||||
"""
|
||||
Prepare the last sampled token ids for TP workers. If it's the last
|
||||
PP rank, then the last sampled token ids are already in the model_input.
|
||||
If it is NOT the last PP rank, then we need to get the last sampled
|
||||
token that is cached in the execute_model_req.
|
||||
"""
|
||||
if get_pp_group().is_last_rank:
|
||||
assert model_input.cached_outputs[
|
||||
-1].sampler_output.sampled_token_ids is None
|
||||
assert model_input.cached_outputs[-1].sampled_token_ids is not None
|
||||
model_input.last_sampled_token_ids = model_input.cached_outputs[
|
||||
-1].sampled_token_ids
|
||||
# free sampled token ids from the previous step if it has been
|
||||
# pythonized. Cannot free the last sampled token ids because
|
||||
# we need it for GPU advance_step.
|
||||
for output in model_input.cached_outputs[:-1]:
|
||||
if output.pythonized:
|
||||
output.sampled_token_ids = None
|
||||
else:
|
||||
# otherwise we need to get the cached sampled token ids from the
|
||||
# execute_model_req
|
||||
assert execute_model_req.last_sampled_token_ids is not None
|
||||
model_input.last_sampled_token_ids = (
|
||||
execute_model_req.last_sampled_token_ids.cuda())
|
||||
model_input.add_sampler_output(
|
||||
SamplerOutput(outputs=[], sampled_token_ids=None),
|
||||
model_input.last_sampled_token_ids)
|
||||
|
||||
# free sampled token ids from the previous step.
|
||||
# TODO(will) we could reuse the sampled token ids tensor from
|
||||
# the previous step instead.
|
||||
for output in model_input.cached_outputs[:-1]:
|
||||
output.sampled_token_ids = None
|
||||
assert model_input.cached_outputs[-1].sampled_token_ids is not None
|
||||
|
||||
def prepare_input(
|
||||
self,
|
||||
execute_model_req: Optional[ExecuteModelRequest] = None,
|
||||
) -> Optional[Tuple[StatefulModelInput, WorkerInput]]:
|
||||
"""
|
||||
Depending on the current state of the request and multi step worker,
|
||||
this method may skip the normal _prepare_model_input and
|
||||
_prepare_worker_input methods and instead used cached values.
|
||||
"""
|
||||
if self.is_driver_worker:
|
||||
if execute_model_req is None:
|
||||
if self.do_metadata_broadcast:
|
||||
# This signals that there's no more requests to process for
|
||||
# now. All workers are running infinite loop with
|
||||
# broadcast_tensor_dict, and it stops the loop when the
|
||||
# driver broadcasts an empty input. Send an empty input to
|
||||
# notify all other workers to stop their execution loop.
|
||||
broadcast_tensor_dict({}, src=0)
|
||||
return None
|
||||
|
||||
virtual_engine = execute_model_req.virtual_engine
|
||||
model_input, worker_input = self._get_driver_input_and_broadcast(
|
||||
execute_model_req)
|
||||
assert isinstance(model_input, StatefulModelInput)
|
||||
if execute_model_req.is_first_multi_step:
|
||||
# cache the worker input and model input for the next steps
|
||||
self.multi_step_states[virtual_engine] = MultiStepState(
|
||||
worker_input=worker_input, model_input=model_input)
|
||||
# if TP workers
|
||||
else:
|
||||
broadcast_data = self._get_worker_input_from_broadcast()
|
||||
# if the driver has sent an empty input, we should stop the worker
|
||||
# loop
|
||||
if broadcast_data is None:
|
||||
return None
|
||||
model_input, worker_input = broadcast_data
|
||||
assert isinstance(model_input, StatefulModelInput)
|
||||
virtual_engine = worker_input.virtual_engine
|
||||
if model_input.is_first_multi_step:
|
||||
pass
|
||||
# TODO(will) Can cache the worker input and model input for the
|
||||
# next steps. See below for details
|
||||
else:
|
||||
# TODO(will) possible to also cache and reuse the cached worker
|
||||
# input and model input. The idea is essentially the delta
|
||||
# optimization for model_inputs. Where the TP workers can cache
|
||||
# the model input states and we only broadcast the delta need
|
||||
# for the next step (sampled_token_ids from the previous step)
|
||||
|
||||
assert isinstance(model_input, StatefulModelInput)
|
||||
# we need to update the last sampled token ids in the model
|
||||
# input for the workers so that they can run inplace
|
||||
# advance_step
|
||||
model_input.add_sampler_output(
|
||||
SamplerOutput(outputs=[], sampled_token_ids=None),
|
||||
model_input.last_sampled_token_ids)
|
||||
|
||||
assert model_input is not None
|
||||
assert worker_input is not None
|
||||
return model_input, worker_input
|
@ -16,7 +16,9 @@ from vllm.sequence import (ExecuteModelRequest, IntermediateTensors,
|
||||
SamplerOutput)
|
||||
from vllm.utils import (enable_trace_function_call_for_thread,
|
||||
update_environment_variables)
|
||||
from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase
|
||||
from vllm.worker.model_runner_base import (BroadcastableModelInput,
|
||||
ModelRunnerBase,
|
||||
ModelRunnerInputBase)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -220,7 +222,7 @@ class LocalOrDistributedWorkerBase(WorkerBase):
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_worker_input_from_broadcast(
|
||||
self) -> Optional[Tuple[ModelRunnerInputBase, WorkerInput]]:
|
||||
self) -> Optional[Tuple[BroadcastableModelInput, WorkerInput]]:
|
||||
""" Get the worker input from the broadcasted tensor dict. """
|
||||
assert self.do_metadata_broadcast
|
||||
assert not self.is_driver_worker
|
||||
@ -237,7 +239,7 @@ class LocalOrDistributedWorkerBase(WorkerBase):
|
||||
|
||||
def _get_driver_input_and_broadcast(
|
||||
self, execute_model_req: ExecuteModelRequest
|
||||
) -> Tuple[ModelRunnerInputBase, WorkerInput]:
|
||||
) -> Tuple[BroadcastableModelInput, WorkerInput]:
|
||||
""" Get the driver input and broadcast it to other workers. """
|
||||
assert self.is_driver_worker
|
||||
|
||||
@ -259,7 +261,7 @@ class LocalOrDistributedWorkerBase(WorkerBase):
|
||||
def prepare_input(
|
||||
self,
|
||||
execute_model_req: Optional[ExecuteModelRequest] = None
|
||||
) -> Optional[Tuple[ModelRunnerInputBase, WorkerInput]]:
|
||||
) -> Optional[Tuple[BroadcastableModelInput, WorkerInput]]:
|
||||
"""
|
||||
Prepare the inputs to ModelRunner and workers.
|
||||
"""
|
||||
|
Reference in New Issue
Block a user