[RFC] [Mistral] FP8 format (#10130)

Signed-off-by: mgoin <mgoin64@gmail.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Patrick von Platen
2025-02-08 22:12:53 +01:00
committed by GitHub
parent 870c37481e
commit d366ccc4e3
4 changed files with 55 additions and 12 deletions

View File

@ -467,6 +467,9 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
mistral_mapping = {
"layers": "model.layers",
"attention": "self_attn",
"qscale_act": "input_scale",
"qscale_weight": "weight_scale",
"kv_fake_quantizer.qscale_act": "kv_scale",
"wq": "q_proj",
"wk": "k_proj",
"wv": "v_proj",
@ -590,15 +593,24 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
modules = name.split(".")
# rotary embeds should be sliced
if "wk" in modules:
if "wk" in modules and modules[-1] == "weight":
loaded_weight = permute(loaded_weight,
self.config.num_key_value_heads)
elif "wq" in modules:
elif "wq" in modules and modules[-1] == "weight":
loaded_weight = permute(loaded_weight,
self.config.num_attention_heads)
for item in modules:
if item in mapping and mapping[item] not in name:
num_modules = len(modules)
for i in range(num_modules):
item = modules[i]
next_item = modules[i + 1] if i < num_modules - 1 else None
combined_item = (f"{item}.{next_item}"
if next_item is not None else None)
if combined_item in mapping:
name = name.replace(combined_item, mapping[combined_item])
elif item in mapping and mapping[item] not in name:
name = name.replace(item, mapping[item])
return name, loaded_weight

View File

@ -54,8 +54,11 @@ def get_max_pixtral_image_tokens(ctx: InputContext):
tokenizer_mode=ctx.model_config.tokenizer_mode)
mm_encoder = tokenizer.instruct.mm_encoder
max_image_size = mm_encoder.mm_config.max_image_size
image_patch_size = mm_encoder.mm_config.image_patch_size
image_config = mm_encoder.mm_config if hasattr(
mm_encoder, "mm_config") else mm_encoder.image_config
max_image_size = image_config.max_image_size
image_patch_size = image_config.image_patch_size
return ((max_image_size // image_patch_size)**2)

View File

@ -4,7 +4,7 @@ import enum
import json
import os
from pathlib import Path
from typing import Any, Dict, Optional, Type, Union
from typing import Any, Dict, Literal, Optional, Type, Union
import huggingface_hub
from huggingface_hub import (file_exists, hf_hub_download, list_repo_files,
@ -554,7 +554,8 @@ def load_params_config(model: Union[str, Path], revision: Optional[str],
for key, value in elem.items():
key = config_mapping.get(key, key)
config_dict[key] = recurse_elems(value)
return PretrainedConfig(**config_dict)
return config_dict
else:
return elem
@ -566,12 +567,30 @@ def load_params_config(model: Union[str, Path], revision: Optional[str],
config_dict["max_position_embeddings"] = config_dict.get(
"max_position_embeddings", 128_000)
if config_dict.get("quantization") is not None:
quantization = config_dict.get("quantization", {})
if quantization.get("qformat_weight") == "fp8_e4m3":
# This maps to the FP8 static per-tensor quantization scheme
quantization_config = {
"quant_method": "fp8",
"activation_scheme": "static"
}
else:
raise ValueError(
f"Found unknown quantization='{quantization}' in config")
config_dict["quantization_config"] = quantization_config
config_type: Literal["text",
"multimodal"] = "multimodal" if config_dict.get(
"vision_encoder") is not None else "text"
if config_dict.get("moe") is not None:
config_dict["architectures"] = ["MixtralForCausalLM"]
else:
config_dict["architectures"] = ["MistralForCausalLM"]
if config_dict.get("vision_encoder") is not None:
if config_type == "multimodal":
multimodal_config = config_dict.pop("vision_encoder")
config_dict = {
@ -583,8 +602,16 @@ def load_params_config(model: Union[str, Path], revision: Optional[str],
config_dict.update(kwargs)
config = recurse_elems(config_dict)
return config
config_dict = recurse_elems(config_dict)
# transform to HF config format
if config_type == "multimodal":
config_dict["text_config"] = PretrainedConfig(
**config_dict["text_config"])
config_dict["vision_config"] = PretrainedConfig(
**config_dict["vision_config"])
return PretrainedConfig(**config_dict)
def get_hf_image_processor_config(

View File

@ -88,7 +88,8 @@ def list_local_repo_files(repo_id: str, revision: Optional[str]) -> List[str]:
def find_tokenizer_file(files: List[str]):
file_pattern = re.compile(r"^tokenizer\.model\.v.*$|^tekken\.json$")
file_pattern = re.compile(
r"^tokenizer\.model\.v.*$|^tekken\.json$|^tokenizer\.mm\.model\.v.*$")
matched_files = [file for file in files if file_pattern.match(file)]
if len(matched_files) > 1: