[Hardware][Neuron] Add on-device sampling support for Neuron (#8746)

Co-authored-by: Ashraf Mahgoub <ashymahg@amazon.com>
This commit is contained in:
Chongming Ni
2024-10-04 16:42:20 -07:00
committed by GitHub
parent 27302dd584
commit cc90419e89
2 changed files with 128 additions and 13 deletions

View File

@ -1,4 +1,5 @@
"""Utilities for selecting and loading neuron models.""" """Utilities for selecting and loading neuron models."""
import copy
import importlib import importlib
import os import os
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
@ -13,6 +14,8 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import get_quantization_config from vllm.model_executor.layers.quantization import get_quantization_config
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
SequenceOutput)
TORCH_DTYPE_TO_NEURON_AMP = { TORCH_DTYPE_TO_NEURON_AMP = {
"auto": "f32", "auto": "f32",
@ -37,15 +40,18 @@ _NEURON_SUPPORTED_MODELS: Dict[str, Tuple[str, str, str]] = {
class NeuronCasualLM(nn.Module): class NeuronCasualLM(nn.Module):
def __init__( def __init__(self,
self, config: PretrainedConfig,
config: PretrainedConfig, on_device_sampling_disabled: bool = False) -> None:
) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.logits_processor = LogitsProcessor(config.vocab_size, self.logits_processor = LogitsProcessor(config.vocab_size,
logits_as_input=True) logits_as_input=True)
self.sampler = Sampler()
self.on_device_sampling_disabled = on_device_sampling_disabled
if self.on_device_sampling_disabled:
# Use default sampler
self.sampler = Sampler()
# Lazy initialized # Lazy initialized
self.model: nn.Module self.model: nn.Module
@ -71,8 +77,29 @@ class NeuronCasualLM(nn.Module):
logits: torch.Tensor, logits: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]: ) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens if self.on_device_sampling_disabled:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
# On-device sampling outputs the token ids directly.
sampled_token_ids = logits.flatten()
next_tokens = []
sample_idx = 0
for seq_group in sampling_metadata.seq_groups:
samples = []
for seq_id in seq_group.seq_ids:
token_id = sampled_token_ids[sample_idx].item()
samples.append(
SequenceOutput(parent_seq_id=seq_id,
output_token=token_id,
logprobs={token_id: Logprob(token_id)}))
sample_idx += 1
next_tokens.append(
CompletionSequenceGroupOutput(samples=samples,
prompt_logprobs=None))
return SamplerOutput(outputs=next_tokens)
def load_weights(self, model_name_or_path: str, **kwargs): def load_weights(self, model_name_or_path: str, **kwargs):
arch = _get_model_architecture(self.config) arch = _get_model_architecture(self.config)
@ -157,10 +184,22 @@ def _get_default_neuron_config(model_config: ModelConfig,
quant=neuron_quantization_config_builder(model_config.quantization) quant=neuron_quantization_config_builder(model_config.quantization)
if model_config.quantization else None, if model_config.quantization else None,
continuous_batching=continuous_batching_config, continuous_batching=continuous_batching_config,
weight_tiling=bool(model_config.quantization)) weight_tiling=bool(model_config.quantization),
on_device_generation=_get_neuron_on_device_generation_config(
model_config))
return default_neuron_args return default_neuron_args
def _get_neuron_on_device_generation_config(model_config: ModelConfig):
if not _is_neuron_on_device_sampling_disabled(model_config):
return copy.deepcopy(model_config.neuron_sampling_params)
return None
def _is_neuron_on_device_sampling_disabled(model_config: ModelConfig) -> bool:
return not getattr(model_config, "neuron_sampling_params", None)
def _get_neuron_config_after_override(default_neuron_config, def _get_neuron_config_after_override(default_neuron_config,
overridden_neuron_config): overridden_neuron_config):
from transformers_neuronx.config import NeuronConfig from transformers_neuronx.config import NeuronConfig
@ -174,7 +213,9 @@ def get_neuron_model(model_config: ModelConfig,
scheduler_config: SchedulerConfig) -> nn.Module: scheduler_config: SchedulerConfig) -> nn.Module:
# Create a model instance. # Create a model instance.
model = NeuronCasualLM(model_config.hf_config) model = NeuronCasualLM(
model_config.hf_config,
_is_neuron_on_device_sampling_disabled(model_config))
default_neuron_config_args = _get_default_neuron_config( default_neuron_config_args = _get_default_neuron_config(
model_config, parallel_config, scheduler_config) model_config, parallel_config, scheduler_config)

View File

