[MODEL] add Exaone model support (#7819)

This commit is contained in:
Yohan Na
2024-08-30 15:34:20 +09:00
committed by GitHub
parent 34a0e96d46
commit dc13e99348
6 changed files with 820 additions and 5 deletions

View File

@ -51,6 +51,10 @@ Decoder-only Language Models
- DeciLM
- :code:`Deci/DeciLM-7B`, :code:`Deci/DeciLM-7B-instruct`, etc.
-
* - :code:`ExaoneForCausalLM`
- EXAONE-3
- :code:`LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct`, etc.
- ✅︎
* - :code:`FalconForCausalLM`
- Falcon
- :code:`tiiuae/falcon-7b`, :code:`tiiuae/falcon-40b`, :code:`tiiuae/falcon-rw-7b`, etc.

View File

@ -22,6 +22,7 @@ _GENERATION_MODELS = {
"DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
"DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
"DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
"ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
"FalconForCausalLM": ("falcon", "FalconForCausalLM"),
"GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
"Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),

View File

@ -0,0 +1,617 @@
# coding=utf-8
# Adapted from
# https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct/blob/main/modeling_exaone.py
# Copyright 2024 The LG U+ CTO AI Tech Lab.
# Copyright 2021 The LG AI Research EXAONE Lab
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Exaone model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
import torch
from torch import nn
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
get_compressed_tensors_cache_scale)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, SamplerOutput
from vllm.transformers_utils.configs.exaone import ExaoneConfig
from vllm.utils import is_hip
from .interfaces import SupportsLoRA
from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers
class ExaoneGatedMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
prefix: str = "",
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
input_size=hidden_size,
output_sizes=[intermediate_size] * 2,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
)
self.c_proj = RowParallelLinear(
input_size=intermediate_size,
output_size=hidden_size,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.c_proj",
)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
self.act_fn = SiluAndMul()
def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.c_proj(x)
return x
class ExaoneAttention(nn.Module):
def __init__(
self,
config: ExaoneConfig,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
cache_config: Optional[CacheConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
# MistralConfig has an optional head_dim introduced by Mistral-Nemo
self.head_dim = getattr(config, "head_dim",
self.hidden_size // self.total_num_heads)
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.qkv_proj = QKVParallelLinear(
hidden_size=hidden_size,
head_size=self.head_dim,
total_num_heads=self.total_num_heads,
total_num_kv_heads=self.total_num_kv_heads,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.out_proj = RowParallelLinear(
input_size=self.total_num_heads * self.head_dim,
output_size=hidden_size,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.out_proj",
)
is_neox_style = True
if quant_config is not None and quant_config.get_name() == "gguf":
is_neox_style = False
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
is_neox_style=is_neox_style,
)
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, _ = self.out_proj(attn_output)
return output
class ExaoneBlockAttention(nn.Module):
def __init__(
self,
config: ExaoneConfig,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
cache_config: Optional[CacheConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.attention = ExaoneAttention(
config=config,
hidden_size=hidden_size,
num_heads=num_heads,
num_kv_heads=num_kv_heads,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
quant_config=quant_config,
bias=bias,
cache_config=cache_config,
prefix=prefix,
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
return self.attention(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
class ExaoneDecoderLayer(nn.Module):
def __init__(
self,
config: ExaoneConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
if rope_scaling is not None and getattr(
config, "original_max_position_embeddings", None):
rope_scaling["original_max_position_embeddings"] = (
config.original_max_position_embeddings)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
# Support abacusai/Smaug-72B-v0.1 with attention_bias
# Support internlm/internlm-7b with bias
attention_bias = getattr(config, "attention_bias", False) or getattr(
config, "bias", False)
self.attn = ExaoneBlockAttention(
config=config,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=getattr(config, "num_key_value_heads",
config.num_attention_heads),
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
quant_config=quant_config,
bias=attention_bias,
cache_config=cache_config,
prefix=f"{prefix}.attn",
)
self.mlp = ExaoneGatedMLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.activation_function,
quant_config=quant_config,
bias=getattr(config, "mlp_bias", False),
prefix=f"{prefix}.mlp",
)
self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.ln_1(hidden_states)
else:
hidden_states, residual = self.ln_1(hidden_states, residual)
hidden_states = self.attn(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
# Fully Connected
hidden_states, residual = self.ln_2(hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
class ExaoneModel(nn.Module):
def __init__(
self,
config: ExaoneConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
self.padding_idx = config.pad_token_id
lora_vocab = ((lora_config.lora_extra_vocab_size *
(lora_config.max_loras or 1)) if lora_config else 0)
self.vocab_size = config.vocab_size + lora_vocab
self.wte = config.vocab_size
if get_pp_group().is_first_rank or (config.tie_word_embeddings
and get_pp_group().is_last_rank):
self.wte = VocabParallelEmbedding(
self.vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
quant_config=quant_config,
)
else:
self.wte = PPMissingLayer()
self.start_layer, self.end_layer, self.h = make_layers(
config.num_hidden_layers,
lambda prefix: ExaoneDecoderLayer(
config=config,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix,
),
prefix=f"{prefix}.h",
)
if get_pp_group().is_last_rank:
self.ln_f = RMSNorm(config.hidden_size,
eps=config.layer_norm_epsilon)
else:
self.ln_f = PPMissingLayer()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.wte(input_ids)
def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.h[i]
hidden_states, residual = layer(
positions,
hidden_states,
kv_caches[i - self.start_layer],
attn_metadata,
residual,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.ln_f(hidden_states, residual)
return hidden_states
class ExaoneForCausalLM(nn.Module, SupportsLoRA):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"c_fc_0",
"c_fc_1",
],
}
# LoRA specific attributes
supported_lora_modules = [
"qkv_proj",
"out_proj",
"gate_up_proj",
"c_proj",
"wte",
"lm_head",
]
embedding_modules = {
"wte": "input_embeddings",
"lm_head": "output_embeddings",
}
embedding_padding_modules = ["lm_head"]
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
"c_fc_0": ("gate_up_proj", 0),
"c_fc_1": ("gate_up_proj", 1),
}
def __init__(
self,
config: ExaoneConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None:
super().__init__()
self.config = config
self.lora_config = lora_config
self.transformer = ExaoneModel(
config,
cache_config,
quant_config,
lora_config=lora_config,
prefix="model",
)
if get_pp_group().is_last_rank:
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if not lora_config else lora_config.lora_vocab_padding_size,
quant_config=quant_config,
)
if config.tie_word_embeddings:
self.lm_head.weight = self.transformer.wte.weight
logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size,
logit_scale)
self.sampler = Sampler()
else:
self.lm_head = PPMissingLayer()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
model_output = self.transformer(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors)
return model_output
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def make_empty_intermediate_tensors(
self, batch_size: int, dtype: torch.dtype,
device: torch.device) -> IntermediateTensors:
return IntermediateTensors({
"hidden_states":
torch.zeros(
(batch_size, self.config.hidden_size),
dtype=dtype,
device=device,
),
"residual":
torch.zeros(
(batch_size, self.config.hidden_size),
dtype=dtype,
device=device,
),
})
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".gate_up_proj", ".c_fc_0", 0),
(".gate_up_proj", ".c_fc_1", 1),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if ("rotary_emb.cos_cached" in name
or "rotary_emb.sin_cached" in name):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
# With tie_word_embeddings, we can skip lm_head.weight
# The weight might appear unnecessarily in the files if the model is
# processed with quantization, LoRA, fine-tuning, etc.
if self.config.tie_word_embeddings and "lm_head.weight" in name:
continue
if scale_name := get_compressed_tensors_cache_scale(name):
# Loading kv cache scales for compressed-tensors quantization
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = loaded_weight[0]
weight_loader(param, loaded_weight)
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# If this function is called, it should always initialize KV cache scale
# factors (or else raise an exception). Thus, handled exceptions should
# make sure to leave KV cache scale factors in a known good (dummy) state
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
for layer_idx, scaling_factor in kv_cache_scales_loader(
quantization_param_path,
tp_rank,
tp_size,
self.config.num_hidden_layers,
self.config.__class__.model_type,
):
if not isinstance(self.transformer.h[layer_idx], nn.Identity):
layer_self_attn = self.transformer.h[layer_idx].attn
if is_hip():
# The scaling factor convention we are assuming is
# quantized_value * scaling_factor ~= true_value
# which is consistent with the practice of setting
# scaling_factor = tensor_amax / FPtype_max
scaling_factor *= 2
if hasattr(layer_self_attn, "kv_scale"):
layer_self_attn.attn._kv_scale = scaling_factor
else:
raise RuntimeError("Self attention has no KV cache scaling "
"factor attribute!")

View File

@ -11,11 +11,11 @@ from transformers.models.auto.modeling_auto import (
from vllm.envs import VLLM_USE_MODELSCOPE
from vllm.logger import init_logger
from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig,
EAGLEConfig, InternVLChatConfig,
JAISConfig, MedusaConfig,
MLPSpeculatorConfig, MPTConfig,
NemotronConfig, RWConfig,
UltravoxConfig)
EAGLEConfig, ExaoneConfig,
InternVLChatConfig, JAISConfig,
MedusaConfig, MLPSpeculatorConfig,
MPTConfig, NemotronConfig,
RWConfig, UltravoxConfig)
if VLLM_USE_MODELSCOPE:
from modelscope import AutoConfig
@ -34,6 +34,7 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
"mlp_speculator": MLPSpeculatorConfig,
"medusa": MedusaConfig,
"eagle": EAGLEConfig,
"exaone": ExaoneConfig,
"internvl_chat": InternVLChatConfig,
"nemotron": NemotronConfig,
"ultravox": UltravoxConfig,

View File

@ -1,6 +1,7 @@
from vllm.transformers_utils.configs.chatglm import ChatGLMConfig
from vllm.transformers_utils.configs.dbrx import DbrxConfig
from vllm.transformers_utils.configs.eagle import EAGLEConfig
from vllm.transformers_utils.configs.exaone import ExaoneConfig
# RWConfig is for the original tiiuae/falcon-40b(-instruct) and
# tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the
# `FalconConfig` class from the official HuggingFace transformers library.
@ -22,6 +23,7 @@ __all__ = [
"JAISConfig",
"MedusaConfig",
"EAGLEConfig",
"ExaoneConfig",
"MLPSpeculatorConfig",
"NemotronConfig",
"UltravoxConfig",

View File

@ -0,0 +1,190 @@
# coding=utf-8
# Copied from
# https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct/blob/main/configuration_exaone.py
# Copyright 2021 The LG AI Research EXAONE Lab. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Exaone model configuration"""
from typing import Dict
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
logger = logging.get_logger(__name__)
EXAONE_PRETRAINED_CONFIG_ARCHIVE_MAP: Dict[str, str] = {}
class ExaoneConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a :class:
`~transformers.ExaoneModel`. It is used to instantiate a GPT Lingvo model
according to the specified arguments, defining the model architecture.
Instantiating a configuration with the defaults will yield a similar
configuration to that of the Exaone
Configuration objects inherit from :class:`~transformers.PretrainedConfig`
and can be used to control the model outputs. Read the documentation from :
class:`~transformers.PretrainedConfig` for more information.
Args:
vocab_size (:obj:`int`, `optional`, defaults to 50257):
Vocabulary size of the GPT Lingvo model. Defines the number of
different tokens that can be represented by the :obj:`inputs_ids`
passed when calling :class:`~transformers.ExaoneModel`. Vocabulary
size of the model.
Defines the different tokens that can be represented by the
`inputs_ids` passed to the forward method of :class:
`~transformers.EXAONEModel`.
hidden_size (:obj:`int`, `optional`, defaults to 2048):
Dimensionality of the encoder layers and the pooler layer.
num_layers (:obj:`int`, `optional`, defaults to 24):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 32):
Number of attention heads for each attention layer in the
Transformer decoder.
num_key_value_heads (`int`, *optional*):
This is the number of key_value heads that should be used to
implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi
Head Attention (MHA), if `num_key_value_heads=1 the model will use
Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint,
each group key and value head should be constructed by meanpooling
all the original heads within that group. For more details checkout
[this paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not
specified, will default to `num_attention_heads`.
rotary_pct (`float`, *optional*, defaults to 0.25):
percentage of hidden dimensions to allocate to rotary embeddings
intermediate_size (:obj:`int`, `optional`, defaults to 8192):
Dimensionality of the "intermediate" (i.e., feed-forward) layer in
the Transformer encoder.
activation_function (:obj:`str` or :obj:`function`, `optional`,
defaults to :obj:`"gelu_new"`):
The non-linear activation function (function or string) in the
encoder and pooler. If string, :obj:`"gelu"`, :obj:`"relu"`,
:obj:`"selu"` and :obj:`"gelu_new"` are supported.
embed_dropout (:obj:`float`, `optional`, defaults to 0.0):
The dropout probabilitiy for all fully connected layers in the
embeddings, encoder, and pooler.
attention_dropout (:obj:`float`, `optional`, defaults to 0.0):
The dropout ratio for the attention probabilities.
max_position_embeddings (:obj:`int`, `optional`, defaults to 2048):
The maximum sequence length that this model might ever be used with.
Typically set this to something large just in case
(e.g., 512 or 1024 or 2048).
type_vocab_size (:obj:`int`, `optional`, defaults to 2):
The vocabulary size of the :obj:`token_type_ids` passed when calling
:class:`~transformers.EXAONEModel`.
initializer_range (:obj:`float`, `optional`, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for
initializing all weight matrices.
layer_norm_epsilon (:obj:`float`, `optional`, defaults to 1e-5):
The epsilon used by the layer normalization layers.
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not the model should return the last key/values
attentions (not used by all models).
Only relevant if ``config.is_decoder=True``.
gradient_checkpointing (:obj:`bool`, `optional`,
defaults to :obj:`False`):
If True, use gradient checkpointing to save memory at the expense
of slower backward pass.
Example::
>>> from transformers import ExoneModel, ExaoneConfig
>>> # Initializing a EXAONE configuration
>>> configuration = ExaoneConfig()
>>> # Initializing a model from configuration
>>> model = ExoneModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
"""
model_type = "exaone"
keys_to_ignore_at_inference = ["past_key_values"]
attribute_map = {"num_hidden_layers": "num_layers"}
def __init__(
self,
vocab_size=102400,
max_position_embeddings=2048,
hidden_size=2048,
num_layers=32,
num_attention_heads=32,
num_key_value_heads=None,
intermediate_size=None,
activation_function="silu",
rotary_pct=0.25,
resid_dropout=0.0,
embed_dropout=0.0,
attention_dropout=0.0,
layer_norm_epsilon=1e-6,
initializer_range=0.02,
use_cache=True,
bos_token_id=0,
eos_token_id=2,
tie_word_embeddings=True,
**kwargs,
):
super().__init__(
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.num_layers = num_layers
self.num_attention_heads = num_attention_heads
self.num_hidden_layers = num_layers
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
if intermediate_size:
self.intermediate_size = intermediate_size
else:
self.intermediate_size = hidden_size * 4
self.activation_function = activation_function
self.resid_dropout = resid_dropout
self.embed_dropout = embed_dropout
self.attention_dropout = attention_dropout
self.layer_norm_epsilon = layer_norm_epsilon
self.initializer_range = initializer_range
self.use_cache = use_cache
self.rotary_pct = rotary_pct
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
self.use_logit_cap = kwargs.pop("use_logit_cap", False)
self.ln_no_scale = kwargs.pop("ln_no_scale", False)
self.use_gated = kwargs.pop("use_gated", False)
self.use_emb_norm = kwargs.pop("use_emb_norm", False)
self.use_rotary_pos = kwargs.pop("use_rotary_pos", False)
self.rotary_type = kwargs.pop("rotary_type", None)
self.scaling_factor = kwargs.pop("scaling_factor", 1)
self.use_absolute_pos = kwargs.pop("use_absolute_pos", True)
self.use_extra_logit = kwargs.pop("use_extra_logit", True)
self.rotary_expand_length = kwargs.pop("rotary_expand_length", None)
self.rotary_base = kwargs.pop("rotary_base", 10000.0)
self.use_qkv_fuse = kwargs.pop("use_qkv_fuse", False)
self.rescale_before_lm_head = kwargs.pop("rescale_before_lm_head",
(rotary_pct == 0.25))
if self.use_rotary_pos:
self.use_absolute_pos = False