mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
Remove hard-dependencies of Speculative decode to CUDA workers (#10587)
Signed-off-by: Chendi Xue <chendi.xue@intel.com>
This commit is contained in:
@ -595,8 +595,8 @@ def test_init_device(acceptance_sampler_method: str):
|
||||
|
||||
target_worker.init_device.assert_called_once()
|
||||
|
||||
metrics_collector.init_gpu_tensors.assert_called_once()
|
||||
spec_decode_sampler.init_gpu_tensors.assert_called_once()
|
||||
metrics_collector.init_tensors.assert_called_once()
|
||||
spec_decode_sampler.init_tensors.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("acceptance_sampler_method",
|
||||
|
@ -990,6 +990,7 @@ class ParallelConfig:
|
||||
# the full name of the worker class to use. If "auto", the worker class
|
||||
# will be determined based on the platform.
|
||||
worker_cls: str = "auto"
|
||||
sd_worker_cls: str = "auto"
|
||||
|
||||
world_size: int = field(init=False)
|
||||
|
||||
|
@ -43,6 +43,21 @@ class SpecDecodeBaseSampler(nn.Module):
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
|
||||
def init_tensors(self,
|
||||
device: Union[int, str],
|
||||
device_type: Union[torch.device, str] = 'cuda') -> None:
|
||||
assert self.num_accepted_tokens is None
|
||||
if isinstance(device_type, torch.device):
|
||||
device_type = device_type.type
|
||||
if isinstance(device, int):
|
||||
device = f"{device_type}:{device}"
|
||||
self.num_accepted_tokens = torch.tensor(0,
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
self.num_emitted_tokens = torch.tensor(0,
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
|
||||
@property
|
||||
def probs_dtype(self):
|
||||
return torch.float32
|
||||
@ -77,7 +92,7 @@ class SpecDecodeBaseSampler(nn.Module):
|
||||
tensor is [batch_size, k + num_bonus_tokens]
|
||||
"""
|
||||
batch_size, k = substitute_token_ids.shape
|
||||
bonus_token_ids = bonus_token_ids.squeeze()
|
||||
bonus_token_ids = bonus_token_ids.squeeze(-1)
|
||||
# Determine the index of the first False value for each row.
|
||||
limits = (accepted == 0).max(1).indices
|
||||
limits[~(accepted == 0).any(1)] = k
|
||||
|
@ -86,4 +86,10 @@ class CpuPlatform(Platform):
|
||||
parallel_config.distributed_executor_backend)
|
||||
parallel_config.distributed_executor_backend = "mp"
|
||||
if parallel_config.worker_cls == "auto":
|
||||
parallel_config.worker_cls = "vllm.worker.cpu_worker.CPUWorker"
|
||||
if vllm_config.speculative_config:
|
||||
parallel_config.worker_cls = \
|
||||
"vllm.spec_decode.spec_decode_worker.create_spec_worker"
|
||||
parallel_config.sd_worker_cls = \
|
||||
"vllm.worker.cpu_worker.CPUWorker"
|
||||
else:
|
||||
parallel_config.worker_cls = "vllm.worker.cpu_worker.CPUWorker"
|
||||
|
@ -106,6 +106,8 @@ class CudaPlatformBase(Platform):
|
||||
elif vllm_config.speculative_config:
|
||||
parallel_config.worker_cls = \
|
||||
"vllm.spec_decode.spec_decode_worker.create_spec_worker"
|
||||
parallel_config.sd_worker_cls = \
|
||||
"vllm.worker.worker.Worker"
|
||||
else:
|
||||
parallel_config.worker_cls = "vllm.worker.worker.Worker"
|
||||
|
||||
@ -236,4 +238,4 @@ try:
|
||||
if not isinstance(pynvml, _MockModule):
|
||||
CudaPlatform.log_warnings()
|
||||
except ModuleNotFoundError:
|
||||
CudaPlatform.log_warnings()
|
||||
CudaPlatform.log_warnings()
|
@ -20,8 +20,9 @@ except (ModuleNotFoundError, ImportError) as err:
|
||||
from vllm.logger import init_logger
|
||||
from vllm.multimodal import MultiModalKwargs
|
||||
from vllm.sequence import ExecuteModelRequest, IntermediateTensors
|
||||
from vllm.worker.model_runner import (ModelInputForGPUWithSamplingMetadata,
|
||||
ModelRunner)
|
||||
from vllm.worker.model_runner_base import (ModelRunnerBase,
|
||||
ModelRunnerInputBase,
|
||||
ModelRunnerWrapperBase)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -33,7 +34,7 @@ debug_advance_input = False
|
||||
allow_gpu_advance_step = True
|
||||
|
||||
|
||||
class TP1DraftModelRunner(ModelRunner):
|
||||
class TP1DraftModelRunner(ModelRunnerWrapperBase):
|
||||
"""Specialized model runner for speculative decoding draft model.
|
||||
Since the draft model always execute k forward passes consecutively to
|
||||
generate k speculative tokens in a single speculative decoding step,
|
||||
@ -46,13 +47,14 @@ class TP1DraftModelRunner(ModelRunner):
|
||||
any broadcasting inside execute_model).
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
if kwargs.get("return_hidden_states"):
|
||||
def __init__(self, model_runner: ModelRunnerBase):
|
||||
if hasattr(
|
||||
model_runner,
|
||||
"return_hidden_states") and model_runner.return_hidden_states:
|
||||
raise ValueError(
|
||||
"return_hidden_states is not supported for TP1DraftModelRunner."
|
||||
)
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
super().__init__(model_runner)
|
||||
|
||||
self.indices_of_seq_with_bonus_tokens = None
|
||||
|
||||
@ -73,10 +75,8 @@ class TP1DraftModelRunner(ModelRunner):
|
||||
assert seq_group.prompt_logprob_indices == [] # No prompt
|
||||
assert seq_group.sample_indices == [i] # Simple
|
||||
|
||||
def _gpu_advance_step(
|
||||
self, model_input: ModelInputForGPUWithSamplingMetadata,
|
||||
last_output: SamplerOutput
|
||||
) -> ModelInputForGPUWithSamplingMetadata:
|
||||
def _gpu_advance_step(self, model_input: ModelRunnerInputBase,
|
||||
last_output: SamplerOutput) -> ModelRunnerInputBase:
|
||||
# Currently, we expect "decode mode" only
|
||||
assert not model_input.is_prompt
|
||||
|
||||
@ -168,7 +168,7 @@ class TP1DraftModelRunner(ModelRunner):
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
self,
|
||||
model_input: ModelInputForGPUWithSamplingMetadata,
|
||||
model_input: ModelRunnerInputBase,
|
||||
kv_caches: List[torch.Tensor],
|
||||
previous_hidden_states: Optional[torch.Tensor] = None,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
|
@ -1,6 +1,6 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Set
|
||||
from typing import Optional, Set, Union
|
||||
|
||||
import torch
|
||||
|
||||
@ -75,9 +75,11 @@ class SpeculativeProposer(ABC):
|
||||
|
||||
class SpeculativeScorer(ABC):
|
||||
|
||||
def __init__(self, scorer_worker: WorkerBase, device: str,
|
||||
vocab_size: int):
|
||||
def __init__(self, scorer_worker: WorkerBase,
|
||||
device: Union[torch.device, str], vocab_size: int):
|
||||
self._scorer_worker = scorer_worker
|
||||
if isinstance(device, torch.device):
|
||||
device = device.type
|
||||
self._device = device
|
||||
self._vocab_size = vocab_size
|
||||
|
||||
|
@ -9,21 +9,22 @@ from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata
|
||||
from vllm.spec_decode.interfaces import SpeculativeProposals
|
||||
from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
|
||||
from vllm.spec_decode.top1_proposer import Top1Proposer
|
||||
from vllm.worker.worker import Worker
|
||||
from vllm.worker.worker_base import WorkerWrapperBase
|
||||
|
||||
|
||||
class MedusaWorker(NonLLMProposerWorkerBase, Worker):
|
||||
class MedusaWorker(NonLLMProposerWorkerBase, WorkerWrapperBase):
|
||||
"""Worker for Medusa.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
super().__init__(kwargs.get("vllm_config"))
|
||||
self.init_worker(*args, **kwargs)
|
||||
|
||||
# Lazy initialization list.
|
||||
self._proposer: Top1Proposer
|
||||
|
||||
def init_device(self):
|
||||
super().init_device()
|
||||
self.worker.init_device()
|
||||
|
||||
self._proposer = Top1Proposer(
|
||||
weakref.proxy(self), # type: ignore[arg-type]
|
||||
|
@ -1,11 +1,12 @@
|
||||
import time
|
||||
from typing import Callable, Optional
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import msgspec
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.spec_decode_base_sampler import (
|
||||
SpecDecodeBaseSampler)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import is_pin_memory_available
|
||||
|
||||
|
||||
@ -81,8 +82,20 @@ class AsyncMetricsCollector:
|
||||
self._rank = rank
|
||||
self._copy_stream = torch.cuda.Stream()
|
||||
|
||||
def init_tensors(self,
|
||||
rank: int,
|
||||
device_type: Union[torch.device, str] = 'cuda') -> None:
|
||||
self._rank = rank
|
||||
if isinstance(device_type, torch.device):
|
||||
device_type = device_type.type
|
||||
if device_type == 'cuda':
|
||||
self._copy_stream = torch.cuda.Stream()
|
||||
|
||||
def maybe_collect_rejsample_metrics(
|
||||
self, k: int) -> Optional[SpecDecodeWorkerMetrics]:
|
||||
# currently using cuda.Event, skip for any non_cuda_alike platform
|
||||
if not current_platform.is_cuda_alike():
|
||||
return None
|
||||
|
||||
# If a copy was initiated in the previous call, collect and return.
|
||||
if self._in_flight_copy is not None:
|
||||
|
@ -5,17 +5,21 @@ from typing import Dict, List, Set, Tuple
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import (ExecuteModelRequest, HiddenStates, SequenceData,
|
||||
SequenceGroupMetadata)
|
||||
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
|
||||
|
||||
if current_platform.is_cuda_alike():
|
||||
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
|
||||
|
||||
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
||||
SpeculativeProposer)
|
||||
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
|
||||
from vllm.spec_decode.top1_proposer import Top1Proposer
|
||||
from vllm.worker.worker import Worker
|
||||
from vllm.worker.worker_base import WorkerWrapperBase
|
||||
|
||||
|
||||
class MultiStepWorker(Worker, ProposerWorkerBase):
|
||||
class MultiStepWorker(ProposerWorkerBase, WorkerWrapperBase):
|
||||
"""The MultiStepWorker is equivalent to a Worker except that it allows
|
||||
multiple forward passes in a single call, assuming the scheduler has
|
||||
allocated enough space to store the additional KV. This reduces overhead
|
||||
@ -28,13 +32,14 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
super().__init__(kwargs.get("vllm_config"))
|
||||
self.init_worker(*args, **kwargs)
|
||||
|
||||
# Lazy initialization list.
|
||||
self._proposer: SpeculativeProposer
|
||||
|
||||
def init_device(self) -> None:
|
||||
super().init_device()
|
||||
self.worker.init_device()
|
||||
|
||||
self._proposer = Top1Proposer(
|
||||
weakref.proxy(self), # type: ignore[arg-type]
|
||||
@ -51,6 +56,18 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
|
||||
self.model_runner.model.sampler.should_modify_greedy_probs_inplace = (
|
||||
True)
|
||||
|
||||
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||
return self.worker.determine_num_available_blocks()
|
||||
|
||||
def get_cache_block_size_bytes(self) -> int:
|
||||
return self.worker.get_cache_block_size_bytes()
|
||||
|
||||
def initialize_cache(self, *args, **kwargs) -> None:
|
||||
self.worker.initialize_cache(*args, **kwargs)
|
||||
|
||||
def execute_model(self, *args, **kwargs) -> List[SamplerOutput]:
|
||||
return self.worker.execute_model(*args, **kwargs)
|
||||
|
||||
@torch.inference_mode()
|
||||
def sampler_output(
|
||||
self,
|
||||
@ -75,7 +92,7 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
|
||||
|
||||
# Run model sample_len times.
|
||||
model_outputs: List[SamplerOutput] = []
|
||||
if isinstance(
|
||||
if current_platform.is_cuda_alike() and isinstance(
|
||||
self.model_runner, TP1DraftModelRunner
|
||||
) and self.model_runner.supports_gpu_multi_step(expanded_request):
|
||||
# Here we run the draft_model_runner with multi-step prepare
|
||||
@ -92,7 +109,7 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
|
||||
# and other restrictions that are part of DraftModelRunner's
|
||||
# supports_gpu_multi_step(..)
|
||||
for _ in range(sample_len):
|
||||
model_output: List[SamplerOutput] = super().execute_model(
|
||||
model_output: List[SamplerOutput] = self.worker.execute_model(
|
||||
execute_model_req=expanded_request)
|
||||
assert (len(model_output) == 1
|
||||
), "composing multistep workers not supported"
|
||||
|
@ -22,6 +22,7 @@ class NGramWorker(NonLLMProposerWorkerBase):
|
||||
# Get local_rank/vocab_size from kwargs attribute
|
||||
self.local_rank = kwargs["local_rank"]
|
||||
self.vocab_size = kwargs["vllm_config"].model_config.get_vocab_size()
|
||||
self.device_type = kwargs.get("device_type", "cuda")
|
||||
|
||||
# Lazy initialization list.
|
||||
self._proposer: Top1Proposer
|
||||
@ -34,7 +35,7 @@ class NGramWorker(NonLLMProposerWorkerBase):
|
||||
self.ngram_prompt_lookup_min = ngram_prompt_lookup_min
|
||||
|
||||
def init_device(self):
|
||||
self.device = torch.device(f"cuda:{self.local_rank}")
|
||||
self.device = torch.device(f"{self.device_type}:{self.local_rank}")
|
||||
self.load_model = lambda *args, **kwargs: None
|
||||
|
||||
# Current NGramWorker only supports Top1Proposer
|
||||
|
@ -14,12 +14,16 @@ from vllm.model_executor.layers.spec_decode_base_sampler import (
|
||||
SpecDecodeBaseSampler, SpecDecodeStochasticBaseSampler)
|
||||
from vllm.model_executor.layers.typical_acceptance_sampler import (
|
||||
TypicalAcceptanceSampler)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import (VLLM_INVALID_TOKEN_ID,
|
||||
CompletionSequenceGroupOutput, ExecuteModelRequest,
|
||||
HiddenStates, SequenceGroupMetadata,
|
||||
get_all_seq_ids_and_request_ids)
|
||||
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
|
||||
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
|
||||
|
||||
if current_platform.is_cuda_alike():
|
||||
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
|
||||
|
||||
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
||||
SpeculativeScorer, SpeculativeScores)
|
||||
from vllm.spec_decode.medusa_worker import MedusaWorker
|
||||
@ -36,8 +40,8 @@ from vllm.spec_decode.util import (Timer, create_logprobs_output,
|
||||
get_all_num_logprobs,
|
||||
get_sampled_token_logprobs, nvtx_range,
|
||||
split_batch_by_proposal_len)
|
||||
from vllm.worker.worker import Worker
|
||||
from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase
|
||||
from vllm.worker.worker_base import (LoraNotSupportedWorkerBase, WorkerBase,
|
||||
WorkerWrapperBase)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -53,7 +57,11 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
|
||||
draft_worker_kwargs = kwargs.copy()
|
||||
|
||||
kwargs["model_runner_cls"] = TargetModelRunner
|
||||
target_worker = Worker(*args, **kwargs)
|
||||
target_worker_config = copy.deepcopy(vllm_config)
|
||||
target_worker_config.parallel_config.worker_cls =\
|
||||
target_worker_config.parallel_config.sd_worker_cls
|
||||
target_worker = WorkerWrapperBase(vllm_config=target_worker_config)
|
||||
target_worker.init_worker(*args, **kwargs)
|
||||
# Set the disable_logprobs variable in the TargetModelRunner instance
|
||||
# as per its value specified in the SpeculativeConfig.
|
||||
target_worker.model_runner.disable_logprobs =\
|
||||
@ -65,6 +73,8 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
|
||||
draft_worker_config.model_config,
|
||||
vllm_config.load_config,
|
||||
)
|
||||
speculative_config.draft_parallel_config.worker_cls =\
|
||||
draft_worker_config.parallel_config.sd_worker_cls
|
||||
draft_worker_config.parallel_config = speculative_config.draft_parallel_config # noqa
|
||||
# TODO allow draft-model specific load config.
|
||||
|
||||
@ -125,7 +135,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
@classmethod
|
||||
def create_worker(
|
||||
cls,
|
||||
scorer_worker: Worker,
|
||||
scorer_worker: WorkerBase,
|
||||
draft_worker_kwargs: Dict[str, Any],
|
||||
disable_mqa_scorer: bool,
|
||||
disable_by_batch_size: Optional[int],
|
||||
@ -145,6 +155,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
draft_parallel_config: ParallelConfig = draft_worker_kwargs[
|
||||
'vllm_config'].parallel_config
|
||||
if ngram_prompt_lookup_max > 0:
|
||||
draft_worker_kwargs[
|
||||
"device_type"] = scorer_worker.device_config.device.type
|
||||
proposer_worker = NGramWorker(**draft_worker_kwargs)
|
||||
proposer_worker.set_ngram_window_size(ngram_prompt_lookup_min,
|
||||
ngram_prompt_lookup_max)
|
||||
@ -158,8 +170,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
proposer_worker = MedusaWorker(**draft_worker_kwargs)
|
||||
else:
|
||||
if draft_tp == 1:
|
||||
draft_worker_kwargs[
|
||||
"model_runner_cls"] = TP1DraftModelRunner
|
||||
if current_platform.is_cuda_alike():
|
||||
draft_worker_kwargs[
|
||||
"model_runner_cls"] = TP1DraftModelRunner
|
||||
else:
|
||||
if draft_model_config.hf_config.model_type == "eagle":
|
||||
raise NotImplementedError(
|
||||
@ -306,8 +319,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
self.scorer_worker.load_model()
|
||||
self.proposer_worker.load_model()
|
||||
|
||||
self._metrics.init_gpu_tensors(self.rank)
|
||||
self.spec_decode_sampler.init_gpu_tensors(self.rank)
|
||||
self._metrics.init_tensors(self.rank, device_type=self.device)
|
||||
self.spec_decode_sampler.init_tensors(self.rank,
|
||||
device_type=self.device)
|
||||
|
||||
scorer_cls: Type[SpeculativeScorer]
|
||||
if self.disable_mqa_scorer:
|
||||
@ -1111,11 +1125,11 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
raise NotImplementedError
|
||||
|
||||
def start_profile(self):
|
||||
if isinstance(self.scorer_worker, Worker):
|
||||
if isinstance(self.scorer_worker, WorkerBase):
|
||||
self.scorer_worker.start_profile()
|
||||
|
||||
def stop_profile(self):
|
||||
if isinstance(self.scorer_worker, Worker):
|
||||
if isinstance(self.scorer_worker, WorkerBase):
|
||||
self.scorer_worker.stop_profile()
|
||||
|
||||
|
||||
|
@ -1,12 +1,12 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.sequence import SequenceGroupMetadata
|
||||
from vllm.worker.model_runner import (ModelInputForGPUWithSamplingMetadata,
|
||||
ModelRunner)
|
||||
from vllm.worker.model_runner_base import (ModelRunnerBase,
|
||||
ModelRunnerInputBase,
|
||||
ModelRunnerWrapperBase)
|
||||
|
||||
|
||||
class TargetModelRunner(ModelRunner):
|
||||
class TargetModelRunner(ModelRunnerWrapperBase):
|
||||
"""Specialized model runner for speculative decoding target model.
|
||||
In speculative decoding, the log probabilities selected finally may not
|
||||
be the same ones as selected by the target model sampling. This means
|
||||
@ -18,32 +18,21 @@ class TargetModelRunner(ModelRunner):
|
||||
requested or not.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
kv_cache_dtype: Optional[str] = "auto",
|
||||
is_driver_worker: bool = False,
|
||||
return_hidden_states: bool = False,
|
||||
):
|
||||
def __init__(self, model_runner: ModelRunnerBase):
|
||||
# An internal boolean member variable to indicate if token log
|
||||
# probabilities are needed or not.
|
||||
super().__init__(model_runner)
|
||||
self.disable_logprobs = True
|
||||
super().__init__(
|
||||
vllm_config=vllm_config,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
is_driver_worker=is_driver_worker,
|
||||
return_hidden_states=return_hidden_states,
|
||||
)
|
||||
|
||||
def prepare_model_input(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
virtual_engine: int = 0,
|
||||
finished_requests_ids: Optional[List[str]] = None
|
||||
) -> ModelInputForGPUWithSamplingMetadata:
|
||||
model_input: ModelInputForGPUWithSamplingMetadata = super(
|
||||
).prepare_model_input(seq_group_metadata_list, virtual_engine,
|
||||
finished_requests_ids)
|
||||
finished_requests_ids: Optional[List[str]] = None,
|
||||
) -> ModelRunnerInputBase:
|
||||
model_input: ModelRunnerInputBase =\
|
||||
self.model_runner.prepare_model_input(
|
||||
seq_group_metadata_list, virtual_engine, finished_requests_ids)
|
||||
# If token log probabilities is disabled then skip generating sampler
|
||||
# CPU output. We directly serialize the GPU sampled_token_id tensors
|
||||
# as needed. If log probabilities is enabled then synchronize all the
|
||||
|
@ -5,6 +5,7 @@ from typing import Dict, List, Optional, Sequence, Tuple
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
|
||||
PromptLogprobs, SequenceGroupMetadata,
|
||||
SequenceOutput)
|
||||
@ -247,11 +248,14 @@ def nvtx_range(msg, *args, **kwargs):
|
||||
Arguments:
|
||||
msg (string): message to associate with the range
|
||||
"""
|
||||
torch.cuda.nvtx.range_push(msg.format(*args, **kwargs))
|
||||
try:
|
||||
if current_platform.is_cuda_alike():
|
||||
torch.cuda.nvtx.range_push(msg.format(*args, **kwargs))
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
torch.cuda.nvtx.range_pop()
|
||||
else:
|
||||
yield
|
||||
finally:
|
||||
torch.cuda.nvtx.range_pop()
|
||||
|
||||
|
||||
class Timer:
|
||||
|
@ -80,6 +80,7 @@ class ModelInputForCPUWithSamplingMetadata(ModelInputForCPU):
|
||||
Used by the ModelRunner.
|
||||
"""
|
||||
sampling_metadata: Optional["SamplingMetadata"] = None
|
||||
is_prompt: Optional[bool] = None
|
||||
|
||||
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
|
||||
tensor_dict = {
|
||||
@ -395,6 +396,7 @@ class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]):
|
||||
vllm_config: VllmConfig,
|
||||
kv_cache_dtype: Optional[str] = "auto",
|
||||
is_driver_worker: bool = False,
|
||||
return_hidden_states: bool = False,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
@ -403,19 +405,25 @@ class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]):
|
||||
cache_config = self.cache_config
|
||||
|
||||
self.is_driver_worker = is_driver_worker
|
||||
self.return_hidden_states = return_hidden_states
|
||||
|
||||
self.device = self.device_config.device
|
||||
self.pin_memory = False
|
||||
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
self.sliding_window = model_config.get_sliding_window()
|
||||
self.block_size = cache_config.block_size
|
||||
num_attn_heads = self.model_config.get_num_attention_heads(
|
||||
self.parallel_config)
|
||||
needs_attn_backend = (num_attn_heads != 0
|
||||
or self.model_config.is_attention_free)
|
||||
self.attn_backend = get_attn_backend(
|
||||
self.model_config.get_head_size(),
|
||||
self.model_config.dtype,
|
||||
self.kv_cache_dtype,
|
||||
self.block_size,
|
||||
self.model_config.is_attention_free,
|
||||
)
|
||||
) if needs_attn_backend else None
|
||||
|
||||
# Multi-modal data support
|
||||
self.mm_registry = MULTIMODAL_REGISTRY
|
||||
@ -444,6 +452,15 @@ class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]):
|
||||
|
||||
return builder.build() # type: ignore
|
||||
|
||||
# sampler property will be used by spec_decode_worker
|
||||
@property
|
||||
def sampler(self):
|
||||
return self.model.sampler
|
||||
|
||||
@property
|
||||
def vocab_size(self) -> int:
|
||||
return self.model_config.get_vocab_size()
|
||||
|
||||
|
||||
class CPUModelRunner(CPUModelRunnerBase[ModelInputForCPUWithSamplingMetadata]):
|
||||
_model_input_cls: Type[ModelInputForCPUWithSamplingMetadata] = (
|
||||
@ -480,9 +497,12 @@ class CPUModelRunner(CPUModelRunnerBase[ModelInputForCPUWithSamplingMetadata]):
|
||||
pin_memory=False,
|
||||
generators=generators)
|
||||
|
||||
is_prompt = (seq_group_metadata_list[0].is_prompt
|
||||
if seq_group_metadata_list else None)
|
||||
return dataclasses.replace(model_input,
|
||||
sampling_metadata=sampling_metadata,
|
||||
virtual_engine=virtual_engine)
|
||||
virtual_engine=virtual_engine,
|
||||
is_prompt=is_prompt)
|
||||
|
||||
@torch.no_grad()
|
||||
def execute_model(
|
||||
@ -491,16 +511,22 @@ class CPUModelRunner(CPUModelRunnerBase[ModelInputForCPUWithSamplingMetadata]):
|
||||
kv_caches: List[torch.Tensor],
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
num_steps: int = 1,
|
||||
previous_hidden_states: Optional[torch.Tensor] = None,
|
||||
) -> Optional[List[SamplerOutput]]:
|
||||
if num_steps > 1:
|
||||
raise ValueError(
|
||||
"CPU worker does not support multi-step execution.")
|
||||
|
||||
model_executable = self.model
|
||||
|
||||
multimodal_kwargs = {}
|
||||
if model_input.multi_modal_kwargs is not None:
|
||||
multimodal_kwargs = MultiModalKwargs.as_kwargs(
|
||||
model_input.multi_modal_kwargs, device=self.device)
|
||||
execute_model_kwargs = {}
|
||||
if previous_hidden_states is not None:
|
||||
execute_model_kwargs.update(
|
||||
{"previous_hidden_states": previous_hidden_states})
|
||||
|
||||
with set_forward_context(model_input.attn_metadata, self.vllm_config):
|
||||
hidden_states = model_executable(
|
||||
@ -509,6 +535,7 @@ class CPUModelRunner(CPUModelRunnerBase[ModelInputForCPUWithSamplingMetadata]):
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=model_input.attn_metadata,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
**execute_model_kwargs,
|
||||
**multimodal_kwargs,
|
||||
)
|
||||
|
||||
@ -525,4 +552,12 @@ class CPUModelRunner(CPUModelRunnerBase[ModelInputForCPUWithSamplingMetadata]):
|
||||
logits=logits,
|
||||
sampling_metadata=model_input.sampling_metadata,
|
||||
)
|
||||
if self.return_hidden_states:
|
||||
# we only need to pass hidden states of most recent token
|
||||
if model_input.is_prompt:
|
||||
output.prefill_hidden_states = hidden_states
|
||||
output.hidden_states = hidden_states
|
||||
return [output]
|
||||
|
||||
def generate_proposals(self, *args, **kwargs):
|
||||
return self.model.generate_proposals(*args, **kwargs)
|
||||
|
@ -128,6 +128,7 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
|
||||
distributed_init_method: str,
|
||||
kv_cache_dtype: Optional[str] = "auto",
|
||||
is_driver_worker: bool = False,
|
||||
model_runner_cls: Optional[Type[CPUModelRunner]] = None,
|
||||
) -> None:
|
||||
WorkerBase.__init__(self, vllm_config=vllm_config)
|
||||
|
||||
@ -151,6 +152,16 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
|
||||
else:
|
||||
self.local_omp_cpuid = omp_cpuids.split("|")[rank]
|
||||
|
||||
# Return hidden states from target model if the draft model is an
|
||||
# mlp_speculator
|
||||
speculative_config = self.speculative_config
|
||||
model_config = self.model_config
|
||||
speculative_args = {} if speculative_config is None \
|
||||
or (speculative_config.draft_model_config.model ==
|
||||
model_config.model) \
|
||||
or (speculative_config.draft_model_config.hf_config.model_type
|
||||
not in ["medusa", "mlp_speculator", "eagle"]) \
|
||||
else {"return_hidden_states": True}
|
||||
ModelRunnerClass: Type[CPUModelRunnerBase] = CPUModelRunner
|
||||
if self.model_config.task == "embedding":
|
||||
ModelRunnerClass = CPUEmbeddingModelRunner
|
||||
@ -159,7 +170,11 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
|
||||
self.model_runner: CPUModelRunnerBase = ModelRunnerClass(
|
||||
vllm_config=vllm_config,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
is_driver_worker=is_driver_worker)
|
||||
is_driver_worker=is_driver_worker,
|
||||
**speculative_args,
|
||||
)
|
||||
if model_runner_cls is not None:
|
||||
self.model_runner = model_runner_cls(self.model_runner)
|
||||
# Uninitialized cache engine. Will be initialized by
|
||||
# initialize_cache.
|
||||
self.cache_engine: List[CPUCacheEngine]
|
||||
@ -197,7 +212,7 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
|
||||
ret = torch.ops._C_utils.init_cpu_threads_env(self.local_omp_cpuid)
|
||||
if ret:
|
||||
logger.info(ret)
|
||||
|
||||
self.device = torch.device("cpu")
|
||||
self.init_distributed_environment()
|
||||
# Set random seed.
|
||||
set_random_seed(self.model_config.seed)
|
||||
@ -297,6 +312,14 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
|
||||
def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
|
||||
return self.cpu_cache
|
||||
|
||||
@property
|
||||
def vocab_size(self) -> int:
|
||||
return self.model_runner.vocab_size
|
||||
|
||||
@property
|
||||
def max_model_len(self) -> int:
|
||||
return self.model_config.max_model_len
|
||||
|
||||
def execute_worker(
|
||||
self,
|
||||
worker_input: WorkerInput,
|
||||
|
@ -289,3 +289,18 @@ class ModelRunnerBase(ABC, Generic[T]):
|
||||
self.generators.pop(request_id, None)
|
||||
|
||||
return self.generators
|
||||
|
||||
|
||||
class ModelRunnerWrapperBase:
|
||||
"""
|
||||
The whole point of this class is to lazily initialize the model_runner.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
moderl_runner: ModelRunnerBase,
|
||||
) -> None:
|
||||
self.model_runner: ModelRunnerBase = moderl_runner
|
||||
|
||||
def __getattr__(self, attr):
|
||||
return getattr(self.model_runner, attr)
|
||||
|
@ -74,9 +74,7 @@ class Worker(LocalOrDistributedWorkerBase):
|
||||
else {"return_hidden_states": True}
|
||||
|
||||
ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner
|
||||
if model_runner_cls is not None:
|
||||
ModelRunnerClass = model_runner_cls
|
||||
elif model_config.task == "embedding":
|
||||
if model_config.task == "embedding":
|
||||
ModelRunnerClass = EmbeddingModelRunner
|
||||
elif self.model_config.is_encoder_decoder:
|
||||
ModelRunnerClass = EncoderDecoderModelRunner
|
||||
@ -86,6 +84,9 @@ class Worker(LocalOrDistributedWorkerBase):
|
||||
is_driver_worker=is_driver_worker,
|
||||
**speculative_args,
|
||||
)
|
||||
if model_runner_cls is not None:
|
||||
self.model_runner = model_runner_cls(self.model_runner)
|
||||
|
||||
# Uninitialized cache engine. Will be initialized by
|
||||
# initialize_cache.
|
||||
self.cache_engine: List[CacheEngine]
|
||||
|
@ -466,6 +466,9 @@ class WorkerWrapperBase:
|
||||
logger.exception(msg)
|
||||
raise e
|
||||
|
||||
def __getattr__(self, attr):
|
||||
return getattr(self.worker, attr)
|
||||
|
||||
|
||||
def extract_previous_hidden_states(
|
||||
data: Union[ExecuteModelRequest, Dict[str, torch.Tensor]]) -> \
|
||||
|
Reference in New Issue
Block a user