mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
666 lines
26 KiB
Python
666 lines
26 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
import dataclasses
|
|
import os
|
|
import time
|
|
from abc import ABC, abstractmethod
|
|
from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union
|
|
|
|
import cloudpickle
|
|
import torch
|
|
import torch.nn as nn
|
|
from collections import defaultdict
|
|
|
|
from vllm.config import (ObservabilityConfig, VllmConfig,
|
|
set_current_vllm_config)
|
|
from vllm.distributed import broadcast_tensor_dict, get_pp_group, get_tp_group
|
|
from vllm.logger import init_logger
|
|
from vllm.lora.request import LoRARequest
|
|
from vllm.model_executor.layers.sampler import SamplerOutput
|
|
from vllm.sequence import ExecuteModelRequest, IntermediateTensors
|
|
from vllm.utils import (enable_trace_function_call_for_thread,
|
|
resolve_obj_by_qualname, run_method,
|
|
update_environment_variables)
|
|
from vllm.worker.model_runner_base import (BroadcastableModelInput,
|
|
ModelRunnerBase,
|
|
ModelRunnerInputBase)
|
|
from vllm.distributed.device_communicators.nixl import DynamoNixlConnector
|
|
from vllm.remote_prefill import MemoryOpType
|
|
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
class WorkerBase(ABC):
|
|
"""Worker interface that allows vLLM to cleanly separate implementations for
|
|
different hardware. Also abstracts control plane communication, e.g., to
|
|
communicate request metadata to other workers.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
vllm_config: VllmConfig,
|
|
) -> None:
|
|
self.vllm_config = vllm_config
|
|
self.model_config = vllm_config.model_config
|
|
self.cache_config = vllm_config.cache_config
|
|
self.lora_config = vllm_config.lora_config
|
|
self.load_config = vllm_config.load_config
|
|
self.parallel_config = vllm_config.parallel_config
|
|
self.scheduler_config = vllm_config.scheduler_config
|
|
self.device_config = vllm_config.device_config
|
|
self.speculative_config = vllm_config.speculative_config
|
|
self.prompt_adapter_config = vllm_config.prompt_adapter_config
|
|
self.observability_config = vllm_config.observability_config
|
|
self.kv_transfer_config = vllm_config.kv_transfer_config
|
|
self.compilation_config = vllm_config.compilation_config
|
|
from vllm.platforms import current_platform
|
|
self.current_platform = current_platform
|
|
|
|
self.nixl_connector: Optional[DynamoNixlConnector] = None
|
|
|
|
@abstractmethod
|
|
def init_device(self) -> None:
|
|
"""Initialize device state, such as loading the model or other on-device
|
|
memory allocations.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
|
"""Determine the number of available blocks for the GPU KV cache and
|
|
swappable CPU KV cache.
|
|
|
|
The implementation may run profiling or other heuristics to determine
|
|
the size of caches.
|
|
|
|
Returns a Tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks
|
|
are blocks that are "active" on the device and can be appended to.
|
|
num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be
|
|
appended to.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def initialize_cache(self, num_gpu_blocks: int,
|
|
num_cpu_blocks: int) -> None:
|
|
"""Initialize the KV cache with the given size in blocks.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def start_worker_execution_loop(self) -> None:
|
|
"""Execute model loop in parallel worker.
|
|
|
|
You can stop the loop by executing a driver worker with an empty output.
|
|
See `stop_remote_worker_execution_loop` for more details.
|
|
"""
|
|
with self.current_platform.inference_mode():
|
|
while True:
|
|
output = self.execute_model(execute_model_req=None)
|
|
if output is None:
|
|
return None
|
|
|
|
@abstractmethod
|
|
def get_model(self) -> nn.Module:
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def execute_model(
|
|
self,
|
|
execute_model_req: Optional[ExecuteModelRequest] = None
|
|
) -> Optional[List[SamplerOutput]]:
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def get_cache_block_size_bytes(self) -> int:
|
|
"""Return the size of a single cache block, in bytes. Used in
|
|
speculative decoding.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def add_lora(self, lora_request: LoRARequest) -> bool:
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def remove_lora(self, lora_id: int) -> bool:
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def pin_lora(self, lora_id: int) -> bool:
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def list_loras(self) -> Set[int]:
|
|
raise NotImplementedError
|
|
|
|
|
|
class DelegateWorkerBase(WorkerBase):
|
|
"""
|
|
A class that delegates all methods to another WorkerBase instance. This is
|
|
useful for creating a WorkerBase that wraps another WorkerBase instance,
|
|
e.g. speculative decoding.
|
|
"""
|
|
worker: WorkerBase
|
|
|
|
def __init__(
|
|
self,
|
|
*args,
|
|
**kwargs,
|
|
) -> None:
|
|
vllm_config: VllmConfig = kwargs.get("vllm_config")
|
|
cls = resolve_obj_by_qualname(vllm_config.parallel_config.worker_cls)
|
|
self.worker = cls(*args, **kwargs)
|
|
|
|
def init_device(self) -> None:
|
|
self.worker.init_device()
|
|
|
|
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
|
return self.worker.determine_num_available_blocks()
|
|
|
|
def initialize_cache(self, num_gpu_blocks: int,
|
|
num_cpu_blocks: int) -> None:
|
|
self.worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
|
|
|
|
def get_model(self) -> nn.Module:
|
|
return self.worker.get_model()
|
|
|
|
def execute_model(
|
|
self,
|
|
execute_model_req: Optional[ExecuteModelRequest] = None
|
|
) -> Optional[List[SamplerOutput]]:
|
|
return self.worker.execute_model(execute_model_req)
|
|
|
|
def get_cache_block_size_bytes(self) -> int:
|
|
return self.worker.get_cache_block_size_bytes()
|
|
|
|
def add_lora(self, lora_request: LoRARequest) -> bool:
|
|
return self.worker.add_lora(lora_request)
|
|
|
|
def remove_lora(self, lora_id: int) -> bool:
|
|
return self.worker.remove_lora(lora_id)
|
|
|
|
def pin_lora(self, lora_id: int) -> bool:
|
|
return self.worker.pin_lora(lora_id)
|
|
|
|
def list_loras(self) -> Set[int]:
|
|
return self.worker.list_loras()
|
|
|
|
def __getattr__(self, attr):
|
|
return getattr(self.worker, attr)
|
|
|
|
|
|
class LoraNotSupportedWorkerBase(WorkerBase):
|
|
"""Partial implementation of WorkerBase that raises exceptions when LoRA
|
|
methods are invoked.
|
|
"""
|
|
|
|
def add_lora(self, lora_request: LoRARequest) -> bool:
|
|
raise ValueError(f"{type(self)} does not support LoRA")
|
|
|
|
def remove_lora(self, lora_id: int) -> bool:
|
|
raise ValueError(f"{type(self)} does not support LoRA")
|
|
|
|
def pin_lora(self, lora_id: int) -> bool:
|
|
return ValueError(
|
|
f"{type(self)} does not support LoRA") # type: ignore
|
|
|
|
def list_loras(self) -> Set[int]:
|
|
raise ValueError(f"{type(self)} does not support LoRA")
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class WorkerInput:
|
|
"""Local inputs to each worker. May contain device-specific data. These
|
|
fields should be broadcastable to other workers.
|
|
"""
|
|
|
|
num_seq_groups: Optional[int] = None
|
|
blocks_to_swap_in: Optional[torch.Tensor] = None
|
|
blocks_to_swap_out: Optional[torch.Tensor] = None
|
|
blocks_to_copy: Optional[torch.Tensor] = None
|
|
virtual_engine: int = 0
|
|
num_steps: int = 1
|
|
|
|
local_block_ids: Optional[List[List[int]]] = None
|
|
staging_block_ids: Optional[List[List[int]]] = None
|
|
remote_block_ids: Optional[List[List[int]]] = None
|
|
remote_engine_id: Optional[List[str]] = None
|
|
notify_msg: Optional[List[str]] = None
|
|
op_type: Optional[List[MemoryOpType]] = None
|
|
|
|
@classmethod
|
|
def from_broadcasted_tensor_dict(
|
|
cls: Type["WorkerInput"],
|
|
tensor_dict: Dict[str, Any],
|
|
) -> "WorkerInput":
|
|
"""
|
|
Pop fields from the given tensor_dict and populate a new instance of
|
|
WorkerInput.
|
|
"""
|
|
return cls(
|
|
num_seq_groups=tensor_dict.pop("num_seq_groups"),
|
|
blocks_to_swap_in=tensor_dict.pop("blocks_to_swap_in"),
|
|
blocks_to_swap_out=tensor_dict.pop("blocks_to_swap_out"),
|
|
blocks_to_copy=tensor_dict.pop("blocks_to_copy"),
|
|
virtual_engine=tensor_dict["virtual_engine"],
|
|
num_steps=tensor_dict.pop("num_steps"),
|
|
local_block_ids=tensor_dict.pop("local_block_ids"),
|
|
staging_block_ids=tensor_dict.pop("staging_block_ids"),
|
|
remote_block_ids=tensor_dict.pop("remote_block_ids"),
|
|
remote_engine_id=tensor_dict.pop("remote_engine_id"),
|
|
notify_msg=tensor_dict.pop("notify_msg"),
|
|
op_type=tensor_dict.pop("op_type"),
|
|
)
|
|
|
|
def as_broadcastable_tensor_dict(
|
|
self) -> Dict[str, Union[int, torch.Tensor]]:
|
|
"""
|
|
Extract broadcastable fields.
|
|
"""
|
|
tensor_dict = {
|
|
"num_seq_groups": self.num_seq_groups,
|
|
"blocks_to_swap_in": self.blocks_to_swap_in,
|
|
"blocks_to_swap_out": self.blocks_to_swap_out,
|
|
"blocks_to_copy": self.blocks_to_copy,
|
|
"virtual_engine": self.virtual_engine,
|
|
"num_steps": self.num_steps,
|
|
"local_block_ids": self.local_block_ids,
|
|
"staging_block_ids": self.staging_block_ids,
|
|
"remote_block_ids": self.remote_block_ids,
|
|
"remote_engine_id": self.remote_engine_id,
|
|
"notify_msg": self.notify_msg,
|
|
"op_type": self.op_type,
|
|
}
|
|
|
|
return tensor_dict
|
|
|
|
|
|
class LocalOrDistributedWorkerBase(WorkerBase):
|
|
"""
|
|
Partial implementation of WorkerBase that has a default `execute_model`
|
|
definition to perform metadata transfer between workers when in distributed
|
|
mode. Subclasses of this interface should use model runners that inherit
|
|
from ModelRunnerBase, and should only need to implement worker-local logic.
|
|
If custom control plane logic is needed to transfer metadata, or if the
|
|
model runner cannot inherit from ModelRunnerBase, use WorkerBase instead.
|
|
"""
|
|
is_driver_worker: bool
|
|
model_runner: ModelRunnerBase
|
|
observability_config: Optional[ObservabilityConfig] = None
|
|
|
|
@property
|
|
@abstractmethod
|
|
def do_metadata_broadcast(self) -> bool:
|
|
"""
|
|
Used by the default `execute_model` to check whether broadcast is
|
|
needed to transfer request inputs from the driver worker to other
|
|
workers in the TP group. If WorkerBase subclass only supports
|
|
single-worker execution, then this method should return False.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@property
|
|
@abstractmethod
|
|
def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
|
|
"""
|
|
Gets the list of kv caches to pass to the worker's model runner. Each
|
|
element in the list is a kv cache corresponding to a particular virtual
|
|
engine (PP stream). Used by the default `execute_model`. If the worker's
|
|
model runner does not follow the ModelRunnerBase interface, then inherit
|
|
from WorkerBase instead.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def prepare_worker_input(
|
|
self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
|
|
"""
|
|
Prepare the inputs to WorkerBase.execute_worker from an execution
|
|
request. This method may move data to the worker's local device. It is
|
|
not allowed to communicate with other workers or devices.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def execute_worker(self, worker_input: WorkerInput) -> None:
|
|
"""
|
|
Process an execution request.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def _get_worker_input_from_broadcast(
|
|
self
|
|
) -> Optional[Tuple[BroadcastableModelInput, WorkerInput, Dict[
|
|
str, torch.Tensor]]]:
|
|
""" Get the worker input from the broadcasted tensor dict. """
|
|
assert self.do_metadata_broadcast
|
|
assert not self.is_driver_worker
|
|
broadcast_data = broadcast_tensor_dict(src=0)
|
|
if not broadcast_data:
|
|
return None
|
|
|
|
worker_input = WorkerInput.from_broadcasted_tensor_dict(broadcast_data)
|
|
if worker_input.num_seq_groups > 0:
|
|
model_input = (
|
|
self.model_runner.make_model_input_from_broadcasted_tensor_dict(
|
|
broadcast_data))
|
|
|
|
kwargs = extract_previous_hidden_states(broadcast_data)
|
|
|
|
return model_input, worker_input, kwargs
|
|
else:
|
|
return None, worker_input, {}
|
|
|
|
def _get_driver_input_and_broadcast(
|
|
self, execute_model_req: ExecuteModelRequest
|
|
) -> Tuple[BroadcastableModelInput, WorkerInput, Dict[str, torch.Tensor]]:
|
|
""" Get the driver input and broadcast it to other workers. """
|
|
assert self.is_driver_worker
|
|
|
|
worker_input: WorkerInput = self.prepare_worker_input(
|
|
execute_model_req=execute_model_req)
|
|
model_input: ModelRunnerInputBase = (
|
|
self.model_runner.prepare_model_input(
|
|
execute_model_req.seq_group_metadata_list,
|
|
execute_model_req.virtual_engine,
|
|
execute_model_req.finished_requests_ids))
|
|
|
|
kwargs = extract_previous_hidden_states(execute_model_req)
|
|
|
|
if self.do_metadata_broadcast:
|
|
broadcast_data = worker_input.as_broadcastable_tensor_dict()
|
|
broadcast_data.update(model_input.as_broadcastable_tensor_dict())
|
|
broadcast_data.update(kwargs)
|
|
broadcast_tensor_dict(broadcast_data, src=0)
|
|
|
|
if execute_model_req.async_callback:
|
|
model_input = dataclasses.replace( # type: ignore
|
|
model_input,
|
|
async_callback=execute_model_req.async_callback)
|
|
|
|
return model_input, worker_input, kwargs
|
|
|
|
def prepare_input(
|
|
self,
|
|
execute_model_req: Optional[ExecuteModelRequest] = None
|
|
) -> Optional[Tuple[BroadcastableModelInput, WorkerInput, Dict[
|
|
str, torch.Tensor]]]:
|
|
"""
|
|
Prepare the inputs to ModelRunner and workers.
|
|
"""
|
|
if self.is_driver_worker:
|
|
if execute_model_req is None:
|
|
if self.do_metadata_broadcast:
|
|
# This signals that there's no more requests to process for
|
|
# now. All workers are running infinite loop with
|
|
# broadcast_tensor_dict, and it stops the loop when the
|
|
# driver broadcasts an empty input. Send an empty input to
|
|
# notify all other workers to stop their execution loop.
|
|
broadcast_tensor_dict({}, src=0)
|
|
return None
|
|
return self._get_driver_input_and_broadcast(execute_model_req)
|
|
else:
|
|
return self._get_worker_input_from_broadcast()
|
|
|
|
def get_model(self) -> nn.Module:
|
|
return self.model_runner.get_model()
|
|
|
|
def execute_model(
|
|
self,
|
|
execute_model_req: Optional[ExecuteModelRequest] = None,
|
|
) -> Optional[List[SamplerOutput]]:
|
|
"""Executes at least one model step on the given sequences, unless no
|
|
sequences are provided."""
|
|
start_time = time.perf_counter()
|
|
|
|
inputs = self.prepare_input(execute_model_req)
|
|
if inputs is None:
|
|
return None
|
|
|
|
model_input, worker_input, kwargs = inputs
|
|
num_steps = worker_input.num_steps
|
|
|
|
self.execute_worker(worker_input)
|
|
|
|
# If there is no input, we don't need to execute the model.
|
|
if worker_input.num_seq_groups > 0:
|
|
|
|
self._read_blocks(worker_input)
|
|
|
|
intermediate_tensors = None
|
|
orig_model_execute_time = 0.0
|
|
if not get_pp_group().is_first_rank:
|
|
intermediate_tensors = IntermediateTensors(
|
|
get_pp_group().recv_tensor_dict(
|
|
all_gather_group=get_tp_group()))
|
|
if (self.observability_config is not None
|
|
and self.observability_config.collect_model_execute_time):
|
|
orig_model_execute_time = intermediate_tensors.tensors.get(
|
|
"model_execute_time", torch.tensor(0)).item()
|
|
|
|
output = self.model_runner.execute_model(
|
|
model_input=model_input,
|
|
kv_caches=self.kv_cache[worker_input.virtual_engine]
|
|
if self.kv_cache is not None else None,
|
|
intermediate_tensors=intermediate_tensors,
|
|
num_steps=num_steps,
|
|
**kwargs,
|
|
)
|
|
|
|
model_execute_time = time.perf_counter() - start_time
|
|
if not get_pp_group().is_last_rank:
|
|
# output is IntermediateTensors
|
|
assert isinstance(output, IntermediateTensors)
|
|
if (self.observability_config is not None
|
|
and self.observability_config.collect_model_execute_time):
|
|
output.tensors["model_execute_time"] = torch.tensor(
|
|
model_execute_time + orig_model_execute_time)
|
|
get_pp_group().send_tensor_dict(output.tensors,
|
|
all_gather_group=get_tp_group())
|
|
return [None]
|
|
if (self.observability_config is not None
|
|
and self.observability_config.collect_model_execute_time
|
|
and output is not None):
|
|
for o in output:
|
|
o.model_execute_time = (orig_model_execute_time +
|
|
model_execute_time)
|
|
|
|
self._write_blocks(worker_input)
|
|
|
|
else:
|
|
output = []
|
|
|
|
# collect kv transfer notifications from non driver workers
|
|
|
|
if self.nixl_connector is not None:
|
|
new_notifs = self.nixl_connector.get_new_notifs()
|
|
rank = get_tp_group().rank
|
|
all_new_notifs = [new_notifs]
|
|
if rank > 0:
|
|
get_tp_group().send_object(new_notifs, dst=0)
|
|
else:
|
|
for i in range(1, get_tp_group().world_size):
|
|
all_new_notifs.append(get_tp_group().recv_object(src=i))
|
|
|
|
request_notif_counter = defaultdict(int)
|
|
for notifs in all_new_notifs:
|
|
for req_ids in notifs.values():
|
|
for req_id in req_ids:
|
|
request_notif_counter[req_id] += 1
|
|
|
|
if request_notif_counter:
|
|
logger.debug("Request notif counter: %s", request_notif_counter)
|
|
|
|
request_done_counter = defaultdict(int)
|
|
for req_id in self.nixl_connector.get_done_tranfers():
|
|
request_done_counter[req_id] += 1
|
|
else:
|
|
request_notif_counter = {}
|
|
request_done_counter = {}
|
|
# output is List[SamplerOutput]
|
|
return output, request_notif_counter, request_done_counter
|
|
|
|
def _read_blocks(self, worker_input: WorkerInput) -> None:
|
|
pass
|
|
|
|
def _write_blocks(self, worker_input: WorkerInput) -> None:
|
|
pass
|
|
|
|
def _execute_model_spmd(
|
|
self,
|
|
execute_model_req: ExecuteModelRequest,
|
|
intermediate_tensors: Optional[IntermediateTensors] = None
|
|
) -> Optional[List[SamplerOutput]]:
|
|
"""
|
|
Execute model in Single Program Multiple Data (SPMD) fashion.
|
|
All workers take the same request, prepare the input and
|
|
execute the model.
|
|
"""
|
|
assert execute_model_req is not None, (
|
|
"_execute_model_spmd() requires each worker to take in an "
|
|
"ExecuteModelRequest")
|
|
worker_input: WorkerInput = self.prepare_worker_input(
|
|
execute_model_req=execute_model_req)
|
|
model_input: ModelRunnerInputBase = (
|
|
self.model_runner.prepare_model_input(
|
|
execute_model_req.seq_group_metadata_list))
|
|
|
|
self.execute_worker(worker_input)
|
|
|
|
# If there is no input, we don't need to execute the model.
|
|
if worker_input.num_seq_groups == 0:
|
|
return []
|
|
|
|
kwargs = extract_previous_hidden_states(execute_model_req)
|
|
|
|
return self.model_runner.execute_model(
|
|
model_input=model_input,
|
|
kv_caches=self.kv_cache[worker_input.virtual_engine]
|
|
if self.kv_cache is not None else None,
|
|
intermediate_tensors=intermediate_tensors,
|
|
**kwargs,
|
|
)
|
|
|
|
|
|
class WorkerWrapperBase:
|
|
"""
|
|
This class represents one process in an executor/engine. It is responsible
|
|
for lazily initializing the worker and handling the worker's lifecycle.
|
|
We first instantiate the WorkerWrapper, which remembers the worker module
|
|
and class name. Then, when we call `update_environment_variables`, and the
|
|
real initialization happens in `init_worker`.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
vllm_config: VllmConfig,
|
|
rpc_rank: int = 0,
|
|
) -> None:
|
|
"""
|
|
Initialize the worker wrapper with the given vllm_config and rpc_rank.
|
|
Note: rpc_rank is the rank of the worker in the executor. In most cases,
|
|
it is also the rank of the worker in the distributed group. However,
|
|
when multiple executors work together, they can be different.
|
|
e.g. in the case of SPMD-style offline inference with TP=2,
|
|
users can launch 2 engines/executors, each with only 1 worker.
|
|
All workers have rpc_rank=0, but they have different ranks in the TP
|
|
group.
|
|
"""
|
|
self.rpc_rank = rpc_rank
|
|
self.worker: Optional[WorkerBase] = None
|
|
# do not store this `vllm_config`, `init_worker` will set the final
|
|
# one. TODO: investigate if we can remove this field in
|
|
# `WorkerWrapperBase`, `init_cached_hf_modules` should be
|
|
# unnecessary now.
|
|
if vllm_config.model_config is not None:
|
|
# it can be None in tests
|
|
trust_remote_code = vllm_config.model_config.trust_remote_code
|
|
if trust_remote_code:
|
|
# note: lazy import to avoid importing torch before initializing
|
|
from vllm.utils import init_cached_hf_modules
|
|
init_cached_hf_modules()
|
|
|
|
def adjust_rank(self, rank_mapping: Dict[int, int]) -> None:
|
|
"""
|
|
Adjust the rpc_rank based on the given mapping.
|
|
It is only used during the initialization of the executor,
|
|
to adjust the rpc_rank of workers after we create all workers.
|
|
"""
|
|
if self.rpc_rank in rank_mapping:
|
|
self.rpc_rank = rank_mapping[self.rpc_rank]
|
|
|
|
def update_environment_variables(self, envs_list: List[Dict[str,
|
|
str]]) -> None:
|
|
envs = envs_list[self.rpc_rank]
|
|
key = 'CUDA_VISIBLE_DEVICES'
|
|
if key in envs and key in os.environ:
|
|
# overwriting CUDA_VISIBLE_DEVICES is desired behavior
|
|
# suppress the warning in `update_environment_variables`
|
|
del os.environ[key]
|
|
update_environment_variables(envs)
|
|
|
|
def init_worker(self, all_kwargs: List[Dict[str, Any]]) -> None:
|
|
"""
|
|
Here we inject some common logic before initializing the worker.
|
|
Arguments are passed to the worker class constructor.
|
|
"""
|
|
kwargs = all_kwargs[self.rpc_rank]
|
|
self.vllm_config = kwargs.get("vllm_config", None)
|
|
assert self.vllm_config is not None, (
|
|
"vllm_config is required to initialize the worker")
|
|
enable_trace_function_call_for_thread(self.vllm_config)
|
|
|
|
from vllm.plugins import load_general_plugins
|
|
load_general_plugins()
|
|
|
|
if isinstance(self.vllm_config.parallel_config.worker_cls, str):
|
|
worker_class = resolve_obj_by_qualname(
|
|
self.vllm_config.parallel_config.worker_cls)
|
|
else:
|
|
assert isinstance(self.vllm_config.parallel_config.worker_cls,
|
|
bytes)
|
|
worker_class = cloudpickle.loads(
|
|
self.vllm_config.parallel_config.worker_cls)
|
|
with set_current_vllm_config(self.vllm_config):
|
|
# To make vLLM config available during worker initialization
|
|
self.worker = worker_class(**kwargs)
|
|
assert self.worker is not None
|
|
|
|
def execute_method(self, method: Union[str, bytes], *args, **kwargs):
|
|
try:
|
|
target = self if self.worker is None else self.worker
|
|
return run_method(target, method, args, kwargs)
|
|
except Exception as e:
|
|
# if the driver worker also execute methods,
|
|
# exceptions in the rest worker may cause deadlock in rpc like ray
|
|
# see https://github.com/vllm-project/vllm/issues/3455
|
|
# print the error and inform the user to solve the error
|
|
msg = (f"Error executing method {method!r}. "
|
|
"This might cause deadlock in distributed execution.")
|
|
logger.exception(msg)
|
|
raise e
|
|
|
|
def __getattr__(self, attr):
|
|
return getattr(self.worker, attr)
|
|
|
|
|
|
def extract_previous_hidden_states(
|
|
data: Union[ExecuteModelRequest, Dict[str, torch.Tensor]]) -> \
|
|
Dict[str, torch.Tensor]:
|
|
"""If data contains previous_hidden_states, extract it. This returns a dict
|
|
which can be used directly as additional kwargs in any following
|
|
execute_model calls. This is used in draft models like EAGLE."""
|
|
output = {}
|
|
|
|
# When called from non-driver worker, data is dict but when called from
|
|
# driver worker, data is ExecuteModelRequest.
|
|
if isinstance(data, dict):
|
|
if "previous_hidden_states" in data:
|
|
output["previous_hidden_states"] = data["previous_hidden_states"]
|
|
elif data.previous_hidden_states is not None:
|
|
output["previous_hidden_states"] = data.previous_hidden_states\
|
|
.hidden_states
|
|
|
|
return output
|