[core] Multi Step Scheduling (#7000)

Co-authored-by: afeldman-nm <156691304+afeldman-nm@users.noreply.github.com>
This commit is contained in:
William Lin
2024-08-19 13:52:13 -07:00
committed by GitHub
parent dad961ef5c
commit 47b65a5508
13 changed files with 1004 additions and 34 deletions

View File

@ -311,6 +311,15 @@ steps:
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s distributed/test_utils.py
- label: Multi-step Tests (4 GPUs) # 10min
working_dir: "/vllm-workspace/tests"
num_gpus: 4
source_file_dependencies:
- vllm/
- tests/multi_step/test_correctness.py
commands:
- pytest -v -s multi_step/test_correctness.py
- label: Pipeline Parallelism Test # 23min
working_dir: "/vllm-workspace/tests"
num_gpus: 4

View File

View File

@ -0,0 +1,85 @@
# Test the AsyncLLMEngine with multi-step-decoding
from typing import List
import pytest
from ..utils import RemoteOpenAIServer
MODELS = [
"JackFram/llama-160m",
]
NUM_SCHEDULER_STEPS = [8] # Multi-step decoding steps
NUM_PROMPTS = [10]
DEFAULT_SERVER_ARGS: List[str] = [
"--disable-log-requests",
"--use-v2-block-manager",
"--worker-use-ray",
"--gpu-memory-utilization",
"0.85",
"--swap-space",
"16",
]
async def completions_with_server_args(prompts: List[str], model_name: str,
server_cli_args: List[str]):
outputs = None
with RemoteOpenAIServer(model_name, server_cli_args) as server:
client = server.get_async_client()
outputs = await client.completions.create(model=model_name,
prompt=prompts,
temperature=0,
stream=False,
max_tokens=5)
assert outputs is not None
return outputs
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize(("tp_size, pp_size"), [
(1, 1),
(2, 2),
])
@pytest.mark.parametrize("eager_mode", [False, True])
@pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS)
@pytest.mark.parametrize("num_prompts", NUM_PROMPTS)
@pytest.mark.asyncio
async def test_multi_step(example_prompts, model: str, tp_size: int,
pp_size: int, eager_mode: int,
num_scheduler_steps: int, num_prompts: int):
prompts = example_prompts
if len(prompts) < num_prompts:
prompts = prompts * ((num_prompts // len(prompts)) + 1)
prompts = prompts[:num_prompts]
assert len(prompts) == num_prompts
server_args = DEFAULT_SERVER_ARGS + ["--enforce-eager"]
ms_server_args = DEFAULT_SERVER_ARGS + \
["--num-scheduler-steps", f"{num_scheduler_steps}"]
if eager_mode:
ms_server_args.append("--enforce-eager")
distributed_args = [
"--tensor-parallel-size",
str(tp_size),
"--pipeline-parallel-size",
str(pp_size),
]
ref_completions = await completions_with_server_args(
prompts, model, server_args + distributed_args)
test_completions = await completions_with_server_args(
prompts, model, ms_server_args + distributed_args)
def get_text_generations(completions):
return [x.text for x in completions.choices]
ref_generations = get_text_generations(ref_completions)
test_generations = get_text_generations(test_completions)
assert ref_generations == test_generations

View File

@ -10,6 +10,7 @@ from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.worker.embedding_model_runner import (
ModelInputForGPUWithPoolingMetadata)
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
from vllm.worker.multi_step_model_runner import StatefulModelInput
class MockAttentionBackend(AttentionBackend):
@ -154,3 +155,79 @@ def test_embedding_model_runner_input():
None) == getattr(attn_metadata, field.name, None)
# Pooling metadata is not broadcast.
assert received_model_input.pooling_metadata is None
def test_multi_step_model_runner_input():
sampling_metadata = SamplingMetadata(
["seq_group"],
"selected_token_indices",
"categorized_sample_indices",
"num_prompts",
)
attn_metadata = AttentionMetadata(
num_prefills=1,
num_prefill_tokens=2,
num_decode_tokens=3,
slot_mapping=torch.zeros(1),
)
frozen_model_input = ModelInputForGPUWithSamplingMetadata(
input_tokens=torch.ones(10),
input_positions=torch.ones(10),
sampling_metadata=sampling_metadata,
attn_metadata=attn_metadata)
model_input = StatefulModelInput(
frozen_model_input=frozen_model_input,
is_last_step=True,
is_first_multi_step=False,
current_step=4,
last_sampled_token_ids=torch.ones((10, 1)),
is_multi_step=True,
num_queries=8,
num_seqs=5,
cached_outputs=[],
)
assert isinstance(model_input, StatefulModelInput)
# Test round trip serialization.
tensor_dict = model_input.as_broadcastable_tensor_dict()
attn_backend = MockAttentionBackend()
received_model_input = (StatefulModelInput.from_broadcasted_tensor_dict(
tensor_dict, attn_backend=attn_backend))
receieved_frozen_input = received_model_input.frozen_model_input
# Check that received copy has correct values.
assert isinstance(received_model_input, StatefulModelInput)
assert receieved_frozen_input.input_tokens is not None
assert (receieved_frozen_input.input_tokens ==
frozen_model_input.input_tokens).all()
assert receieved_frozen_input.input_positions is not None
assert (receieved_frozen_input.input_positions ==
frozen_model_input.input_positions).all()
assert receieved_frozen_input.multi_modal_kwargs is None
assert (frozen_model_input.multi_modal_kwargs ==
frozen_model_input.multi_modal_kwargs)
assert receieved_frozen_input.lora_requests is None
assert (receieved_frozen_input.lora_requests ==
frozen_model_input.lora_requests)
assert receieved_frozen_input.lora_mapping is None
assert (
receieved_frozen_input.lora_mapping == frozen_model_input.lora_mapping)
for field in dataclasses.fields(AttentionMetadata):
assert getattr(receieved_frozen_input.attn_metadata, field.name,
None) == getattr(attn_metadata, field.name, None)
# For sampling metadata, only selected_token_indices is copied.
assert (receieved_frozen_input.sampling_metadata.selected_token_indices ==
sampling_metadata.selected_token_indices)
assert receieved_frozen_input.sampling_metadata.seq_groups is None
# check non frozen fields
assert received_model_input.is_last_step == model_input.is_last_step
assert (received_model_input.is_first_multi_step ==
model_input.is_first_multi_step)
assert received_model_input.current_step == model_input.current_step
assert (received_model_input.last_sampled_token_ids ==
model_input.last_sampled_token_ids).all()
assert received_model_input.is_multi_step == model_input.is_multi_step

View File

@ -853,6 +853,12 @@ class EngineArgs:
"in low performance due to small KV cache space. Consider "
"setting --max-model-len to a smaller value.", max_model_len)
if self.num_scheduler_steps > 1 and not self.use_v2_block_manager:
self.use_v2_block_manager = True
logger.warning(
"Enabled BlockSpaceManagerV2 because it is "
"required for multi-step (--num-scheduler-steps > 1)")
speculative_config = SpeculativeConfig.maybe_create_spec_config(
target_model_config=model_config,
target_parallel_config=parallel_config,
@ -881,7 +887,6 @@ class EngineArgs:
)
if self.num_scheduler_steps > 1:
raise NotImplementedError("Multi-step is not yet supported.")
if speculative_config is not None:
raise ValueError("Speculative decoding is not supported with "
"multi-step (--num-scheduler-steps > 1)")

