[Neuron] Add multi-LoRA support for Neuron. (#18284)

Signed-off-by: Satyajith Chilappagari <satchill@amazon.com>
This commit is contained in:
Satyajith Chilappagari
2025-05-29 01:41:22 -07:00
committed by GitHub
parent fd7bb88d72
commit 972eddf7c9
6 changed files with 343 additions and 26 deletions

View File

@ -0,0 +1,98 @@
# SPDX-License-Identifier: Apache-2.0
from huggingface_hub import snapshot_download
from vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequest
def test_llama_single_lora():
sql_lora_files = snapshot_download(
repo_id="yard1/llama-2-7b-sql-lora-test")
llm = LLM(model="meta-llama/Llama-2-7b-hf",
tensor_parallel_size=2,
max_num_seqs=4,
max_model_len=512,
use_v2_block_manager=True,
override_neuron_config={
"sequence_parallel_enabled": False,
"skip_warmup": True,
"lora_modules": [{
"name": "lora_id_1",
"path": sql_lora_files
}]
},
enable_lora=True,
max_loras=1,
max_lora_rank=256,
device="neuron")
"""For multi-lora requests using NxDI as the backend, only the lora_name
needs to be specified. The lora_id and lora_path are supplied at the LLM
class/server initialization, after which the paths are handled by NxDI"""
lora_req_1 = LoRARequest("lora_id_1", 0, " ")
prompts = [
"The president of the United States is",
"The capital of France is",
]
outputs = llm.generate(prompts,
SamplingParams(top_k=1),
lora_request=[lora_req_1, lora_req_1])
expected_outputs = [
" the head of state and head of government of the United States. "
"The president direct",
" a city of contrasts. The city is home to the Eiffel Tower"
]
for expected_output, output in zip(expected_outputs, outputs):
generated_text = output.outputs[0].text
assert (expected_output == generated_text)
def test_llama_multiple_lora():
sql_lora_files = snapshot_download(
repo_id="yard1/llama-2-7b-sql-lora-test")
llm = LLM(model="meta-llama/Llama-2-7b-hf",
tensor_parallel_size=2,
max_num_seqs=4,
max_model_len=512,
use_v2_block_manager=True,
override_neuron_config={
"sequence_parallel_enabled":
False,
"skip_warmup":
True,
"lora_modules": [{
"name": "lora_id_1",
"path": sql_lora_files
}, {
"name": "lora_id_2",
"path": sql_lora_files
}]
},
enable_lora=True,
max_loras=2,
max_lora_rank=256,
device="neuron")
"""For multi-lora requests using NxDI as the backend, only the lora_name
needs to be specified. The lora_id and lora_path are supplied at the LLM
class/server initialization, after which the paths are handled by NxDI"""
lora_req_1 = LoRARequest("lora_id_1", 0, " ")
lora_req_2 = LoRARequest("lora_id_2", 1, " ")
prompts = [
"The president of the United States is",
"The capital of France is",
]
outputs = llm.generate(prompts,
SamplingParams(top_k=1),
lora_request=[lora_req_1, lora_req_2])
expected_outputs = [
" the head of state and head of government of the United States. "
"The president direct",
" a city of contrasts. The city is home to the Eiffel Tower"
]
for expected_output, output in zip(expected_outputs, outputs):
generated_text = output.outputs[0].text
assert (expected_output == generated_text)

View File

@ -17,6 +17,8 @@ from neuronx_distributed_inference.models.config import (
FusedSpecNeuronConfig, OnDeviceSamplingConfig)
from neuronx_distributed_inference.models.mllama.utils import (
create_vision_mask)
from neuronx_distributed_inference.modules.lora_serving import (
LoraServingConfig)
from neuronx_distributed_inference.utils.hf_adapter import (
load_pretrained_config)
from transformers import AutoModelForCausalLM, AutoTokenizer, PretrainedConfig
@ -80,25 +82,26 @@ class NeuronCausalLM(nn.Module):
# Lazy initialized
self.model: nn.Module
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_block_ids: torch.Tensor,
sampling_params: torch.Tensor,
) -> torch.Tensor:
def forward(self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_block_ids: torch.Tensor,
sampling_params: torch.Tensor,
prev_hidden: Optional[torch.Tensor] = None,
adapter_ids: Optional[torch.Tensor] = None) -> torch.Tensor:
# sort block ids sequentially for perf/neuron support reasons
sorted_input_block_ids, sorted_indices = torch.sort(input_block_ids)
input_ids = torch.index_select(input_ids, 0, sorted_indices)
positions = torch.index_select(positions, 0, sorted_indices)
sampling_params = torch.index_select(sampling_params, 0,
sorted_indices)
output = self.model(input_ids,
attention_mask=None,
position_ids=positions,
seq_ids=sorted_input_block_ids,
sampling_params=sampling_params)
sampling_params=sampling_params,
prev_hidden=prev_hidden,
adapter_ids=adapter_ids)
# on-device sampling
if self.config.neuron_config.on_device_sampling_config:
output = output.hidden_states
@ -522,7 +525,8 @@ def _get_model_architecture(config: PretrainedConfig) -> str:
def _get_default_neuron_config(model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig):
scheduler_config: SchedulerConfig,
lora_serving_config: LoraServingConfig):
"""Generate a neuron config based on vllm config args."""
on_device_sampling_config = OnDeviceSamplingConfig(dynamic=True,
deterministic=False)
@ -541,7 +545,7 @@ def _get_default_neuron_config(model_config: ModelConfig,
padding_side="right",
on_device_sampling_config=on_device_sampling_config,
sequence_parallel_enabled=True,
)
lora_serving_config=lora_serving_config)
return neuron_config
@ -581,7 +585,8 @@ def _get_neuron_config_after_override(default_neuron_config,
def get_neuron_model(model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig) -> nn.Module:
scheduler_config: SchedulerConfig,
lora_serving_config: LoraServingConfig) -> nn.Module:
"""Initializes a neuron-optimized model for inference."""
model_arch = _get_model_architecture(model_config.hf_config)
if model_arch == "MllamaForConditionalGeneration":
@ -589,7 +594,7 @@ def get_neuron_model(model_config: ModelConfig,
else:
model = NeuronCausalLM(model_config.hf_config)
default_neuron_config_args = _get_default_neuron_config(
model_config, parallel_config, scheduler_config)
model_config, parallel_config, scheduler_config, lora_serving_config)
neuron_config = _get_neuron_config_after_override(
default_neuron_config_args, model_config.override_neuron_config)