@ -1,9 +1,11 @@
import os
from dataclasses import dataclass from dataclasses import dataclass
from importlib.util import find_spec from importlib.util import find_spec
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import torch import torch
from torch import nn from torch import nn
from transformers_neuronx.config import GenerationConfig
from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig, from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig,
SchedulerConfig) SchedulerConfig)
@ -50,6 +52,9 @@ class ModelInputForNeuron(ModelRunnerInputBase):
class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]): class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
# NEURON has an upper limit on the top_k
_MAX_NEURON_SAMPLING_TOP_K = 256
def __init__( def __init__(
self, self,
model_config: ModelConfig, model_config: ModelConfig,
@ -76,6 +81,34 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
# Lazy initialization. # Lazy initialization.
self.model: nn.Module # initialize after load_model. self.model: nn.Module # initialize after load_model.
# Once NEURON_ON_DEVICE_SAMPLING_DISABLED is set to a non-zero value,
# turn off on-device sampling.
self._on_device_sampling_disabled = int(
os.getenv("NEURON_ON_DEVICE_SAMPLING_DISABLED", "0"))
# NEURON needs to update sampling parameters when request IDs change
# across batches. This variable stores the previous batch's request IDs
# to determine if an update is needed.
self._previous_batch_request_ids: List[str] = []
if not self._on_device_sampling_disabled:
logger.warning(
"On-device sampling is turned on in Neuron by default, only "
"top_k, top_p, and temperature are current supported sampling "
"parameters. To turn off the on-device sampling, please set "
"the environment variable NEURON_ON_DEVICE_SAMPLING_DISABLED=1."
)
self.model_config.neuron_sampling_params = GenerationConfig(
max_length=self.scheduler_config.max_model_len,
do_sample=True,
per_batch_line=True,
top_k=[self._MAX_NEURON_SAMPLING_TOP_K] \
* self.scheduler_config.max_num_seqs,
top_p=[1.0] * self.scheduler_config.max_num_seqs,
temperature=[1.0] * self.scheduler_config.max_num_seqs,
dynamic=True,
global_top_k=self._MAX_NEURON_SAMPLING_TOP_K)
def load_model(self) -> None: def load_model(self) -> None:
if find_spec("transformers_neuronx") is not None: if find_spec("transformers_neuronx") is not None:
self.model = get_neuron_model( self.model = get_neuron_model(
@ -215,7 +248,7 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
else: else:
(input_tokens, input_positions, (input_tokens, input_positions,
input_block_ids) = self._prepare_decode(seq_group_metadata_list) input_block_ids) = self._prepare_decode(seq_group_metadata_list)
seq_lens = [] seq_lens = None
sampling_metadata = SamplingMetadata.prepare( sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list, seq_group_metadata_list,
seq_lens, seq_lens,
@ -227,12 +260,49 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
self.pin_memory, self.pin_memory,
generators=self.get_generators(finished_requests_ids)) generators=self.get_generators(finished_requests_ids))
if not self._on_device_sampling_disabled:
# Once the request IDs are changed in current iteration, we will
# update the on-device sampling parameters.
current_batch_request_ids = [
seq_group_meta_data.request_id
for seq_group_meta_data in seq_group_metadata_list
]
if current_batch_request_ids != self._previous_batch_request_ids:
self._update_neuron_sampling_params(sampling_metadata)
self._previous_batch_request_ids = current_batch_request_ids
return ModelInputForNeuron(input_tokens=input_tokens, return ModelInputForNeuron(input_tokens=input_tokens,
input_positions=input_positions, input_positions=input_positions,
input_block_ids=input_block_ids, input_block_ids=input_block_ids,
sampling_metadata=sampling_metadata, sampling_metadata=sampling_metadata,
multi_modal_kwargs=multi_modal_kwargs) multi_modal_kwargs=multi_modal_kwargs)
def _update_neuron_sampling_params(self,
sampling_metadata: SamplingMetadata):
# Update Neuron sampling parameters (GenerationConfig in Neuron)
current_sampling_params = self.model_config.neuron_sampling_params
assert current_sampling_params is not None, (
f"Failed to update sampling_params, "
f"current sampling params is {current_sampling_params}")
top_k = current_sampling_params.top_k
top_p = current_sampling_params.top_p
temperature = current_sampling_params.temperature
for index, sequence_group_to_sample in enumerate(
sampling_metadata.seq_groups):
top_k[index] = self._convert_to_neuron_top_k(
sequence_group_to_sample.sampling_params.top_k)
top_p[index] = sequence_group_to_sample.sampling_params.top_p
temperature[index] = \
sequence_group_to_sample.sampling_params.temperature
self.model.model.update_generation_config(current_sampling_params)
def _convert_to_neuron_top_k(self, top_k: int) -> int:
if top_k < 0 or top_k > self._MAX_NEURON_SAMPLING_TOP_K:
return self._MAX_NEURON_SAMPLING_TOP_K
return top_k
@torch.inference_mode() @torch.inference_mode()
def execute_model( def execute_model(
self, self,
@ -253,9 +323,13 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
device=self.device), device=self.device),
) )
# Compute the logits. # Compute the logits only if the on-device sampling is turned off as
logits = self.model.compute_logits(hidden_states, # on-device sampling outputs the token ids.
model_input.sampling_metadata) if self._on_device_sampling_disabled:
logits = self.model.compute_logits(hidden_states,
model_input.sampling_metadata)
else:
logits = hidden_states
# Sample the next token. # Sample the next token.
output = self.model.sample( output = self.model.sample(