mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Model] Support GGUF models newly added in transformers
4.46.0 (#9685)
Signed-off-by: Isotr0py <2037008807@qq.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
@ -3,27 +3,20 @@ from huggingface_hub import hf_hub_download
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
|
||||
def run_gguf_inference(model_path):
|
||||
PROMPT_TEMPLATE = "<|system|>\n{system_message}</s>\n<|user|>\n{prompt}</s>\n<|assistant|>\n" # noqa: E501
|
||||
system_message = "You are a friendly chatbot who always responds in the style of a pirate." # noqa: E501
|
||||
def run_gguf_inference(model_path, tokenizer):
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
"How many helicopters can a human eat in one sitting?",
|
||||
"What's the future of AI?",
|
||||
]
|
||||
prompts = [
|
||||
PROMPT_TEMPLATE.format(system_message=system_message, prompt=prompt)
|
||||
for prompt in prompts
|
||||
]
|
||||
prompts = [[{"role": "user", "content": prompt}] for prompt in prompts]
|
||||
# Create a sampling params object.
|
||||
sampling_params = SamplingParams(temperature=0, max_tokens=128)
|
||||
|
||||
# Create an LLM.
|
||||
llm = LLM(model=model_path,
|
||||
tokenizer="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
gpu_memory_utilization=0.95)
|
||||
llm = LLM(model=model_path, tokenizer=tokenizer)
|
||||
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
outputs = llm.chat(prompts, sampling_params)
|
||||
# Print the outputs.
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
@ -32,7 +25,8 @@ def run_gguf_inference(model_path):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
repo_id = "TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF"
|
||||
filename = "tinyllama-1.1b-chat-v1.0.Q4_0.gguf"
|
||||
repo_id = "bartowski/Phi-3-medium-4k-instruct-GGUF"
|
||||
filename = "Phi-3-medium-4k-instruct-IQ2_M.gguf"
|
||||
tokenizer = "microsoft/Phi-3-medium-4k-instruct"
|
||||
model = hf_hub_download(repo_id, filename=filename)
|
||||
run_gguf_inference(model)
|
||||
run_gguf_inference(model, tokenizer)
|
||||
|
@ -4,6 +4,7 @@ Note: To pass the test, quantization higher than Q4 should be used
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import List, NamedTuple, Type
|
||||
|
||||
import pytest
|
||||
from huggingface_hub import hf_hub_download
|
||||
@ -11,6 +12,7 @@ from transformers import AutoTokenizer
|
||||
|
||||
from tests.quantization.utils import is_quant_method_supported
|
||||
|
||||
from ....conftest import VllmRunner
|
||||
from ...utils import check_logprobs_close
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
||||
@ -18,31 +20,74 @@ os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
||||
MAX_MODEL_LEN = 1024
|
||||
|
||||
|
||||
class GGUFTestConfig(NamedTuple):
|
||||
original_model: str
|
||||
gguf_repo: str
|
||||
gguf_filename: str
|
||||
|
||||
@property
|
||||
def gguf_model(self):
|
||||
return hf_hub_download(self.gguf_repo, filename=self.gguf_filename)
|
||||
|
||||
|
||||
LLAMA_CONFIG = GGUFTestConfig(
|
||||
original_model="meta-llama/Llama-3.2-1B-Instruct",
|
||||
gguf_repo="bartowski/Llama-3.2-1B-Instruct-GGUF",
|
||||
gguf_filename="Llama-3.2-1B-Instruct-IQ4_XS.gguf",
|
||||
)
|
||||
|
||||
QWEN2_CONFIG = GGUFTestConfig(
|
||||
original_model="Qwen/Qwen2.5-1.5B-Instruct",
|
||||
gguf_repo="Qwen/Qwen2.5-1.5B-Instruct-GGUF",
|
||||
gguf_filename="qwen2.5-1.5b-instruct-q6_k.gguf",
|
||||
)
|
||||
|
||||
PHI3_CONFIG = GGUFTestConfig(
|
||||
original_model="microsoft/Phi-3.5-mini-instruct",
|
||||
gguf_repo="bartowski/Phi-3.5-mini-instruct-GGUF",
|
||||
gguf_filename="Phi-3.5-mini-instruct-IQ4_XS.gguf",
|
||||
)
|
||||
|
||||
GPT2_CONFIG = GGUFTestConfig(
|
||||
original_model="openai-community/gpt2-large",
|
||||
gguf_repo="QuantFactory/gpt2-large-GGUF",
|
||||
gguf_filename="gpt2-large.Q4_K_M.gguf",
|
||||
)
|
||||
|
||||
STABLELM_CONFIG = GGUFTestConfig(
|
||||
original_model="stabilityai/stablelm-3b-4e1t",
|
||||
gguf_repo="afrideva/stablelm-3b-4e1t-GGUF",
|
||||
gguf_filename="stablelm-3b-4e1t.q4_k_m.gguf",
|
||||
)
|
||||
|
||||
STARCODER_CONFIG = GGUFTestConfig(
|
||||
original_model="bigcode/starcoder2-3b",
|
||||
gguf_repo="QuantFactory/starcoder2-3b-GGUF",
|
||||
gguf_filename="starcoder2-3b.Q6_K.gguf",
|
||||
)
|
||||
|
||||
MODELS = [
|
||||
LLAMA_CONFIG,
|
||||
QWEN2_CONFIG,
|
||||
PHI3_CONFIG,
|
||||
GPT2_CONFIG,
|
||||
STABLELM_CONFIG,
|
||||
# STARCODER_CONFIG, # broken
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_quant_method_supported("gguf"),
|
||||
reason="gguf is not supported on this GPU type.")
|
||||
@pytest.mark.parametrize(("original_model", "gguf_id", "gguf_path"), [
|
||||
("meta-llama/Llama-3.2-1B-Instruct",
|
||||
"bartowski/Llama-3.2-1B-Instruct-GGUF",
|
||||
"Llama-3.2-1B-Instruct-Q4_K_M.gguf"),
|
||||
("meta-llama/Llama-3.2-1B-Instruct",
|
||||
"bartowski/Llama-3.2-1B-Instruct-GGUF",
|
||||
"Llama-3.2-1B-Instruct-IQ4_XS.gguf"),
|
||||
("Qwen/Qwen2-1.5B-Instruct", "Qwen/Qwen2-1.5B-Instruct-GGUF",
|
||||
"qwen2-1_5b-instruct-q4_k_m.gguf"),
|
||||
("Qwen/Qwen2-1.5B-Instruct", "legraphista/Qwen2-1.5B-Instruct-IMat-GGUF",
|
||||
"Qwen2-1.5B-Instruct.IQ4_XS.gguf"),
|
||||
])
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("max_tokens", [32])
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
@pytest.mark.parametrize("tp_size", [1, 2])
|
||||
def test_models(
|
||||
num_gpus_available,
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
original_model,
|
||||
gguf_id,
|
||||
gguf_path,
|
||||
num_gpus_available: int,
|
||||
vllm_runner: Type[VllmRunner],
|
||||
example_prompts: List[str],
|
||||
model: GGUFTestConfig,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
@ -51,28 +96,26 @@ def test_models(
|
||||
if num_gpus_available < tp_size:
|
||||
pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}")
|
||||
|
||||
gguf_model = hf_hub_download(gguf_id, filename=gguf_path)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(original_model)
|
||||
messages = [[{
|
||||
'role': 'user',
|
||||
'content': prompt
|
||||
}] for prompt in example_prompts]
|
||||
example_prompts = tokenizer.apply_chat_template(messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model.original_model)
|
||||
if tokenizer.chat_template is not None:
|
||||
messages = [[{
|
||||
'role': 'user',
|
||||
'content': prompt
|
||||
}] for prompt in example_prompts]
|
||||
example_prompts = tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True)
|
||||
|
||||
# Run unquantized model.
|
||||
with vllm_runner(model_name=original_model,
|
||||
with vllm_runner(model_name=model.original_model,
|
||||
dtype=dtype,
|
||||
max_model_len=MAX_MODEL_LEN,
|
||||
tensor_parallel_size=tp_size) as original_model:
|
||||
|
||||
original_outputs = original_model.generate_greedy_logprobs(
|
||||
example_prompts[:-1], max_tokens, num_logprobs)
|
||||
|
||||
# Run gguf model.
|
||||
with vllm_runner(model_name=gguf_model,
|
||||
with vllm_runner(model_name=model.gguf_model,
|
||||
tokenizer_name=model.original_model,
|
||||
dtype=dtype,
|
||||
max_model_len=MAX_MODEL_LEN,
|
||||
tensor_parallel_size=tp_size) as gguf_model:
|
||||
|
@ -447,8 +447,14 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
is_gguf_weight = getattr(param, "is_gguf_weight", False)
|
||||
is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
|
||||
if is_gguf_weight_type:
|
||||
param.data[loaded_shard_id].copy_(loaded_weight)
|
||||
param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
|
||||
if loaded_shard_id is not None:
|
||||
param.data[loaded_shard_id].copy_(loaded_weight)
|
||||
param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
|
||||
else:
|
||||
param.shard_weight_type = {
|
||||
i: loaded_weight.item()
|
||||
for i, _ in enumerate(self.output_sizes)
|
||||
}
|
||||
return
|
||||
|
||||
if is_gguf_weight:
|
||||
@ -459,15 +465,15 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
shard_size = loaded_weight.size(output_dim) // tp_size
|
||||
start_idx = tp_rank * shard_size
|
||||
|
||||
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
|
||||
shard_size)
|
||||
|
||||
param.shard_id.append(loaded_shard_id)
|
||||
param.shard_id_map[loaded_shard_id] = len(param.data_container)
|
||||
param.data_container.append(loaded_weight)
|
||||
if len(param.data_container) == 2:
|
||||
self.qweight = param.materialize_nested()
|
||||
return
|
||||
if loaded_shard_id is not None:
|
||||
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
|
||||
shard_size)
|
||||
param.shard_id.append(loaded_shard_id)
|
||||
param.shard_id_map[loaded_shard_id] = len(param.data_container)
|
||||
param.data_container.append(loaded_weight)
|
||||
if len(param.data_container) == 2:
|
||||
self.qweight = param.materialize_nested()
|
||||
return
|
||||
|
||||
param_data = param.data
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
@ -811,10 +817,16 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
# initialize GGUF param after we know the quantize type
|
||||
is_gguf_weight = getattr(param, "is_gguf_weight", False)
|
||||
is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
|
||||
if is_gguf_weight_type and loaded_shard_id is not None:
|
||||
if is_gguf_weight_type:
|
||||
idx_map = {"q": 0, "k": 1, "v": 2}
|
||||
param.data[idx_map[loaded_shard_id]].copy_(loaded_weight)
|
||||
param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
|
||||
if loaded_shard_id is not None:
|
||||
param.data[idx_map[loaded_shard_id]].copy_(loaded_weight)
|
||||
param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
|
||||
else:
|
||||
param.shard_weight_type = {
|
||||
k: loaded_weight.item()
|
||||
for k in idx_map
|
||||
}
|
||||
return
|
||||
|
||||
if is_gguf_weight:
|
||||
@ -825,15 +837,15 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
shard_size = loaded_weight.size(output_dim) // tp_size
|
||||
start_idx = tp_rank * shard_size
|
||||
|
||||
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
|
||||
shard_size)
|
||||
|
||||
param.shard_id.append(loaded_shard_id)
|
||||
param.shard_id_map[loaded_shard_id] = len(param.data_container)
|
||||
param.data_container.append(loaded_weight)
|
||||
if len(param.data_container) == 3:
|
||||
self.qweight = param.materialize_nested()
|
||||
return
|
||||
if loaded_shard_id is not None:
|
||||
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
|
||||
shard_size)
|
||||
param.shard_id.append(loaded_shard_id)
|
||||
param.shard_id_map[loaded_shard_id] = len(param.data_container)
|
||||
param.data_container.append(loaded_weight)
|
||||
if len(param.data_container) == 3:
|
||||
self.qweight = param.materialize_nested()
|
||||
return
|
||||
|
||||
param_data = param.data
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
|
@ -198,7 +198,10 @@ class GPT2Model(nn.Module):
|
||||
assert not config.scale_attn_by_inverse_layer_idx
|
||||
assert not config.reorder_and_upcast_attn
|
||||
self.embed_dim = config.hidden_size
|
||||
self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim)
|
||||
self.wte = VocabParallelEmbedding(config.vocab_size,
|
||||
self.embed_dim,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.wte")
|
||||
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
|
||||
self.start_layer, self.end_layer, self.h = make_layers(
|
||||
config.num_hidden_layers,
|
||||
@ -259,7 +262,9 @@ class GPT2LMHeadModel(nn.Module, SupportsPP):
|
||||
self.lm_head = self.transformer.wte
|
||||
else:
|
||||
self.lm_head = ParallelLMHead(self.config.vocab_size,
|
||||
self.config.hidden_size)
|
||||
self.config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.lm_head")
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = get_sampler()
|
||||
self.make_empty_intermediate_tensors = (
|
||||
@ -304,7 +309,7 @@ class GPT2LMHeadModel(nn.Module, SupportsPP):
|
||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||
loaded_params: Set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if "lm_head.weight" in name:
|
||||
if name.startswith("lm_head"):
|
||||
# GPT-2 ties the weights of the embedding layer and the final
|
||||
# linear layer.
|
||||
continue
|
||||
|
@ -156,7 +156,8 @@ class LlamaAttention(nn.Module):
|
||||
)
|
||||
|
||||
is_neox_style = True
|
||||
if quant_config is not None and quant_config.get_name() == "gguf":
|
||||
is_gguf = quant_config and quant_config.get_name() == "gguf"
|
||||
if is_gguf and config.model_type == "llama":
|
||||
is_neox_style = False
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
|
@ -22,7 +22,7 @@ from typing import Iterable, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
from transformers import StableLmConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
@ -50,8 +50,9 @@ from .utils import (is_pp_missing_parameter,
|
||||
class StablelmMLP(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None) -> None:
|
||||
config: StableLmConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "") -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
@ -59,10 +60,13 @@ class StablelmMLP(nn.Module):
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
config.hidden_size, [config.intermediate_size] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.gate_up_proj")
|
||||
self.down_proj = RowParallelLinear(config.intermediate_size,
|
||||
config.hidden_size,
|
||||
bias=False)
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.down_proj")
|
||||
self.act_fn = SiluAndMul()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
@ -75,7 +79,7 @@ class StablelmMLP(nn.Module):
|
||||
class StablelmAttention(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: PretrainedConfig,
|
||||
config: StableLmConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "") -> None:
|
||||
@ -116,11 +120,13 @@ class StablelmAttention(nn.Module):
|
||||
self.total_num_heads,
|
||||
self.total_num_key_value_heads,
|
||||
self.qkv_bias,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj")
|
||||
self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim,
|
||||
self.hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj")
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.rotary_ndims,
|
||||
@ -154,7 +160,7 @@ class StablelmDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
config: StableLmConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
@ -164,7 +170,7 @@ class StablelmDecoderLayer(nn.Module):
|
||||
cache_config,
|
||||
quant_config,
|
||||
prefix=f"{prefix}.self_attn")
|
||||
self.mlp = StablelmMLP(config, quant_config)
|
||||
self.mlp = StablelmMLP(config, quant_config, prefix=f"{prefix}.mlp")
|
||||
norm_eps = getattr(config, "norm_eps",
|
||||
getattr(config, "layer_norm_eps", 1e-05))
|
||||
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=norm_eps)
|
||||
@ -210,6 +216,8 @@ class StableLMEpochModel(nn.Module):
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.embed_tokens",
|
||||
)
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
@ -270,7 +278,8 @@ class StablelmForCausalLM(nn.Module, SupportsPP):
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.lm_head")
|
||||
if self.config.tie_word_embeddings:
|
||||
self.lm_head.weight = self.model.embed_tokens.weight
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
|
@ -88,12 +88,14 @@ class Starcoder2Attention(nn.Module):
|
||||
self.total_num_kv_heads,
|
||||
bias=self.use_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
self.hidden_size,
|
||||
bias=self.use_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
)
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
@ -129,19 +131,22 @@ class Starcoder2MLP(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: Starcoder2Config,
|
||||
quant_config: Optional[QuantizationConfig] = None):
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
super().__init__()
|
||||
self.c_fc = ColumnParallelLinear(
|
||||
config.hidden_size,
|
||||
config.intermediate_size,
|
||||
bias=config.use_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.c_fc",
|
||||
)
|
||||
self.c_proj = RowParallelLinear(
|
||||
config.intermediate_size,
|
||||
config.hidden_size,
|
||||
bias=config.use_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.c_proj",
|
||||
)
|
||||
self.act = get_act_fn(config.hidden_act)
|
||||
|
||||
@ -165,7 +170,9 @@ class Starcoder2DecoderLayer(nn.Module):
|
||||
cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self_attn")
|
||||
self.mlp = Starcoder2MLP(config, quant_config=quant_config)
|
||||
self.mlp = Starcoder2MLP(config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp")
|
||||
self.input_layernorm = nn.LayerNorm(config.hidden_size,
|
||||
eps=config.norm_epsilon)
|
||||
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
|
||||
@ -213,8 +220,11 @@ class Starcoder2Model(nn.Module):
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
# TODO: consider padding_idx (currently removed)
|
||||
self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
|
||||
config.hidden_size)
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.embed_tokens")
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: Starcoder2DecoderLayer(
|
||||
@ -279,6 +289,7 @@ class Starcoder2ForCausalLM(nn.Module, SupportsPP):
|
||||
org_num_embeddings=config.vocab_size,
|
||||
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.lm_head",
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||
config.vocab_size)
|
||||
|
Reference in New Issue
Block a user