mirror of
				https://github.com/huggingface/transformers.git
				synced 2025-10-25 20:55:14 +08:00 
			
		
		
		
	Compare commits
	
		
			35 Commits
		
	
	
		
			v4.56.0
			...
			refactorin
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| cdb8c6b19d | |||
| 709429a141 | |||
| 35576acfcd | |||
| f3fe0b340a | |||
| 3dedb93c45 | |||
| daebeeaf04 | |||
| 45f20f5641 | |||
| eaaf34f303 | |||
| d3ab98e5ae | |||
| d5c00047da | |||
| 8fe406fd17 | |||
| 774a4af6de | |||
| a47468a938 | |||
| 580fbe19e2 | |||
| 0782ffd2c4 | |||
| 3a3510ab73 | |||
| ca181ab402 | |||
| 8752d35aa8 | |||
| 2a654ec763 | |||
| 1aabcc1a73 | |||
| 22ff159e50 | |||
| 1632e0f4bd | |||
| e467d2fede | |||
| 7545c5f766 | |||
| 740e5bd35c | |||
| 022727c480 | |||
| d68766aa7c | |||
| 92b6218e18 | |||
| e08d8eb963 | |||
| eb5c2e27e1 | |||
| 1fa297cf1f | |||
| 0bb0af9ac0 | |||
| bd59e58ca8 | |||
| 564813d72e | |||
| f02e2fb8cc | 
							
								
								
									
										104
									
								
								src/transformers/models/cohere/diff_cohere.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										104
									
								
								src/transformers/models/cohere/diff_cohere.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,104 @@ | ||||
