Remove hard-dependencies of Speculative decode to CUDA workers (#10587)

Signed-off-by: Chendi Xue <chendi.xue@intel.com>
This commit is contained in:
Chendi.Xue
2024-11-26 19:57:11 -06:00
committed by GitHub
parent 2f0a0a17a4
commit 0a71900bc9
19 changed files with 219 additions and 77 deletions

View File

@ -595,8 +595,8 @@ def test_init_device(acceptance_sampler_method: str):
target_worker.init_device.assert_called_once()
metrics_collector.init_gpu_tensors.assert_called_once()
spec_decode_sampler.init_gpu_tensors.assert_called_once()
metrics_collector.init_tensors.assert_called_once()
spec_decode_sampler.init_tensors.assert_called_once()
@pytest.mark.parametrize("acceptance_sampler_method",

View File

@ -990,6 +990,7 @@ class ParallelConfig:
# the full name of the worker class to use. If "auto", the worker class
# will be determined based on the platform.
worker_cls: str = "auto"
sd_worker_cls: str = "auto"
world_size: int = field(init=False)

View File

@ -43,6 +43,21 @@ class SpecDecodeBaseSampler(nn.Module):
dtype=torch.long,
device=device)
def init_tensors(self,
device: Union[int, str],
device_type: Union[torch.device, str] = 'cuda') -> None:
assert self.num_accepted_tokens is None
if isinstance(device_type, torch.device):
device_type = device_type.type
if isinstance(device, int):
device = f"{device_type}:{device}"
self.num_accepted_tokens = torch.tensor(0,
dtype=torch.long,
device=device)
self.num_emitted_tokens = torch.tensor(0,
dtype=torch.long,
device=device)
@property
def probs_dtype(self):
return torch.float32
@ -77,7 +92,7 @@ class SpecDecodeBaseSampler(nn.Module):
tensor is [batch_size, k + num_bonus_tokens]
"""
batch_size, k = substitute_token_ids.shape
bonus_token_ids = bonus_token_ids.squeeze()
bonus_token_ids = bonus_token_ids.squeeze(-1)
# Determine the index of the first False value for each row.
limits = (accepted == 0).max(1).indices
limits[~(accepted == 0).any(1)] = k

View File

@ -86,4 +86,10 @@ class CpuPlatform(Platform):
parallel_config.distributed_executor_backend)
parallel_config.distributed_executor_backend = "mp"
if parallel_config.worker_cls == "auto":
parallel_config.worker_cls = "vllm.worker.cpu_worker.CPUWorker"
if vllm_config.speculative_config:
parallel_config.worker_cls = \
"vllm.spec_decode.spec_decode_worker.create_spec_worker"
parallel_config.sd_worker_cls = \
"vllm.worker.cpu_worker.CPUWorker"
else:
parallel_config.worker_cls = "vllm.worker.cpu_worker.CPUWorker"

View File

@ -106,6 +106,8 @@ class CudaPlatformBase(Platform):
elif vllm_config.speculative_config:
parallel_config.worker_cls = \
"vllm.spec_decode.spec_decode_worker.create_spec_worker"
parallel_config.sd_worker_cls = \
"vllm.worker.worker.Worker"
else:
parallel_config.worker_cls = "vllm.worker.worker.Worker"
@ -236,4 +238,4 @@ try:
if not isinstance(pynvml, _MockModule):
CudaPlatform.log_warnings()
except ModuleNotFoundError:
CudaPlatform.log_warnings()
CudaPlatform.log_warnings()

View File

@ -20,8 +20,9 @@ except (ModuleNotFoundError, ImportError) as err:
from vllm.logger import init_logger
from vllm.multimodal import MultiModalKwargs
from vllm.sequence import ExecuteModelRequest, IntermediateTensors
from vllm.worker.model_runner import (ModelInputForGPUWithSamplingMetadata,
ModelRunner)
from vllm.worker.model_runner_base import (ModelRunnerBase,
ModelRunnerInputBase,
ModelRunnerWrapperBase)
logger = init_logger(__name__)
@ -33,7 +34,7 @@ debug_advance_input = False
allow_gpu_advance_step = True
class TP1DraftModelRunner(ModelRunner):
class TP1DraftModelRunner(ModelRunnerWrapperBase):
"""Specialized model runner for speculative decoding draft model.
Since the draft model always execute k forward passes consecutively to
generate k speculative tokens in a single speculative decoding step,
@ -46,13 +47,14 @@ class TP1DraftModelRunner(ModelRunner):
any broadcasting inside execute_model).
"""
def __init__(self, *args, **kwargs):
if kwargs.get("return_hidden_states"):
def __init__(self, model_runner: ModelRunnerBase):
if hasattr(
model_runner,
"return_hidden_states") and model_runner.return_hidden_states:
raise ValueError(
"return_hidden_states is not supported for TP1DraftModelRunner."
)
super().__init__(*args, **kwargs)
super().__init__(model_runner)
self.indices_of_seq_with_bonus_tokens = None
@ -73,10 +75,8 @@ class TP1DraftModelRunner(ModelRunner):
assert seq_group.prompt_logprob_indices == [] # No prompt
assert seq_group.sample_indices == [i] # Simple
def _gpu_advance_step(
self, model_input: ModelInputForGPUWithSamplingMetadata,
last_output: SamplerOutput
) -> ModelInputForGPUWithSamplingMetadata:
def _gpu_advance_step(self, model_input: ModelRunnerInputBase,
last_output: SamplerOutput) -> ModelRunnerInputBase:
# Currently, we expect "decode mode" only
assert not model_input.is_prompt
@ -168,7 +168,7 @@ class TP1DraftModelRunner(ModelRunner):
@torch.inference_mode()
def execute_model(
self,
model_input: ModelInputForGPUWithSamplingMetadata,
model_input: ModelRunnerInputBase,
kv_caches: List[torch.Tensor],
previous_hidden_states: Optional[torch.Tensor] = None,
intermediate_tensors: Optional[IntermediateTensors] = None,

View File

@ -1,6 +1,6 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Optional, Set
from typing import Optional, Set, Union
import torch
@ -75,9 +75,11 @@ class SpeculativeProposer(ABC):
class SpeculativeScorer(ABC):
def __init__(self, scorer_worker: WorkerBase, device: str,
vocab_size: int):
def __init__(self, scorer_worker: WorkerBase,
device: Union[torch.device, str], vocab_size: int):
self._scorer_worker = scorer_worker
if isinstance(device, torch.device):
device = device.type
self._device = device
self._vocab_size = vocab_size

View File

@ -9,21 +9,22 @@ from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata
from vllm.spec_decode.interfaces import SpeculativeProposals
from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
from vllm.spec_decode.top1_proposer import Top1Proposer
from vllm.worker.worker import Worker
from vllm.worker.worker_base import WorkerWrapperBase
class MedusaWorker(NonLLMProposerWorkerBase, Worker):
class MedusaWorker(NonLLMProposerWorkerBase, WorkerWrapperBase):
"""Worker for Medusa.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
super().__init__(kwargs.get("vllm_config"))
self.init_worker(*args, **kwargs)
# Lazy initialization list.
self._proposer: Top1Proposer
def init_device(self):
super().init_device()
self.worker.init_device()
self._proposer = Top1Proposer(
weakref.proxy(self), # type: ignore[arg-type]

View File

@ -1,11 +1,12 @@
import time
from typing import Callable, Optional
from typing import Callable, Optional, Union
import msgspec
import torch
from vllm.model_executor.layers.spec_decode_base_sampler import (
SpecDecodeBaseSampler)
from vllm.platforms import current_platform
from vllm.utils import is_pin_memory_available
@ -81,8 +82,20 @@ class AsyncMetricsCollector:
self._rank = rank
self._copy_stream = torch.cuda.Stream()
def init_tensors(self,
rank: int,
device_type: Union[torch.device, str] = 'cuda') -> None:
self._rank = rank
if isinstance(device_type, torch.device):
device_type = device_type.type
if device_type == 'cuda':
self._copy_stream = torch.cuda.Stream()
def maybe_collect_rejsample_metrics(
self, k: int) -> Optional[SpecDecodeWorkerMetrics]:
# currently using cuda.Event, skip for any non_cuda_alike platform
if not current_platform.is_cuda_alike():
return None
# If a copy was initiated in the previous call, collect and return.
if self._in_flight_copy is not None:

View File

@ -5,17 +5,21 @@ from typing import Dict, List, Set, Tuple
import torch
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.platforms import current_platform
from vllm.sequence import (ExecuteModelRequest, HiddenStates, SequenceData,
SequenceGroupMetadata)
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
if current_platform.is_cuda_alike():
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeProposer)
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
from vllm.spec_decode.top1_proposer import Top1Proposer
from vllm.worker.worker import Worker
from vllm.worker.worker_base import WorkerWrapperBase
class MultiStepWorker(Worker, ProposerWorkerBase):
class MultiStepWorker(ProposerWorkerBase, WorkerWrapperBase):
"""The MultiStepWorker is equivalent to a Worker except that it allows
multiple forward passes in a single call, assuming the scheduler has
allocated enough space to store the additional KV. This reduces overhead
@ -28,13 +32,14 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
super().__init__(kwargs.get("vllm_config"))
self.init_worker(*args, **kwargs)
# Lazy initialization list.
self._proposer: SpeculativeProposer
def init_device(self) -> None:
super().init_device()
self.worker.init_device()
self._proposer = Top1Proposer(
weakref.proxy(self), # type: ignore[arg-type]
@ -51,6 +56,18 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
self.model_runner.model.sampler.should_modify_greedy_probs_inplace = (
True)
def determine_num_available_blocks(self) -> Tuple[int, int]:
return self.worker.determine_num_available_blocks()
def get_cache_block_size_bytes(self) -> int:
return self.worker.get_cache_block_size_bytes()
def initialize_cache(self, *args, **kwargs) -> None:
self.worker.initialize_cache(*args, **kwargs)
def execute_model(self, *args, **kwargs) -> List[SamplerOutput]:
return self.worker.execute_model(*args, **kwargs)
@torch.inference_mode()
def sampler_output(
self,
@ -75,7 +92,7 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
# Run model sample_len times.
model_outputs: List[SamplerOutput] = []
if isinstance(
if current_platform.is_cuda_alike() and isinstance(
self.model_runner, TP1DraftModelRunner
) and self.model_runner.supports_gpu_multi_step(expanded_request):
# Here we run the draft_model_runner with multi-step prepare
@ -92,7 +109,7 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
# and other restrictions that are part of DraftModelRunner's
# supports_gpu_multi_step(..)
for _ in range(sample_len):
model_output: List[SamplerOutput] = super().execute_model(
model_output: List[SamplerOutput] = self.worker.execute_model(
execute_model_req=expanded_request)
assert (len(model_output) == 1
), "composing multistep workers not supported"

View File

@ -22,6 +22,7 @@ class NGramWorker(NonLLMProposerWorkerBase):
# Get local_rank/vocab_size from kwargs attribute
self.local_rank = kwargs["local_rank"]
self.vocab_size = kwargs["vllm_config"].model_config.get_vocab_size()
self.device_type = kwargs.get("device_type", "cuda")
# Lazy initialization list.
self._proposer: Top1Proposer
@ -34,7 +35,7 @@ class NGramWorker(NonLLMProposerWorkerBase):
self.ngram_prompt_lookup_min = ngram_prompt_lookup_min
def init_device(self):
self.device = torch.device(f"cuda:{self.local_rank}")
self.device = torch.device(f"{self.device_type}:{self.local_rank}")
self.load_model = lambda *args, **kwargs: None
# Current NGramWorker only supports Top1Proposer

View File

@ -14,12 +14,16 @@ from vllm.model_executor.layers.spec_decode_base_sampler import (
SpecDecodeBaseSampler, SpecDecodeStochasticBaseSampler)
from vllm.model_executor.layers.typical_acceptance_sampler import (
TypicalAcceptanceSampler)
from vllm.platforms import current_platform
from vllm.sequence import (VLLM_INVALID_TOKEN_ID,
CompletionSequenceGroupOutput, ExecuteModelRequest,
HiddenStates, SequenceGroupMetadata,
get_all_seq_ids_and_request_ids)
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
if current_platform.is_cuda_alike():
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeScorer, SpeculativeScores)
from vllm.spec_decode.medusa_worker import MedusaWorker
@ -36,8 +40,8 @@ from vllm.spec_decode.util import (Timer, create_logprobs_output,
get_all_num_logprobs,
get_sampled_token_logprobs, nvtx_range,
split_batch_by_proposal_len)
from vllm.worker.worker import Worker
from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase
from vllm.worker.worker_base import (LoraNotSupportedWorkerBase, WorkerBase,
WorkerWrapperBase)
logger = init_logger(__name__)
@ -53,7 +57,11 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
draft_worker_kwargs = kwargs.copy()
kwargs["model_runner_cls"] = TargetModelRunner
target_worker = Worker(*args, **kwargs)
target_worker_config = copy.deepcopy(vllm_config)
target_worker_config.parallel_config.worker_cls =\
target_worker_config.parallel_config.sd_worker_cls
target_worker = WorkerWrapperBase(vllm_config=target_worker_config)
target_worker.init_worker(*args, **kwargs)
# Set the disable_logprobs variable in the TargetModelRunner instance
# as per its value specified in the SpeculativeConfig.
target_worker.model_runner.disable_logprobs =\
@ -65,6 +73,8 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
draft_worker_config.model_config,
vllm_config.load_config,
)
speculative_config.draft_parallel_config.worker_cls =\
draft_worker_config.parallel_config.sd_worker_cls
draft_worker_config.parallel_config = speculative_config.draft_parallel_config # noqa
# TODO allow draft-model specific load config.
@ -125,7 +135,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
@classmethod
def create_worker(
cls,
scorer_worker: Worker,
scorer_worker: WorkerBase,
draft_worker_kwargs: Dict[str, Any],
disable_mqa_scorer: bool,
disable_by_batch_size: Optional[int],
@ -145,6 +155,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
draft_parallel_config: ParallelConfig = draft_worker_kwargs[
'vllm_config'].parallel_config
if ngram_prompt_lookup_max > 0:
draft_worker_kwargs[
"device_type"] = scorer_worker.device_config.device.type
proposer_worker = NGramWorker(**draft_worker_kwargs)
proposer_worker.set_ngram_window_size(ngram_prompt_lookup_min,
ngram_prompt_lookup_max)
@ -158,8 +170,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
proposer_worker = MedusaWorker(**draft_worker_kwargs)
else:
if draft_tp == 1:
draft_worker_kwargs[
"model_runner_cls"] = TP1DraftModelRunner
if current_platform.is_cuda_alike():
draft_worker_kwargs[
"model_runner_cls"] = TP1DraftModelRunner
else:
if draft_model_config.hf_config.model_type == "eagle":
raise NotImplementedError(
@ -306,8 +319,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
self.scorer_worker.load_model()
self.proposer_worker.load_model()
self._metrics.init_gpu_tensors(self.rank)
self.spec_decode_sampler.init_gpu_tensors(self.rank)
self._metrics.init_tensors(self.rank, device_type=self.device)
self.spec_decode_sampler.init_tensors(self.rank,
device_type=self.device)
scorer_cls: Type[SpeculativeScorer]
if self.disable_mqa_scorer:
@ -1111,11 +1125,11 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
raise NotImplementedError
def start_profile(self):
if isinstance(self.scorer_worker, Worker):
if isinstance(self.scorer_worker, WorkerBase):
self.scorer_worker.start_profile()
def stop_profile(self):
if isinstance(self.scorer_worker, Worker):
if isinstance(self.scorer_worker, WorkerBase):
self.scorer_worker.stop_profile()

View File

@ -1,12 +1,12 @@
from typing import List, Optional
from vllm.config import VllmConfig
from vllm.sequence import SequenceGroupMetadata
from vllm.worker.model_runner import (ModelInputForGPUWithSamplingMetadata,
ModelRunner)
from vllm.worker.model_runner_base import (ModelRunnerBase,
ModelRunnerInputBase,
ModelRunnerWrapperBase)
class TargetModelRunner(ModelRunner):
class TargetModelRunner(ModelRunnerWrapperBase):
"""Specialized model runner for speculative decoding target model.
In speculative decoding, the log probabilities selected finally may not
be the same ones as selected by the target model sampling. This means
@ -18,32 +18,21 @@ class TargetModelRunner(ModelRunner):
requested or not.
"""
def __init__(
self,
vllm_config: VllmConfig,
kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False,
return_hidden_states: bool = False,
):
def __init__(self, model_runner: ModelRunnerBase):
# An internal boolean member variable to indicate if token log
# probabilities are needed or not.
super().__init__(model_runner)
self.disable_logprobs = True
super().__init__(
vllm_config=vllm_config,
kv_cache_dtype=kv_cache_dtype,
is_driver_worker=is_driver_worker,
return_hidden_states=return_hidden_states,
)
def prepare_model_input(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None
) -> ModelInputForGPUWithSamplingMetadata:
model_input: ModelInputForGPUWithSamplingMetadata = super(
).prepare_model_input(seq_group_metadata_list, virtual_engine,
finished_requests_ids)
finished_requests_ids: Optional[List[str]] = None,
) -> ModelRunnerInputBase:
model_input: ModelRunnerInputBase =\
self.model_runner.prepare_model_input(
seq_group_metadata_list, virtual_engine, finished_requests_ids)
# If token log probabilities is disabled then skip generating sampler
# CPU output. We directly serialize the GPU sampled_token_id tensors
# as needed. If log probabilities is enabled then synchronize all the

View File

@ -5,6 +5,7 @@ from typing import Dict, List, Optional, Sequence, Tuple
import torch
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.platforms import current_platform
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
PromptLogprobs, SequenceGroupMetadata,
SequenceOutput)
@ -247,11 +248,14 @@ def nvtx_range(msg, *args, **kwargs):
Arguments:
msg (string): message to associate with the range
"""
torch.cuda.nvtx.range_push(msg.format(*args, **kwargs))
try:
if current_platform.is_cuda_alike():
torch.cuda.nvtx.range_push(msg.format(*args, **kwargs))
try:
yield
finally:
torch.cuda.nvtx.range_pop()
else:
yield
finally:
torch.cuda.nvtx.range_pop()
class Timer:

