mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
[Hardware][Neuron] Add on-device sampling support for Neuron (#8746)
Co-authored-by: Ashraf Mahgoub <ashymahg@amazon.com>
This commit is contained in:
@ -1,4 +1,5 @@
|
|||||||
"""Utilities for selecting and loading neuron models."""
|
"""Utilities for selecting and loading neuron models."""
|
||||||
|
import copy
|
||||||
import importlib
|
import importlib
|
||||||
import os
|
import os
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
@ -13,6 +14,8 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
|||||||
from vllm.model_executor.layers.quantization import get_quantization_config
|
from vllm.model_executor.layers.quantization import get_quantization_config
|
||||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
|
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
|
||||||
|
SequenceOutput)
|
||||||
|
|
||||||
TORCH_DTYPE_TO_NEURON_AMP = {
|
TORCH_DTYPE_TO_NEURON_AMP = {
|
||||||
"auto": "f32",
|
"auto": "f32",
|
||||||
@ -37,15 +40,18 @@ _NEURON_SUPPORTED_MODELS: Dict[str, Tuple[str, str, str]] = {
|
|||||||
|
|
||||||
class NeuronCasualLM(nn.Module):
|
class NeuronCasualLM(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(self,
|
||||||
self,
|
config: PretrainedConfig,
|
||||||
config: PretrainedConfig,
|
on_device_sampling_disabled: bool = False) -> None:
|
||||||
) -> None:
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.logits_processor = LogitsProcessor(config.vocab_size,
|
self.logits_processor = LogitsProcessor(config.vocab_size,
|
||||||
logits_as_input=True)
|
logits_as_input=True)
|
||||||
self.sampler = Sampler()
|
|
||||||
|
self.on_device_sampling_disabled = on_device_sampling_disabled
|
||||||
|
if self.on_device_sampling_disabled:
|
||||||
|
# Use default sampler
|
||||||
|
self.sampler = Sampler()
|
||||||
|
|
||||||
# Lazy initialized
|
# Lazy initialized
|
||||||
self.model: nn.Module
|
self.model: nn.Module
|
||||||
@ -71,8 +77,29 @@ class NeuronCasualLM(nn.Module):
|
|||||||
logits: torch.Tensor,
|
logits: torch.Tensor,
|
||||||
sampling_metadata: SamplingMetadata,
|
sampling_metadata: SamplingMetadata,
|
||||||
) -> Optional[SamplerOutput]:
|
) -> Optional[SamplerOutput]:
|
||||||
next_tokens = self.sampler(logits, sampling_metadata)
|
|
||||||
return next_tokens
|
if self.on_device_sampling_disabled:
|
||||||
|
next_tokens = self.sampler(logits, sampling_metadata)
|
||||||
|
return next_tokens
|
||||||
|
|
||||||
|
# On-device sampling outputs the token ids directly.
|
||||||
|
sampled_token_ids = logits.flatten()
|
||||||
|
next_tokens = []
|
||||||
|
sample_idx = 0
|
||||||
|
for seq_group in sampling_metadata.seq_groups:
|
||||||
|
samples = []
|
||||||
|
for seq_id in seq_group.seq_ids:
|
||||||
|
token_id = sampled_token_ids[sample_idx].item()
|
||||||
|
samples.append(
|
||||||
|
SequenceOutput(parent_seq_id=seq_id,
|
||||||
|
output_token=token_id,
|
||||||
|
logprobs={token_id: Logprob(token_id)}))
|
||||||
|
sample_idx += 1
|
||||||
|
next_tokens.append(
|
||||||
|
CompletionSequenceGroupOutput(samples=samples,
|
||||||
|
prompt_logprobs=None))
|
||||||
|
|
||||||
|
return SamplerOutput(outputs=next_tokens)
|
||||||
|
|
||||||
def load_weights(self, model_name_or_path: str, **kwargs):
|
def load_weights(self, model_name_or_path: str, **kwargs):
|
||||||
arch = _get_model_architecture(self.config)
|
arch = _get_model_architecture(self.config)
|
||||||
@ -157,10 +184,22 @@ def _get_default_neuron_config(model_config: ModelConfig,
|
|||||||
quant=neuron_quantization_config_builder(model_config.quantization)
|
quant=neuron_quantization_config_builder(model_config.quantization)
|
||||||
if model_config.quantization else None,
|
if model_config.quantization else None,
|
||||||
continuous_batching=continuous_batching_config,
|
continuous_batching=continuous_batching_config,
|
||||||
weight_tiling=bool(model_config.quantization))
|
weight_tiling=bool(model_config.quantization),
|
||||||
|
on_device_generation=_get_neuron_on_device_generation_config(
|
||||||
|
model_config))
|
||||||
return default_neuron_args
|
return default_neuron_args
|
||||||
|
|
||||||
|
|
||||||
|
def _get_neuron_on_device_generation_config(model_config: ModelConfig):
|
||||||
|
if not _is_neuron_on_device_sampling_disabled(model_config):
|
||||||
|
return copy.deepcopy(model_config.neuron_sampling_params)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _is_neuron_on_device_sampling_disabled(model_config: ModelConfig) -> bool:
|
||||||
|
return not getattr(model_config, "neuron_sampling_params", None)
|
||||||
|
|
||||||
|
|
||||||
def _get_neuron_config_after_override(default_neuron_config,
|
def _get_neuron_config_after_override(default_neuron_config,
|
||||||
overridden_neuron_config):
|
overridden_neuron_config):
|
||||||
from transformers_neuronx.config import NeuronConfig
|
from transformers_neuronx.config import NeuronConfig
|
||||||
@ -174,7 +213,9 @@ def get_neuron_model(model_config: ModelConfig,
|
|||||||
scheduler_config: SchedulerConfig) -> nn.Module:
|
scheduler_config: SchedulerConfig) -> nn.Module:
|
||||||
|
|
||||||
# Create a model instance.
|
# Create a model instance.
|
||||||
model = NeuronCasualLM(model_config.hf_config)
|
model = NeuronCasualLM(
|
||||||
|
model_config.hf_config,
|
||||||
|
_is_neuron_on_device_sampling_disabled(model_config))
|
||||||
|
|
||||||
default_neuron_config_args = _get_default_neuron_config(
|
default_neuron_config_args = _get_default_neuron_config(
|
||||||
model_config, parallel_config, scheduler_config)
|
model_config, parallel_config, scheduler_config)
|
||||||
|
@ -1,9 +1,11 @@
|
|||||||
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from importlib.util import find_spec
|
from importlib.util import find_spec
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
from transformers_neuronx.config import GenerationConfig
|
||||||
|
|
||||||
from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig,
|
from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig,
|
||||||
SchedulerConfig)
|
SchedulerConfig)
|
||||||
@ -50,6 +52,9 @@ class ModelInputForNeuron(ModelRunnerInputBase):
|
|||||||
|
|
||||||
class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
|
class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
|
||||||
|
|
||||||
|
# NEURON has an upper limit on the top_k
|
||||||
|
_MAX_NEURON_SAMPLING_TOP_K = 256
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_config: ModelConfig,
|
model_config: ModelConfig,
|
||||||
@ -76,6 +81,34 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
|
|||||||
# Lazy initialization.
|
# Lazy initialization.
|
||||||
self.model: nn.Module # initialize after load_model.
|
self.model: nn.Module # initialize after load_model.
|
||||||
|
|
||||||
|
# Once NEURON_ON_DEVICE_SAMPLING_DISABLED is set to a non-zero value,
|
||||||
|
# turn off on-device sampling.
|
||||||
|
self._on_device_sampling_disabled = int(
|
||||||
|
os.getenv("NEURON_ON_DEVICE_SAMPLING_DISABLED", "0"))
|
||||||
|
|
||||||
|
# NEURON needs to update sampling parameters when request IDs change
|
||||||
|
# across batches. This variable stores the previous batch's request IDs
|
||||||
|
# to determine if an update is needed.
|
||||||
|
self._previous_batch_request_ids: List[str] = []
|
||||||
|
|
||||||
|
if not self._on_device_sampling_disabled:
|
||||||
|
logger.warning(
|
||||||
|
"On-device sampling is turned on in Neuron by default, only "
|
||||||
|
"top_k, top_p, and temperature are current supported sampling "
|
||||||
|
"parameters. To turn off the on-device sampling, please set "
|
||||||
|
"the environment variable NEURON_ON_DEVICE_SAMPLING_DISABLED=1."
|
||||||
|
)
|
||||||
|
self.model_config.neuron_sampling_params = GenerationConfig(
|
||||||
|
max_length=self.scheduler_config.max_model_len,
|
||||||
|
do_sample=True,
|
||||||
|
per_batch_line=True,
|
||||||
|
top_k=[self._MAX_NEURON_SAMPLING_TOP_K] \
|
||||||
|
* self.scheduler_config.max_num_seqs,
|
||||||
|
top_p=[1.0] * self.scheduler_config.max_num_seqs,
|
||||||
|
temperature=[1.0] * self.scheduler_config.max_num_seqs,
|
||||||
|
dynamic=True,
|
||||||
|
global_top_k=self._MAX_NEURON_SAMPLING_TOP_K)
|
||||||
|
|
||||||
def load_model(self) -> None:
|
def load_model(self) -> None:
|
||||||
if find_spec("transformers_neuronx") is not None:
|
if find_spec("transformers_neuronx") is not None:
|
||||||
self.model = get_neuron_model(
|
self.model = get_neuron_model(
|
||||||
@ -215,7 +248,7 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
|
|||||||
else:
|
else:
|
||||||
(input_tokens, input_positions,
|
(input_tokens, input_positions,
|
||||||
input_block_ids) = self._prepare_decode(seq_group_metadata_list)
|
input_block_ids) = self._prepare_decode(seq_group_metadata_list)
|
||||||
seq_lens = []
|
seq_lens = None
|
||||||
sampling_metadata = SamplingMetadata.prepare(
|
sampling_metadata = SamplingMetadata.prepare(
|
||||||
seq_group_metadata_list,
|
seq_group_metadata_list,
|
||||||
seq_lens,
|
seq_lens,
|
||||||
@ -227,12 +260,49 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
|
|||||||
self.pin_memory,
|
self.pin_memory,
|
||||||
generators=self.get_generators(finished_requests_ids))
|
generators=self.get_generators(finished_requests_ids))
|
||||||
|
|
||||||
|
if not self._on_device_sampling_disabled:
|
||||||
|
# Once the request IDs are changed in current iteration, we will
|
||||||
|
# update the on-device sampling parameters.
|
||||||
|
current_batch_request_ids = [
|
||||||
|
seq_group_meta_data.request_id
|
||||||
|
for seq_group_meta_data in seq_group_metadata_list
|
||||||
|
]
|
||||||
|
if current_batch_request_ids != self._previous_batch_request_ids:
|
||||||
|
self._update_neuron_sampling_params(sampling_metadata)
|
||||||
|
self._previous_batch_request_ids = current_batch_request_ids
|
||||||
|
|
||||||
return ModelInputForNeuron(input_tokens=input_tokens,
|
return ModelInputForNeuron(input_tokens=input_tokens,
|
||||||
input_positions=input_positions,
|
input_positions=input_positions,
|
||||||
input_block_ids=input_block_ids,
|
input_block_ids=input_block_ids,
|
||||||
sampling_metadata=sampling_metadata,
|
sampling_metadata=sampling_metadata,
|
||||||
multi_modal_kwargs=multi_modal_kwargs)
|
multi_modal_kwargs=multi_modal_kwargs)
|
||||||
|
|
||||||
|
def _update_neuron_sampling_params(self,
|
||||||
|
sampling_metadata: SamplingMetadata):
|
||||||
|
# Update Neuron sampling parameters (GenerationConfig in Neuron)
|
||||||
|
current_sampling_params = self.model_config.neuron_sampling_params
|
||||||
|
assert current_sampling_params is not None, (
|
||||||
|
f"Failed to update sampling_params, "
|
||||||
|
f"current sampling params is {current_sampling_params}")
|
||||||
|
|
||||||
|
top_k = current_sampling_params.top_k
|
||||||
|
top_p = current_sampling_params.top_p
|
||||||
|
temperature = current_sampling_params.temperature
|
||||||
|
for index, sequence_group_to_sample in enumerate(
|
||||||
|
sampling_metadata.seq_groups):
|
||||||
|
top_k[index] = self._convert_to_neuron_top_k(
|
||||||
|
sequence_group_to_sample.sampling_params.top_k)
|
||||||
|
top_p[index] = sequence_group_to_sample.sampling_params.top_p
|
||||||
|
temperature[index] = \
|
||||||
|
sequence_group_to_sample.sampling_params.temperature
|
||||||
|
|
||||||
|
self.model.model.update_generation_config(current_sampling_params)
|
||||||
|
|
||||||
|
def _convert_to_neuron_top_k(self, top_k: int) -> int:
|
||||||
|
if top_k < 0 or top_k > self._MAX_NEURON_SAMPLING_TOP_K:
|
||||||
|
return self._MAX_NEURON_SAMPLING_TOP_K
|
||||||
|
return top_k
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def execute_model(
|
def execute_model(
|
||||||
self,
|
self,
|
||||||
@ -253,9 +323,13 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
|
|||||||
device=self.device),
|
device=self.device),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Compute the logits.
|
# Compute the logits only if the on-device sampling is turned off as
|
||||||
logits = self.model.compute_logits(hidden_states,
|
# on-device sampling outputs the token ids.
|
||||||
model_input.sampling_metadata)
|
if self._on_device_sampling_disabled:
|
||||||
|
logits = self.model.compute_logits(hidden_states,
|
||||||
|
model_input.sampling_metadata)
|
||||||
|
else:
|
||||||
|
logits = hidden_states
|
||||||
|
|
||||||
# Sample the next token.
|
# Sample the next token.
|
||||||
output = self.model.sample(
|
output = self.model.sample(
|
||||||
|
Reference in New Issue
Block a user