| from transformers.models.llama.modeling_llama import * | ||||
| import torch.nn as nn | ||||
| from transformers import CohereConfig | ||||
| from transformers.utils import ModelConverter | ||||
|  | ||||
| CohereConverter = ModelConverter(__file__) | ||||
| # now should the cohere converted be added to all model converters?  | ||||
|  | ||||
| class CohereLayerNorm(nn.Module): | ||||
|     def __init__(self, hidden_size=None, eps=1e-5, bias=False): | ||||
|         """The hidden size can be a tuple or an int. The tuple is used for QKNorm to normalize across head_dim""" | ||||
|         super().__init__() | ||||
|         self.weight = nn.Parameter(torch.ones(hidden_size)) | ||||
|         self.variance_epsilon = eps | ||||
|  | ||||
|     def forward(self, hidden_states): | ||||
|         input_dtype = hidden_states.dtype | ||||
|         hidden_states = hidden_states.to(torch.float32) | ||||
|         mean = hidden_states.mean(-1, keepdim=True) | ||||
|         variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True) | ||||
|         hidden_states = (hidden_states - mean) * torch.rsqrt(variance + self.variance_epsilon) | ||||
|         hidden_states = self.weight.to(torch.float32) * hidden_states | ||||
|         return hidden_states.to(input_dtype) | ||||
|  | ||||
| class CohereRotaryEmbedding(LlamaRotaryEmbedding): | ||||
|  | ||||
|     def rotate_half(self, x): | ||||
|         # Split and rotate | ||||
|         x1 = x[..., ::2] | ||||
|         x2 = x[..., 1::2] | ||||
|         rot_x = torch.stack([-x2, x1], dim=-1).flatten(-2) | ||||
|         return rot_x | ||||
|  | ||||
|     def forward(self, q, k, position_ids=None, unsqueeze_dim=1): | ||||
|         dtype = q.dtype | ||||
|         q,k  = q.float(), k.float() | ||||
|         cos, sin = self.comput_cos_sin(q, position_ids) | ||||
|         cos = cos.unsqueeze(unsqueeze_dim) | ||||
|         sin = sin.unsqueeze(unsqueeze_dim) | ||||
|         q_embed = (q * cos) + (self.rotate_half(q) * sin) | ||||
|         k_embed = (k * cos) + (self.rotate_half(k) * sin) | ||||
|         return q_embed.to(dtype=dtype), k_embed.to(dtype=dtype) | ||||
|  | ||||
| CohereMLP = CohereConverter.register("CohereMLP", LlamaMLP)  | ||||
| CohereAttention = CohereConverter.register("CohereAttention", LlamaAttention)  | ||||
| CohereSdpaAttention = CohereConverter.register("CohereSdpaAttention", LlamaAttention)  | ||||
| CohereFlashAttention2 = CohereConverter.register("CohereFlashAttention2", LlamaAttention)  | ||||
|  | ||||
| COHERE_ATTENTION_CLASSES = {"eager": CohereAttention, "flash_attention_2": CohereFlashAttention2, "sdpa": CohereSdpaAttention} | ||||
|  | ||||
| class CohereDecoderLayer(nn.Module): | ||||
|     def __init__(self, config: CohereConfig, layer_idx: int): | ||||
|         super().__init__() | ||||
|         self.hidden_size = config.hidden_size | ||||
|  | ||||
|         self.self_attn = COHERE_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) | ||||
|  | ||||
|         self.mlp = CohereMLP(config) | ||||
|         self.input_layernorm = CohereLayerNorm(hidden_size=(config.hidden_size), eps=config.layer_norm_eps) | ||||
|  | ||||
|     def forward( | ||||
|         self, | ||||
|         hidden_states: torch.Tensor, | ||||
|         attention_mask: Optional[torch.Tensor] = None, | ||||
|         position_ids: Optional[torch.LongTensor] = None, | ||||
|         past_key_value: Optional[Tuple[torch.Tensor]] = None, | ||||
|         output_attentions: Optional[bool] = False, | ||||
|         use_cache: Optional[bool] = False, | ||||
|         cache_position: Optional[torch.LongTensor] = None, | ||||
|     ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: | ||||
|         residual = hidden_states | ||||
|  | ||||
|         hidden_states = self.input_layernorm(hidden_states) | ||||
|  | ||||
|         # Self Attention | ||||
|         hidden_states_attention, self_attn_weights, present_key_value = self.self_attn( | ||||
|             hidden_states=hidden_states, | ||||
|             attention_mask=attention_mask, | ||||
|             position_ids=position_ids, | ||||
|             past_key_value=past_key_value, | ||||
|             output_attentions=output_attentions, | ||||
|             use_cache=use_cache, | ||||
|             cache_position=cache_position, | ||||
|         ) | ||||
|  | ||||
|         # Fully Connected | ||||
|         hidden_states_mlp = self.mlp(hidden_states) | ||||
|  | ||||
|         # Add everything together (main diff with llama ) | ||||
|         hidden_states = residual + hidden_states_attention + hidden_states_mlp | ||||
|  | ||||
|         outputs = (hidden_states,) | ||||
|  | ||||
|         if output_attentions: | ||||
|             outputs += (self_attn_weights,) | ||||
|  | ||||
|         if use_cache: | ||||
|             outputs += (present_key_value,) | ||||
|  | ||||
|         return outputs | ||||
|  | ||||
| CoherePreTrainedModel = CohereConverter.register("CoherePreTrainedModel", LlamaPreTrainedModel) | ||||
| CohereModel = CohereConverter.register("CohereModel", LlamaModel) | ||||
| CohereForCausalLM = CohereConverter.register("CohereForCausalLM", LlamaForCausalLM) | ||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										291
									
								
								src/transformers/models/gemma/diff_gemma.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										291
									
								
								src/transformers/models/gemma/diff_gemma.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,291 @@ | ||||
| # coding=utf-8 | ||||
| # Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved. | ||||
| # | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| """ PyTorch Gemma model.""" | ||||
| from transformers.models.llama.modeling_llama import * | ||||
| import torch.nn as nn | ||||
| from transformers.utils import ModelConverter | ||||
|  | ||||
|  | ||||
| import math | ||||
| from typing import List, Optional, Tuple, Union | ||||
|  | ||||
| import torch | ||||
| import torch.utils.checkpoint | ||||
| from torch import nn | ||||
|  | ||||
| from ...activations import ACT2FN | ||||
| from ...cache_utils import Cache | ||||
|  | ||||
| from ...pytorch_utils import ALL_LAYERNORM_LAYERS  | ||||
| from ...utils import ( | ||||
|     logging, | ||||
| ) | ||||
| from .configuration_gemma import GemmaConfig | ||||
|  | ||||
| from transformers.models.llama.modeling_llama import repeat_kv, apply_rotary_pos_emb | ||||
|  | ||||
|  | ||||
| logger = logging.get_logger(__name__) | ||||
|  | ||||
| GemmaConverter = ModelConverter(__file__) | ||||
|  | ||||
| class GemmaRMSNorm(nn.Module): | ||||
|     def __init__(self, dim: int, eps: float = 1e-6): | ||||
|         super().__init__() | ||||
|         self.eps = eps | ||||
|         self.weight = nn.Parameter(torch.zeros(dim)) | ||||
|  | ||||
|     def _norm(self, x): | ||||
|         return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) | ||||
|  | ||||
|     def forward(self, x): | ||||
|         output = self._norm(x.float()) | ||||
|         # Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16) | ||||
|         # See https://github.com/huggingface/transformers/pull/29402 | ||||
|         output = output * (1.0 + self.weight.float()) | ||||
|         return output.type_as(x) | ||||
|  | ||||
|  | ||||
| ALL_LAYERNORM_LAYERS.append(GemmaRMSNorm) | ||||
|  | ||||
|  | ||||
| class GemmaRotaryEmbedding(nn.Module): | ||||
|     def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): | ||||
|         super().__init__() | ||||
|  | ||||
|         self.dim = dim | ||||
|         self.max_position_embeddings = max_position_embeddings | ||||
|         self.base = base | ||||
|         self.register_buffer("inv_freq", None, persistent=False) | ||||
|  | ||||
|     @torch.no_grad() | ||||
|     def forward(self, x, position_ids, seq_len=None): | ||||
|         # x: [bs, num_attention_heads, seq_len, head_size] | ||||
|         if self.inv_freq is None: | ||||
|             self.inv_freq = 1.0 / ( | ||||
|                 self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim) | ||||
|             ) | ||||
|         inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) | ||||
|         position_ids_expanded = position_ids[:, None, :].float() | ||||
|         # Force float32 since bfloat16 loses precision on long contexts | ||||
|         # See https://github.com/huggingface/transformers/pull/29285 | ||||
|         device_type = x.device.type | ||||
|         device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" | ||||
|         with torch.autocast(device_type=device_type, enabled=False): | ||||
|             freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) | ||||
|             emb = torch.cat((freqs, freqs), dim=-1) | ||||
|             cos = emb.cos() | ||||
|             sin = emb.sin() | ||||
|         return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) | ||||
|  | ||||
|  | ||||
|  | ||||
| class GemmaMLP(nn.Module): | ||||
|     def __init__(self, config): | ||||
|         super().__init__() | ||||
|         self.config = config | ||||
|         self.hidden_size = config.hidden_size | ||||
|         self.intermediate_size = config.intermediate_size | ||||
|         self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) | ||||
|         self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) | ||||
|         self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) | ||||
|         if config.hidden_activation is None: | ||||
|             logger.warning_once( | ||||
|                 "`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.\n" | ||||
|                 "Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use\n" | ||||
|                 "`config.hidden_activation` if you want to override this behaviour.\n" | ||||
|                 "See https://github.com/huggingface/transformers/pull/29402 for more details." | ||||
|             ) | ||||
|             config.hidden_activation = "gelu_pytorch_tanh" | ||||
|         hidden_activation = config.hidden_activation | ||||
|         self.act_fn = ACT2FN[hidden_activation] | ||||
|  | ||||
|     def forward(self, x): | ||||
|         return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) | ||||
|  | ||||
|  | ||||
|  | ||||
| class GemmaAttention(nn.Module): | ||||
|     """Multi-headed attention from 'Attention Is All You Need' paper""" | ||||
|  | ||||
|     # Ignore copy | ||||
|     def __init__(self, config: GemmaConfig, layer_idx: Optional[int] = None): | ||||
|         super().__init__() | ||||
|         self.config = config | ||||
|         self.layer_idx = layer_idx | ||||
|         if layer_idx is None: | ||||
|             logger.warning_once( | ||||
|                 f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " | ||||
|                 "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " | ||||
|                 "when creating this class." | ||||
|             ) | ||||
|  | ||||
|         self.attention_dropout = config.attention_dropout | ||||
|         self.hidden_size = config.hidden_size | ||||
|         self.num_heads = config.num_attention_heads | ||||
|         self.head_dim = config.head_dim | ||||
|         self.num_key_value_heads = config.num_key_value_heads | ||||
|         self.num_key_value_groups = self.num_heads // self.num_key_value_heads | ||||
|         self.max_position_embeddings = config.max_position_embeddings | ||||
|         self.rope_theta = config.rope_theta | ||||
|         self.is_causal = True | ||||
|  | ||||
|         if self.hidden_size % self.num_heads != 0: | ||||
|             raise ValueError( | ||||
|                 f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" | ||||
|                 f" and `num_heads`: {self.num_heads})." | ||||
|             ) | ||||
|  | ||||
|         self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) | ||||
|         self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) | ||||
|         self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) | ||||
|         self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) | ||||
|         self.rotary_emb = GemmaRotaryEmbedding( | ||||
|             self.head_dim, | ||||
|             max_position_embeddings=self.max_position_embeddings, | ||||
|             base=self.rope_theta, | ||||
|         ) | ||||
|  | ||||
|     def forward( | ||||
|         self, | ||||
|         hidden_states: torch.Tensor, | ||||
|         attention_mask: Optional[torch.Tensor] = None, | ||||
|         position_ids: Optional[torch.LongTensor] = None, | ||||
|         past_key_value: Optional[Cache] = None, | ||||
|         output_attentions: bool = False, | ||||
|         use_cache: bool = False, | ||||
|         cache_position: Optional[torch.LongTensor] = None, | ||||
|         **kwargs, | ||||
|     ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: | ||||
|         bsz, q_len, _ = hidden_states.size() | ||||
|  | ||||
|         query_states = self.q_proj(hidden_states) | ||||
|         key_states = self.k_proj(hidden_states) | ||||
|         value_states = self.v_proj(hidden_states) | ||||
|  | ||||
|         query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) | ||||
|         key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) | ||||
|         value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) | ||||
|  | ||||
|         cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None) | ||||
|         query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None) | ||||
|  | ||||
|         if past_key_value is not None: | ||||
|             # sin and cos are specific to RoPE models; cache_position needed for the static cache | ||||
|             cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} | ||||
|             key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) | ||||
|  | ||||
|         key_states = repeat_kv(key_states, self.num_key_value_groups) | ||||
|         value_states = repeat_kv(value_states, self.num_key_value_groups) | ||||
|  | ||||
|         attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) | ||||
|  | ||||
|         if attention_mask is not None:  # no matter the length, we just slice it | ||||
|             causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] | ||||
|             attn_weights = attn_weights + causal_mask | ||||
|  | ||||
|         # upcast attention to fp32 | ||||
|         attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) | ||||
|         attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) | ||||
|         attn_output = torch.matmul(attn_weights, value_states) | ||||
|  | ||||
|         if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): | ||||
|             raise ValueError( | ||||
|                 f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" | ||||
|                 f" {attn_output.size()}" | ||||
|             ) | ||||
|  | ||||
|         attn_output = attn_output.transpose(1, 2).contiguous() | ||||
|  | ||||
|         attn_output = attn_output.view(bsz, q_len, -1) | ||||
|         attn_output = self.o_proj(attn_output) | ||||
|  | ||||
|         if not output_attentions: | ||||
|             attn_weights = None | ||||
|  | ||||
|         return attn_output, attn_weights, past_key_value | ||||
|  | ||||
|  | ||||
| GemmaFlashAttention2 = GemmaConverter.register("GemmaFlashAttention2", LlamaFlashAttention2) | ||||
| GemmaSdpaAttention = GemmaConverter.register("GemmaSdpaAttention", LlamaSdpaAttention) | ||||
|  | ||||
| COHERE_ATTENTION_CLASSES = {"eager": GemmaAttention, "flash_attention_2": GemmaFlashAttention2, "sdpa": GemmaSdpaAttention} | ||||
|  | ||||
| GemmaConverter.register("GemmaDecoderLayer", LlamaDecoderLayer)  | ||||
| GemmaConverter.register("GemmaPreTrainedModel", LlamaPreTrainedModel) | ||||
|  | ||||
| class GemmaModel(LlamaModel): | ||||
|      | ||||
|     def forward( | ||||
|         self, | ||||
|         input_ids: torch.LongTensor = None, | ||||
|         attention_mask: Optional[torch.Tensor] = None, | ||||
|         position_ids: Optional[torch.LongTensor] = None, | ||||
|         past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, | ||||
|         inputs_embeds: Optional[torch.FloatTensor] = None, | ||||
|         labels: Optional[torch.LongTensor] = None, | ||||
|         use_cache: Optional[bool] = None, | ||||
|         output_attentions: Optional[bool] = None, | ||||
|         output_hidden_states: Optional[bool] = None, | ||||
|         return_dict: Optional[bool] = None, | ||||
|         cache_position: Optional[torch.LongTensor] = None, | ||||
|     ) -> Union[Tuple, CausalLMOutputWithPast]: | ||||
|         output_attentions = output_attentions | self.config.output_attentions | ||||
|         output_hidden_states = output_hidden_states| self.config.output_hidden_states | ||||
|         return_dict = return_dict | self.config.use_return_dict | ||||
|  | ||||
|         outputs = self.model( | ||||
|             input_ids=input_ids, | ||||
|             attention_mask=attention_mask, | ||||
|             position_ids=position_ids, | ||||
|             past_key_values=past_key_values, | ||||
|             inputs_embeds=inputs_embeds, | ||||
|             use_cache=use_cache, | ||||
|             output_attentions=output_attentions, | ||||
|             output_hidden_states=output_hidden_states, | ||||
|             return_dict=return_dict, | ||||
|             cache_position=cache_position, | ||||
|         ) | ||||
|  | ||||
|         hidden_states = outputs[0] | ||||
|         logits = self.lm_head(hidden_states) | ||||
|         logits = logits.float() | ||||
|         loss = None | ||||
|         if labels is not None: | ||||
|             # Shift so that tokens < n predict n | ||||
|             shift_logits = logits[..., :-1, :].contiguous() | ||||
|             shift_labels = labels[..., 1:].contiguous() | ||||
|             # Flatten the tokens | ||||
|             loss_fct = CrossEntropyLoss() | ||||
|             shift_logits = shift_logits.view(-1, self.config.vocab_size) | ||||
|             shift_labels = shift_labels.view(-1) | ||||
|             # Enable model parallelism | ||||
|             shift_labels = shift_labels.to(shift_logits.device) | ||||
|             loss = loss_fct(shift_logits, shift_labels) | ||||
|  | ||||
|         if not return_dict: | ||||
|             output = (logits,) + outputs[1:] | ||||
|             return (loss,) + output if loss is not None else output | ||||
|  | ||||
|         return CausalLMOutputWithPast( | ||||
|             loss=loss, | ||||
|             logits=logits, | ||||
|             past_key_values=outputs.past_key_values, | ||||
|             hidden_states=outputs.hidden_states, | ||||
|             attentions=outputs.attentions, | ||||
|         ) | ||||
|  | ||||
| GemmaConverter.register("GemmaForCausalLM", LlamaForCausalLM) | ||||
| @ -1,3 +1,14 @@ | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| # coding=utf-8 | ||||
| # Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved. | ||||
| # | ||||
| @ -14,68 +25,32 @@ | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| """ PyTorch Gemma model.""" | ||||
| from transformers.models.llama.modeling_llama import * | ||||
| import torch.nn as nn | ||||
| from transformers.utils import ModelConverter | ||||
|  | ||||
|  | ||||
| import math | ||||
| import warnings | ||||
| from typing import List, Optional, Tuple, Union | ||||
|  | ||||
| import torch | ||||
| import torch.nn.functional as F | ||||
| import torch.utils.checkpoint | ||||
| from torch import nn | ||||
| from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss | ||||
|  | ||||
| from ...activations import ACT2FN | ||||
| from ...cache_utils import Cache, DynamicCache, StaticCache | ||||
| from ...modeling_attn_mask_utils import ( | ||||
|     AttentionMaskConverter, | ||||
|     _prepare_4d_causal_attention_mask, | ||||
| ) | ||||
| from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast | ||||
| from ...modeling_utils import PreTrainedModel | ||||
| from ...pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13 | ||||
| from ...cache_utils import Cache | ||||
|  | ||||
| from ...pytorch_utils import ALL_LAYERNORM_LAYERS  | ||||
| from ...utils import ( | ||||
|     add_start_docstrings, | ||||
|     add_start_docstrings_to_model_forward, | ||||
|     is_flash_attn_2_available, | ||||
|     is_flash_attn_greater_or_equal_2_10, | ||||
|     logging, | ||||
|     replace_return_docstrings, | ||||
| ) | ||||
| from ...utils.import_utils import is_torch_fx_available | ||||
| from .configuration_gemma import GemmaConfig | ||||
|  | ||||
|  | ||||
| if is_flash_attn_2_available(): | ||||
|     from flash_attn import flash_attn_func, flash_attn_varlen_func | ||||
|     from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input  # noqa | ||||
|  | ||||
|  | ||||
| # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph. | ||||
| # It means that the function will not be traced through and simply appear as a node in the graph. | ||||
| if is_torch_fx_available(): | ||||
|     if not is_torch_greater_or_equal_than_1_13: | ||||
|         import torch.fx | ||||
|  | ||||
|     _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask) | ||||
| from transformers.models.llama.modeling_llama import repeat_kv, apply_rotary_pos_emb | ||||
|  | ||||
|  | ||||
| logger = logging.get_logger(__name__) | ||||
|  | ||||
| _CONFIG_FOR_DOC = "GemmaConfig" | ||||
|  | ||||
|  | ||||
| def _get_unpad_data(attention_mask): | ||||
|     seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) | ||||
|     indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() | ||||
|     max_seqlen_in_batch = seqlens_in_batch.max().item() | ||||
|     cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) | ||||
|     return ( | ||||
|         indices, | ||||
|         cu_seqlens, | ||||
|         max_seqlen_in_batch, | ||||
|     ) | ||||
|  | ||||
|  | ||||
| class GemmaRMSNorm(nn.Module): | ||||
|     def __init__(self, dim: int, eps: float = 1e-6): | ||||
| @ -127,41 +102,6 @@ class GemmaRotaryEmbedding(nn.Module): | ||||
|         return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) | ||||
|  | ||||
|  | ||||
| # Copied from transformers.models.llama.modeling_llama.rotate_half | ||||
| def rotate_half(x): | ||||
|     """Rotates half the hidden dims of the input.""" | ||||
|     x1 = x[..., : x.shape[-1] // 2] | ||||
|     x2 = x[..., x.shape[-1] // 2 :] | ||||
|     return torch.cat((-x2, x1), dim=-1) | ||||
|  | ||||
|  | ||||
| # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb | ||||
| def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): | ||||
|     """Applies Rotary Position Embedding to the query and key tensors. | ||||
|  | ||||
|     Args: | ||||
|         q (`torch.Tensor`): The query tensor. | ||||
|         k (`torch.Tensor`): The key tensor. | ||||
|         cos (`torch.Tensor`): The cosine part of the rotary embedding. | ||||
|         sin (`torch.Tensor`): The sine part of the rotary embedding. | ||||
|         position_ids (`torch.Tensor`, *optional*): | ||||
|             Deprecated and unused. | ||||
|         unsqueeze_dim (`int`, *optional*, defaults to 1): | ||||
|             The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and | ||||
|             sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note | ||||
|             that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and | ||||
|             k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes | ||||
|             cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have | ||||
|             the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. | ||||
|     Returns: | ||||
|         `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. | ||||
|     """ | ||||
|     cos = cos.unsqueeze(unsqueeze_dim) | ||||
|     sin = sin.unsqueeze(unsqueeze_dim) | ||||
|     q_embed = (q * cos) + (rotate_half(q) * sin) | ||||
|     k_embed = (k * cos) + (rotate_half(k) * sin) | ||||
|     return q_embed, k_embed | ||||
|  | ||||
|  | ||||
| class GemmaMLP(nn.Module): | ||||
|     def __init__(self, config): | ||||
| @ -187,18 +127,6 @@ class GemmaMLP(nn.Module): | ||||
|         return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) | ||||
|  | ||||
|  | ||||
| # Copied from transformers.models.llama.modeling_llama.repeat_kv | ||||
| def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: | ||||
|     """ | ||||
|     This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, | ||||
|     num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) | ||||
|     """ | ||||
|     batch, num_key_value_heads, slen, head_dim = hidden_states.shape | ||||
|     if n_rep == 1: | ||||
|         return hidden_states | ||||
|     hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) | ||||
|     return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) | ||||
|  | ||||
|  | ||||
| class GemmaAttention(nn.Module): | ||||
|     """Multi-headed attention from 'Attention Is All You Need' paper""" | ||||
| @ -301,7 +229,6 @@ class GemmaAttention(nn.Module): | ||||
|         return attn_output, attn_weights, past_key_value | ||||
|  | ||||
|  | ||||
| # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->Gemma | ||||
| class GemmaFlashAttention2(GemmaAttention): | ||||
|     """ | ||||
|     Gemma flash attention module. This module inherits from `GemmaAttention` as the weights of the module stays | ||||
| @ -317,7 +244,6 @@ class GemmaFlashAttention2(GemmaAttention): | ||||
|         # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). | ||||
|         self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() | ||||
|  | ||||
|     # Ignore copy | ||||
|     def forward( | ||||
|         self, | ||||
|         hidden_states: torch.Tensor, | ||||
| @ -334,6 +260,7 @@ class GemmaFlashAttention2(GemmaAttention): | ||||
|                 "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " | ||||
|                 "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" | ||||
|             ) | ||||
|  | ||||
|         output_attentions = False | ||||
|  | ||||
|         bsz, q_len, _ = hidden_states.size() | ||||
| @ -349,8 +276,8 @@ class GemmaFlashAttention2(GemmaAttention): | ||||
|         key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) | ||||
|         value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) | ||||
|  | ||||
|         cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None) | ||||
|         query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None) | ||||
|         cos, sin = self.rotary_emb(value_states, position_ids) | ||||
|         query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) | ||||
|  | ||||
|         if past_key_value is not None: | ||||
|             # sin and cos are specific to RoPE models; cache_position needed for the static cache | ||||
| @ -395,7 +322,7 @@ class GemmaFlashAttention2(GemmaAttention): | ||||
|             query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate | ||||
|         ) | ||||
|  | ||||
|         attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() | ||||
|         attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() | ||||
|         attn_output = self.o_proj(attn_output) | ||||
|  | ||||
|         if not output_attentions: | ||||
| @ -500,8 +427,6 @@ class GemmaFlashAttention2(GemmaAttention): | ||||
|             (max_seqlen_in_batch_q, max_seqlen_in_batch_k), | ||||
|         ) | ||||
|  | ||||
|  | ||||
| # Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Gemma | ||||
| class GemmaSdpaAttention(GemmaAttention): | ||||
|     """ | ||||
|     Gemma attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from | ||||
| @ -509,7 +434,7 @@ class GemmaSdpaAttention(GemmaAttention): | ||||
|     SDPA API. | ||||
|     """ | ||||
|  | ||||
|     # Ignore copy | ||||
|     # Adapted from GemmaAttention.forward | ||||
|     def forward( | ||||
|         self, | ||||
|         hidden_states: torch.Tensor, | ||||
| @ -546,8 +471,8 @@ class GemmaSdpaAttention(GemmaAttention): | ||||
|         key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) | ||||
|         value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) | ||||
|  | ||||
|         cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None) | ||||
|         query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None) | ||||
|         cos, sin = self.rotary_emb(value_states, position_ids) | ||||
|         query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) | ||||
|  | ||||
|         if past_key_value is not None: | ||||
|             # sin and cos are specific to RoPE models; cache_position needed for the static cache | ||||
| @ -582,27 +507,21 @@ class GemmaSdpaAttention(GemmaAttention): | ||||
|         ) | ||||
|  | ||||
|         attn_output = attn_output.transpose(1, 2).contiguous() | ||||
|         attn_output = attn_output.view(bsz, q_len, -1) | ||||
|         attn_output = attn_output.view(bsz, q_len, self.hidden_size) | ||||
|  | ||||
|         attn_output = self.o_proj(attn_output) | ||||
|  | ||||
|         return attn_output, None, past_key_value | ||||
|  | ||||
|  | ||||
| GEMMA_ATTENTION_CLASSES = { | ||||
|     "eager": GemmaAttention, | ||||
|     "flash_attention_2": GemmaFlashAttention2, | ||||
|     "sdpa": GemmaSdpaAttention, | ||||
| } | ||||
| COHERE_ATTENTION_CLASSES = {"eager": GemmaAttention, "flash_attention_2": GemmaFlashAttention2, "sdpa": GemmaSdpaAttention} | ||||
|  | ||||
|  | ||||
| # Copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with LLAMA->GEMMA,Llama->Gemma | ||||
| class GemmaDecoderLayer(nn.Module): | ||||
|     def __init__(self, config: GemmaConfig, layer_idx: int): | ||||
|         super().__init__() | ||||
|         self.hidden_size = config.hidden_size | ||||
|  | ||||
|         self.self_attn = GEMMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) | ||||
|         self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) | ||||
|  | ||||
|         self.mlp = GemmaMLP(config) | ||||
|         self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) | ||||
| @ -671,35 +590,16 @@ class GemmaDecoderLayer(nn.Module): | ||||
|  | ||||
|         return outputs | ||||
|  | ||||
|  | ||||
| GEMMA_START_DOCSTRING = r""" | ||||
|     This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the | ||||
|     library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads | ||||
|     etc.) | ||||
|  | ||||
|     This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. | ||||
|     Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage | ||||
|     and behavior. | ||||
|  | ||||
|     Parameters: | ||||
|         config ([`GemmaConfig`]): | ||||
|             Model configuration class with all the parameters of the model. Initializing with a config file does not | ||||
|             load the weights associated with the model, only the configuration. Check out the | ||||
|             [`~PreTrainedModel.from_pretrained`] method to load the model weights. | ||||
| """ | ||||
|  | ||||
|  | ||||
| @add_start_docstrings( | ||||
|     "The bare Gemma Model outputting raw hidden-states without any specific head on top.", | ||||
|     GEMMA_START_DOCSTRING, | ||||
|     "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", | ||||
|     LLAMA_START_DOCSTRING, | ||||
| ) | ||||
| class GemmaPreTrainedModel(PreTrainedModel): | ||||
|     config_class = GemmaConfig | ||||
|     base_model_prefix = "model" | ||||
|     supports_gradient_checkpointing = True | ||||
|     _keep_in_fp32_modules = ["inv_freq", "rotary_emb", "cos_cached", "sin_cached"] | ||||
|     _no_split_modules = ["GemmaDecoderLayer"] | ||||
|     _skip_keys_device_placement = ["past_key_values", "causal_mask"] | ||||
|     _skip_keys_device_placement = ["past_key_values"] | ||||
|     _supports_flash_attn_2 = True | ||||
|     _supports_sdpa = True | ||||
|     _supports_cache_class = True | ||||
| @ -716,116 +616,36 @@ class GemmaPreTrainedModel(PreTrainedModel): | ||||
|                 module.weight.data[module.padding_idx].zero_() | ||||
|  | ||||
|  | ||||
| GEMMA_INPUTS_DOCSTRING = r""" | ||||
|     Args: | ||||
|         input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): | ||||
|             Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide | ||||
|             it. | ||||
|  | ||||
|             Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and | ||||
|             [`PreTrainedTokenizer.__call__`] for details. | ||||
|  | ||||
|             [What are input IDs?](../glossary#input-ids) | ||||
|         attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): | ||||
|             Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: | ||||
|  | ||||
|             - 1 for tokens that are **not masked**, | ||||
|             - 0 for tokens that are **masked**. | ||||
|  | ||||
|             [What are attention masks?](../glossary#attention-mask) | ||||
|  | ||||
|             Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and | ||||
|             [`PreTrainedTokenizer.__call__`] for details. | ||||
|  | ||||
|             If `past_key_values` is used, optionally only the last `input_ids` have to be input (see | ||||
|             `past_key_values`). | ||||
|  | ||||
|             If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] | ||||
|             and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more | ||||
|             information on the default strategy. | ||||
|  | ||||
|             - 1 indicates the head is **not masked**, | ||||
|             - 0 indicates the head is **masked**. | ||||
|         position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): | ||||
|             Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, | ||||
|             config.n_positions - 1]`. | ||||
|  | ||||
|             [What are position IDs?](../glossary#position-ids) | ||||
|         past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): | ||||
|             Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention | ||||
|             blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` | ||||
|             returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. | ||||
|  | ||||
|             Two formats are allowed: | ||||
|             - a [`~cache_utils.Cache`] instance; | ||||
|             - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of | ||||
|             shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy | ||||
|             cache format. | ||||
|  | ||||
|             The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the | ||||
|             legacy cache format will be returned. | ||||
|  | ||||
|             If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't | ||||
|             have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` | ||||
|             of shape `(batch_size, sequence_length)`. | ||||
|         inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): | ||||
|             Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This | ||||
|             is useful if you want more control over how to convert `input_ids` indices into associated vectors than the | ||||
|             model's internal embedding lookup matrix. | ||||
|         use_cache (`bool`, *optional*): | ||||
|             If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see | ||||
|             `past_key_values`). | ||||
|         output_attentions (`bool`, *optional*): | ||||
|             Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned | ||||
|             tensors for more detail. | ||||
|         output_hidden_states (`bool`, *optional*): | ||||
|             Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for | ||||
|             more detail. | ||||
|         return_dict (`bool`, *optional*): | ||||
|             Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. | ||||
|         cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): | ||||
|             Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, | ||||
|             this tensor is not affected by padding. It is used to update the cache in the correct position and to infer | ||||
|             the complete sequence length. | ||||
| """ | ||||
|  | ||||
|  | ||||
| @add_start_docstrings( | ||||
|     "The bare Gemma Model outputting raw hidden-states without any specific head on top.", | ||||
|     GEMMA_START_DOCSTRING, | ||||
|     "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", | ||||
|     LLAMA_START_DOCSTRING, | ||||
| ) | ||||
| # Copied from transformers.models.llama.modeling_llama.LlamaModel with LLAMA->GEMMA,Llama->Gemma | ||||
| class GemmaModel(GemmaPreTrainedModel): | ||||
| class LlamaModel(LlamaPreTrainedModel): | ||||
|     """ | ||||
|     Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`GemmaDecoderLayer`] | ||||
|     Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] | ||||
|  | ||||
|     Args: | ||||
|         config: GemmaConfig | ||||
|         config: LlamaConfig | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, config: GemmaConfig): | ||||
|     def __init__(self, config: LlamaConfig): | ||||
|         super().__init__(config) | ||||
|         self.padding_idx = config.pad_token_id | ||||
|         self.vocab_size = config.vocab_size | ||||
|  | ||||
|         self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) | ||||
|         self.layers = nn.ModuleList( | ||||
|             [GemmaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] | ||||
|             [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] | ||||
|         ) | ||||
|         self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) | ||||
|         self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) | ||||
|         self.gradient_checkpointing = False | ||||
|  | ||||
|         # Initialize weights and apply final processing | ||||
|         self.post_init() | ||||
|  | ||||
|     def get_input_embeddings(self): | ||||
|         return self.embed_tokens | ||||
|  | ||||
|     def set_input_embeddings(self, value): | ||||
|         self.embed_tokens = value | ||||
|      | ||||
|     @add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING) | ||||
|     # Ignore copy | ||||
|     def forward( | ||||
|         self, | ||||
|         input_ids: torch.LongTensor = None, | ||||
| @ -833,114 +653,56 @@ class GemmaModel(GemmaPreTrainedModel): | ||||
|         position_ids: Optional[torch.LongTensor] = None, | ||||
|         past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, | ||||
|         inputs_embeds: Optional[torch.FloatTensor] = None, | ||||
|         labels: Optional[torch.LongTensor] = None, | ||||
|         use_cache: Optional[bool] = None, | ||||
|         output_attentions: Optional[bool] = None, | ||||
|         output_hidden_states: Optional[bool] = None, | ||||
|         return_dict: Optional[bool] = None, | ||||
|         cache_position: Optional[torch.LongTensor] = None, | ||||
|     ) -> Union[Tuple, BaseModelOutputWithPast]: | ||||
|         output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | ||||
|         output_hidden_states = ( | ||||
|             output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | ||||
|         ) | ||||
|         use_cache = use_cache if use_cache is not None else self.config.use_cache | ||||
|         return_dict = return_dict if return_dict is not None else self.config.use_return_dict | ||||
|     ) -> Union[Tuple, CausalLMOutputWithPast]: | ||||
|         output_attentions = output_attentions | self.config.output_attentions | ||||
|         output_hidden_states = output_hidden_states| self.config.output_hidden_states | ||||
|         return_dict = return_dict | self.config.use_return_dict | ||||
|  | ||||
|         if (input_ids is None) ^ (inputs_embeds is not None): | ||||
|             raise ValueError( | ||||
|                 "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" | ||||
|             ) | ||||
|  | ||||
|         if self.gradient_checkpointing and self.training and use_cache: | ||||
|             logger.warning_once( | ||||
|                 "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." | ||||
|             ) | ||||
|             use_cache = False | ||||
|  | ||||
|         if inputs_embeds is None: | ||||
|             inputs_embeds = self.embed_tokens(input_ids) | ||||
|  | ||||
|         return_legacy_cache = False | ||||
|         if use_cache and not isinstance(past_key_values, Cache):  # kept for BC (non `Cache` `past_key_values` inputs) | ||||
|             return_legacy_cache = True | ||||
|             past_key_values = DynamicCache.from_legacy_cache(past_key_values) | ||||
|  | ||||
|         if cache_position is None: | ||||
|             past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 | ||||
|             cache_position = torch.arange( | ||||
|                 past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device | ||||
|             ) | ||||
|  | ||||
|         if position_ids is None: | ||||
|             position_ids = cache_position.unsqueeze(0) | ||||
|  | ||||
|         causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values) | ||||
|  | ||||
|         # embed positions | ||||
|         hidden_states = inputs_embeds | ||||
|  | ||||
|         # normalized | ||||
|         # Gemma downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5 | ||||
|         # See https://github.com/huggingface/transformers/pull/29402 | ||||
|         normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype) | ||||
|         hidden_states = hidden_states * normalizer | ||||
|  | ||||
|         # decoder layers | ||||
|         all_hidden_states = () if output_hidden_states else None | ||||
|         all_self_attns = () if output_attentions else None | ||||
|         next_decoder_cache = None | ||||
|  | ||||
|         for decoder_layer in self.layers: | ||||
|             if output_hidden_states: | ||||
|                 all_hidden_states += (hidden_states,) | ||||
|  | ||||
|             if self.gradient_checkpointing and self.training: | ||||
|                 layer_outputs = self._gradient_checkpointing_func( | ||||
|                     decoder_layer.__call__, | ||||
|                     hidden_states, | ||||
|                     causal_mask, | ||||
|                     position_ids, | ||||
|                     past_key_values, | ||||
|                     output_attentions, | ||||
|                     use_cache, | ||||
|                     cache_position, | ||||
|                 ) | ||||
|             else: | ||||
|                 layer_outputs = decoder_layer( | ||||
|                     hidden_states, | ||||
|                     attention_mask=causal_mask, | ||||
|         outputs = self.model( | ||||
|             input_ids=input_ids, | ||||
|             attention_mask=attention_mask, | ||||
|             position_ids=position_ids, | ||||
|                     past_key_value=past_key_values, | ||||
|                     output_attentions=output_attentions, | ||||
|             past_key_values=past_key_values, | ||||
|             inputs_embeds=inputs_embeds, | ||||
|             use_cache=use_cache, | ||||
|             output_attentions=output_attentions, | ||||
|             output_hidden_states=output_hidden_states, | ||||
|             return_dict=return_dict, | ||||
|             cache_position=cache_position, | ||||
|         ) | ||||
|  | ||||
|             hidden_states = layer_outputs[0] | ||||
|  | ||||
|             if use_cache: | ||||
|                 next_decoder_cache = layer_outputs[2 if output_attentions else 1] | ||||
|  | ||||
|             if output_attentions: | ||||
|                 all_self_attns += (layer_outputs[1],) | ||||
|  | ||||
|         hidden_states = self.norm(hidden_states) | ||||
|  | ||||
|         # add hidden states from the last decoder layer | ||||
|         if output_hidden_states: | ||||
|             all_hidden_states += (hidden_states,) | ||||
|  | ||||
|         next_cache = next_decoder_cache if use_cache else None | ||||
|         if return_legacy_cache: | ||||
|             next_cache = next_cache.to_legacy_cache() | ||||
|         hidden_states = outputs[0] | ||||
|         logits = self.lm_head(hidden_states) | ||||
|         logits = logits.float() | ||||
|         loss = None | ||||
|         if labels is not None: | ||||
|             # Shift so that tokens < n predict n | ||||
|             shift_logits = logits[..., :-1, :].contiguous() | ||||
|             shift_labels = labels[..., 1:].contiguous() | ||||
|             # Flatten the tokens | ||||
|             loss_fct = CrossEntropyLoss() | ||||
|             shift_logits = shift_logits.view(-1, self.config.vocab_size) | ||||
|             shift_labels = shift_labels.view(-1) | ||||
|             # Enable model parallelism | ||||
|             shift_labels = shift_labels.to(shift_logits.device) | ||||
|             loss = loss_fct(shift_logits, shift_labels) | ||||
|  | ||||
|         if not return_dict: | ||||
|             return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) | ||||
|         return BaseModelOutputWithPast( | ||||
|             last_hidden_state=hidden_states, | ||||
|             past_key_values=next_cache, | ||||
|             hidden_states=all_hidden_states, | ||||
|             attentions=all_self_attns, | ||||
|             output = (logits,) + outputs[1:] | ||||
|             return (loss,) + output if loss is not None else output | ||||
|  | ||||
|         return CausalLMOutputWithPast( | ||||
|             loss=loss, | ||||
|             logits=logits, | ||||
|             past_key_values=outputs.past_key_values, | ||||
|             hidden_states=outputs.hidden_states, | ||||
|             attentions=outputs.attentions, | ||||
|         ) | ||||
|  | ||||
|     def _update_causal_mask( | ||||
| @ -1030,7 +792,65 @@ class GemmaModel(GemmaPreTrainedModel): | ||||
|         return causal_mask | ||||
|  | ||||
|      | ||||
| # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with LLAMA->GEMMA,Llama->Gemma,llama->gemma | ||||
|     def forward( | ||||
|         self, | ||||
|         input_ids: torch.LongTensor = None, | ||||
|         attention_mask: Optional[torch.Tensor] = None, | ||||
|         position_ids: Optional[torch.LongTensor] = None, | ||||
|         past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, | ||||
|         inputs_embeds: Optional[torch.FloatTensor] = None, | ||||
|         labels: Optional[torch.LongTensor] = None, | ||||
|         use_cache: Optional[bool] = None, | ||||
|         output_attentions: Optional[bool] = None, | ||||
|         output_hidden_states: Optional[bool] = None, | ||||
|         return_dict: Optional[bool] = None, | ||||
|         cache_position: Optional[torch.LongTensor] = None, | ||||
|     ) -> Union[Tuple, CausalLMOutputWithPast]: | ||||
|         output_attentions = output_attentions | self.config.output_attentions | ||||
|         output_hidden_states = output_hidden_states| self.config.output_hidden_states | ||||
|         return_dict = return_dict | self.config.use_return_dict | ||||
|  | ||||
|         outputs = self.model( | ||||
|             input_ids=input_ids, | ||||
|             attention_mask=attention_mask, | ||||
|             position_ids=position_ids, | ||||
|             past_key_values=past_key_values, | ||||
|             inputs_embeds=inputs_embeds, | ||||
|             use_cache=use_cache, | ||||
|             output_attentions=output_attentions, | ||||
|             output_hidden_states=output_hidden_states, | ||||
|             return_dict=return_dict, | ||||
|             cache_position=cache_position, | ||||
|         ) | ||||
|  | ||||
|         hidden_states = outputs[0] | ||||
|         logits = self.lm_head(hidden_states) | ||||
|         logits = logits.float() | ||||
|         loss = None | ||||
|         if labels is not None: | ||||
|             # Shift so that tokens < n predict n | ||||
|             shift_logits = logits[..., :-1, :].contiguous() | ||||
|             shift_labels = labels[..., 1:].contiguous() | ||||
|             # Flatten the tokens | ||||
|             loss_fct = CrossEntropyLoss() | ||||
|             shift_logits = shift_logits.view(-1, self.config.vocab_size) | ||||
|             shift_labels = shift_labels.view(-1) | ||||
|             # Enable model parallelism | ||||
|             shift_labels = shift_labels.to(shift_logits.device) | ||||
|             loss = loss_fct(shift_logits, shift_labels) | ||||
|  | ||||
|         if not return_dict: | ||||
|             output = (logits,) + outputs[1:] | ||||
|             return (loss,) + output if loss is not None else output | ||||
|  | ||||
|         return CausalLMOutputWithPast( | ||||
|             loss=loss, | ||||
|             logits=logits, | ||||
|             past_key_values=outputs.past_key_values, | ||||
|             hidden_states=outputs.hidden_states, | ||||
|             attentions=outputs.attentions, | ||||
|         ) | ||||
|  | ||||
| class GemmaForCausalLM(GemmaPreTrainedModel): | ||||
|     _tied_weights_keys = ["lm_head.weight"] | ||||
|  | ||||
| @ -1061,8 +881,7 @@ class GemmaForCausalLM(GemmaPreTrainedModel): | ||||
|     def get_decoder(self): | ||||
|         return self.model | ||||
|  | ||||
|     # Ignore copy | ||||
|     @add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING) | ||||
|     @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) | ||||
|     @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) | ||||
|     def forward( | ||||
|         self, | ||||
| @ -1092,16 +911,16 @@ class GemmaForCausalLM(GemmaPreTrainedModel): | ||||
|         ```python | ||||
|         >>> from transformers import AutoTokenizer, GemmaForCausalLM | ||||
|  | ||||
|         >>> model = GemmaForCausalLM.from_pretrained("google/gemma-7b") | ||||
|         >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b") | ||||
|         >>> model = GemmaForCausalLM.from_pretrained("meta-llama/Gemma-2-7b-hf") | ||||
|         >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Gemma-2-7b-hf") | ||||
|  | ||||
|         >>> prompt = "What is your favorite condiment?" | ||||
|         >>> prompt = "Hey, are you conscious? Can you talk to me?" | ||||
|         >>> inputs = tokenizer(prompt, return_tensors="pt") | ||||
|  | ||||
|         >>> # Generate | ||||
|         >>> generate_ids = model.generate(inputs.input_ids, max_length=30) | ||||
|         >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] | ||||
|         "What is your favorite condiment?" | ||||
|         "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." | ||||
|         ```""" | ||||
|         output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | ||||
|         output_hidden_states = ( | ||||
| @ -1124,8 +943,14 @@ class GemmaForCausalLM(GemmaPreTrainedModel): | ||||
|         ) | ||||
|  | ||||
|         hidden_states = outputs[0] | ||||
|         if self.config.pretraining_tp > 1: | ||||
|             lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) | ||||
|             logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] | ||||
|             logits = torch.cat(logits, dim=-1) | ||||
|         else: | ||||
|             logits = self.lm_head(hidden_states) | ||||
|         logits = logits.float() | ||||
|  | ||||
|         loss = None | ||||
|         if labels is not None: | ||||
|             # Shift so that tokens < n predict n | ||||
| @ -1238,126 +1063,3 @@ class GemmaForCausalLM(GemmaPreTrainedModel): | ||||
|             ) | ||||
|         return reordered_past | ||||
|  | ||||
|  | ||||
| @add_start_docstrings( | ||||
|     """ | ||||
|     The Gemma Model transformer with a sequence classification head on top (linear layer). | ||||
|  | ||||
|     [`GemmaForSequenceClassification`] uses the last token in order to do the classification, as other causal models | ||||
|     (e.g. GPT-2) do. | ||||
|  | ||||
|     Since it does classification on the last token, it requires to know the position of the last token. If a | ||||
|     `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If | ||||
|     no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the | ||||
|     padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in | ||||
|     each row of the batch). | ||||
|     """, | ||||
|     GEMMA_START_DOCSTRING, | ||||
| ) | ||||
| # Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with LLAMA->GEMMA,Llama->Gemma | ||||
| class GemmaForSequenceClassification(GemmaPreTrainedModel): | ||||
|     def __init__(self, config): | ||||
|         super().__init__(config) | ||||
|         self.num_labels = config.num_labels | ||||
|         self.model = GemmaModel(config) | ||||
|         self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) | ||||
|  | ||||
|         # Initialize weights and apply final processing | ||||
|         self.post_init() | ||||
|  | ||||
|     def get_input_embeddings(self): | ||||
|         return self.model.embed_tokens | ||||
|  | ||||
|     def set_input_embeddings(self, value): | ||||
|         self.model.embed_tokens = value | ||||
|  | ||||
|     @add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING) | ||||
|     def forward( | ||||
|         self, | ||||
|         input_ids: torch.LongTensor = None, | ||||
|         attention_mask: Optional[torch.Tensor] = None, | ||||
|         position_ids: Optional[torch.LongTensor] = None, | ||||
|         past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, | ||||
|         inputs_embeds: Optional[torch.FloatTensor] = None, | ||||
|         labels: Optional[torch.LongTensor] = None, | ||||
|         use_cache: Optional[bool] = None, | ||||
|         output_attentions: Optional[bool] = None, | ||||
|         output_hidden_states: Optional[bool] = None, | ||||
|         return_dict: Optional[bool] = None, | ||||
|     ) -> Union[Tuple, SequenceClassifierOutputWithPast]: | ||||
|         r""" | ||||
|         labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): | ||||
|             Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., | ||||
|             config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If | ||||
|             `config.num_labels > 1` a classification loss is computed (Cross-Entropy). | ||||
|         """ | ||||
|         return_dict = return_dict if return_dict is not None else self.config.use_return_dict | ||||
|  | ||||
|         transformer_outputs = self.model( | ||||
|             input_ids, | ||||
|             attention_mask=attention_mask, | ||||
|             position_ids=position_ids, | ||||
|             past_key_values=past_key_values, | ||||
|             inputs_embeds=inputs_embeds, | ||||
|             use_cache=use_cache, | ||||
|             output_attentions=output_attentions, | ||||
|             output_hidden_states=output_hidden_states, | ||||
|             return_dict=return_dict, | ||||
|         ) | ||||
|         hidden_states = transformer_outputs[0] | ||||
|         logits = self.score(hidden_states) | ||||
|  | ||||
|         if input_ids is not None: | ||||
|             batch_size = input_ids.shape[0] | ||||
|         else: | ||||
|             batch_size = inputs_embeds.shape[0] | ||||
|  | ||||
|         if self.config.pad_token_id is None and batch_size != 1: | ||||
|             raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") | ||||
|         if self.config.pad_token_id is None: | ||||
|             sequence_lengths = -1 | ||||
|         else: | ||||
|             if input_ids is not None: | ||||
|                 # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility | ||||
|                 sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 | ||||
|                 sequence_lengths = sequence_lengths % input_ids.shape[-1] | ||||
|                 sequence_lengths = sequence_lengths.to(logits.device) | ||||
|             else: | ||||
|                 sequence_lengths = -1 | ||||
|  | ||||
|         pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] | ||||
|  | ||||
|         loss = None | ||||
|         if labels is not None: | ||||
|             labels = labels.to(logits.device) | ||||
|             if self.config.problem_type is None: | ||||
|                 if self.num_labels == 1: | ||||
|                     self.config.problem_type = "regression" | ||||
|                 elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): | ||||
|                     self.config.problem_type = "single_label_classification" | ||||
|                 else: | ||||
|                     self.config.problem_type = "multi_label_classification" | ||||
|  | ||||
|             if self.config.problem_type == "regression": | ||||
|                 loss_fct = MSELoss() | ||||
|                 if self.num_labels == 1: | ||||
|                     loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) | ||||
|                 else: | ||||
|                     loss = loss_fct(pooled_logits, labels) | ||||
|             elif self.config.problem_type == "single_label_classification": | ||||
|                 loss_fct = CrossEntropyLoss() | ||||
|                 loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) | ||||
|             elif self.config.problem_type == "multi_label_classification": | ||||
|                 loss_fct = BCEWithLogitsLoss() | ||||
|                 loss = loss_fct(pooled_logits, labels) | ||||
|         if not return_dict: | ||||
|             output = (pooled_logits,) + transformer_outputs[1:] | ||||
|             return ((loss,) + output) if loss is not None else output | ||||
|  | ||||
|         return SequenceClassifierOutputWithPast( | ||||
|             loss=loss, | ||||
|             logits=pooled_logits, | ||||
|             past_key_values=transformer_outputs.past_key_values, | ||||
|             hidden_states=transformer_outputs.hidden_states, | ||||
|             attentions=transformer_outputs.attentions, | ||||
|         ) | ||||
|  | ||||
							
								
								
									
										1474
									
								
								src/transformers/models/llama/diff_llama.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1474
									
								
								src/transformers/models/llama/diff_llama.py
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										87
									
								
								src/transformers/models/persimmon/diff_persimmon.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										87
									
								
								src/transformers/models/persimmon/diff_persimmon.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,87 @@ | ||||
| from transformers.models.llama.modeling_llama import * | ||||
| import torch.nn as nn | ||||
| from .configuration_persimmon import PersimmonConfig | ||||
| from transformers.utils import ModelConverter | ||||
|  | ||||
| PersimmonConverter = ModelConverter(__file__) | ||||
|  | ||||
| PersimmonConverter.register("PersimmonRotaryEmbedding", LlamaRotaryEmbedding) | ||||
| PersimmonConverter.register("PersimmonMLP", LlamaMLP)  | ||||
|  | ||||
| class PersimmonAttention(LlamaAttention): | ||||
|     """Multi-headed attention from 'Attention Is All You Need' paper""" | ||||
|  | ||||
|     def __init__(self, config: PersimmonConfig, layer_idx: Optional[int] = None): | ||||
|         super().__init__() | ||||
|         ... # copy before? add the line? how to best support this | ||||
|         self.query_key_value = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=True) | ||||
|         self.dense = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=True) | ||||
|         self.qk_layernorm = config.qk_layernorm | ||||
|  | ||||
|         if self.qk_layernorm: | ||||
|             self.q_layernorm = nn.LayerNorm( | ||||
|                 config.hidden_size // self.num_heads, eps=config.layer_norm_eps, elementwise_affine=True | ||||
|             ) | ||||
|             self.k_layernorm = nn.LayerNorm( | ||||
|                 config.hidden_size // self.num_heads, eps=config.layer_norm_eps, elementwise_affine=True | ||||
|             ) | ||||
|         self.attention_dropout = nn.Dropout(config.attention_dropout) | ||||
|         self._init_rope() | ||||
|  | ||||
|     def forward( | ||||
|         self, | ||||
|         hidden_states: torch.Tensor, | ||||
|         attention_mask: Optional[torch.Tensor] = None, | ||||
|         position_ids: Optional[torch.LongTensor] = None, | ||||
|         past_key_value: Optional[Cache] = None, | ||||
|         output_attentions: bool = False, | ||||
|         use_cache: bool = False, | ||||
|     ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: | ||||
|         bsz, q_len, _ = hidden_states.size() | ||||
|  | ||||
|         # [batch_size, seq_length, 3 x hidden_size] | ||||
|         fused_qkv = self.query_key_value(hidden_states) | ||||
|  | ||||
|         # 3 x [batch_size, seq_length, num_heads, head_dim] | ||||
|         (query_states, key_states, value_states) = self._split_heads(fused_qkv) | ||||
|  | ||||
|         if self.qk_layernorm: | ||||
|             query_states = self.q_layernorm(query_states) | ||||
|             key_states = self.k_layernorm(key_states) | ||||
|  | ||||
|         # [batch_size, num_heads, seq_length, head_dim] -> [batch_size, seq_length, num_heads, head_dim] | ||||
|         query_states = query_states.transpose(1, 2) | ||||
|         value_states = value_states.transpose(1, 2) | ||||
|         key_states = key_states.transpose(1, 2) | ||||
|  | ||||
|         os, sin = self.rotary_emb(value_states, seq_len=None) | ||||
|  | ||||
|         # Partial rotary embedding | ||||
|         query_rot, query_pass = ( | ||||
|             query_states[..., : self.rotary_emb.dim], | ||||
|             query_states[..., self.rotary_emb.dim :], | ||||
|         ) | ||||
|         key_rot, key_pass = ( | ||||
|             key_states[..., : self.rotary_emb.dim], | ||||
|             key_states[..., self.rotary_emb.dim :], | ||||
|         ) | ||||
|         # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor] | ||||
|         query_rot, key_rot = self.rotary_emb(query_rot, key_rot, position_ids) | ||||
|  | ||||
|         # [batch_size, seq_length, num_heads, head_dim] | ||||
|         query_states = torch.cat((query_rot, query_pass), dim=-1) | ||||
|         key_states = torch.cat((key_rot, key_pass), dim=-1) | ||||
|         ... # TODO copy the rest of the function? if we do this it's unusable | ||||
|  | ||||
|  | ||||
|  | ||||
| PersimmonSdpaAttention = PersimmonConverter.register("PersimmonSdpaAttention", LlamaAttention)  | ||||
| PersimmonFlashAttention2 = PersimmonConverter.register("PersimmonFlashAttention2", LlamaAttention)  | ||||
|  | ||||
| COHERE_ATTENTION_CLASSES = {"eager": PersimmonAttention, "flash_attention_2": PersimmonFlashAttention2, "sdpa": PersimmonSdpaAttention} | ||||
|  | ||||
| PersimmonConverter.register("PersimmonDecoderLayer", LlamaDecoderLayer)  | ||||
| PersimmonConverter.register("PersimmonPreTrainedModel", LlamaPreTrainedModel) | ||||
|  | ||||
| PersimmonConverter.register("PersimmonModel", LlamaModel) | ||||
| PersimmonConverter.register("PersimmonForCausalLM", LlamaForCausalLM) | ||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										395
									
								
								src/transformers/models/stablelm/diff_stablelm.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										395
									
								
								src/transformers/models/stablelm/diff_stablelm.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,395 @@ | ||||
| from typing import Tuple | ||||
| from transformers.models.llama.configuration_llama import LlamaConfig | ||||
| from transformers.models.llama.modeling_llama import * | ||||
| import torch.nn as nn | ||||
| from transformers import StableLmConfig | ||||
| from transformers.utils import ModelConverter | ||||
|  | ||||
| StableLmConverter = ModelConverter(__file__) | ||||
|  | ||||
| StableLmRMSNorm = StableLmConverter.register("StableLmRMSNorm", LlamaRMSNorm) | ||||
| StarcoderRotaryEmbedding = StableLmConverter.register("StarcoderRotaryEmbedding", LlamaRotaryEmbedding) | ||||
| StableLmMLP = StableLmConverter.register("StableLmMLP", LlamaMLP) | ||||
|  | ||||
|  | ||||
| class StableLmLayerNormPerHead(nn.Module): | ||||
|     def __init__(self, dim, num_heads, eps=1e-5, bias=False): | ||||
|         super().__init__() | ||||
|         self.dim = dim | ||||
|         self.num_heads = num_heads | ||||
|         self.norms = nn.ModuleList([nn.LayerNorm(dim, eps=eps, bias=bias) for _ in range(self.num_heads)]) | ||||
|  | ||||
|     def forward(self, hidden_states: torch.Tensor): | ||||
|         # Split along the num_heads axis to get per-head inputs | ||||
|         # [batch_size, num_heads, seq_len, head_dim] -> [batch_size, 1, seq_len, head_dim] * num_heads | ||||
|         states_per_heads = torch.split(hidden_states, 1, dim=1) | ||||
|         # Normalize and merge the heads back together | ||||
|         return torch.cat([norm(hidden_states) for norm, hidden_states in zip(self.norms, states_per_heads)], dim=1) | ||||
|  | ||||
| class StableLmAttention(LlamaAttention): | ||||
|     def __init__(self, config: LlamaConfig, layer_idx: int | None = None): | ||||
|         super().__init__(config, layer_idx) # here call to super means | ||||
|                                             # we should copy super | ||||
|         self.qk_layernorm = config.qk_layernorm | ||||
|         self.q_layernorm = StableLmLayerNormPerHead(self.head_dim, self.num_heads, eps=config.layer_norm_eps) | ||||
|         self.k_layernorm = StableLmLayerNormPerHead( | ||||
|             self.head_dim, self.num_key_value_heads, eps=config.layer_norm_eps | ||||
|         ) | ||||
|         self.attention_dropout = nn.Dropout(config.attention_dropout) | ||||
|  | ||||
|  | ||||
|     def forward( | ||||
|         self, | ||||
|         hidden_states: torch.Tensor, | ||||
|         attention_mask: Optional[torch.Tensor] = None, | ||||
|         position_ids: Optional[torch.LongTensor] = None, | ||||
|         past_key_value: Optional[Cache] = None, | ||||
|         output_attentions: bool = False, | ||||
|         use_cache: bool = False, | ||||
|     ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: | ||||
|         bsz, q_len, _ = hidden_states.size() | ||||
|  | ||||
|         query_states = self.q_proj(hidden_states) | ||||
|         key_states = self.k_proj(hidden_states) | ||||
|         value_states = self.v_proj(hidden_states) | ||||
|  | ||||
|         query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) | ||||
|         key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) | ||||
|         value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) | ||||
|  | ||||
|         if self.qk_layernorm: | ||||
|             query_states = self.q_layernorm(query_states) | ||||
|             key_states = self.k_layernorm(key_states) | ||||
|  | ||||
|         kv_seq_len = key_states.shape[-2] | ||||
|         if past_key_value is not None: | ||||
|             if self.layer_idx is None: | ||||
|                 raise ValueError( | ||||
|                     f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " | ||||
|                     "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " | ||||
|                     "with a layer index." | ||||
|                 ) | ||||
|             kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) | ||||
|         cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) | ||||
|  | ||||
|         # Partial rotary embedding | ||||
|         query_rot, query_pass = ( | ||||
|             query_states[..., : self.rotary_emb.dim], | ||||
|             query_states[..., self.rotary_emb.dim :], | ||||
|         ) | ||||
|         key_rot, key_pass = ( | ||||
|             key_states[..., : self.rotary_emb.dim], | ||||
|             key_states[..., self.rotary_emb.dim :], | ||||
|         ) | ||||
|         # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor] | ||||
|         query_rot, key_rot = self.rotary_emb(query_rot, key_rot, cos, sin, position_ids) | ||||
|  | ||||
|         # [batch_size, seq_length, num_heads, head_dim] | ||||
|         query_states = torch.cat((query_rot, query_pass), dim=-1) | ||||
|         key_states = torch.cat((key_rot, key_pass), dim=-1) | ||||
|  | ||||
|         if past_key_value is not None: | ||||
|             # Specific to RoPE models with partial rotation | ||||
|             cache_kwargs = {"sin": sin, "cos": cos, "partial_rotation_size": self.rotary_emb.dim} | ||||
|             key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) | ||||
|  | ||||
|         # Repeat k/v heads if n_kv_heads < n_heads | ||||
|         key_states = repeat_kv(key_states, self.num_key_value_groups) | ||||
|         value_states = repeat_kv(value_states, self.num_key_value_groups) | ||||
|  | ||||
|         attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) | ||||
|  | ||||
|         if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): | ||||
|             raise ValueError( | ||||
|                 f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" | ||||
|                 f" {attn_weights.size()}" | ||||
|             ) | ||||
|  | ||||
|         if attention_mask is not None: | ||||
|             if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): | ||||
|                 raise ValueError( | ||||
|                     f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" | ||||
|                 ) | ||||
|             attn_weights = attn_weights + attention_mask | ||||
|  | ||||
|         # upcast attention to fp32 | ||||
|         attn_weights = nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query_states.dtype) | ||||
|         attn_weights = self.attention_dropout(attn_weights) | ||||
|  | ||||
|         attn_output = torch.matmul(attn_weights, value_states) | ||||
|  | ||||
|         if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): | ||||
|             raise ValueError( | ||||
|                 f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" | ||||
|                 f" {attn_output.size()}" | ||||
|             ) | ||||
|  | ||||
|         attn_output = attn_output.transpose(1, 2).contiguous() | ||||
|         attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) | ||||
|  | ||||
|         attn_output = self.o_proj(attn_output) | ||||
|  | ||||
|         if not output_attentions: | ||||
|             attn_weights = None | ||||
|  | ||||
|         return attn_output, attn_weights, past_key_value | ||||
|  | ||||
| class StableLmSdpaAttention(StableLmAttention): | ||||
|     def forward( | ||||
|         self, | ||||
|         hidden_states: torch.Tensor, | ||||
|         attention_mask: Optional[torch.Tensor] = None, | ||||
|         position_ids: Optional[torch.LongTensor] = None, | ||||
|         past_key_value: Optional[Cache] = None, | ||||
|         output_attentions: bool = False, | ||||
|         use_cache: bool = False, | ||||
|     ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: | ||||
|         if output_attentions: | ||||
|             # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. | ||||
|             logger.warning_once( | ||||
|                 "StableLmModel is using StableLmSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " | ||||
|                 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' | ||||
|             ) | ||||
|             return super().forward( | ||||
|                 hidden_states=hidden_states, | ||||
|                 attention_mask=attention_mask, | ||||
|                 position_ids=position_ids, | ||||
|                 past_key_value=past_key_value, | ||||
|                 output_attentions=output_attentions, | ||||
|                 use_cache=use_cache, | ||||
|             ) | ||||
|  | ||||
|         bsz, q_len, _ = hidden_states.size() | ||||
|  | ||||
|         query_states = self.q_proj(hidden_states) | ||||
|         key_states = self.k_proj(hidden_states) | ||||
|         value_states = self.v_proj(hidden_states) | ||||
|  | ||||
|         query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) | ||||
|         key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) | ||||
|         value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) | ||||
|  | ||||
|         if self.qk_layernorm: | ||||
|             query_states = self.q_layernorm(query_states) | ||||
|             key_states = self.k_layernorm(key_states) | ||||
|  | ||||
|         kv_seq_len = key_states.shape[-2] | ||||
|         if past_key_value is not None: | ||||
|             if self.layer_idx is None: | ||||
|                 raise ValueError( | ||||
|                     f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " | ||||
|                     "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " | ||||
|                     "with a layer index." | ||||
|                 ) | ||||
|             kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) | ||||
|         cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) | ||||
|  | ||||
|         # Partial rotary embedding | ||||
|         query_rot, query_pass = ( | ||||
|             query_states[..., : self.rotary_emb.dim], | ||||
|             query_states[..., self.rotary_emb.dim :], | ||||
|         ) | ||||
|         key_rot, key_pass = ( | ||||
|             key_states[..., : self.rotary_emb.dim], | ||||
|             key_states[..., self.rotary_emb.dim :], | ||||
|         ) | ||||
|         # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor] | ||||
|         query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids) | ||||
|  | ||||
|         # [batch_size, seq_length, num_heads, head_dim] | ||||
|         query_states = torch.cat((query_rot, query_pass), dim=-1) | ||||
|         key_states = torch.cat((key_rot, key_pass), dim=-1) | ||||
|  | ||||
|         if past_key_value is not None: | ||||
|             # Specific to RoPE models with partial rotation | ||||
|             cache_kwargs = {"sin": sin, "cos": cos, "partial_rotation_size": self.rotary_emb.dim} | ||||
|             key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) | ||||
|  | ||||
|         # Repeat k/v heads if n_kv_heads < n_heads | ||||
|         key_states = repeat_kv(key_states, self.num_key_value_groups) | ||||
|         value_states = repeat_kv(value_states, self.num_key_value_groups) | ||||
|  | ||||
|         # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, | ||||
|         # Reference: https://github.com/pytorch/pytorch/issues/112577. | ||||
|         if query_states.device.type == "cuda" and attention_mask is not None: | ||||
|             query_states = query_states.contiguous() | ||||
|             key_states = key_states.contiguous() | ||||
|             value_states = value_states.contiguous() | ||||
|  | ||||
|         attn_output = torch.nn.functional.scaled_dot_product_attention( | ||||
|             query_states, | ||||
|             key_states, | ||||
|             value_states, | ||||
|             attn_mask=attention_mask, | ||||
|             dropout_p=self.attention_dropout.p if self.training else 0.0, | ||||
|             # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. | ||||
|             is_causal=self.is_causal and attention_mask is None and q_len > 1, | ||||
|         ) | ||||
|  | ||||
|         attn_output = attn_output.transpose(1, 2).contiguous() | ||||
|         attn_output = attn_output.view(bsz, q_len, self.hidden_size) | ||||
|  | ||||
|         attn_output = self.o_proj(attn_output) | ||||
|  | ||||
|         return attn_output, None, past_key_value | ||||
|  | ||||
| class StableLmFlashAttention2(LlamaFlashAttention2): | ||||
|     """ | ||||
|     StableLM flash attention module. This module inherits from `StableLmAttention` as the weights of the module stays | ||||
|     untouched. The only required change would be on the forward pass where it needs to correctly call the public API of | ||||
|     flash attention and deal with padding tokens in case the input contains any of them. | ||||
|     """ | ||||
|  | ||||
|     def forward( | ||||
|         self, | ||||
|         hidden_states: torch.Tensor, | ||||
|         attention_mask: Optional[torch.LongTensor] = None, | ||||
|         position_ids: Optional[torch.LongTensor] = None, | ||||
|         past_key_value: Optional[Cache] = None, | ||||
|         output_attentions: bool = False, | ||||
|         use_cache: bool = False, | ||||
|         **kwargs, | ||||
|     ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: | ||||
|         output_attentions = False | ||||
|  | ||||
|         bsz, q_len, _ = hidden_states.size() | ||||
|  | ||||
|         query_states = self.q_proj(hidden_states) | ||||
|         key_states = self.k_proj(hidden_states) | ||||
|         value_states = self.v_proj(hidden_states) | ||||
|  | ||||
|         # Flash attention requires the input to have the shape | ||||
|         # batch_size x seq_length x head_dim x hidden_dim | ||||
|         # therefore we just need to keep the original shape | ||||
|         query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) | ||||
|         key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) | ||||
|         value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) | ||||
|  | ||||
|         if self.qk_layernorm: | ||||
|             query_states = self.q_layernorm(query_states) | ||||
|             key_states = self.k_layernorm(key_states) | ||||
|  | ||||
|         kv_seq_len = key_states.shape[-2] | ||||
|         if past_key_value is not None: | ||||
|             if self.layer_idx is None: | ||||
|                 raise ValueError( | ||||
|                     f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " | ||||
|                     "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " | ||||
|                     "with a layer index." | ||||
|                 ) | ||||
|             kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) | ||||
|         cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) | ||||
|  | ||||
|         # Partial rotary embedding | ||||
|         query_rot, query_pass = ( | ||||
|             query_states[..., : self.rotary_emb.dim], | ||||
|             query_states[..., self.rotary_emb.dim :], | ||||
|         ) | ||||
|         key_rot, key_pass = ( | ||||
|             key_states[..., : self.rotary_emb.dim], | ||||
|             key_states[..., self.rotary_emb.dim :], | ||||
|         ) | ||||
|         query_rot, key_rot = self.rotary_emb(query_rot, key_rot, cos, sin, position_ids) | ||||
|  | ||||
|         # [batch_size, seq_length, num_heads, head_dim] | ||||
|         query_states = torch.cat((query_rot, query_pass), dim=-1) | ||||
|         key_states = torch.cat((key_rot, key_pass), dim=-1) | ||||
|  | ||||
|         if past_key_value is not None: | ||||
|             cache_kwargs = {"sin": sin, "cos": cos, "partial_rotation_size": self.rotary_emb.dim} | ||||
|             key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) | ||||
|  | ||||
|         # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache | ||||
|         # to be able to avoid many of these transpose/reshape/view. | ||||
|         query_states = query_states.transpose(1, 2) | ||||
|         key_states = key_states.transpose(1, 2) | ||||
|         value_states = value_states.transpose(1, 2) | ||||
|  | ||||
|         dropout_rate = self.attention_dropout.p if self.training else 0.0 | ||||
|  | ||||
|         attn_output = self._flash_attention_forward( | ||||
|             query_states, | ||||
|             key_states, | ||||
|             value_states, | ||||
|             attention_mask, | ||||
|             q_len, | ||||
|             dropout=dropout_rate, | ||||
|         ) | ||||
|  | ||||
|         attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() | ||||
|         attn_output = self.o_proj(attn_output) | ||||
|  | ||||
|         if not output_attentions: | ||||
|             attn_weights = None | ||||
|  | ||||
|         return attn_output, attn_weights, past_key_value | ||||
|  | ||||
|  | ||||
| StableLm_ATTENTION_CLASSES = {"eager": StableLmAttention, "flash_attention_2": StableLmFlashAttention2, "sdpa": StableLmSdpaAttention} | ||||
|  | ||||
| class StableLmDecoderLayer(nn.Module): | ||||
|     def __init__(self, config: StableLmConfig, layer_idx: int): | ||||
|         super().__init__() | ||||
|         self.use_parallel_residual = config.use_parallel_residual | ||||
|         self.hidden_size = config.hidden_size | ||||
|         self.self_attn = StableLm_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx) | ||||
|         self.mlp = StableLmMLP(config) | ||||
|         self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) | ||||
|         self.post_attention_layernorm = None | ||||
|         if not self.use_parallel_residual: | ||||
|             self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) | ||||
|         self.dropout = nn.Dropout(config.hidden_dropout) | ||||
|  | ||||
|     def forward( | ||||
|         self, | ||||
|         hidden_states: torch.Tensor, | ||||
|         attention_mask: Optional[torch.Tensor] = None, | ||||
|         position_ids: Optional[torch.LongTensor] = None, | ||||
|         past_key_value: Optional[Tuple[torch.Tensor]] = None, | ||||
|         output_attentions: Optional[bool] = False, | ||||
|         use_cache: Optional[bool] = False, | ||||
|     ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: | ||||
|         residual = hidden_states | ||||
|  | ||||
|         hidden_states = self.input_layernorm(hidden_states) | ||||
|  | ||||
|         # Self Attention | ||||
|         self_attn_output, self_attn_weights, present_key_value = self.self_attn( | ||||
|             hidden_states=hidden_states, | ||||
|             attention_mask=attention_mask, | ||||
|             position_ids=position_ids, | ||||
|             past_key_value=past_key_value, | ||||
|             output_attentions=output_attentions, | ||||
|             use_cache=use_cache, | ||||
|         ) | ||||
|  | ||||
|         # copied from transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXLayer.forward | ||||
|         if self.use_parallel_residual: | ||||
|             # x = x + attn(ln1(x)) + mlp(ln1(x)) | ||||
|             # Fully Connected | ||||
|             mlp_output = self.mlp(hidden_states) | ||||
|             mlp_output = self.dropout(mlp_output) | ||||
|             hidden_states = residual + self_attn_output + mlp_output | ||||
|         else: | ||||
|             # x = x + attn(ln1(x)) | ||||
|             # x = x + mlp(ln2(x)) | ||||
|             residual = residual + self_attn_output | ||||
|             # Fully Connected | ||||
|             mlp_output = self.mlp(self.post_attention_layernorm(residual)) | ||||
|             mlp_output = self.dropout(mlp_output) | ||||
|             hidden_states = residual + mlp_output | ||||
|  | ||||
|         outputs = (hidden_states,) | ||||
|  | ||||
|         if output_attentions: | ||||
|             outputs += (self_attn_weights,) | ||||
|  | ||||
|         if use_cache: | ||||
|             outputs += (present_key_value,) | ||||
|  | ||||
|         return outputs | ||||
|  | ||||
| StableLmPreTrainedModel = StableLmConverter.register("StableLmPreTrainedModel", LlamaPreTrainedModel) | ||||
| StableLmdModel = StableLmConverter.register("StableLmdModel", LlamaModel) | ||||
| StableLmForCausalLM = StableLmConverter.register("StableLmForCausalLM", LlamaForCausalLM) | ||||
| StableLmForSequenceClassification = StableLmConverter.register("StableLmForSequenceClassification", LlamaForSequenceClassification) | ||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										166
									
								
								src/transformers/models/starcoder2/diff_starcoder2.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										166
									
								
								src/transformers/models/starcoder2/diff_starcoder2.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,166 @@ | ||||
| # coding=utf-8 | ||||
| # Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved. | ||||
| # | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| from typing import List, Tuple | ||||
| from torch import FloatTensor, LongTensor, Tensor | ||||
| from transformers.modeling_outputs import BaseModelOutputWithPast | ||||
| from transformers.models.llama.configuration_llama import LlamaConfig | ||||
| from transformers.models.llama.modeling_llama import * | ||||
| import torch.nn as nn | ||||
| from transformers import Starcoder2Config | ||||
| from transformers.utils import ModelConverter | ||||
|  | ||||
| Starcoder2Converter = ModelConverter(__file__) | ||||
|  | ||||
| Starcoder2RMSNorm = Starcoder2Converter.register("Starcoder2RMSNorm", LlamaRMSNorm) | ||||
| StarcoderRotaryEmbedding = Starcoder2Converter.register("StarcoderRotaryEmbedding", LlamaRotaryEmbedding) | ||||
|  | ||||
| class Starcoder2MLP(nn.Module): | ||||
|     def __init__(self, config: Starcoder2Config): | ||||
|         super().__init__() | ||||
|         embed_dim = config.hidden_size | ||||
|         self.c_fc = nn.Linear(embed_dim, config.intermediate_size, bias=config.use_bias) | ||||
|         self.c_proj = nn.Linear(config.intermediate_size, embed_dim, bias=config.use_bias) | ||||
|         self.act = ACT2FN[config.hidden_act] | ||||
|         self.residual_dropout = config.residual_dropout | ||||
|  | ||||
|     def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor: | ||||
|         hidden_states = self.c_fc(hidden_states) | ||||
|         hidden_states = self.act(hidden_states) | ||||
|         hidden_states = self.c_proj(hidden_states) | ||||
|         hidden_states = nn.functional.dropout(hidden_states, p=self.residual_dropout, training=self.training) | ||||
|         return hidden_states | ||||
|  | ||||
| # TODO either we support this, or we don't allow call to super? | ||||
| # if part of the super is used, then we are fucked. Let's restrict this to init? | ||||
|  | ||||
| # TODO if a class is not registered, the original should be copied with replaces? | ||||
| # Copied form where? No. | ||||
| # But then how do we check the architecture etc. | ||||
|  | ||||
| # TODO do we support multiple inheritance?  | ||||
| # This will depend on whether we usually copy from more than one module | ||||
| # Mixtral for example?  | ||||
|  | ||||
| class Starcoder2Attention(LlamaAttention): | ||||
|     def __init__(self, config: LlamaConfig, layer_idx: int | None = None): | ||||
|         super().__init__(config, layer_idx) # here call to super means | ||||
|         self.attention_dropout = config.attention_dropout | ||||
|  | ||||
|     def forward( | ||||
|         self, | ||||
|         hidden_states: torch.Tensor, | ||||
|         attention_mask: Optional[torch.Tensor] = None, | ||||
|         position_ids: Optional[torch.LongTensor] = None, | ||||
|         past_key_value: Optional[Cache] = None, | ||||
|         output_attentions: bool = False, | ||||
|         use_cache: bool = False, | ||||
|         **kwargs, | ||||
|     ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: | ||||
|         if "padding_mask" in kwargs: | ||||
|             warnings.warn( | ||||
|                 "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" | ||||
|             ) | ||||
|         bsz, q_len, _ = hidden_states.size() | ||||
|  | ||||
|         query_states = self.q_proj(hidden_states) | ||||
|         key_states = self.k_proj(hidden_states) | ||||
|         value_states = self.v_proj(hidden_states) | ||||
|  | ||||
|         query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) | ||||
|         key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) | ||||
|         value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) | ||||
|  | ||||
|         kv_seq_len = key_states.shape[-2] | ||||
|         if past_key_value is not None: | ||||
|             if self.layer_idx is None: | ||||
|                 raise ValueError( | ||||
|                     f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " | ||||
|                     "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " | ||||
|                     "with a layer index." | ||||
|                 ) | ||||
|             kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) | ||||
|         cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) | ||||
|         query_states, key_states = self.rotary_emb(query_states, key_states, cos, sin, position_ids) | ||||
|  | ||||
|         if past_key_value is not None: | ||||
|             cache_kwargs = {"sin": sin, "cos": cos}  # Specific to RoPE models | ||||
|             key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) | ||||
|  | ||||
|         # repeat k/v heads if n_kv_heads < n_heads | ||||
|         key_states = repeat_kv(key_states, self.num_key_value_groups) | ||||
|         value_states = repeat_kv(value_states, self.num_key_value_groups) | ||||
|  | ||||
|         attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) | ||||
|  | ||||
|         if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): | ||||
|             raise ValueError( | ||||
|                 f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" | ||||
|                 f" {attn_weights.size()}" | ||||
|             ) | ||||
|  | ||||
|         if attention_mask is not None: | ||||
|             if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): | ||||
|                 raise ValueError( | ||||
|                     f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" | ||||
|                 ) | ||||
|  | ||||
|             attn_weights = attn_weights + attention_mask | ||||
|  | ||||
|         # upcast attention to fp32 | ||||
|         attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) | ||||
|         attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) | ||||
|         attn_output = torch.matmul(attn_weights, value_states) | ||||
|  | ||||
|         if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): | ||||
|             raise ValueError( | ||||
|                 f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" | ||||
|                 f" {attn_output.size()}" | ||||
|             ) | ||||
|  | ||||
|         attn_output = attn_output.transpose(1, 2).contiguous() | ||||
|         attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) | ||||
|  | ||||
|         attn_output = self.o_proj(attn_output) | ||||
|         attn_output = nn.functional.dropout(attn_output, p=self.residual_dropout, training=self.training) | ||||
|  | ||||
|         if not output_attentions: | ||||
|             attn_weights = None | ||||
|  | ||||
|         return attn_output, attn_weights, past_key_value | ||||
|  | ||||
| Starcoder2SdpaAttention = Starcoder2Converter.register("Starcoder2SdpaAttention", LlamaAttention)  | ||||
| Starcoder2FlashAttention2 = Starcoder2Converter.register("Starcoder2FlashAttention2", LlamaAttention)  | ||||
|  | ||||
| STARCODER2_ATTENTION_CLASSES = {"eager": Starcoder2Attention, "flash_attention_2": Starcoder2FlashAttention2, "sdpa": Starcoder2SdpaAttention} | ||||
|  | ||||
|  | ||||
| Starcoder2DecoderLayer = Starcoder2Converter.register("Starcoder2DecoderLayer", LlamaDecoderLayer)  | ||||
| Starcoder2PreTrainedModel = Starcoder2Converter.register("Starcoder2PreTrainedModel", LlamaPreTrainedModel) | ||||
|  | ||||
| class Starcoder2Model(LlamaModel): | ||||
|     def __init__(self, config): | ||||
|         super().__init__(config) | ||||
|         self.embedding_dropout = config.embedding_dropout | ||||
|  | ||||
|     def forward(self, input_ids: LongTensor = None, attention_mask: Tensor | None = None, position_ids: LongTensor | None = None, past_key_values: List[FloatTensor] | None = None, inputs_embeds: FloatTensor | None = None, use_cache: bool | None = None, output_attentions: bool | None = None, output_hidden_states: bool | None = None, return_dict: bool | None = None, cache_position: LongTensor | None = None) -> Tuple | BaseModelOutputWithPast: | ||||
|         if inputs_embeds is None:  | ||||
|             inputs_embeds = self.embed_tokens(input_ids) | ||||
|         hidden_states = inputs_embeds | ||||
|         hidden_states = nn.functional.dropout(hidden_states, p=self.embedding_dropout, training=self.training) | ||||
|         return super().forward(None, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position) | ||||
|  | ||||
| Starcoder2ForCausalLM = Starcoder2Converter.register("Starcoder2ForCausalLM", LlamaForCausalLM) | ||||
| Starcoder2ForSequenceClassification = Starcoder2Converter.register("Starcoder2ForSequenceClassification", LlamaForSequenceClassification) | ||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							| @ -215,7 +215,7 @@ from .peft_utils import ( | ||||
|     check_peft_version, | ||||
|     find_adapter_config_file, | ||||
| ) | ||||
|  | ||||
| from .model_converter import ModelConverter | ||||
|  | ||||
| WEIGHTS_NAME = "pytorch_model.bin" | ||||
| WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json" | ||||
|  | ||||
							
								
								
									
										57
									
								
								src/transformers/utils/model_converter.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										57
									
								
								src/transformers/utils/model_converter.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,57 @@ | ||||
| """ | ||||
| running this script on `src/transformers/models/**_diff.py` should produce the equivalent single model single files | ||||
| 1. Iterate though `**_diff.py` files | ||||
| 2. How to handle the imports? | ||||
|     a. `model_type` should always be present? | ||||
|     b. `ConfigClass` should be defined as well? | ||||
| 3. Copy each class and function one by one. | ||||
|     a. if there is a class registered for this file like `@__file__.register(MyNewClass, OldClass)` | ||||
|     then copy the content of `OldClass`, replacing all names of `Old` with `MyNew`. | ||||
|     Also copy the decorators that are on top of this class. | ||||
|     b. if there is inheritance, copy non-overloaded functions from base, and overloaded from non base. | ||||
| 4. Register things? | ||||
| new = type("new_class", (torch.nn.Linear,),{}) | ||||
| new__main__.new_class | ||||
| new(10,10) | ||||
| new_class(in_features=10, out_features=10, bias=True) | ||||
| CohereConverter = ModelConverter(__file__) | ||||
| CohereMLP = CohereConverter.register("CohereMLP", LlamaMLP) | ||||
| CohereMLP | ||||
| <class 'transformers.models.cohere.modeling_cohere.CohereMLP'> | ||||
| CohereMLP(LlamaConfig()) | ||||
| CohereMLP( | ||||
|   (gate_proj): Linear(in_features=4096, out_features=11008, bias=False) | ||||
|   (up_proj): Linear(in_features=4096, out_features=11008, bias=False) | ||||
|   (down_proj): Linear(in_features=11008, out_features=4096, bias=False) | ||||
|   (act_fn): SiLU() | ||||
| ) | ||||
| >>> CohereMLP(LlamaConfig())(torch.ones(1,1,4096)) | ||||
| How to deal with submodules? | ||||
| CohereSdpaAttention( | ||||
|   (q_proj): Linear(in_features=4096, out_features=4096, bias=False) | ||||
|   (k_proj): Linear(in_features=4096, out_features=4096, bias=False) | ||||
|   (v_proj): Linear(in_features=4096, out_features=4096, bias=False) | ||||
|   (o_proj): Linear(in_features=4096, out_features=4096, bias=False) | ||||
|   (rotary_emb): LlamaRotaryEmbedding() | ||||
| ) | ||||
| """ | ||||
| import regex as re | ||||
| class ModelConverter: | ||||
|  | ||||
|     def __init__(self, file): | ||||
|         self.diff_file = file | ||||
|         self.model_name = re.search(r'models/(.*?)/diff', self.diff_file).group(1) | ||||
|         self.modeling_file = file.replace("diff", "modeling") | ||||
|         self.registered_classes = {} | ||||
|         self.modules_to_import = [] | ||||
|     def register(self, new_class, old_class): | ||||
|         # registering. Returns the old class to be usable with a new name | ||||
|         self.registered_classes[new_class] = old_class | ||||
|         self.modules_to_import.append([old_class, old_class.__module__]) | ||||
|         new_class = type(new_class, (old_class,), {}) | ||||
|         base_model_name = re.search(r'models\.(.*?)\.modeling', old_class.__module__).group(1) | ||||
|         new_class.__module__ = re.sub(base_model_name, self.model_name, old_class.__module__) | ||||
|         return new_class | ||||
|  | ||||
|     def __repr__(self) -> str: | ||||
|         return f"ModelConverter({self.diff_file}, {self.model_name}, {self.registered_classes})" | ||||
							
								
								
									
										148
									
								
								utils/convert_diff_model.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										148
									
								
								utils/convert_diff_model.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,148 @@ | ||||
| import importlib | ||||
| import argparse | ||||
| import glob | ||||
| from transformers import MODEL_NAMES_MAPPING | ||||
| import regex as re | ||||
| import inspect | ||||
| # pattern = re.compile(r'(class|def|XXXConverter\.register)\s+[\w.()]+\s*:(\s*(?:[^class|def|XXXConverter\.register]|\n)+)', re.MULTILINE) | ||||
| # For each and every diff files we should import all packages from the modules that are imported. | ||||
| # pattern = r"((    [\s\S]*?)\n\n(?=    \S))|((    [\s\S]*?)(?=\Z))" is super important | ||||
|  | ||||
| # TODO in order to get everything from LLAMA we need to copy each line from Llama | ||||
| # only updating the attention classes.  | ||||
| # Need to keep the order correct | ||||
| # TODO the imports here are not correctly done. We should dynamically import what we need | ||||
| APACHE_LICENCE = """# Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| """ | ||||
| # 1. all the imports from the original file should be copied until end of header? __HEADER__ | ||||
| # with open(CohereConverter.original_file, 'r') as file, open("result.py", "w+") as modeling: | ||||
| #         pass | ||||
| # TODO also copy and import from all modules in CohereConverter.modules_to_import to be able to use inspect | ||||
| def replace_super_calls_in_method(method_body, parent_method_body, method_name): | ||||
|     # Indent parent method body to match the child's method indentation | ||||
|     indent = re.match(r'(\s*)def ', method_body).group(1) | ||||
|     indented_parent_method_body = "\n".join([indent + line if line.strip() else line for line in parent_method_body.split('\n')]) | ||||
|     method_name = method_name.strip() | ||||
|     # Handle super().method_name(args) and return super().method_name(args) | ||||
|     super_call_pattern = re.compile(r'(\s*)return super\(\)\.' + method_name + r'\((.*?)\)') | ||||
|     method_body = super_call_pattern.sub(r'\1return (\2\n' + indented_parent_method_body + r'\1)', method_body) | ||||
|  | ||||
|     super_call_pattern_no_return = re.compile(r'(\s*)super\(\)\.' + method_name + r'\((.*?)\)') | ||||
|     method_body = super_call_pattern_no_return.sub(r'\1\2\n' + indented_parent_method_body, method_body) | ||||
|  | ||||
|     return method_body | ||||
| # 2. Write all the classes. Use the `CohereConverter` class for this. | ||||
|  | ||||
|  | ||||
| def create_single_model_file(converter): | ||||
|     if hasattr(converter, "diff_file"): | ||||
|         model_identifier = converter.diff_file.split("diff_") | ||||
|         # temporarily add the source to the path in order to load everything? | ||||
|         # 1. Import all modules from the registered classes | ||||
|         modules = set([ _class.__module__ for _class in converter.registered_classes.values()]) or set() | ||||
|         for module in modules | {re.sub(r'.*src/(.*)\.py', r'\1', converter.diff_file).replace('/', '.')}: | ||||
|             modeling_ = importlib.import_module(module) | ||||
|             globals().update({k: getattr(modeling_, k) for k in modeling_.__dict__.keys()}) | ||||
|  | ||||
|         with open(converter.diff_file, 'r') as file, open(f"{model_identifier[0]}modeling_{model_identifier[1]}", "w+") as modeling: | ||||
|             modeling.write(APACHE_LICENCE) | ||||
|             function_set = {} | ||||
|             for line in file: | ||||
|                     if "Converter.register" in line: # TODO use map() to map lines to this | ||||
|                         # write the code of the original model | ||||
|                         class_to_use, old_class = re.search(r'Converter\.register\(\"(.*?)\", (.*?)\)', line).groups() | ||||
|                         model_identifier_camel = re.findall(r'[A-Z][a-z0-9]*', class_to_use)[0] | ||||
|                         old_model_identifier_camel = re.findall(r'[A-Z][a-z0-9]*', old_class)[0] | ||||
|                         # import all necessary modules from the path: | ||||
|  | ||||
|                         source_code = inspect.getsource(converter.registered_classes[class_to_use]).replace(old_class, class_to_use) | ||||
|                         source_code = source_code.replace(old_model_identifier_camel, model_identifier_camel) | ||||
|                         modeling.write(source_code) | ||||
|                         modeling.write("\n") | ||||
|  | ||||
|                     elif match:=re.match(r"class (\w+)\((\w+)\):", line): | ||||
|                         class_name, parent_class = match.groups() | ||||
|                         pattern = re.compile( r"(\ {4}(?:[\S\s\ \n]*?)(?=\n\ ^[\) ]|\n\n    (?:def|@)|\Z))", re.MULTILINE) | ||||
|                         parent_class_def = inspect.getsource(eval(parent_class)) | ||||
|                         modeling.write(parent_class_def.split('\n')[0].replace(parent_class,class_name)+"\n") | ||||
|  | ||||
|                         function_name_pattern = r"(?=    def ([\S]*)\()" | ||||
|                         function_body_pattern = r"(\ {4}(?:[\S\s\ \n]*?)(?=\n\ ^[\) ]|\n\n    (?:def|@)|\Z))" | ||||
|  | ||||
|                         pattern = re.compile(function_body_pattern) | ||||
|                         matches = pattern.finditer(parent_class_def) | ||||
|                         parent_function_set = {} | ||||
|                         for match in matches: | ||||
|                             full_function = match.group() | ||||
|                             if "def" in full_function: | ||||
|                                 parent_function_set[full_function.split("def")[1].split("(")[0]] = full_function | ||||
|                             else: | ||||
|                                 parent_function_set[full_function] = full_function | ||||
|  | ||||
|                         parent_identifier_camel = re.findall(r'[A-Z][a-z0-9]*', parent_class)[0] | ||||
|                         child_identifier_camel = re.findall(r'[A-Z][a-z0-9]*', class_name)[0] | ||||
|                         print(f"`{class_name}` -> `{parent_class}`") | ||||
|  | ||||
|                         child_function_set = parent_function_set.copy() | ||||
|                         class_def = inspect.getsource(eval(class_name)) | ||||
|                         matches = pattern.finditer(class_def) | ||||
|                         for match in matches: | ||||
|                             # TODO handle call to super! | ||||
|                             full_function = match.group() | ||||
|                             if "def" in full_function: | ||||
|                                 function_name = full_function.split("def")[1].split("(")[0] | ||||
|  | ||||
|                                 if (f"super()." in full_function or f"return super()." in full_function) and parent_identifier_camel != child_identifier_camel: | ||||
|                                     replaced_function = replace_super_calls_in_method(full_function, | ||||
|                                                                                       parent_function_set.get(function_name, | ||||
|                                                                                                               ""), | ||||
|                                                                                       function_name) | ||||
|                                     child_function_set[function_name] = replaced_function | ||||
|                                 else: | ||||
|                                     child_function_set[function_name] = full_function | ||||
|                             else: | ||||
|                                 child_function_set[full_function] = full_function | ||||
|  | ||||
|                         modeling.write("\n".join(child_function_set.values())) # TODO we wrote the code, next lines shall be ignored | ||||
|                         modeling.write("\n") | ||||
|  | ||||
|                     elif "= ModelConverter(__file__)" in line: | ||||
|                         pass # don't write the converter to the result file | ||||
|                     elif line not in "".join(function_set.values()) or line=="\n": | ||||
|                         modeling.write(line) | ||||
|  | ||||
|  | ||||
| def dynamically_import_object(module_path, object_name): | ||||
|     try: | ||||
|         module = importlib.import_module(module_path) | ||||
|         obj = getattr(module, object_name) | ||||
|         return obj | ||||
|     except (ImportError, AttributeError) as e: | ||||
|         print(f"Failed to import object '{object_name}' from module '{module_path}'") | ||||
|         print(e) | ||||
|  | ||||
|  | ||||
| # 3. Apply ruff fix to remove unused imports | ||||
| # 4. Run a tiny test to import from this new file. | ||||
| if __name__ == '__main__': | ||||
|     parser = argparse.ArgumentParser() | ||||
|     parser.add_argument("--files_to_parse", default="all", help="A list of `diff_xxxx` files that should be converted to single model file") | ||||
|     args = parser.parse_args() | ||||
|     if args.files_to_parse == "all": | ||||
|         args.files_to_parse = glob.glob("src/transformers/models/**/diff_*.py", recursive=True) | ||||
|     for file_name in args.files_to_parse: | ||||
|         print(f"Converting {file_name} to a single model single file format") | ||||
|         module_path = file_name.replace("/",".").replace(".py","").replace("src.","") | ||||
|         model_name = MODEL_NAMES_MAPPING[module_path.split('_')[-1]] | ||||
|         converter = dynamically_import_object(module_path, f"{model_name}Converter") | ||||
|         create_single_model_file(converter) | ||||
		Reference in New Issue
	
	Block a user
	