View File

@ -80,6 +80,7 @@ class ModelInputForCPUWithSamplingMetadata(ModelInputForCPU):
Used by the ModelRunner.
"""
sampling_metadata: Optional["SamplingMetadata"] = None
is_prompt: Optional[bool] = None
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
tensor_dict = {
@ -395,6 +396,7 @@ class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]):
vllm_config: VllmConfig,
kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False,
return_hidden_states: bool = False,
*args,
**kwargs,
):
@ -403,19 +405,25 @@ class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]):
cache_config = self.cache_config
self.is_driver_worker = is_driver_worker
self.return_hidden_states = return_hidden_states
self.device = self.device_config.device
self.pin_memory = False
self.kv_cache_dtype = kv_cache_dtype
self.sliding_window = model_config.get_sliding_window()
self.block_size = cache_config.block_size
num_attn_heads = self.model_config.get_num_attention_heads(
self.parallel_config)
needs_attn_backend = (num_attn_heads != 0
or self.model_config.is_attention_free)
self.attn_backend = get_attn_backend(
self.model_config.get_head_size(),
self.model_config.dtype,
self.kv_cache_dtype,
self.block_size,
self.model_config.is_attention_free,
)
) if needs_attn_backend else None
# Multi-modal data support
self.mm_registry = MULTIMODAL_REGISTRY
@ -444,6 +452,15 @@ class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]):
return builder.build() # type: ignore
# sampler property will be used by spec_decode_worker
@property
def sampler(self):
return self.model.sampler
@property
def vocab_size(self) -> int:
return self.model_config.get_vocab_size()
class CPUModelRunner(CPUModelRunnerBase[ModelInputForCPUWithSamplingMetadata]):
_model_input_cls: Type[ModelInputForCPUWithSamplingMetadata] = (
@ -480,9 +497,12 @@ class CPUModelRunner(CPUModelRunnerBase[ModelInputForCPUWithSamplingMetadata]):
pin_memory=False,
generators=generators)
is_prompt = (seq_group_metadata_list[0].is_prompt
if seq_group_metadata_list else None)
return dataclasses.replace(model_input,
sampling_metadata=sampling_metadata,
virtual_engine=virtual_engine)
virtual_engine=virtual_engine,
is_prompt=is_prompt)
@torch.no_grad()
def execute_model(
@ -491,16 +511,22 @@ class CPUModelRunner(CPUModelRunnerBase[ModelInputForCPUWithSamplingMetadata]):
kv_caches: List[torch.Tensor],
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1,
previous_hidden_states: Optional[torch.Tensor] = None,
) -> Optional[List[SamplerOutput]]:
if num_steps > 1:
raise ValueError(
"CPU worker does not support multi-step execution.")
model_executable = self.model
multimodal_kwargs = {}
if model_input.multi_modal_kwargs is not None:
multimodal_kwargs = MultiModalKwargs.as_kwargs(
model_input.multi_modal_kwargs, device=self.device)
execute_model_kwargs = {}
if previous_hidden_states is not None:
execute_model_kwargs.update(
{"previous_hidden_states": previous_hidden_states})
with set_forward_context(model_input.attn_metadata, self.vllm_config):
hidden_states = model_executable(
@ -509,6 +535,7 @@ class CPUModelRunner(CPUModelRunnerBase[ModelInputForCPUWithSamplingMetadata]):
kv_caches=kv_caches,
attn_metadata=model_input.attn_metadata,
intermediate_tensors=intermediate_tensors,
**execute_model_kwargs,
**multimodal_kwargs,
)
@ -525,4 +552,12 @@ class CPUModelRunner(CPUModelRunnerBase[ModelInputForCPUWithSamplingMetadata]):
logits=logits,
sampling_metadata=model_input.sampling_metadata,
)
if self.return_hidden_states:
# we only need to pass hidden states of most recent token
if model_input.is_prompt:
output.prefill_hidden_states = hidden_states
output.hidden_states = hidden_states
return [output]
def generate_proposals(self, *args, **kwargs):
return self.model.generate_proposals(*args, **kwargs)

View File

@ -128,6 +128,7 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
distributed_init_method: str,
kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False,
model_runner_cls: Optional[Type[CPUModelRunner]] = None,
) -> None:
WorkerBase.__init__(self, vllm_config=vllm_config)
@ -151,6 +152,16 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
else:
self.local_omp_cpuid = omp_cpuids.split("|")[rank]
# Return hidden states from target model if the draft model is an
# mlp_speculator
speculative_config = self.speculative_config
model_config = self.model_config
speculative_args = {} if speculative_config is None \
or (speculative_config.draft_model_config.model ==
model_config.model) \
or (speculative_config.draft_model_config.hf_config.model_type
not in ["medusa", "mlp_speculator", "eagle"]) \
else {"return_hidden_states": True}
ModelRunnerClass: Type[CPUModelRunnerBase] = CPUModelRunner
if self.model_config.task == "embedding":
ModelRunnerClass = CPUEmbeddingModelRunner
@ -159,7 +170,11 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
self.model_runner: CPUModelRunnerBase = ModelRunnerClass(
vllm_config=vllm_config,
kv_cache_dtype=kv_cache_dtype,
is_driver_worker=is_driver_worker)
is_driver_worker=is_driver_worker,
**speculative_args,
)
if model_runner_cls is not None:
self.model_runner = model_runner_cls(self.model_runner)
# Uninitialized cache engine. Will be initialized by
# initialize_cache.
self.cache_engine: List[CPUCacheEngine]
@ -197,7 +212,7 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
ret = torch.ops._C_utils.init_cpu_threads_env(self.local_omp_cpuid)
if ret:
logger.info(ret)
self.device = torch.device("cpu")
self.init_distributed_environment()
# Set random seed.
set_random_seed(self.model_config.seed)
@ -297,6 +312,14 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
return self.cpu_cache
@property
def vocab_size(self) -> int:
return self.model_runner.vocab_size
@property
def max_model_len(self) -> int:
return self.model_config.max_model_len
def execute_worker(
self,
worker_input: WorkerInput,

View File

@ -289,3 +289,18 @@ class ModelRunnerBase(ABC, Generic[T]):
self.generators.pop(request_id, None)
return self.generators
class ModelRunnerWrapperBase:
"""
The whole point of this class is to lazily initialize the model_runner.
"""
def __init__(
self,
moderl_runner: ModelRunnerBase,
) -> None:
self.model_runner: ModelRunnerBase = moderl_runner
def __getattr__(self, attr):
return getattr(self.model_runner, attr)

View File

@ -74,9 +74,7 @@ class Worker(LocalOrDistributedWorkerBase):
else {"return_hidden_states": True}
ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner
if model_runner_cls is not None:
ModelRunnerClass = model_runner_cls
elif model_config.task == "embedding":
if model_config.task == "embedding":
ModelRunnerClass = EmbeddingModelRunner
elif self.model_config.is_encoder_decoder:
ModelRunnerClass = EncoderDecoderModelRunner
@ -86,6 +84,9 @@ class Worker(LocalOrDistributedWorkerBase):
is_driver_worker=is_driver_worker,
**speculative_args,
)
if model_runner_cls is not None:
self.model_runner = model_runner_cls(self.model_runner)
# Uninitialized cache engine. Will be initialized by
# initialize_cache.
self.cache_engine: List[CacheEngine]

View File

@ -466,6 +466,9 @@ class WorkerWrapperBase:
logger.exception(msg)
raise e
def __getattr__(self, attr):
return getattr(self.worker, attr)
def extract_previous_hidden_states(
data: Union[ExecuteModelRequest, Dict[str, torch.Tensor]]) -> \