mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-20 17:13:56 +08:00
Add Gemma2 GGUF support (#34002)
* initial setup for ggml.py * initial setup of GGUFGemma2Converter class * Add gemma2 model to gguf.md doc * Partial work on GGUF_TENSOR_MAPPING * initial setup of GGUF_TENSOR_MAPPING for Gemma2 * refactor: rename GemmaConvert class to GemmaConverter for naming consistency * feat: complete gemma2 tensor mapping implementation * feat: add initial implementation of GGUFGemmaConverter * feat: complete GGUFGemmaConverter implementation * feat: add test code for gemma2 * refactor: minor code cleanup * refactor: minor code cleanup * fix: resolve suggestions * Update tests/quantization/ggml/test_ggml.py Co-authored-by: Isotr0py <2037008807@qq.com> --------- Co-authored-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
@ -88,6 +88,7 @@ For now the supported model architectures are the architectures that have been v
|
||||
- T5
|
||||
- Mamba
|
||||
- Nemotron
|
||||
- Gemma2
|
||||
|
||||
## Example usage
|
||||
|
||||
|
@ -1271,7 +1271,7 @@ class XGLMConverter(SpmConverter):
|
||||
)
|
||||
|
||||
|
||||
class GemmaConvert(SpmConverter):
|
||||
class GemmaConverter(SpmConverter):
|
||||
handle_byte_fallback = True
|
||||
SpmExtractor = GemmaSentencePieceExtractor
|
||||
# start and end of turn tokens must be marked as special
|
||||
@ -1601,7 +1601,7 @@ SLOW_TO_FAST_CONVERTERS = {
|
||||
"XGLMTokenizer": XGLMConverter,
|
||||
"LlamaTokenizer": LlamaConverter,
|
||||
"CodeLlamaTokenizer": LlamaConverter,
|
||||
"GemmaTokenizer": GemmaConvert,
|
||||
"GemmaTokenizer": GemmaConverter,
|
||||
"Phi3Tokenizer": LlamaConverter,
|
||||
}
|
||||
|
||||
|
@ -25,7 +25,7 @@ from tokenizers import Tokenizer, decoders, normalizers, pre_tokenizers, process
|
||||
from tokenizers.models import BPE, Unigram
|
||||
|
||||
from .. import AddedToken
|
||||
from ..convert_slow_tokenizer import GPT2Converter, LlamaConverter, Qwen2Converter, T5Converter
|
||||
from ..convert_slow_tokenizer import GemmaConverter, GPT2Converter, LlamaConverter, Qwen2Converter, T5Converter
|
||||
from ..utils import logging
|
||||
from ..utils.logging import tqdm
|
||||
|
||||
@ -262,6 +262,22 @@ GGUF_TENSOR_MAPPING = {
|
||||
"output.weight": "lm_head.weight",
|
||||
"output_norm": "model.norm",
|
||||
},
|
||||
"gemma2": {
|
||||
"token_embd": "model.embed_tokens",
|
||||
"blk": "model.layers",
|
||||
"ffn_up": "mlp.up_proj",
|
||||
"ffn_down": "mlp.down_proj",
|
||||
"ffn_gate": "mlp.gate_proj",
|
||||
"ffn_norm": "pre_feedforward_layernorm",
|
||||
"post_attention_norm": "post_attention_layernorm",
|
||||
"post_ffw_norm": "post_feedforward_layernorm",
|
||||
"attn_norm": "input_layernorm",
|
||||
"attn_q": "self_attn.q_proj",
|
||||
"attn_v": "self_attn.v_proj",
|
||||
"attn_k": "self_attn.k_proj",
|
||||
"attn_output": "self_attn.o_proj",
|
||||
"output_norm": "model.norm",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@ -423,6 +439,18 @@ GGUF_CONFIG_MAPPING = {
|
||||
"attention.layer_norm_rms_epsilon": "norm_eps",
|
||||
"vocab_size": "vocab_size",
|
||||
},
|
||||
"gemma2": {
|
||||
"context_length": "max_position_embeddings",
|
||||
"block_count": "num_hidden_layers",
|
||||
"feed_forward_length": "intermediate_size",
|
||||
"embedding_length": "hidden_size",
|
||||
"rope.dimension_count": None,
|
||||
"rope.freq_base": "rope_theta",
|
||||
"attention.head_count": "num_attention_heads",
|
||||
"attention.head_count_kv": "num_key_value_heads",
|
||||
"attention.layer_norm_rms_epsilon": "rms_norm_eps",
|
||||
"vocab_size": "vocab_size",
|
||||
},
|
||||
}
|
||||
|
||||
GGUF_TOKENIZER_MAPPING = {
|
||||
@ -807,6 +835,71 @@ class GGUFT5Converter(T5Converter):
|
||||
return tokenizer
|
||||
|
||||
|
||||
class GGUFGemmaConverter(GemmaConverter):
|
||||
def __init__(self, tokenizer_dict):
|
||||
# set dummy data to avoid unnecessary merges calculation
|
||||
tokenizer_dict["merges"] = ["dummy text"]
|
||||
|
||||
self.proto = GGUFTokenizerSkeleton(tokenizer_dict)
|
||||
self.original_tokenizer = self.proto
|
||||
self.additional_kwargs = {}
|
||||
|
||||
def vocab(self, proto):
|
||||
original_vocab = list(zip(proto.tokens, proto.scores))
|
||||
updated_vocab = []
|
||||
|
||||
for token, score in original_vocab:
|
||||
if token == "<0x09>":
|
||||
updated_vocab.append(("\t", score))
|
||||
elif " " in token and len(token.strip()) == 0:
|
||||
underscores = "▁" * len(token)
|
||||
updated_vocab.append((underscores, score))
|
||||
else:
|
||||
updated_vocab.append((token, score))
|
||||
|
||||
return updated_vocab
|
||||
|
||||
def normalizer(self, proto):
|
||||
return normalizers.Replace(" ", "▁")
|
||||
|
||||
def decoder(self, replacement, add_prefix_space):
|
||||
sequence = [
|
||||
decoders.Replace("▁", " "),
|
||||
decoders.ByteFallback(),
|
||||
decoders.Fuse(),
|
||||
]
|
||||
|
||||
if add_prefix_space:
|
||||
sequence += [decoders.Strip(content=" ", left=1)]
|
||||
return decoders.Sequence(sequence)
|
||||
|
||||
def converted(self) -> Tokenizer:
|
||||
vocab_scores = self.vocab(self.proto)
|
||||
tokenizer = Tokenizer(
|
||||
Unigram(
|
||||
vocab_scores,
|
||||
unk_id=self.proto.unk_token_id,
|
||||
byte_fallback=self.handle_byte_fallback,
|
||||
)
|
||||
)
|
||||
|
||||
normalizer = self.normalizer(self.proto)
|
||||
if normalizer is not None:
|
||||
tokenizer.normalizer = normalizer
|
||||
|
||||
replacement = "▁"
|
||||
add_prefix_space = True
|
||||
if hasattr(self.original_tokenizer, "add_prefix_space"):
|
||||
add_prefix_space = self.original_tokenizer.add_prefix_space
|
||||
|
||||
tokenizer.decoder = self.decoder(replacement, add_prefix_space)
|
||||
pre_tokenizer = self.pre_tokenizer(replacement, add_prefix_space)
|
||||
if pre_tokenizer is not None:
|
||||
tokenizer.pre_tokenizer = pre_tokenizer
|
||||
|
||||
return tokenizer
|
||||
|
||||
|
||||
GGUF_TO_FAST_CONVERTERS = {
|
||||
"llama": GGUFLlamaConverter,
|
||||
"qwen2": GGUFQwen2Converter,
|
||||
@ -820,6 +913,7 @@ GGUF_TO_FAST_CONVERTERS = {
|
||||
"t5": GGUFT5Converter,
|
||||
"mamba": GGUFGPTConverter,
|
||||
"nemotron": GGUFGPTConverter,
|
||||
"gemma2": GGUFGemmaConverter,
|
||||
}
|
||||
|
||||
|
||||
|
@ -238,6 +238,18 @@ class MambaTensorProcessor(TensorProcessor):
|
||||
return GGUFTensor(weights, name, {})
|
||||
|
||||
|
||||
class Gemma2TensorProcessor(TensorProcessor):
|
||||
def __init__(self, config=None):
|
||||
super().__init__(config=config)
|
||||
|
||||
# ref: https://github.com/ggerganov/llama.cpp/blob/d79d8f39b4da6deca4aea8bf130c6034c482b320/convert_hf_to_gguf.py#L3191
|
||||
# ref: https://github.com/huggingface/transformers/blob/fc37f38915372c15992b540dfcbbe00a916d4fc6/src/transformers/models/gemma/modeling_gemma.py#L89
|
||||
def process(self, weights, name, **kwargs):
|
||||
if "norm.weight" in name:
|
||||
weights = weights - 1
|
||||
return GGUFTensor(weights, name, {})
|
||||
|
||||
|
||||
TENSOR_PROCESSORS = {
|
||||
"llama": LlamaTensorProcessor,
|
||||
"qwen2moe": Qwen2MoeTensorProcessor,
|
||||
@ -246,6 +258,7 @@ TENSOR_PROCESSORS = {
|
||||
"t5encoder": T5TensorProcessor,
|
||||
"gpt2": GPT2TensorProcessor,
|
||||
"mamba": MambaTensorProcessor,
|
||||
"gemma2": Gemma2TensorProcessor,
|
||||
}
|
||||
|
||||
|
||||
|
@ -64,6 +64,8 @@ class GgufIntegrationTests(unittest.TestCase):
|
||||
mamba_model_id = "jpodivin/mamba-2.8b-hf-GGUF"
|
||||
nemotron_original_model_id = "nvidia/Nemotron-Mini-4B-Instruct"
|
||||
nemotron_model_id = "bartowski/Nemotron-Mini-4B-Instruct-GGUF"
|
||||
original_gemma2_model_id = "google/gemma-2-2b-it"
|
||||
gemma2_model_id = "bartowski/gemma-2-2b-it-GGUF"
|
||||
|
||||
# standard quants
|
||||
q4_0_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q4_0.gguf"
|
||||
@ -111,6 +113,9 @@ class GgufIntegrationTests(unittest.TestCase):
|
||||
fp16_mamba_model_id = "ggml-model-f16.gguf"
|
||||
q6_k_nemotron_model_id = "Nemotron-Mini-4B-Instruct-Q6_K.gguf"
|
||||
fp16_nemotron_model_id = "Nemotron-Mini-4B-Instruct-f16.gguf"
|
||||
q3_k_gemma2_model_id = "gemma-2-2b-it-Q3_K_L.gguf"
|
||||
q8_0_gemma2_model_id = "gemma-2-2b-it-Q8_0.gguf"
|
||||
fp32_gemma2_model_id = "gemma-2-2b-it-f32.gguf"
|
||||
|
||||
example_text = "Hello"
|
||||
|
||||
@ -833,6 +838,70 @@ class GgufIntegrationTests(unittest.TestCase):
|
||||
EXPECTED_TEXT = "'Hello. hotmail.com.'"
|
||||
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
|
||||
|
||||
def test_gemma2_q3_k(self):
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
self.gemma2_model_id,
|
||||
gguf_file=self.q3_k_gemma2_model_id,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.gemma2_model_id, gguf_file=self.q3_k_gemma2_model_id)
|
||||
text = tokenizer(self.example_text, return_tensors="pt")["input_ids"]
|
||||
out = model.generate(text, max_new_tokens=10)
|
||||
|
||||
EXPECTED_TEXT = "Hello! 👋\n\nI'm trying to create a"
|
||||
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
|
||||
|
||||
def test_gemma2_q8_0(self):
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
self.gemma2_model_id,
|
||||
gguf_file=self.q8_0_gemma2_model_id,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.gemma2_model_id, gguf_file=self.q8_0_gemma2_model_id)
|
||||
text = tokenizer(self.example_text, return_tensors="pt")["input_ids"]
|
||||
out = model.generate(text, max_new_tokens=10)
|
||||
|
||||
EXPECTED_TEXT = "Hello! 👋\n\nI'm a large language model"
|
||||
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
|
||||
|
||||
def test_gemma2_fp32(self):
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
self.gemma2_model_id,
|
||||
gguf_file=self.fp32_gemma2_model_id,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.gemma2_model_id, gguf_file=self.fp32_gemma2_model_id)
|
||||
text = tokenizer(self.example_text, return_tensors="pt")["input_ids"]
|
||||
out = model.generate(text, max_new_tokens=10)
|
||||
|
||||
EXPECTED_TEXT = "Hello! 👋\n\nI'm a large language model"
|
||||
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
|
||||
|
||||
def test_gemma2_weights_conversion_fp32(self):
|
||||
original_model = AutoModelForCausalLM.from_pretrained(
|
||||
self.original_gemma2_model_id,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
converted_model = AutoModelForCausalLM.from_pretrained(
|
||||
self.gemma2_model_id,
|
||||
gguf_file=self.fp32_gemma2_model_id,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
converted_state_dict = converted_model.state_dict()
|
||||
original_state_dict = original_model.state_dict()
|
||||
|
||||
for layer_name, original_params in original_state_dict.items():
|
||||
if layer_name in converted_state_dict:
|
||||
self.assertTrue(original_params.shape == converted_state_dict[layer_name].shape)
|
||||
torch.testing.assert_close(original_params, converted_state_dict[layer_name])
|
||||
else:
|
||||
raise ValueError(f"Layer {layer_name} is not presented in GGUF model")
|
||||
|
||||
def test_tokenization_xnli(self):
|
||||
import tqdm
|
||||
from datasets import load_dataset
|
||||
|
Reference in New Issue
Block a user