mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[RFC] [Mistral] FP8 format (#10130)
Signed-off-by: mgoin <mgoin64@gmail.com> Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
committed by
GitHub
parent
870c37481e
commit
d366ccc4e3
@ -467,6 +467,9 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
mistral_mapping = {
|
||||
"layers": "model.layers",
|
||||
"attention": "self_attn",
|
||||
"qscale_act": "input_scale",
|
||||
"qscale_weight": "weight_scale",
|
||||
"kv_fake_quantizer.qscale_act": "kv_scale",
|
||||
"wq": "q_proj",
|
||||
"wk": "k_proj",
|
||||
"wv": "v_proj",
|
||||
@ -590,15 +593,24 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
modules = name.split(".")
|
||||
|
||||
# rotary embeds should be sliced
|
||||
if "wk" in modules:
|
||||
if "wk" in modules and modules[-1] == "weight":
|
||||
loaded_weight = permute(loaded_weight,
|
||||
self.config.num_key_value_heads)
|
||||
elif "wq" in modules:
|
||||
elif "wq" in modules and modules[-1] == "weight":
|
||||
loaded_weight = permute(loaded_weight,
|
||||
self.config.num_attention_heads)
|
||||
|
||||
for item in modules:
|
||||
if item in mapping and mapping[item] not in name:
|
||||
num_modules = len(modules)
|
||||
for i in range(num_modules):
|
||||
item = modules[i]
|
||||
next_item = modules[i + 1] if i < num_modules - 1 else None
|
||||
|
||||
combined_item = (f"{item}.{next_item}"
|
||||
if next_item is not None else None)
|
||||
|
||||
if combined_item in mapping:
|
||||
name = name.replace(combined_item, mapping[combined_item])
|
||||
elif item in mapping and mapping[item] not in name:
|
||||
name = name.replace(item, mapping[item])
|
||||
|
||||
return name, loaded_weight
|
||||
|
@ -54,8 +54,11 @@ def get_max_pixtral_image_tokens(ctx: InputContext):
|
||||
tokenizer_mode=ctx.model_config.tokenizer_mode)
|
||||
mm_encoder = tokenizer.instruct.mm_encoder
|
||||
|
||||
max_image_size = mm_encoder.mm_config.max_image_size
|
||||
image_patch_size = mm_encoder.mm_config.image_patch_size
|
||||
image_config = mm_encoder.mm_config if hasattr(
|
||||
mm_encoder, "mm_config") else mm_encoder.image_config
|
||||
|
||||
max_image_size = image_config.max_image_size
|
||||
image_patch_size = image_config.image_patch_size
|
||||
|
||||
return ((max_image_size // image_patch_size)**2)
|
||||
|
||||
|
@ -4,7 +4,7 @@ import enum
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional, Type, Union
|
||||
from typing import Any, Dict, Literal, Optional, Type, Union
|
||||
|
||||
import huggingface_hub
|
||||
from huggingface_hub import (file_exists, hf_hub_download, list_repo_files,
|
||||
@ -554,7 +554,8 @@ def load_params_config(model: Union[str, Path], revision: Optional[str],
|
||||
for key, value in elem.items():
|
||||
key = config_mapping.get(key, key)
|
||||
config_dict[key] = recurse_elems(value)
|
||||
return PretrainedConfig(**config_dict)
|
||||
|
||||
return config_dict
|
||||
else:
|
||||
return elem
|
||||
|
||||
@ -566,12 +567,30 @@ def load_params_config(model: Union[str, Path], revision: Optional[str],
|
||||
config_dict["max_position_embeddings"] = config_dict.get(
|
||||
"max_position_embeddings", 128_000)
|
||||
|
||||
if config_dict.get("quantization") is not None:
|
||||
quantization = config_dict.get("quantization", {})
|
||||
if quantization.get("qformat_weight") == "fp8_e4m3":
|
||||
# This maps to the FP8 static per-tensor quantization scheme
|
||||
quantization_config = {
|
||||
"quant_method": "fp8",
|
||||
"activation_scheme": "static"
|
||||
}
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Found unknown quantization='{quantization}' in config")
|
||||
|
||||
config_dict["quantization_config"] = quantization_config
|
||||
|
||||
config_type: Literal["text",
|
||||
"multimodal"] = "multimodal" if config_dict.get(
|
||||
"vision_encoder") is not None else "text"
|
||||
|
||||
if config_dict.get("moe") is not None:
|
||||
config_dict["architectures"] = ["MixtralForCausalLM"]
|
||||
else:
|
||||
config_dict["architectures"] = ["MistralForCausalLM"]
|
||||
|
||||
if config_dict.get("vision_encoder") is not None:
|
||||
if config_type == "multimodal":
|
||||
multimodal_config = config_dict.pop("vision_encoder")
|
||||
|
||||
config_dict = {
|
||||
@ -583,8 +602,16 @@ def load_params_config(model: Union[str, Path], revision: Optional[str],
|
||||
|
||||
config_dict.update(kwargs)
|
||||
|
||||
config = recurse_elems(config_dict)
|
||||
return config
|
||||
config_dict = recurse_elems(config_dict)
|
||||
|
||||
# transform to HF config format
|
||||
if config_type == "multimodal":
|
||||
config_dict["text_config"] = PretrainedConfig(
|
||||
**config_dict["text_config"])
|
||||
config_dict["vision_config"] = PretrainedConfig(
|
||||
**config_dict["vision_config"])
|
||||
|
||||
return PretrainedConfig(**config_dict)
|
||||
|
||||
|
||||
def get_hf_image_processor_config(
|
||||
|
@ -88,7 +88,8 @@ def list_local_repo_files(repo_id: str, revision: Optional[str]) -> List[str]:
|
||||
|
||||
|
||||
def find_tokenizer_file(files: List[str]):
|
||||
file_pattern = re.compile(r"^tokenizer\.model\.v.*$|^tekken\.json$")
|
||||
file_pattern = re.compile(
|
||||
r"^tokenizer\.model\.v.*$|^tekken\.json$|^tokenizer\.mm\.model\.v.*$")
|
||||
|
||||
matched_files = [file for file in files if file_pattern.match(file)]
|
||||
if len(matched_files) > 1:
|
||||
|
Reference in New Issue
Block a user