mirror of
				https://github.com/vllm-project/vllm.git
				synced 2025-10-20 23:03:52 +08:00 
			
		
		
		
	[Neuron] Adding support for adding/ overriding neuron configuration a… (#8062)
Co-authored-by: Harsha Bikki <harbikh@amazon.com>
This commit is contained in:
		
				
					committed by
					
						 GitHub
						GitHub
					
				
			
			
				
	
			
			
			
						parent
						
							77d9e514a2
						
					
				
				
					commit
					008cf886c9
				
			
							
								
								
									
										50
									
								
								examples/offline_inference_neuron_int8_quantization.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										50
									
								
								examples/offline_inference_neuron_int8_quantization.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,50 @@ | ||||
| import os | ||||
|  | ||||
| from vllm import LLM, SamplingParams | ||||
|  | ||||
| # creates XLA hlo graphs for all the context length buckets. | ||||
| os.environ['NEURON_CONTEXT_LENGTH_BUCKETS'] = "128,512,1024,2048" | ||||
| # creates XLA hlo graphs for all the token gen buckets. | ||||
| os.environ['NEURON_TOKEN_GEN_BUCKETS'] = "128,512,1024,2048" | ||||
| # Quantizes neuron model weight to int8 , | ||||
| # The default config for quantization is int8 dtype. | ||||
| os.environ['NEURON_QUANT_DTYPE'] = "s8" | ||||
|  | ||||
| # Sample prompts. | ||||
| prompts = [ | ||||
|     "Hello, my name is", | ||||
|     "The president of the United States is", | ||||
|     "The capital of France is", | ||||
|     "The future of AI is", | ||||
| ] | ||||
| # Create a sampling params object. | ||||
| sampling_params = SamplingParams(temperature=0.8, top_p=0.95) | ||||
|  | ||||
| # Create an LLM. | ||||
| llm = LLM( | ||||
|     model="TinyLlama/TinyLlama-1.1B-Chat-v1.0", | ||||
|     max_num_seqs=8, | ||||
|     # The max_model_len and block_size arguments are required to be same as | ||||
|     # max sequence length when targeting neuron device. | ||||
|     # Currently, this is a known limitation in continuous batching support | ||||
|     # in transformers-neuronx. | ||||
|     # TODO(liangfu): Support paged-attention in transformers-neuronx. | ||||
|     max_model_len=2048, | ||||
|     block_size=2048, | ||||
|     # The device can be automatically detected when AWS Neuron SDK is installed. | ||||
|     # The device argument can be either unspecified for automated detection, | ||||
|     # or explicitly assigned. | ||||
|     device="neuron", | ||||
|     quantization="neuron_quant", | ||||
|     override_neuron_config={ | ||||
|         "cast_logits_dtype": "bfloat16", | ||||
|     }, | ||||
|     tensor_parallel_size=2) | ||||
| # Generate texts from the prompts. The output is a list of RequestOutput objects | ||||
| # that contain the prompt, generated text, and other information. | ||||
| outputs = llm.generate(prompts, sampling_params) | ||||
| # Print the outputs. | ||||
| for output in outputs: | ||||
|     prompt = output.prompt | ||||
|     generated_text = output.outputs[0].text | ||||
|     print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") | ||||
| @ -1,8 +1,8 @@ | ||||
| import enum | ||||
| import json | ||||
| from dataclasses import dataclass, field, fields | ||||
| from typing import (TYPE_CHECKING, ClassVar, List, Mapping, Optional, Tuple, | ||||
|                     Type, Union) | ||||
| from typing import (TYPE_CHECKING, Any, ClassVar, Dict, List, Mapping, | ||||
|                     Optional, Tuple, Type, Union) | ||||
|  | ||||
| import torch | ||||
| from transformers import PretrainedConfig | ||||
| @ -115,6 +115,10 @@ class ModelConfig: | ||||
|             the model name will be the same as `model`. | ||||
|         limit_mm_per_prompt: Maximum number of data instances per modality  | ||||
|             per prompt. Only applicable for multimodal models. | ||||
|         override_neuron_config: Initialize non default neuron config or  | ||||
|             override default neuron config that are specific to Neuron devices,  | ||||
|             this argument will be used to configure the neuron config that  | ||||
|             can not be gathered from the vllm arguments.  | ||||
|     """ | ||||
|  | ||||
|     def __init__( | ||||
| @ -143,7 +147,7 @@ class ModelConfig: | ||||
|             served_model_name: Optional[Union[str, List[str]]] = None, | ||||
|             limit_mm_per_prompt: Optional[Mapping[str, int]] = None, | ||||
|             use_async_output_proc: bool = True, | ||||
|     ) -> None: | ||||
|             override_neuron_config: Optional[Dict[str, Any]] = None) -> None: | ||||
|         self.model = model | ||||
|         self.tokenizer = tokenizer | ||||
|         self.tokenizer_mode = tokenizer_mode | ||||
| @ -227,6 +231,9 @@ class ModelConfig: | ||||
|             limit_mm_per_prompt) | ||||
|         if not self.skip_tokenizer_init: | ||||
|             self._verify_tokenizer_mode() | ||||
|  | ||||
|         self.override_neuron_config = override_neuron_config if is_neuron( | ||||
|         ) else None | ||||
|         self._verify_embedding_mode() | ||||
|         self._verify_quantization() | ||||
|         self._verify_cuda_graph() | ||||
| @ -275,6 +282,7 @@ class ModelConfig: | ||||
|             "experts_int8" | ||||
|         ] | ||||
|         tpu_supported_quantization = ["tpu_int8"] | ||||
|         neuron_supported_quantization = ["neuron_quant"] | ||||
|         if self.quantization is not None: | ||||
|             self.quantization = self.quantization.lower() | ||||
|  | ||||
| @ -329,6 +337,11 @@ class ModelConfig: | ||||
|                     "Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ" | ||||
|                     " is not set, enabling VLLM_USE_TRITON_AWQ.") | ||||
|                 envs.VLLM_USE_TRITON_AWQ = True | ||||
|             if is_neuron( | ||||
|             ) and self.quantization not in neuron_supported_quantization: | ||||
|                 raise ValueError( | ||||
|                     f"{self.quantization} quantization is currently not " | ||||
|                     f"supported in Neuron Backend.") | ||||
|  | ||||
|     def _verify_cuda_graph(self) -> None: | ||||
|         if self.max_seq_len_to_capture is None: | ||||
|  | ||||
| @ -2,8 +2,8 @@ import argparse | ||||
| import dataclasses | ||||
| import json | ||||
| from dataclasses import dataclass | ||||
| from typing import (TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Type, | ||||
|                     Union) | ||||
| from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, | ||||
|                     Type, Union) | ||||
|  | ||||
| import torch | ||||
|  | ||||
| @ -149,6 +149,7 @@ class EngineArgs: | ||||
|     otlp_traces_endpoint: Optional[str] = None | ||||
|     collect_detailed_traces: Optional[str] = None | ||||
|     disable_async_output_proc: bool = False | ||||
|     override_neuron_config: Optional[Dict[str, Any]] = None | ||||
|  | ||||
|     def __post_init__(self): | ||||
|         if self.tokenizer is None: | ||||
| @ -742,6 +743,16 @@ class EngineArgs: | ||||
|             default=EngineArgs.disable_async_output_proc, | ||||
|             help="Disable async output processing. This may result in " | ||||
|             "lower performance.") | ||||
|         parser.add_argument( | ||||
|             '--override-neuron-config', | ||||
|             type=lambda configs: { | ||||
|                 str(key): value | ||||
|                 for key, value in | ||||
|                 (config.split(':') for config in configs.split(',')) | ||||
|             }, | ||||
|             default=None, | ||||
|             help="override or set neuron device configuration.") | ||||
|  | ||||
|         return parser | ||||
|  | ||||
|     @classmethod | ||||
| @ -802,7 +813,7 @@ class EngineArgs: | ||||
|             served_model_name=self.served_model_name, | ||||
|             limit_mm_per_prompt=self.limit_mm_per_prompt, | ||||
|             use_async_output_proc=not self.disable_async_output_proc, | ||||
|         ) | ||||
|             override_neuron_config=self.override_neuron_config) | ||||
|         cache_config = CacheConfig( | ||||
|             block_size=self.block_size if self.device != "neuron" else | ||||
|             self.max_model_len,  # neuron needs block_size = max_model_len | ||||
|  | ||||
| @ -214,6 +214,7 @@ class LLMEngine: | ||||
|             "Initializing an LLM engine (v%s) with config: " | ||||
|             "model=%r, speculative_config=%r, tokenizer=%r, " | ||||
|             "skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, " | ||||
|             "override_neuron_config=%s, " | ||||
|             "rope_scaling=%r, rope_theta=%r, tokenizer_revision=%s, " | ||||
|             "trust_remote_code=%s, dtype=%s, max_seq_len=%d, " | ||||
|             "download_dir=%r, load_format=%s, tensor_parallel_size=%d, " | ||||
| @ -232,6 +233,7 @@ class LLMEngine: | ||||
|             model_config.skip_tokenizer_init, | ||||
|             model_config.tokenizer_mode, | ||||
|             model_config.revision, | ||||
|             model_config.override_neuron_config, | ||||
|             model_config.rope_scaling, | ||||
|             model_config.rope_theta, | ||||
|             model_config.tokenizer_revision, | ||||
|  | ||||
| @ -22,6 +22,8 @@ from vllm.model_executor.layers.quantization.gptq_marlin import ( | ||||
| from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( | ||||
|     GPTQMarlin24Config) | ||||
| from vllm.model_executor.layers.quantization.marlin import MarlinConfig | ||||
| from vllm.model_executor.layers.quantization.neuron_quant import ( | ||||
|     NeuronQuantConfig) | ||||
| from vllm.model_executor.layers.quantization.qqq import QQQConfig | ||||
| from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig | ||||
| from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig | ||||
| @ -46,6 +48,7 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = { | ||||
|     "bitsandbytes": BitsAndBytesConfig, | ||||
|     "qqq": QQQConfig, | ||||
|     "experts_int8": ExpertsInt8Config, | ||||
|     "neuron_quant": NeuronQuantConfig, | ||||
| } | ||||
|  | ||||
|  | ||||
|  | ||||
							
								
								
									
										67
									
								
								vllm/model_executor/layers/quantization/neuron_quant.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										67
									
								
								vllm/model_executor/layers/quantization/neuron_quant.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,67 @@ | ||||
| import os | ||||
| from importlib.util import find_spec | ||||
| from typing import Any, Dict, List, Optional | ||||
|  | ||||
| from torch.nn import Module | ||||
|  | ||||
| from vllm.model_executor.layers.quantization.base_config import ( | ||||
|     QuantizationConfig) | ||||
|  | ||||
| SUPPORTED_QUANT_DTYPE_LIST = ['s8', 'f8e4m3fn'] | ||||
|  | ||||
|  | ||||
| class NeuronQuantConfig(QuantizationConfig): | ||||
|     """Int8 Quantization Config class for Neuron Backend.""" | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         dequant_dtype: str = "f16", | ||||
|         quantize_method: str = "vector_dynamic", | ||||
|     ) -> None: | ||||
|         self.quant_dtype = os.getenv("NEURON_QUANT_DTYPE", "s8") | ||||
|         if self.quant_dtype not in SUPPORTED_QUANT_DTYPE_LIST: | ||||
|             raise ValueError( | ||||
|                 f"Neuron quantization datatype {self.quant_dtype} is not valid," | ||||
|                 f"the quantization datatype should match one of the below types" | ||||
|                 f"{SUPPORTED_QUANT_DTYPE_LIST}") | ||||
|         self.dequant_dtype = dequant_dtype | ||||
|         self.quantize_method = quantize_method | ||||
|  | ||||
|     def get_name(self) -> str: | ||||
|         return "neuron_quant" | ||||
|  | ||||
|     def get_supported_act_dtypes(self) -> List[str]: | ||||
|         return SUPPORTED_QUANT_DTYPE_LIST | ||||
|  | ||||
|     @classmethod | ||||
|     def get_min_capability(cls) -> int: | ||||
|         raise NotImplementedError( | ||||
|             "This function should not be called with Neuron Backend") | ||||
|  | ||||
|     @staticmethod | ||||
|     def get_config_filenames() -> List[str]: | ||||
|         return [] | ||||
|  | ||||
|     @classmethod | ||||
|     def from_config(cls, config: Dict[str, Any]) -> "NeuronQuantConfig": | ||||
|         quantize_method = cls.get_from_keys(config, ["quantize_method"]) | ||||
|         dequant_dtype = cls.get_from_keys(config, ["dequant_dtype"]) | ||||
|         return cls(dequant_dtype=dequant_dtype, | ||||
|                    quantize_method=quantize_method) | ||||
|  | ||||
|     def get_quant_method(self, layer: Module, prefix: str) -> Optional[Any]: | ||||
|         if find_spec("transformers_neuronx") is not None: | ||||
|             return self.get_quantization_config() | ||||
|         else: | ||||
|             raise NotImplementedError( | ||||
|                 "Neuron Quantization is only supported through" | ||||
|                 " transformers_neuronx.") | ||||
|  | ||||
|     def get_scaled_act_names(self) -> List[str]: | ||||
|         return [] | ||||
|  | ||||
|     def get_quantization_config(self): | ||||
|         from transformers_neuronx.config import QuantizationConfig | ||||
|         return QuantizationConfig(quant_dtype=self.quant_dtype, | ||||
|                                   dequant_dtype=self.dequant_dtype, | ||||
|                                   quantize_method=self.quantize_method) | ||||
| @ -10,6 +10,7 @@ from transformers import PretrainedConfig | ||||
|  | ||||
| from vllm.config import ModelConfig, ParallelConfig, SchedulerConfig | ||||
| from vllm.model_executor.layers.logits_processor import LogitsProcessor | ||||
| from vllm.model_executor.layers.quantization import get_quantization_config | ||||
| from vllm.model_executor.layers.sampler import Sampler, SamplerOutput | ||||
| from vllm.model_executor.sampling_metadata import SamplingMetadata | ||||
|  | ||||
| @ -81,8 +82,7 @@ class NeuronCasualLM(nn.Module): | ||||
|         neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls_name) | ||||
|  | ||||
|         split_model_dir = f"{model_name_or_path}-split" | ||||
|         if os.path.isdir(os.path.join(model_name_or_path, | ||||
|                                       "pytorch_model.bin")): | ||||
|         if _is_pretrained_neuron_checkpoint(model_name_or_path): | ||||
|             split_model_dir = model_name_or_path | ||||
|         elif not os.path.exists(f"{model_name_or_path}-split"): | ||||
|             hf_model_cls = getattr(transformers, hf_model_cls_name) | ||||
| @ -97,6 +97,23 @@ class NeuronCasualLM(nn.Module): | ||||
|         self.model.to_neuron() | ||||
|  | ||||
|  | ||||
| def _is_pretrained_neuron_checkpoint(model_name_or_path: str) -> bool: | ||||
|     # Checking if the neuron checkpoint is saved in the old format. | ||||
|     if os.path.isdir(os.path.join(model_name_or_path, "pytorch_model.bin")): | ||||
|         return True | ||||
|     # Checking if the neuron checkpoint is saved in the new format. | ||||
|     pretrained_split_files = ["config.json", "generation_config.json"] | ||||
|     pretrained_split_format = ".safetensors" | ||||
|     for file in pretrained_split_files: | ||||
|         file_path = os.path.join(model_name_or_path, file) | ||||
|         if not os.path.isfile(file_path): | ||||
|             return False | ||||
|     for file in os.listdir(model_name_or_path): | ||||
|         if file.endswith(pretrained_split_format): | ||||
|             return True | ||||
|     return False | ||||
|  | ||||
|  | ||||
| def _get_model_architecture(config: PretrainedConfig) -> str: | ||||
|     architectures = getattr(config, "architectures", []) | ||||
|     for arch in architectures: | ||||
| @ -119,19 +136,51 @@ def _get_buckets(env: str, default_value: List[int]) -> List[int]: | ||||
|     return buckets_list | ||||
|  | ||||
|  | ||||
| def _get_default_neuron_config(model_config: ModelConfig, | ||||
|                                parallel_config: ParallelConfig, | ||||
|                                scheduler_config: SchedulerConfig): | ||||
|     from transformers_neuronx.config import ContinuousBatchingConfig | ||||
|     from transformers_neuronx.constants import LAYOUT_BSH | ||||
|  | ||||
|     continuous_batching_config = ContinuousBatchingConfig( | ||||
|         batch_size_for_shared_caches=scheduler_config.max_num_seqs) | ||||
|     quant_config = dict( | ||||
|         dequant_dtype=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype], | ||||
|         quantize_method="vector_dynamic") | ||||
|     neuron_quantization_config_builder = lambda quant: get_quantization_config( | ||||
|         quant).from_config(quant_config).get_quant_method(None, "") | ||||
|     # TODO: Add Paged attention config to the default neuron arguments. | ||||
|     default_neuron_args = dict( | ||||
|         collectives_layout=LAYOUT_BSH, | ||||
|         attention_layout=LAYOUT_BSH, | ||||
|         fuse_qkv=True, | ||||
|         quant=neuron_quantization_config_builder(model_config.quantization) | ||||
|         if model_config.quantization else None, | ||||
|         continuous_batching=continuous_batching_config, | ||||
|         weight_tiling=bool(model_config.quantization)) | ||||
|     return default_neuron_args | ||||
|  | ||||
|  | ||||
| def _get_neuron_config_after_override(default_neuron_config, | ||||
|                                       overridden_neuron_config): | ||||
|     from transformers_neuronx.config import NeuronConfig | ||||
|     overridden_neuron_config = overridden_neuron_config or {} | ||||
|     default_neuron_config.update(overridden_neuron_config) | ||||
|     return NeuronConfig(**default_neuron_config) | ||||
|  | ||||
|  | ||||
| def get_neuron_model(model_config: ModelConfig, | ||||
|                      parallel_config: ParallelConfig, | ||||
|                      scheduler_config: SchedulerConfig) -> nn.Module: | ||||
|     from transformers_neuronx.config import (ContinuousBatchingConfig, | ||||
|                                              NeuronConfig) | ||||
|  | ||||
|     # Create a model instance. | ||||
|     model = NeuronCasualLM(model_config.hf_config) | ||||
|  | ||||
|     continuous_batching_config = ContinuousBatchingConfig( | ||||
|         batch_size_for_shared_caches=scheduler_config.max_num_seqs) | ||||
|     neuron_config = NeuronConfig( | ||||
|         continuous_batching=continuous_batching_config) | ||||
|     default_neuron_config_args = _get_default_neuron_config( | ||||
|         model_config, parallel_config, scheduler_config) | ||||
|  | ||||
|     neuron_config = _get_neuron_config_after_override( | ||||
|         default_neuron_config_args, model_config.override_neuron_config) | ||||
|  | ||||
|     context_length_estimates = _get_buckets("NEURON_CONTEXT_LENGTH_BUCKETS", | ||||
|                                             [scheduler_config.max_model_len]) | ||||
|  | ||||
| @ -1,4 +1,5 @@ | ||||
| from dataclasses import dataclass | ||||
| from importlib.util import find_spec | ||||
| from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union | ||||
|  | ||||
| import torch | ||||
| @ -76,9 +77,14 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]): | ||||
|         self.model: nn.Module  # initialize after load_model. | ||||
|  | ||||
|     def load_model(self) -> None: | ||||
|         self.model = get_neuron_model(self.model_config, | ||||
|         if find_spec("transformers_neuronx") is not None: | ||||
|             self.model = get_neuron_model( | ||||
|                 self.model_config, | ||||
|                 parallel_config=self.parallel_config, | ||||
|                 scheduler_config=self.scheduler_config) | ||||
|         else: | ||||
|             raise NotImplementedError( | ||||
|                 "Supports only Transformer-NeuronX based models.") | ||||
|  | ||||
|     def _prepare_prompt( | ||||
|         self, | ||||
|  | ||||
		Reference in New Issue
	
	Block a user