View File

@ -1,9 +1,11 @@
import asyncio
import time
from dataclasses import dataclass
from functools import partial
from typing import (AsyncGenerator, Callable, Dict, Iterable, List, Mapping,
Optional, Set, Tuple, Type, Union)
import torch
from transformers import PreTrainedTokenizer
from typing_extensions import assert_never
@ -27,7 +29,8 @@ from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.sequence import (ExecuteModelRequest, SamplerOutput,
SequenceGroupMetadata)
from vllm.usage.usage_lib import UsageContext
from vllm.utils import print_warning_once
@ -249,9 +252,25 @@ class RequestTracker:
return not self._new_requests.empty()
@dataclass
class SchedulerOutputState:
"""Caches the scheduler outputs for a virtual engine. Used for Multi-Step"""
last_output: Optional[SamplerOutput] = None
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None
scheduler_outputs: Optional[SchedulerOutputs] = None
class _AsyncLLMEngine(LLMEngine):
"""Extension of LLMEngine to add async methods."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
pipeline_parallel_size = \
self.parallel_config.pipeline_parallel_size
self.cached_scheduler_outputs = [
SchedulerOutputState() for _ in range(pipeline_parallel_size)
]
async def step_async(
self, virtual_engine: int
) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
@ -264,13 +283,39 @@ class _AsyncLLMEngine(LLMEngine):
and updates the scheduler with the model outputs. Finally, it decodes
the sequences and returns the newly generated results.
"""
seq_group_metadata_list, scheduler_outputs = self.scheduler[
virtual_engine].schedule()
# these are cached outputs from previous iterations. None if on first
# iteration
cached_outputs = self.cached_scheduler_outputs[virtual_engine]
seq_group_metadata_list = cached_outputs.seq_group_metadata_list
scheduler_outputs = cached_outputs.scheduler_outputs
# skip the scheduler if there are any remaining steps in the seq groups.
# This ensures that the scheduler is only called again when the current
# batch has completed.
if not self._has_remaining_steps(seq_group_metadata_list):
seq_group_metadata_list, scheduler_outputs = self.scheduler[
virtual_engine].schedule()
if (self.scheduler_config.is_multi_step
and scheduler_outputs.num_lookahead_slots > 0):
# cache the scheduler outputs for the next iteration if we have
# lookahead slots
self._cache_scheduler_outputs_for_multi_step(
virtual_engine, seq_group_metadata_list, scheduler_outputs)
assert seq_group_metadata_list is not None
assert scheduler_outputs is not None
if not scheduler_outputs.is_empty():
# Execute the model.
finished_requests_ids = self.scheduler[
virtual_engine].get_and_reset_finished_requests_ids()
# Check if we have a cached last_output from the previous iteration.
# For supporting PP this is probably the best way to pass the
# sampled_token_ids, as a separate broadcast over all the PP stages
# will cause one virtual engine's microbatch to block the pipeline.
last_sampled_token_ids = \
self._get_last_sampled_token_ids(virtual_engine)
execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
@ -279,15 +324,35 @@ class _AsyncLLMEngine(LLMEngine):
virtual_engine=virtual_engine,
num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
running_queue_size=scheduler_outputs.running_queue_size,
finished_requests_ids=finished_requests_ids)
finished_requests_ids=finished_requests_ids,
# We use ExecuteModelRequest to pass the last sampled_token_ids
# to each of the non-last PP stages for in-place prepare_input.
last_sampled_token_ids=last_sampled_token_ids)
# Execute the model.
output = await self.model_executor.execute_model_async(
execute_model_req)
# we need to do this here so that last step's sampled_token_ids can
# be passed to the next iteration for PP.
if self.scheduler_config.is_multi_step:
self._update_cached_scheduler_output(virtual_engine, output)
else:
output = []
request_outputs = self._process_model_outputs(
output, scheduler_outputs.scheduled_seq_groups,
scheduler_outputs.ignored_seq_groups, seq_group_metadata_list)
# Finish the current step for all the sequence groups.
if self.scheduler_config.is_multi_step:
for seq_group in seq_group_metadata_list:
seq_group.finish_step()
if not self._has_remaining_steps(seq_group_metadata_list):
# clear the cache if we have finished all the steps
if self.scheduler_config.is_multi_step:
self.cached_scheduler_outputs[
virtual_engine] = SchedulerOutputState()
request_outputs = self._process_model_outputs(
output, scheduler_outputs.scheduled_seq_groups,
scheduler_outputs.ignored_seq_groups, seq_group_metadata_list)
else:
request_outputs = []
# Log stats.
self.do_log_stats(scheduler_outputs, output)
@ -297,6 +362,60 @@ class _AsyncLLMEngine(LLMEngine):
return request_outputs
def _has_remaining_steps(
self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]]
) -> bool:
if (not self.scheduler_config.is_multi_step
or not seq_group_metadata_list):
return False
# TODO(will) this is a sanity check for nowto make sure that all the
# seqs are on the same steps. Eventually we will want to do some sort of
# dynamic scheduling when doing multi-step decoding.
ref_remaining_steps = seq_group_metadata_list[0].state.remaining_steps
if any([
seq_group.state.remaining_steps != ref_remaining_steps
for seq_group in seq_group_metadata_list[1:]
]):
raise AssertionError(("All running sequence groups should "
"have the same remaining steps."))
return ref_remaining_steps > 0
def _cache_scheduler_outputs_for_multi_step(
self, virtual_engine: int,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
scheduler_outputs: SchedulerOutputs) -> None:
self.cached_scheduler_outputs[
virtual_engine].seq_group_metadata_list = seq_group_metadata_list
self.cached_scheduler_outputs[virtual_engine].scheduler_outputs = \
scheduler_outputs
self.cached_scheduler_outputs[virtual_engine].last_output = None
def _get_last_sampled_token_ids(
self, virtual_engine: int) -> Optional[torch.Tensor]:
cached_last_output = self.cached_scheduler_outputs[
virtual_engine].last_output
if (self.scheduler_config.is_multi_step
and self.parallel_config.pipeline_parallel_size > 1
and cached_last_output is not None
and cached_last_output.sampled_token_ids_cpu is not None):
return cached_last_output.sampled_token_ids_cpu
return None
def _update_cached_scheduler_output(
self, virtual_engine: int,
output: List[Optional[SamplerOutput]]) -> None:
if (self.parallel_config.pipeline_parallel_size > 1 and len(output) > 0
and output[0] is not None):
last_output = output[-1]
assert last_output is not None
assert last_output.sampled_token_ids_cpu is not None
assert last_output.sampled_token_ids is None
assert last_output.sampled_token_probs is None
self.cached_scheduler_outputs[
virtual_engine].last_output = last_output
async def stop_remote_worker_execution_loop_async(self) -> None:
"""Stop the remote worker execution loop."""
await self.model_executor.stop_remote_worker_execution_loop_async()