View File

@ -49,9 +49,6 @@ class NeuronPlatform(Platform):
if parallel_config.world_size > 1:
parallel_config.distributed_executor_backend = "uni"
assert (vllm_config.lora_config
is None), "LoRA is not supported for Neuron backend."
if vllm_config.cache_config and vllm_config.model_config:
# neuron needs block_size = max_model_len
vllm_config.cache_config.block_size = \

View File

@ -2,13 +2,15 @@
import os
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union
import torch
from torch import nn
from vllm.config import DeviceConfig, VllmConfig
from vllm.logger import init_logger
from vllm.lora.layers import LoRAMapping
from vllm.lora.request import LoRARequest
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader.neuron import get_neuron_model
@ -36,6 +38,7 @@ class ModelInputForNeuron(ModelRunnerInputBase):
input_block_ids: Optional[torch.Tensor] = None
sampling_metadata: SamplingMetadata = None
multi_modal_kwargs: BatchedTensorInputs = None
adapter_ids: Optional[str] = None
def as_broadcastable_tensor_dict(
self) -> Dict[str, Union[int, torch.Tensor]]:
@ -80,6 +83,7 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
"The model will run without sliding window.")
self.device_config = (self.device_config if self.device_config
is not None else DeviceConfig())
self.lora_config = vllm_config.lora_config
self.device = self.device_config.device
self.pin_memory = is_pin_memory_available()
@ -378,6 +382,7 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
positions=model_input.input_positions,
input_block_ids=model_input.input_block_ids,
sampling_params=sampling_params,
adapter_ids=model_input.adapter_ids,
**MultiModalKwargs.as_kwargs(
model_input.multi_modal_kwargs or {},
dtype=self.model_config.dtype,
@ -416,3 +421,28 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
@property
def vocab_size(self) -> int:
return self.model_config.get_vocab_size()
def remove_all_loras(self):
raise NotImplementedError(
"LoRAs are not supported for Transformers NeuronX framework")
def set_active_loras(self, lora_requests: Set[LoRARequest],
lora_mapping: LoRAMapping) -> None:
raise NotImplementedError(
"LoRAs are not supported for Transformers NeuronX framework")
def add_lora(self, lora_request: LoRARequest):
raise NotImplementedError(
"LoRAs are not supported for Transformers NeuronX framework")
def remove_lora(self, lora_id: int) -> bool:
raise NotImplementedError(
"LoRAs are not supported for Transformers NeuronX framework")
def pin_lora(self, lora_id: int) -> bool:
raise NotImplementedError(
"LoRAs are not supported for Transformers NeuronX framework")
def list_loras(self) -> Set[int]:
raise NotImplementedError(
"LoRAs are not supported for Transformers NeuronX framework")

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
"""A Neuron worker class."""
import os
from typing import List, Optional, Tuple
from typing import List, Optional, Set, Tuple
import torch.distributed
@ -9,19 +9,19 @@ from vllm.config import VllmConfig
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed
from vllm.platforms import current_platform
from vllm.platforms.neuron import NeuronFramework
from vllm.sequence import ExecuteModelRequest
from vllm.worker.neuron_model_runner import NeuronModelRunner
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase,
LoRANotSupportedWorkerBase, WorkerBase,
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase,
WorkerInput)
logger = init_logger(__name__)
class NeuronWorker(LoRANotSupportedWorkerBase, LocalOrDistributedWorkerBase):
class NeuronWorker(LocalOrDistributedWorkerBase):
"""A worker class that executes the model on a group of neuron cores.
"""
@ -38,6 +38,7 @@ class NeuronWorker(LoRANotSupportedWorkerBase, LocalOrDistributedWorkerBase):
self.rank = rank
self.distributed_init_method = distributed_init_method
self.is_driver_worker = is_driver_worker
self.lora_config = vllm_config.lora_config
if self.model_config.trust_remote_code:
# note: lazy import to avoid importing torch before initializing
@ -59,6 +60,9 @@ class NeuronWorker(LoRANotSupportedWorkerBase, LocalOrDistributedWorkerBase):
"[transformers-neuronx, neuronx-distributed-inference]")
def get_tnx_model_runner(self, vllm_config):
assert (self.lora_config
is None), ("LoRA is not supported for TransformersNeuronX "
"framework.")
from vllm.worker.multi_step_neuron_model_runner import (
MultiStepNeuronModelRunner)
if self.speculative_config is not None:
@ -72,6 +76,8 @@ class NeuronWorker(LoRANotSupportedWorkerBase, LocalOrDistributedWorkerBase):
from vllm.worker.neuronx_distributed_model_runner import (
NeuronxDistributedModelRunner)
if self.speculative_config is not None:
assert (self.lora_config
is None), "LoRA is not supported for Speculative Decoding"
return MultiStepNeuronxDistributedModelRunner(
vllm_config=vllm_config)
else:
@ -156,3 +162,31 @@ class NeuronWorker(LoRANotSupportedWorkerBase, LocalOrDistributedWorkerBase):
1,
1,
)
def add_lora(self, lora_request: LoRARequest) -> bool:
if current_platform.use_transformers_neuronx():
raise NotImplementedError(
f"{type(self)} does not support LoRA with Neuron Framework "
f"Transformers NeuronX")
return self.model_runner.add_lora(lora_request)
def remove_lora(self, lora_id: int) -> bool:
if current_platform.use_transformers_neuronx():
raise NotImplementedError(
f"{type(self)} does not support LoRA with Neuron Framework "
f"Transformers NeuronX")
return self.model_runner.remove_lora(lora_id)
def pin_lora(self, lora_id: int) -> bool:
if current_platform.use_transformers_neuronx():
raise NotImplementedError(
f"{type(self)} does not support LoRA with Neuron Framework "
f"Transformers NeuronX")
return self.model_runner.pin_lora(lora_id)
def list_loras(self) -> Set[int]:
if current_platform.use_transformers_neuronx():
raise NotImplementedError(
f"{type(self)} does not support LoRA with Neuron Framework "
f"Transformers NeuronX")
return self.model_runner.list_loras()

