mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-20 17:13:56 +08:00
revert changes and use less llama modules
This commit is contained in:
@ -20,11 +20,11 @@
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from collections.abc import Callable
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache
|
||||
@ -33,15 +33,40 @@ from ...masking_utils import create_causal_mask
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
|
||||
from ...utils.generic import check_model_inputs
|
||||
from .configuration_nanochat import NanoChatConfig
|
||||
|
||||
|
||||
class NanoChatRMSNorm(torch.nn.Module):
|
||||
"""
|
||||
NanoChatRMSNorm inherits from Llama4TextL2Norm (weight-less RMS normalization).
|
||||
Overrides __init__ to match NanoChat's API with hidden_size parameter.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
def _norm(self, x):
|
||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
|
||||
def forward(self, x):
|
||||
return self._norm(x.float()).type_as(x)
|
||||
|
||||
def extra_repr(self):
|
||||
return f"eps={self.eps}"
|
||||
|
||||
|
||||
class NanoChatRotaryEmbedding(nn.Module):
|
||||
"""Inherits from LlamaRotaryEmbedding but uses NanoChat's rotate_half."""
|
||||
"""
|
||||
NanoChat's Rotary Position Embedding.
|
||||
Inherits from LlamaRotaryEmbedding but produces cos/sin tensors with shape
|
||||
[batch, seq_len, 1, head_dim//2] instead of duplicating to full head_dim.
|
||||
"""
|
||||
|
||||
inv_freq: torch.Tensor # fix linting for `register_buffer`
|
||||
|
||||
@ -65,38 +90,31 @@ class NanoChatRotaryEmbedding(nn.Module):
|
||||
@torch.no_grad()
|
||||
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
|
||||
def forward(self, x, position_ids):
|
||||
"""
|
||||
Returns cos and sin tensors for NanoChat's RoPE.
|
||||
|
||||
Unlike LlamaRotaryEmbedding which duplicates freqs to full head_dim,
|
||||
NanoChat keeps only head_dim//2 for memory efficiency.
|
||||
"""
|
||||
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
|
||||
position_ids_expanded = position_ids[:, None, :].float()
|
||||
|
||||
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False): # Force float32
|
||||
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos() * self.attention_scaling
|
||||
sin = emb.sin() * self.attention_scaling
|
||||
# NanoChat-specific: Don't duplicate freqs - keep as head_dim//2
|
||||
cos = freqs.cos() * self.attention_scaling
|
||||
sin = freqs.sin() * self.attention_scaling
|
||||
|
||||
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
||||
# Add extra dimension for NanoChat's broadcasting: [batch, seq_len] -> [batch, seq_len, 1, head_dim//2]
|
||||
return cos.unsqueeze(2).to(dtype=x.dtype), sin.unsqueeze(2).to(dtype=x.dtype)
|
||||
|
||||
|
||||
class NanoChatRMSNorm(torch.nn.Module):
|
||||
"""
|
||||
NanoChatRMSNorm inherits from Llama4TextL2Norm (weight-less RMS normalization).
|
||||
Overrides __init__ to match NanoChat's API with hidden_size parameter.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
def _norm(self, x):
|
||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
|
||||
def forward(self, x):
|
||||
return self._norm(x.float()).type_as(x)
|
||||
|
||||
def extra_repr(self):
|
||||
return f"eps={self.eps}"
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||
@ -164,38 +182,34 @@ def eager_attention_forward(
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input.
|
||||
|
||||
NanoChat uses a different rotation convention than standard Llama.
|
||||
Llama uses: [-x2, x1], NanoChat uses: [x2, -x1] to match the original nanochat implementation.
|
||||
This results in: [q1 * cos + q2 * sin, -(q1 * sin) + q2 * cos] instead of [q1 * cos - q2 * sin, q1 * sin + q2 * cos]
|
||||
"""
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((x2, -x1), dim=-1)
|
||||
|
||||
|
||||
class NanoChatAttention(nn.Module):
|
||||
"""
|
||||
Multi-headed attention from NanoChat with custom QK normalization.
|
||||
Inherits from LlamaAttention but adds RMSNorm to queries and keys after RoPE.
|
||||
Multi-headed attention from NanoChat with custom RoPE and QK normalization.
|
||||
|
||||
Based on: https://github.com/karpathy/nanochat/blob/main/nanochat/gpt.py#L64
|
||||
"""
|
||||
|
||||
def __init__(self, config: NanoChatConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer_idx = layer_idx
|
||||
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
||||
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.attention_dropout = config.attention_dropout
|
||||
self.is_causal = True
|
||||
# Override bias settings for NanoChat
|
||||
self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.qkv_bias)
|
||||
self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.qkv_bias)
|
||||
self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.qkv_bias)
|
||||
self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.qkv_bias)
|
||||
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.num_kv_heads = config.num_key_value_heads
|
||||
self.head_dim = self.hidden_size // self.num_heads
|
||||
self.num_key_value_groups = self.num_heads // self.num_kv_heads
|
||||
|
||||
self.attention_dropout = config.attention_dropout
|
||||
self.scaling = self.head_dim**-0.5
|
||||
|
||||
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.qkv_bias)
|
||||
self.k_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=config.qkv_bias)
|
||||
self.v_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=config.qkv_bias)
|
||||
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.qkv_bias)
|
||||
self.query_norm = NanoChatRMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
||||
self.key_norm = NanoChatRMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -206,28 +220,46 @@ class NanoChatAttention(nn.Module):
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
batch, seq_len, _ = hidden_states.shape
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, self.head_dim)
|
||||
|
||||
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
# Project the input to get queries, keys, and values [batch, num_heads, seq_len, head_dim]
|
||||
query_states = (
|
||||
self.q_proj(hidden_states).view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
||||
)
|
||||
key_states = (
|
||||
self.k_proj(hidden_states)
|
||||
.view(batch, seq_len, self.num_kv_heads, self.head_dim)
|
||||
.transpose(1, 2)
|
||||
.contiguous()
|
||||
)
|
||||
value_states = (
|
||||
self.v_proj(hidden_states)
|
||||
.view(batch, seq_len, self.num_kv_heads, self.head_dim)
|
||||
.transpose(1, 2)
|
||||
.contiguous()
|
||||
)
|
||||
|
||||
cos, sin = position_embeddings
|
||||
# Apply Rotary Embeddings to queries and keys to get relative positional encoding
|
||||
cos, sin = position_embeddings # [batch, seq_len, 1, head_dim//2]
|
||||
cos = cos.squeeze(2)
|
||||
sin = sin.squeeze(2)
|
||||
cos = torch.cat([cos, cos], dim=-1)
|
||||
sin = torch.cat([-sin, -sin], dim=-1)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
# NanoChat-specific: Apply QK normalization after RoPE
|
||||
query_states = F.rms_norm(query_states, (query_states.size(-1),), eps=self.config.rms_norm_eps)
|
||||
key_states = F.rms_norm(key_states, (key_states.size(-1),), eps=self.config.rms_norm_eps)
|
||||
# Apply QK normalization (RMSNorm)
|
||||
query_states = self.query_norm(query_states)
|
||||
key_states = self.key_norm(key_states)
|
||||
|
||||
# Apply KV cache: insert current k,v into cache, get the full view so far
|
||||
if past_key_values is not None:
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
cache_kwargs = {"cache_position": cache_position}
|
||||
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
attention_interface = eager_attention_forward
|
||||
# Use attention interface pattern for vLLM compatibility
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
attn_output, attn_weights = attention_interface(
|
||||
@ -241,6 +273,7 @@ class NanoChatAttention(nn.Module):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Reshape and project output
|
||||
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
||||
attn_output = self.o_proj(attn_output)
|
||||
return attn_output, attn_weights
|
||||
@ -263,16 +296,13 @@ class NanoChatMLP(nn.Module):
|
||||
|
||||
|
||||
class NanoChatDecoderLayer(GradientCheckpointingLayer):
|
||||
"""
|
||||
NanoChat decoder layer with pre-norm architecture.
|
||||
"""
|
||||
"""NanoChat decoder layer with pre-norm architecture."""
|
||||
|
||||
def __init__(self, config: NanoChatConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
self.config = config
|
||||
self.self_attn = NanoChatAttention(config, layer_idx)
|
||||
self.mlp = NanoChatMLP(config)
|
||||
# Replace Llama's norm layers with NanoChat's weight-less norm
|
||||
self.input_layernorm = NanoChatRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = NanoChatRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
@ -286,15 +316,13 @@ class NanoChatDecoderLayer(GradientCheckpointingLayer):
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> torch.Tensor:
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
hidden_states, _ = self.self_attn(
|
||||
hidden_states, self_attn_weights = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**kwargs,
|
||||
@ -305,7 +333,7 @@ class NanoChatDecoderLayer(GradientCheckpointingLayer):
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
return hidden_states
|
||||
return hidden_states, self_attn_weights
|
||||
|
||||
|
||||
@auto_docstring
|
||||
@ -343,18 +371,13 @@ class NanoChatPreTrainedModel(PreTrainedModel):
|
||||
|
||||
@auto_docstring
|
||||
class NanoChatModel(NanoChatPreTrainedModel):
|
||||
"""
|
||||
NanoChat model that inherits from LlamaModel but uses NanoChat-specific layers
|
||||
and RMSNorm without learnable weights.
|
||||
"""
|
||||
|
||||
def __init__(self, config: NanoChatConfig):
|
||||
# Call PreTrainedModel.__init__ directly to avoid LlamaModel's __init__
|
||||
super().__init__(config)
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
||||
self.initial_norm = NanoChatRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.layers = nn.ModuleList(
|
||||
[NanoChatDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
||||
)
|
||||
@ -364,6 +387,12 @@ class NanoChatModel(NanoChatPreTrainedModel):
|
||||
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, new_embeddings: nn.Embedding) -> None:
|
||||
self.embed_tokens = new_embeddings
|
||||
|
||||
@check_model_inputs()
|
||||
@auto_docstring
|
||||
def forward(
|
||||
@ -373,22 +402,25 @@ class NanoChatModel(NanoChatPreTrainedModel):
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Cache] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> BaseModelOutputWithPast:
|
||||
output_attentions = kwargs.get("output_attentions", False)
|
||||
output_hidden_states = kwargs.get("output_hidden_states", False)
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
if use_cache and past_key_values is None:
|
||||
past_key_values = DynamicCache(config=self.config)
|
||||
|
||||
if cache_position is None:
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
cache_position: torch.Tensor = torch.arange(
|
||||
cache_position = torch.arange(
|
||||
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
||||
)
|
||||
|
||||
@ -406,22 +438,38 @@ class NanoChatModel(NanoChatPreTrainedModel):
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||
hidden_states = self.initial_norm(hidden_states)
|
||||
|
||||
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
||||
hidden_states = decoder_layer(
|
||||
# Collect hidden states and attentions if requested
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
for decoder_layer in self.layers:
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
hidden_states, self_attn_weights = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=causal_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns = all_self_attns + (self_attn_weights,)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=past_key_values,
|
||||
past_key_values=past_key_values if use_cache else None,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
|
||||
@ -429,7 +477,6 @@ class NanoChatModel(NanoChatPreTrainedModel):
|
||||
class NanoChatForCausalLM(NanoChatPreTrainedModel, GenerationMixin):
|
||||
"""
|
||||
The NanoChat Model transformer with a language modeling head on top.
|
||||
Inherits from LlamaForCausalLM but uses NanoChatModel and supports logits soft capping.
|
||||
"""
|
||||
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
@ -437,14 +484,17 @@ class NanoChatForCausalLM(NanoChatPreTrainedModel, GenerationMixin):
|
||||
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
||||
|
||||
def __init__(self, config: NanoChatConfig):
|
||||
# Call PreTrainedModel.__init__ directly
|
||||
super().__init__(config)
|
||||
self.model = NanoChatModel(config)
|
||||
self.vocab_size = config.vocab_size
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.model.get_input_embeddings()
|
||||
|
||||
def set_input_embeddings(self, new_embeddings: nn.Embedding) -> None:
|
||||
self.model.set_input_embeddings(new_embeddings)
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
@ -477,7 +527,7 @@ class NanoChatForCausalLM(NanoChatPreTrainedModel, GenerationMixin):
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||||
```"""
|
||||
outputs: BaseModelOutputWithPast = self.model(
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
@ -491,8 +541,6 @@ class NanoChatForCausalLM(NanoChatPreTrainedModel, GenerationMixin):
|
||||
hidden_states = outputs.last_hidden_state
|
||||
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
||||
|
||||
# NanoChat-specific: Apply logits soft capping if configured
|
||||
if self.config.logits_soft_cap is not None:
|
||||
cap = self.config.logits_soft_cap
|
||||
logits = cap * torch.tanh(logits / cap)
|
||||
|
@ -14,23 +14,24 @@
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from collections.abc import Callable
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache
|
||||
from ...cache_utils import Cache, DynamicCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...masking_utils import create_causal_mask
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...modeling_rope_utils import dynamic_rope_update
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import TransformersKwargs, auto_docstring
|
||||
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
|
||||
from ...utils.generic import check_model_inputs
|
||||
from ..llama.modeling_llama import (
|
||||
LlamaAttention,
|
||||
LlamaDecoderLayer,
|
||||
LlamaForCausalLM,
|
||||
LlamaModel,
|
||||
LlamaPreTrainedModel,
|
||||
LlamaRotaryEmbedding,
|
||||
apply_rotary_pos_emb,
|
||||
@ -40,23 +41,6 @@ from ..llama4.modeling_llama4 import Llama4TextL2Norm
|
||||
from .configuration_nanochat import NanoChatConfig
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input.
|
||||
|
||||
NanoChat uses a different rotation convention than standard Llama.
|
||||
Llama uses: [-x2, x1], NanoChat uses: [x2, -x1] to match the original nanochat implementation.
|
||||
This results in: [q1 * cos + q2 * sin, -(q1 * sin) + q2 * cos] instead of [q1 * cos - q2 * sin, q1 * sin + q2 * cos]
|
||||
"""
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((x2, -x1), dim=-1)
|
||||
|
||||
|
||||
class NanoChatRotaryEmbedding(LlamaRotaryEmbedding):
|
||||
"""Inherits from LlamaRotaryEmbedding but uses NanoChat's rotate_half."""
|
||||
pass
|
||||
|
||||
|
||||
class NanoChatRMSNorm(Llama4TextL2Norm):
|
||||
"""
|
||||
NanoChatRMSNorm inherits from Llama4TextL2Norm (weight-less RMS normalization).
|
||||
@ -68,28 +52,64 @@ class NanoChatRMSNorm(Llama4TextL2Norm):
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
|
||||
class NanoChatAttention(LlamaAttention):
|
||||
class NanoChatRotaryEmbedding(LlamaRotaryEmbedding):
|
||||
"""
|
||||
Multi-headed attention from NanoChat with custom QK normalization.
|
||||
Inherits from LlamaAttention but adds RMSNorm to queries and keys after RoPE.
|
||||
NanoChat's Rotary Position Embedding.
|
||||
Inherits from LlamaRotaryEmbedding but produces cos/sin tensors with shape
|
||||
[batch, seq_len, 1, head_dim//2] instead of duplicating to full head_dim.
|
||||
"""
|
||||
|
||||
@torch.no_grad()
|
||||
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
|
||||
def forward(self, x, position_ids):
|
||||
"""
|
||||
Returns cos and sin tensors for NanoChat's RoPE.
|
||||
|
||||
Unlike LlamaRotaryEmbedding which duplicates freqs to full head_dim,
|
||||
NanoChat keeps only head_dim//2 for memory efficiency.
|
||||
"""
|
||||
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
|
||||
position_ids_expanded = position_ids[:, None, :].float()
|
||||
|
||||
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False): # Force float32
|
||||
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
||||
# NanoChat-specific: Don't duplicate freqs - keep as head_dim//2
|
||||
cos = freqs.cos() * self.attention_scaling
|
||||
sin = freqs.sin() * self.attention_scaling
|
||||
|
||||
# Add extra dimension for NanoChat's broadcasting: [batch, seq_len] -> [batch, seq_len, 1, head_dim//2]
|
||||
return cos.unsqueeze(2).to(dtype=x.dtype), sin.unsqueeze(2).to(dtype=x.dtype)
|
||||
|
||||
|
||||
class NanoChatAttention(nn.Module):
|
||||
"""
|
||||
Multi-headed attention from NanoChat with custom RoPE and QK normalization.
|
||||
|
||||
Based on: https://github.com/karpathy/nanochat/blob/main/nanochat/gpt.py#L64
|
||||
"""
|
||||
|
||||
def __init__(self, config: NanoChatConfig, layer_idx: int):
|
||||
super().__init__(config, layer_idx)
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer_idx = layer_idx
|
||||
self.is_causal = True
|
||||
# Override bias settings for NanoChat
|
||||
self.q_proj = nn.Linear(
|
||||
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.qkv_bias
|
||||
)
|
||||
self.k_proj = nn.Linear(
|
||||
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.qkv_bias
|
||||
)
|
||||
self.v_proj = nn.Linear(
|
||||
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.qkv_bias
|
||||
)
|
||||
self.o_proj = nn.Linear(
|
||||
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.qkv_bias
|
||||
)
|
||||
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.num_kv_heads = config.num_key_value_heads
|
||||
self.head_dim = self.hidden_size // self.num_heads
|
||||
self.num_key_value_groups = self.num_heads // self.num_kv_heads
|
||||
|
||||
self.attention_dropout = config.attention_dropout
|
||||
self.scaling = self.head_dim**-0.5
|
||||
|
||||
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.qkv_bias)
|
||||
self.k_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=config.qkv_bias)
|
||||
self.v_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=config.qkv_bias)
|
||||
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.qkv_bias)
|
||||
self.query_norm = NanoChatRMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
||||
self.key_norm = NanoChatRMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -100,27 +120,46 @@ class NanoChatAttention(LlamaAttention):
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
batch, seq_len, _ = hidden_states.shape
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, self.head_dim)
|
||||
|
||||
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
# Project the input to get queries, keys, and values [batch, num_heads, seq_len, head_dim]
|
||||
query_states = (
|
||||
self.q_proj(hidden_states).view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
||||
)
|
||||
key_states = (
|
||||
self.k_proj(hidden_states)
|
||||
.view(batch, seq_len, self.num_kv_heads, self.head_dim)
|
||||
.transpose(1, 2)
|
||||
.contiguous()
|
||||
)
|
||||
value_states = (
|
||||
self.v_proj(hidden_states)
|
||||
.view(batch, seq_len, self.num_kv_heads, self.head_dim)
|
||||
.transpose(1, 2)
|
||||
.contiguous()
|
||||
)
|
||||
|
||||
cos, sin = position_embeddings
|
||||
# Apply Rotary Embeddings to queries and keys to get relative positional encoding
|
||||
cos, sin = position_embeddings # [batch, seq_len, 1, head_dim//2]
|
||||
cos = cos.squeeze(2)
|
||||
sin = sin.squeeze(2)
|
||||
cos = torch.cat([cos, cos], dim=-1)
|
||||
sin = torch.cat([-sin, -sin], dim=-1)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
# NanoChat-specific: Apply QK normalization after RoPE
|
||||
query_states = F.rms_norm(query_states, (query_states.size(-1),), eps=self.config.rms_norm_eps)
|
||||
key_states = F.rms_norm(key_states, (key_states.size(-1),), eps=self.config.rms_norm_eps)
|
||||
# Apply QK normalization (RMSNorm)
|
||||
query_states = self.query_norm(query_states)
|
||||
key_states = self.key_norm(key_states)
|
||||
|
||||
# Apply KV cache: insert current k,v into cache, get the full view so far
|
||||
if past_key_values is not None:
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
cache_kwargs = {"cache_position": cache_position}
|
||||
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
attention_interface = eager_attention_forward
|
||||
# Use attention interface pattern for vLLM compatibility
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
attn_output, attn_weights = attention_interface(
|
||||
@ -134,6 +173,7 @@ class NanoChatAttention(LlamaAttention):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Reshape and project output
|
||||
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
||||
attn_output = self.o_proj(attn_output)
|
||||
return attn_output, attn_weights
|
||||
@ -155,16 +195,14 @@ class NanoChatMLP(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class NanoChatDecoderLayer(LlamaDecoderLayer):
|
||||
"""
|
||||
NanoChat decoder layer with pre-norm architecture.
|
||||
"""
|
||||
class NanoChatDecoderLayer(GradientCheckpointingLayer):
|
||||
"""NanoChat decoder layer with pre-norm architecture."""
|
||||
|
||||
def __init__(self, config: NanoChatConfig, layer_idx: int):
|
||||
super().__init__(config, layer_idx)
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.self_attn = NanoChatAttention(config, layer_idx)
|
||||
self.mlp = NanoChatMLP(config)
|
||||
# Replace Llama's norm layers with NanoChat's weight-less norm
|
||||
self.input_layernorm = NanoChatRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = NanoChatRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
@ -178,15 +216,13 @@ class NanoChatDecoderLayer(LlamaDecoderLayer):
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> torch.Tensor:
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
hidden_states, _ = self.self_attn(
|
||||
hidden_states, self_attn_weights = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**kwargs,
|
||||
@ -197,13 +233,14 @@ class NanoChatDecoderLayer(LlamaDecoderLayer):
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
return hidden_states
|
||||
return hidden_states, self_attn_weights
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class NanoChatPreTrainedModel(LlamaPreTrainedModel):
|
||||
config_class = NanoChatConfig
|
||||
base_model_prefix = "model"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["NanoChatDecoderLayer"]
|
||||
_supports_attention_backend = True
|
||||
_skip_keys_device_placement = ["past_key_values"]
|
||||
@ -212,16 +249,15 @@ class NanoChatPreTrainedModel(LlamaPreTrainedModel):
|
||||
_supports_flex_attn = True
|
||||
|
||||
_can_compile_fullgraph = True
|
||||
_supports_attention_backend = True
|
||||
|
||||
|
||||
_can_record_outputs = {
|
||||
"hidden_states": NanoChatDecoderLayer,
|
||||
"attentions": NanoChatAttention,
|
||||
}
|
||||
|
||||
|
||||
def _init_weights(self, module: nn.Module) -> None:
|
||||
super()._init_weights(module)
|
||||
|
||||
|
||||
# NanoChat-specific: scaled initialization for output projection
|
||||
for name, param in module.named_parameters():
|
||||
if name == "o_proj.weight":
|
||||
@ -233,19 +269,14 @@ class NanoChatPreTrainedModel(LlamaPreTrainedModel):
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class NanoChatModel(LlamaModel):
|
||||
"""
|
||||
NanoChat model that inherits from LlamaModel but uses NanoChat-specific layers
|
||||
and RMSNorm without learnable weights.
|
||||
"""
|
||||
|
||||
class NanoChatModel(NanoChatPreTrainedModel):
|
||||
def __init__(self, config: NanoChatConfig):
|
||||
# Call PreTrainedModel.__init__ directly to avoid LlamaModel's __init__
|
||||
PreTrainedModel.__init__(self, config)
|
||||
super().__init__(config)
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
||||
self.initial_norm = NanoChatRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.layers = nn.ModuleList(
|
||||
[NanoChatDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
||||
)
|
||||
@ -255,6 +286,8 @@ class NanoChatModel(LlamaModel):
|
||||
|
||||
self.post_init()
|
||||
|
||||
@check_model_inputs()
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
@ -262,25 +295,25 @@ class NanoChatModel(LlamaModel):
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Cache] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> BaseModelOutputWithPast:
|
||||
from ...cache_utils import DynamicCache
|
||||
from ...masking_utils import create_causal_mask
|
||||
output_attentions = kwargs.get("output_attentions", False)
|
||||
output_hidden_states = kwargs.get("output_hidden_states", False)
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
if use_cache and past_key_values is None:
|
||||
past_key_values = DynamicCache(config=self.config)
|
||||
|
||||
if cache_position is None:
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
cache_position: torch.Tensor = torch.arange(
|
||||
cache_position = torch.arange(
|
||||
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
||||
)
|
||||
|
||||
@ -298,41 +331,65 @@ class NanoChatModel(LlamaModel):
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||
hidden_states = self.initial_norm(hidden_states)
|
||||
|
||||
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
||||
hidden_states = decoder_layer(
|
||||
# Collect hidden states and attentions if requested
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
for decoder_layer in self.layers:
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
hidden_states, self_attn_weights = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=causal_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns = all_self_attns + (self_attn_weights,)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=past_key_values,
|
||||
past_key_values=past_key_values if use_cache else None,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class NanoChatForCausalLM(LlamaForCausalLM):
|
||||
class NanoChatForCausalLM(NanoChatPreTrainedModel, GenerationMixin):
|
||||
"""
|
||||
The NanoChat Model transformer with a language modeling head on top.
|
||||
Inherits from LlamaForCausalLM but uses NanoChatModel and supports logits soft capping.
|
||||
"""
|
||||
|
||||
def __init__(self, config: NanoChatConfig):
|
||||
# Call PreTrainedModel.__init__ directly
|
||||
PreTrainedModel.__init__(self, config)
|
||||
self.model = NanoChatModel(config)
|
||||
self.vocab_size = config.vocab_size
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tp_plan = {"lm_head": "colwise_rep"}
|
||||
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
||||
|
||||
def __init__(self, config: NanoChatConfig):
|
||||
super().__init__(config)
|
||||
self.model = NanoChatModel(config)
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.model.get_input_embeddings()
|
||||
|
||||
def set_input_embeddings(self, new_embeddings: nn.Embedding) -> None:
|
||||
self.model.set_input_embeddings(new_embeddings)
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
@ -346,7 +403,31 @@ class NanoChatForCausalLM(LlamaForCausalLM):
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> CausalLMOutputWithPast:
|
||||
outputs: BaseModelOutputWithPast = self.model(
|
||||
r"""
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
|
||||
>>> model = AutoModelForCausalLM.from_pretrained(model_id"karpathy/nanochat-d32")
|
||||
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("karpathy/nanochat-d32")
|
||||
|
||||
>>> conversation = [
|
||||
{"role": "user", "content": "What is the capital of France?"},
|
||||
]
|
||||
|
||||
>>> inputs = tokenizer.apply_chat_template(
|
||||
conversation, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt"
|
||||
).to(device)
|
||||
|
||||
>>> with torch.no_grad():
|
||||
>>> outputs = model.generate(**inputs, max_new_tokens=64, do_sample=False)
|
||||
|
||||
>>> generated_tokens = outputs[0, inputs["input_ids"].shape[1] :]
|
||||
>>> output = tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
||||
```"""
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
@ -360,8 +441,6 @@ class NanoChatForCausalLM(LlamaForCausalLM):
|
||||
hidden_states = outputs.last_hidden_state
|
||||
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
||||
|
||||
# NanoChat-specific: Apply logits soft capping if configured
|
||||
if self.config.logits_soft_cap is not None:
|
||||
cap = self.config.logits_soft_cap
|
||||
logits = cap * torch.tanh(logits / cap)
|
||||
@ -384,4 +463,3 @@ __all__ = [
|
||||
"NanoChatModel",
|
||||
"NanoChatForCausalLM",
|
||||
]
|
||||
|
||||
|
Reference in New Issue
Block a user