View File

@ -69,13 +69,19 @@ class GPUExecutor(ExecutorBase):
distributed_init_method: Optional[str] = None) -> Dict:
worker_kwargs = self._get_worker_kwargs(local_rank, rank,
distributed_init_method)
if self.speculative_config is None:
worker_kwargs.update(worker_module_name="vllm.worker.worker",
worker_class_name="Worker")
else:
if self.scheduler_config.is_multi_step:
worker_kwargs.update(
worker_module_name="vllm.worker.multi_step_worker",
worker_class_name="MultiStepWorker")
elif self.speculative_config:
worker_kwargs.update(
worker_module_name="vllm.spec_decode.spec_decode_worker",
worker_class_name="create_spec_worker")
else:
worker_kwargs.update(worker_module_name="vllm.worker.worker",
worker_class_name="Worker")
return worker_kwargs
def _create_worker(self,

View File

@ -94,6 +94,9 @@ class RayGPUExecutor(DistributedGPUExecutor):
if self.speculative_config is not None:
worker_module_name = "vllm.spec_decode.spec_decode_worker"
worker_class_name = "create_spec_worker"
elif self.scheduler_config.is_multi_step:
worker_module_name = "vllm.worker.multi_step_worker"
worker_class_name = "MultiStepWorker"
else:
worker_module_name = "vllm.worker.worker"
worker_class_name = "Worker"

View File

@ -9,7 +9,6 @@ from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Set,
Tuple, Union, cast)
import msgspec
import numpy
import torch
from vllm.inputs.parse import is_valid_encoder_decoder_llm_inputs
@ -1082,7 +1081,10 @@ class SamplerOutput(
# On-device tensor containing the sampled token ids.
sampled_token_ids: Optional[torch.Tensor] = None
sampled_token_ids_numpy: Optional[numpy.ndarray] = None
# CPU tensor containing the sampled token ids. Used during multi-step to
# return the sampled token ids from last rank to AsyncLLMEngine to be
# 'broadcasted' to all other PP ranks for next step.
sampled_token_ids_cpu: Optional[torch.Tensor] = None
# Spec decode metrics populated by workers.
spec_decode_worker_metrics: Optional[SpecDecodeWorkerMetrics] = None
@ -1257,9 +1259,7 @@ class ExecuteModelRequest(
assert len(self.seq_group_metadata_list) > 0
first_seq_group = self.seq_group_metadata_list[0]
assert first_seq_group.state is not None
num_steps = first_seq_group.state.num_steps
current_step = first_seq_group.state.current_step
return num_steps - current_step == 1
return first_seq_group.state.remaining_steps == 1
@property
def current_step(self) -> int:

View File

@ -14,7 +14,7 @@ if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
from vllm.model_executor import SamplingMetadata
T = TypeVar('T', bound="ModelRunnerInputBase")
T = TypeVar('T', bound="BroadcastableModelInput")
def _add_attn_metadata_broadcastable_dict(
@ -81,18 +81,26 @@ def _add_sampling_metadata_broadcastable_dict(
sampling_metadata.selected_token_indices)
@dataclasses.dataclass(frozen=True)
class ModelRunnerInputBase(ABC):
"""Local inputs to each worker's model runner. May contain
device-specific data. Different worker backends may have different methods
of converting from the global ExecuteModelRequest produced by the LLM
engine to the worker-local ModelRunnerInputBase objects.
Model runners that support multi-GPU execution should define a
ModelRunnerInputBase subclass, add their required fields, and specify how to
serialize/deserialize a ModelInput for broadcast between workers.
def _init_frozen_model_input_from_tensor_dict(
frozen_model_input_cls: Type["ModelRunnerInputBase"],
tensor_dict: Dict[str, Any]) -> Dict[str, Any]:
"""
Helper method to initialize a frozen ModelInput based on broadcastable
"""
valid_tensor_kwargs = {}
for field in dataclasses.fields(frozen_model_input_cls):
val = tensor_dict.pop(field.name, None)
if val is not None:
valid_tensor_kwargs[field.name] = val
frozen_model_input = frozen_model_input_cls(**valid_tensor_kwargs)
tensor_dict["frozen_model_input"] = frozen_model_input
return tensor_dict
class BroadcastableModelInput(ABC):
@abstractmethod
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
"""
Extract broadcastable fields. Override for fields that require some
@ -109,11 +117,25 @@ class ModelRunnerInputBase(ABC):
) -> T:
"""
Pop fields from the given tensor_dict and populate a new instance of
ModelRunnerInputBase.
BroadcastableModelInput.
"""
raise NotImplementedError
@dataclasses.dataclass(frozen=True)
class ModelRunnerInputBase(BroadcastableModelInput):
"""Local inputs to each worker's model runner. May contain
device-specific data. Different worker backends may have different methods
of converting from the global ExecuteModelRequest produced by the LLM
engine to the worker-local ModelRunnerInputBase objects.
Model runners that support multi-GPU execution should define a
ModelRunnerInputBase subclass, add their required fields, and specify how to
serialize/deserialize a ModelInput for broadcast between workers.
"""
pass
class ModelRunnerInputBuilderBase(ABC, Generic[T]):
"""A builder to create ModelRunnerInputBase objects.
"""

View File

@ -0,0 +1,453 @@
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
try:
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
except ModuleNotFoundError:
# vllm_flash_attn is not installed, use the identical ROCm FA metadata
from vllm.attention.backends.rocm_flash_attn import (
ROCmFlashAttentionMetadata as FlashAttentionMetadata)
import torch
from vllm import _custom_ops as ops
from vllm.distributed import get_pp_group
from vllm.logger import init_logger
from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors,
Logprob, SamplerOutput, SequenceGroupMetadata,
SequenceOutput)
from vllm.worker.model_runner import (GPUModelRunnerBase,
ModelInputForGPUWithSamplingMetadata)
from vllm.worker.model_runner_base import (
BroadcastableModelInput, _init_attn_metadata_from_tensor_dict,
_init_frozen_model_input_from_tensor_dict,
_init_sampling_metadata_from_tensor_dict)
from ..model_executor.model_loader.tensorizer import TensorizerConfig
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
logger = init_logger(__name__)
@dataclass
class ModelOutput:
"""The output of a single model forward pass.
The sampler_output_ready_event is set when the tensors in
sampler_output are ready (the model+sampler forward pass has
completed). We use the event to synchronize the GPU->CPU transfer,
which we want to only run when the data has been written to the
GPU tensors. Until the event is ready, the tensors in sampler_output
will have garbage data.
There are two scenarios:
1. The output tensors are ready and we can pythonize them immediately.
2. The output tensors are not ready and we need to wait for the event to be
ready.
"""
sampler_output: SamplerOutput
sampler_output_ready_event: torch.cuda.Event
sampled_token_ids: Optional[torch.Tensor] = None
pythonized: bool = False
def pythonize(self, input_metadata: "StatefulModelInput",
copy_stream: torch.cuda.Stream,
pinned_sampled_token_buffer: torch.Tensor) -> None:
"""Pythonize the output. Blocking."""
if not self.pythonized:
self._pythonize_sampler_output(input_metadata, copy_stream,
pinned_sampled_token_buffer, True)
self.pythonized = True
def maybe_pythonize(self, input_metadata: "StatefulModelInput",
copy_stream: torch.cuda.Stream,
pinned_sampled_token_buffer: torch.Tensor) -> None:
"""Pythonize the output if ready, else return None. Non-blocking."""
if not self.pythonized:
self.pythonized = self._pythonize_sampler_output(
input_metadata, copy_stream, pinned_sampled_token_buffer,
False)
def _pythonize_sampler_output(self, input_metadata: "StatefulModelInput",
copy_stream: torch.cuda.Stream,
pinned_sampled_token_buffer: torch.Tensor,
blocking: bool) -> bool:
"""
If blocking is set, will block until the forward pass for the output is
ready and pythonize the output.
"""
assert self.sampled_token_ids is not None
if not blocking and not self.sampler_output_ready_event.query():
return False
if blocking:
self.sampler_output_ready_event.synchronize()
with torch.cuda.stream(copy_stream):
_pythonize_sampler_output(input_metadata, self.sampler_output,
pinned_sampled_token_buffer,
self.sampled_token_ids)
return True
@dataclass(frozen=False)
class StatefulModelInput(BroadcastableModelInput):
# actual frozen model input dataclass passed to _base_model_runner
frozen_model_input: Optional[ModelInputForGPUWithSamplingMetadata] = None
# list of model outputs for each step, may not be all pythonized
cached_outputs: List[ModelOutput] = field(default_factory=list)
# used to pass sampled token ids from the last step to the current step for
# TP workers. Used to append to end of outputs and used by advance_step
last_sampled_token_ids: Optional[torch.Tensor] = None
current_step: int = 0
is_multi_step: bool = True
is_last_step: bool = False
is_first_multi_step: bool = False
# ping-pong data structures for multi-step to wait on the previous step
step_cuda_events: List[torch.cuda.Event] = field(
default_factory=lambda: [torch.cuda.Event(blocking=True)] * 2)
num_seqs: int = -1
num_queries: int = -1
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
assert self.frozen_model_input is not None
tensor_dict = self.frozen_model_input.as_broadcastable_tensor_dict()
new_tensor_dict = {
'last_sampled_token_ids': self.last_sampled_token_ids,
'current_step': self.current_step,
'is_multi_step': self.is_multi_step,
'is_last_step': self.is_last_step,
'is_first_multi_step': self.is_first_multi_step,
'num_seqs': self.num_seqs,
'num_queries': self.num_queries,
}
tensor_dict.update(new_tensor_dict)
return tensor_dict
@classmethod
def from_broadcasted_tensor_dict(
cls,
tensor_dict: Dict[str, Any],
attn_backend: Optional["AttentionBackend"] = None,
) -> "StatefulModelInput":
tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict)
if attn_backend is not None:
tensor_dict = _init_attn_metadata_from_tensor_dict(
attn_backend, tensor_dict)
tensor_dict = _init_frozen_model_input_from_tensor_dict(
ModelInputForGPUWithSamplingMetadata, tensor_dict)
return cls(**tensor_dict)
def record_step_event(self, current_stream: torch.cuda.Stream):
# record the event for the current step so that the next step can sync
# on it. We modulo by 2 to keep the events in a circular buffer and
# support any attn backends that may be supported in the future. ie
# Flashinfer would want two DecodeWrappers to overlap the CPU and GPU.
self.step_cuda_events[self.current_step & 1] = \
torch.cuda.Event(blocking=True)
self.step_cuda_events[self.current_step & 1].record(current_stream)
def wait_previous_step(self):
# These cuda events are an explicit synchronization to ensure that
# advance_step() (for other attn backends that may be supported in the
# future) do not clobber any data structures that is also used by any
# enqueued forwards steps. For distributed case, only a single event is
# needed, but for single GPU case, since we can let the CPU run much
# further ahead, two events allow us to overlap the advance_step with
# the previous forward (ie using two DecodeWrappers for flashinfer
# backend)
self.step_cuda_events[(self.current_step + 1) & 1].wait()
def add_sampler_output(self,
sampler_output: SamplerOutput,
sampled_token_ids: Optional[torch.Tensor] = None):
self.cached_outputs.append(
ModelOutput(sampler_output=sampler_output,
sampler_output_ready_event=None,
sampled_token_ids=sampled_token_ids,
pythonized=False))
# MutableModelInputForGPUWithMultiStepMetadata is not subclass of
# ModelInputForGPU but it wraps the actual input dataclass and adds multi-step
# metadata
# mypy: disable-error-code=type-var
class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
# mypy: enable-error-code=type-var
def __init__(self, base_model_runner: GPUModelRunnerBase, *args, **kwargs):
super().__init__(*args, **kwargs)
# uses the base model runner to execute the model and wraps it with
# multi-step logic
self._base_model_runner: GPUModelRunnerBase = base_model_runner
self.is_multi_step = self.scheduler_config.is_multi_step
# used to copy tensors from GPU to CPU asynchronously
self._copy_stream = torch.cuda.Stream()
self.pinned_sampled_token_ids: Optional[torch.Tensor] = None
def make_model_input_from_broadcasted_tensor_dict(
self, tensor_dict: Dict[str, Any]) -> StatefulModelInput:
model_input = (StatefulModelInput.from_broadcasted_tensor_dict(
tensor_dict,
attn_backend=self.attn_backend,
))
return model_input
def prepare_model_input(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None
) -> StatefulModelInput:
frozen_model_input = self._base_model_runner.prepare_model_input(
seq_group_metadata_list, virtual_engine, finished_requests_ids)
model_input = StatefulModelInput(
frozen_model_input=frozen_model_input,
num_seqs=len(frozen_model_input.seq_lens),
num_queries=len(frozen_model_input.query_lens),
)
return model_input
@torch.inference_mode()
def execute_model(
self,
model_input: StatefulModelInput,
kv_caches: List[torch.Tensor],
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1,
) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]:
"""
Execute the model for a single step and update multi-step
metadata
"""
assert num_steps == 1, "MultiStepModelRunner only supports num_steps=1"
frozen_model_input = model_input.frozen_model_input
assert frozen_model_input is not None
# path for warm up runs
if not model_input.is_multi_step:
return self._base_model_runner.execute_model(
frozen_model_input, kv_caches, intermediate_tensors, num_steps)
# make sure we skip the sampler on the lask rank and only pythonize
# if CPU is ahead.
if self.is_driver_worker and get_pp_group().is_last_rank:
if self.pinned_sampled_token_ids is None:
self.pinned_sampled_token_ids = torch.zeros(
(self.scheduler_config.max_num_seqs, 1),
dtype=torch.long,
device="cpu",
pin_memory=True)
self._base_model_runner.model.sampler.include_gpu_probs_tensor = (
True)
if frozen_model_input.sampling_metadata:
frozen_model_input.sampling_metadata.skip_sampler_cpu_output = (
True)
# some pre-execute model logic for multi-step:
# - if it's the first step, we need to reset the sampling tensors
# - if it's not the first step, we need to advance the step using the
# appended sampler output from last iteration
# - also maybe pythonize if CPU is ahead of GPU
current_stream = torch.cuda.current_stream()
if not model_input.is_first_multi_step:
# Explicitly block on the previous step's forward to make sure we
# don't clobber any GPU tensors still in use.
# This is not needed for flashattn backend, but for other attn
# backends such as flashinfer that performs extra CPU operations on
# input metadata we may need to synchronize any CPU operations that
# might clobber enqueued forwards. (prevents CPU from running too
# far ahead if needed)
model_input.wait_previous_step()
model_input = self._advance_step(
model_input, model_input.cached_outputs[-1].sampler_output)
# Execute the model
output = self._base_model_runner.execute_model(frozen_model_input,
kv_caches,
intermediate_tensors,
num_steps=1)
# record the event for the current step so that the next step can sync
model_input.record_step_event(current_stream)
if get_pp_group().is_last_rank and self.is_driver_worker:
assert len(
output
) == 1, "MultiStepModelRunner requires single-step base_models"
# event for the pythonization so that we only pythonize if the
# tensors are ready. May be able to be combined with the step event
output_ready_event = torch.cuda.Event()
output_ready_event.record(current_stream)
if self.parallel_config.pipeline_parallel_size > 1:
output[0].sampled_token_ids_cpu = output[
0].sampled_token_ids.cpu()
model_input.cached_outputs.append(
ModelOutput(output[0], output_ready_event,
output[0].sampled_token_ids, False))
# make sure we dont try to serialize any GPU tensors
output[0].sampled_token_ids = None
output[0].sampled_token_probs = None
output[0].logprobs = None
# Pythonize the output if CPU is ahead and the previous step is
# ready.
for model_output in model_input.cached_outputs:
model_output.maybe_pythonize(model_input, self._copy_stream,
self.pinned_sampled_token_ids)
model_input.current_step += 1
if not get_pp_group().is_last_rank:
# Should be IntermediateTensors
assert isinstance(output, IntermediateTensors)
return output
if not self.is_driver_worker:
return []
# Pythonize the output and block if needed since it is the last step
if model_input.is_last_step:
outputs = []
for output in model_input.cached_outputs:
output.pythonize(model_input, self._copy_stream,
self.pinned_sampled_token_ids)
outputs.append(output.sampler_output)
return outputs
# should be [SamplerOutput]
return output
def _update_sampling_metadata(self, sampling_metadata, num_seqs,
num_queries):
assert sampling_metadata.num_prompts == 0
assert len(sampling_metadata.seq_groups) == num_queries
assert sampling_metadata.selected_token_indices.shape == (
num_queries, )
# assert sampling_metadata.categorized_sample_indices == TODO: Add if needed # noqa: E501
# Verify that all sequences are decodes
for i in range(num_queries):
seq_group = sampling_metadata.seq_groups[i]
assert seq_group.is_prompt is False # No prompt
assert seq_group.prompt_logprob_indices == [] # No prompt
assert seq_group.sample_indices == [i] # Simple
assert seq_group.seq_len is None # Decode
assert seq_group.query_len is None # Decode
def _advance_step(self, model_input: StatefulModelInput,
out: SamplerOutput) -> StatefulModelInput:
frozen_model_input = model_input.frozen_model_input
assert frozen_model_input is not None
assert frozen_model_input.attn_metadata is not None
num_seqs = model_input.num_seqs
num_queries = model_input.num_queries
assert num_seqs > 0
assert num_queries > 0
assert num_seqs >= num_queries
attn_metadata = frozen_model_input.attn_metadata
assert isinstance(attn_metadata, FlashAttentionMetadata)
attn_metadata.advance_step(num_seqs, num_queries)
# Update GPU tensors
ops.advance_step(
num_seqs=num_seqs,
num_queries=num_queries,
block_size=self.block_size,
input_tokens=frozen_model_input.input_tokens,
sampled_token_ids=model_input.cached_outputs[-1].sampled_token_ids,
input_positions=frozen_model_input.input_positions,
seq_lens=attn_metadata.seq_lens_tensor,
slot_mapping=attn_metadata.slot_mapping,
block_tables=attn_metadata.block_tables)
if frozen_model_input.seq_lens is not None:
for i in range(num_queries):
frozen_model_input.seq_lens[i] = attn_metadata.seq_lens[i]
return model_input
def load_model(self) -> None:
return self._base_model_runner.load_model()
def save_sharded_state(
self,
path: str,
pattern: Optional[str] = None,
max_size: Optional[int] = None,
) -> None:
return self._base_model_runner.save_sharded_state(
path, pattern, max_size)
def save_tensorized_model(self,
tensorizer_config: TensorizerConfig) -> None:
return self._base_model_runner.save_tensorized_model(tensorizer_config)
def profile_run(self) -> None:
return self._base_model_runner.profile_run()
def remove_all_loras(self):
return self._base_model_runner.remove_all_loras()
def capture_model(self, kv_caches: List[List]) -> None:
return self._base_model_runner.capture_model(kv_caches)
@property
def vocab_size(self) -> int:
return self._base_model_runner.vocab_size
def _pythonize_sampler_output(model_input: StatefulModelInput,
output: SamplerOutput,
pinned_sampled_token_buffer: torch.Tensor,
sampled_token_ids: torch.Tensor) -> None:
""" This function is only called when the output tensors are ready.
See ModelOutput
"""
assert model_input.frozen_model_input is not None
frozen_model_input = model_input.frozen_model_input
assert frozen_model_input.sampling_metadata is not None
# samples generation should have been skipped
assert not output.outputs
pinned_buffer = pinned_sampled_token_buffer[:model_input.num_queries]
# CPU GPU sync
pinned_buffer = pinned_buffer.copy_(sampled_token_ids, non_blocking=False)
# this will not block as the tensors are already on CPU
samples_list = pinned_buffer.tolist()
sampling_metadata = frozen_model_input.sampling_metadata
for (seq_group, sample_result) in zip(sampling_metadata.seq_groups,
samples_list):
seq_ids = seq_group.seq_ids
next_token_ids = sample_result
parent_ids = [0]
seq_outputs: List[SequenceOutput] = []
if seq_group.sampling_params.logits_processors:
assert len(seq_group.sampling_params.logits_processors) == 0, (
"Logits Processors are not supported in multi-step decoding")
for parent_id, next_token_id in zip(parent_ids, next_token_ids):
# TODO(will): support logprobs
# Hard coded logprob
seq_outputs.append(
SequenceOutput(seq_ids[parent_id], next_token_id,
{next_token_id: Logprob(logprob=-1)}))
output.outputs.append(CompletionSequenceGroupOutput(seq_outputs, None))
assert len(output.outputs) > 0