View File

@ -1,17 +1,24 @@
# SPDX-License-Identifier: Apache-2.0
from typing import List, Optional
from typing import List, Optional, Set
import torch
from neuronx_distributed_inference.modules.generation.sampling import (
prepare_sampling_params)
from neuronx_distributed_inference.modules.lora_serving import (
LoraCheckpoint, LoraServingConfig)
from vllm.config import VllmConfig
from vllm.entrypoints.openai.serving_models import LoRAModulePath
from vllm.logger import init_logger
from vllm.lora.layers import LoRAMapping
from vllm.lora.request import LoRARequest
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader.neuronx_distributed import (
_get_model_architecture, get_neuron_model)
from vllm.sequence import IntermediateTensors
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
from vllm.worker.neuron_model_runner import (ModelInputForNeuron,
NeuronModelRunner)
@ -25,11 +32,44 @@ class NeuronxDistributedModelRunner(NeuronModelRunner):
vllm_config: VllmConfig,
):
super().__init__(vllm_config)
self.lora_checkpoint = None
self.model = None
self.lora_serving_config = None
@staticmethod
def _get_lora_paths_strings(lora_modules: List[LoRAModulePath]):
if not lora_modules:
return None
return {_.get("name"): _.get("path") for _ in lora_modules}
def _get_nxdi_lora_config(self):
override_neuron_config = self.model_config.override_neuron_config
lora_modules = override_neuron_config.pop("lora_modules", None)
target_modules = override_neuron_config.pop("target_modules", None)
lora_ckpt_paths = self._get_lora_paths_strings(lora_modules)
if self.lora_config.max_loras < len(lora_ckpt_paths):
raise ValueError(
"Number of LoRAs (%s) exceeds maximum "
"allowed (%s)", len(lora_ckpt_paths),
self.lora_config.max_loras)
return LoraServingConfig(
max_loras=self.lora_config.max_loras,
max_lora_rank=self.lora_config.max_lora_rank,
target_modules=target_modules,
lora_ckpt_paths=lora_ckpt_paths,
)
def load_model(self) -> None:
self.model = get_neuron_model(self.model_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config)
# Update LoRA config
if self.lora_config is not None:
self.lora_serving_config = self._get_nxdi_lora_config()
self.lora_checkpoint = LoraCheckpoint(self.lora_serving_config)
self.model = get_neuron_model(
self.model_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config,
lora_serving_config=self.lora_serving_config)
def get_nxd_sampling_params(self, sampling_metadata):
if self.model.config.neuron_config.on_device_sampling_config:
@ -134,3 +174,116 @@ class NeuronxDistributedModelRunner(NeuronModelRunner):
)
return [output]
def _get_lora_adapter_ids(self, seq_group_metadata_list):
# set LoRA adapter IDs for multi-lora serving
batch_size = len(seq_group_metadata_list)
if self.lora_checkpoint is not None:
# "0" indicates NxDI to use the base model for inference
adapter_ids = ["0"] * batch_size
for idx, seq_group_metadata in enumerate(seq_group_metadata_list):
if seq_group_metadata.lora_request is not None:
adapter_ids[
idx] = seq_group_metadata.lora_request.lora_name
# convert adapter_ids from strings to integers
adapter_ids = self.lora_checkpoint.convert_adapter_ids_to_indices(
adapter_ids, batch_size)
else:
adapter_ids = torch.zeros((batch_size), dtype=torch.int32)
return adapter_ids
def prepare_model_input(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None
) -> ModelInputForNeuron:
multi_modal_kwargs = None
# NOTE: We assume that all sequences in the group are all prompts or
# all decodes.
is_prompt = seq_group_metadata_list[0].is_prompt
# Prepare input tensors.
if is_prompt:
(input_tokens, input_positions, input_block_ids, seq_lens,
multi_modal_kwargs
) = self._prepare_prompt(seq_group_metadata_list)
else:
(input_tokens, input_positions,
input_block_ids) = self._prepare_decode(seq_group_metadata_list)
seq_lens = None
if not self._on_device_sampling_disabled:
for seq_group_metadata in seq_group_metadata_list:
sampling_params = seq_group_metadata.sampling_params
top_k, top_p, temperature = (
self._convert_to_neuron_sampling_params(sampling_params))
sampling_params.top_k = top_k
sampling_params.top_p = top_p
sampling_params.temperature = temperature
lora_adapter_ids = self._get_lora_adapter_ids(seq_group_metadata_list)
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list,
seq_lens,
# query_lens is not needed if chunked prefill is not
# supported. Since neuron worker doesn't support chunked prefill
# just use seq_lens instead.
seq_lens,
self.device,
self.pin_memory,
generators=self.get_generators(finished_requests_ids))
if current_platform.use_transformers_neuronx(
) and not self._on_device_sampling_disabled:
# Once the request IDs are changed in current iteration, we will
# update the on-device sampling parameters.
current_batch_request_ids = [
seq_group_meta_data.request_id
for seq_group_meta_data in seq_group_metadata_list
]
if current_batch_request_ids != self._previous_batch_request_ids:
self._update_neuron_sampling_params(seq_group_metadata_list)
self._previous_batch_request_ids = current_batch_request_ids
return ModelInputForNeuron(input_tokens=input_tokens,
input_positions=input_positions,
input_block_ids=input_block_ids,
sampling_metadata=sampling_metadata,
multi_modal_kwargs=multi_modal_kwargs,
adapter_ids=lora_adapter_ids)
def remove_all_loras(self):
raise NotImplementedError(
"Managing LoRAs is only supported through the "
"lora_modules parameter in override_neuron_config")
def set_active_loras(self, lora_requests: Set[LoRARequest],
lora_mapping: LoRAMapping) -> None:
raise NotImplementedError(
"Managing LoRAs is only supported through the "
"lora_modules parameter in override_neuron_config")
def add_lora(self, lora_request: LoRARequest):
logger.warning(
"Adding LoRAs is only supported through the "
"lora_modules parameter in override_neuron_config. If you supplied "
"the parameter, you can ignore this warning. Ignoring"
"lora request: ", lora_request)
def remove_lora(self, lora_id: int) -> bool:
raise NotImplementedError(
"Managing LoRAs is only supported through the "
"lora_modules parameter in override_neuron_config")
def pin_lora(self, lora_id: int) -> bool:
raise NotImplementedError(
"Managing LoRAs is only supported through the "
"lora_modules parameter in override_neuron_config")
def list_loras(self) -> Set[int]:
raise NotImplementedError(
"Managing LoRAs is only supported through the "
"lora_modules parameter in override_neuron_config")