View File

@ -0,0 +1,189 @@
from dataclasses import dataclass
from typing import List, Optional, Tuple
from vllm.distributed import broadcast_tensor_dict, get_pp_group
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.worker.model_runner_base import BroadcastableModelInput
from vllm.worker.multi_step_model_runner import (MultiStepModelRunner,
StatefulModelInput)
from vllm.worker.worker import Worker, WorkerInput
@dataclass
class MultiStepState:
worker_input: WorkerInput
model_input: StatefulModelInput
class MultiStepWorker(Worker):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
base_model_runner = self.model_runner
# for multi-step model, wrap the model runner with MultiStepModelRunner
self.model_runner = MultiStepModelRunner(
base_model_runner,
base_model_runner.model_config,
base_model_runner.parallel_config,
base_model_runner.scheduler_config,
base_model_runner.device_config,
base_model_runner.cache_config,
load_config=base_model_runner.load_config,
lora_config=self.lora_config,
kv_cache_dtype=self.cache_config.cache_dtype,
is_driver_worker=base_model_runner.is_driver_worker,
prompt_adapter_config=base_model_runner.prompt_adapter_config,
observability_config=base_model_runner.observability_config,
)
pipeline_parallel_size = self.parallel_config.pipeline_parallel_size
self.multi_step_states: List[
Optional[MultiStepState]] = [None] * pipeline_parallel_size
self.temp_output = None
def _get_driver_input_and_broadcast(
self, execute_model_req: ExecuteModelRequest
) -> Tuple[BroadcastableModelInput, WorkerInput]:
"""
Get the driver input and broadcast it to other workers.
"""
assert self.is_driver_worker
virtual_engine = execute_model_req.virtual_engine
is_first_multi_step = execute_model_req.is_first_multi_step
if is_first_multi_step:
# on first step we prepare the worker input and model input normally
worker_input: WorkerInput = self.prepare_worker_input(
execute_model_req=execute_model_req)
model_input: StatefulModelInput = (
self.model_runner.prepare_model_input(
execute_model_req.seq_group_metadata_list,
execute_model_req.virtual_engine,
execute_model_req.finished_requests_ids))
else:
# on subsequent steps we reuse the worker input and model input
multi_step_state = self.multi_step_states[virtual_engine]
worker_input = multi_step_state.worker_input
model_input = multi_step_state.model_input
frozen_model_input = model_input.frozen_model_input
assert frozen_model_input is not None
assert frozen_model_input.attn_metadata is not None
# clear the cached decode metadata so that it can be recomputed on
# the workers
frozen_model_input.attn_metadata._cached_decode_metadata = None
model_input.is_first_multi_step = is_first_multi_step
model_input.is_last_step = execute_model_req.is_last_step
if not is_first_multi_step:
# we broadcast the last sampled token ids to all TP workers so they
# can update their model input metadata in-place.
self._prepare_last_sampled_token_ids_for_tp_workers(
execute_model_req=execute_model_req, model_input=model_input)
if self.do_metadata_broadcast:
broadcast_data = worker_input.as_broadcastable_tensor_dict()
broadcast_data.update(model_input.as_broadcastable_tensor_dict())
broadcast_tensor_dict(broadcast_data, src=0)
return model_input, worker_input
def _prepare_last_sampled_token_ids_for_tp_workers(
self,
execute_model_req: ExecuteModelRequest,
model_input: StatefulModelInput,
) -> None:
"""
Prepare the last sampled token ids for TP workers. If it's the last
PP rank, then the last sampled token ids are already in the model_input.
If it is NOT the last PP rank, then we need to get the last sampled
token that is cached in the execute_model_req.
"""
if get_pp_group().is_last_rank:
assert model_input.cached_outputs[
-1].sampler_output.sampled_token_ids is None
assert model_input.cached_outputs[-1].sampled_token_ids is not None
model_input.last_sampled_token_ids = model_input.cached_outputs[
-1].sampled_token_ids
# free sampled token ids from the previous step if it has been
# pythonized. Cannot free the last sampled token ids because
# we need it for GPU advance_step.
for output in model_input.cached_outputs[:-1]:
if output.pythonized:
output.sampled_token_ids = None
else:
# otherwise we need to get the cached sampled token ids from the
# execute_model_req
assert execute_model_req.last_sampled_token_ids is not None
model_input.last_sampled_token_ids = (
execute_model_req.last_sampled_token_ids.cuda())
model_input.add_sampler_output(
SamplerOutput(outputs=[], sampled_token_ids=None),
model_input.last_sampled_token_ids)
# free sampled token ids from the previous step.
# TODO(will) we could reuse the sampled token ids tensor from
# the previous step instead.
for output in model_input.cached_outputs[:-1]:
output.sampled_token_ids = None
assert model_input.cached_outputs[-1].sampled_token_ids is not None
def prepare_input(
self,
execute_model_req: Optional[ExecuteModelRequest] = None,
) -> Optional[Tuple[StatefulModelInput, WorkerInput]]:
"""
Depending on the current state of the request and multi step worker,
this method may skip the normal _prepare_model_input and
_prepare_worker_input methods and instead used cached values.
"""
if self.is_driver_worker:
if execute_model_req is None:
if self.do_metadata_broadcast:
# This signals that there's no more requests to process for
# now. All workers are running infinite loop with
# broadcast_tensor_dict, and it stops the loop when the
# driver broadcasts an empty input. Send an empty input to
# notify all other workers to stop their execution loop.
broadcast_tensor_dict({}, src=0)
return None
virtual_engine = execute_model_req.virtual_engine
model_input, worker_input = self._get_driver_input_and_broadcast(
execute_model_req)
assert isinstance(model_input, StatefulModelInput)
if execute_model_req.is_first_multi_step:
# cache the worker input and model input for the next steps
self.multi_step_states[virtual_engine] = MultiStepState(
worker_input=worker_input, model_input=model_input)
# if TP workers
else:
broadcast_data = self._get_worker_input_from_broadcast()
# if the driver has sent an empty input, we should stop the worker
# loop
if broadcast_data is None:
return None
model_input, worker_input = broadcast_data
assert isinstance(model_input, StatefulModelInput)
virtual_engine = worker_input.virtual_engine
if model_input.is_first_multi_step:
pass
# TODO(will) Can cache the worker input and model input for the
# next steps. See below for details
else:
# TODO(will) possible to also cache and reuse the cached worker
# input and model input. The idea is essentially the delta
# optimization for model_inputs. Where the TP workers can cache
# the model input states and we only broadcast the delta need
# for the next step (sampled_token_ids from the previous step)
assert isinstance(model_input, StatefulModelInput)
# we need to update the last sampled token ids in the model
# input for the workers so that they can run inplace
# advance_step
model_input.add_sampler_output(
SamplerOutput(outputs=[], sampled_token_ids=None),
model_input.last_sampled_token_ids)
assert model_input is not None
assert worker_input is not None
return model_input, worker_input

View File

@ -16,7 +16,9 @@ from vllm.sequence import (ExecuteModelRequest, IntermediateTensors,
SamplerOutput)
from vllm.utils import (enable_trace_function_call_for_thread,
update_environment_variables)
from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase
from vllm.worker.model_runner_base import (BroadcastableModelInput,
ModelRunnerBase,
ModelRunnerInputBase)
logger = init_logger(__name__)
@ -220,7 +222,7 @@ class LocalOrDistributedWorkerBase(WorkerBase):
raise NotImplementedError
def _get_worker_input_from_broadcast(
self) -> Optional[Tuple[ModelRunnerInputBase, WorkerInput]]:
self) -> Optional[Tuple[BroadcastableModelInput, WorkerInput]]:
""" Get the worker input from the broadcasted tensor dict. """
assert self.do_metadata_broadcast
assert not self.is_driver_worker
@ -237,7 +239,7 @@ class LocalOrDistributedWorkerBase(WorkerBase):
def _get_driver_input_and_broadcast(
self, execute_model_req: ExecuteModelRequest
) -> Tuple[ModelRunnerInputBase, WorkerInput]:
) -> Tuple[BroadcastableModelInput, WorkerInput]:
""" Get the driver input and broadcast it to other workers. """
assert self.is_driver_worker
@ -259,7 +261,7 @@ class LocalOrDistributedWorkerBase(WorkerBase):
def prepare_input(
self,
execute_model_req: Optional[ExecuteModelRequest] = None
) -> Optional[Tuple[ModelRunnerInputBase, WorkerInput]]:
) -> Optional[Tuple[BroadcastableModelInput, WorkerInput]]:
"""
Prepare the inputs to ModelRunner and workers.
"""