From 99801d04cd9cae2c9b614708f11444e290fb6ac1 Mon Sep 17 00:00:00 2001 From: Ben Burtenshaw Date: Sat, 18 Oct 2025 18:31:36 +0000 Subject: [PATCH] revert changes and use less llama modules --- .../models/nanochat/modeling_nanochat.py | 228 +++++++++------ .../models/nanochat/modular_nanochat.py | 270 +++++++++++------- 2 files changed, 312 insertions(+), 186 deletions(-) diff --git a/src/transformers/models/nanochat/modeling_nanochat.py b/src/transformers/models/nanochat/modeling_nanochat.py index 3609a2d9552..66ac9bc850e 100644 --- a/src/transformers/models/nanochat/modeling_nanochat.py +++ b/src/transformers/models/nanochat/modeling_nanochat.py @@ -20,11 +20,11 @@ # limitations under the License. import math +from collections.abc import Callable from typing import Optional, Union import torch import torch.nn as nn -import torch.nn.functional as F from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache @@ -33,15 +33,40 @@ from ...masking_utils import create_causal_mask from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update -from ...modeling_utils import PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple from ...utils.generic import check_model_inputs from .configuration_nanochat import NanoChatConfig +class NanoChatRMSNorm(torch.nn.Module): + """ + NanoChatRMSNorm inherits from Llama4TextL2Norm (weight-less RMS normalization). + Overrides __init__ to match NanoChat's API with hidden_size parameter. + """ + + def __init__(self, hidden_size, eps=1e-6): + super().__init__() + self.eps = eps + self.hidden_size = hidden_size + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + return self._norm(x.float()).type_as(x) + + def extra_repr(self): + return f"eps={self.eps}" + + class NanoChatRotaryEmbedding(nn.Module): - """Inherits from LlamaRotaryEmbedding but uses NanoChat's rotate_half.""" + """ + NanoChat's Rotary Position Embedding. + Inherits from LlamaRotaryEmbedding but produces cos/sin tensors with shape + [batch, seq_len, 1, head_dim//2] instead of duplicating to full head_dim. + """ inv_freq: torch.Tensor # fix linting for `register_buffer` @@ -65,38 +90,31 @@ class NanoChatRotaryEmbedding(nn.Module): @torch.no_grad() @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) def forward(self, x, position_ids): + """ + Returns cos and sin tensors for NanoChat's RoPE. + + Unlike LlamaRotaryEmbedding which duplicates freqs to full head_dim, + NanoChat keeps only head_dim//2 for memory efficiency. + """ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() * self.attention_scaling - sin = emb.sin() * self.attention_scaling + # NanoChat-specific: Don't duplicate freqs - keep as head_dim//2 + cos = freqs.cos() * self.attention_scaling + sin = freqs.sin() * self.attention_scaling - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + # Add extra dimension for NanoChat's broadcasting: [batch, seq_len] -> [batch, seq_len, 1, head_dim//2] + return cos.unsqueeze(2).to(dtype=x.dtype), sin.unsqueeze(2).to(dtype=x.dtype) -class NanoChatRMSNorm(torch.nn.Module): - """ - NanoChatRMSNorm inherits from Llama4TextL2Norm (weight-less RMS normalization). - Overrides __init__ to match NanoChat's API with hidden_size parameter. - """ - - def __init__(self, hidden_size, eps=1e-6): - super().__init__() - self.eps = eps - self.hidden_size = hidden_size - - def _norm(self, x): - return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) - - def forward(self, x): - return self._norm(x.float()).type_as(x) - - def extra_repr(self): - return f"eps={self.eps}" +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): @@ -164,38 +182,34 @@ def eager_attention_forward( return attn_output, attn_weights -def rotate_half(x): - """Rotates half the hidden dims of the input. - - NanoChat uses a different rotation convention than standard Llama. - Llama uses: [-x2, x1], NanoChat uses: [x2, -x1] to match the original nanochat implementation. - This results in: [q1 * cos + q2 * sin, -(q1 * sin) + q2 * cos] instead of [q1 * cos - q2 * sin, q1 * sin + q2 * cos] - """ - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((x2, -x1), dim=-1) - - class NanoChatAttention(nn.Module): """ - Multi-headed attention from NanoChat with custom QK normalization. - Inherits from LlamaAttention but adds RMSNorm to queries and keys after RoPE. + Multi-headed attention from NanoChat with custom RoPE and QK normalization. + + Based on: https://github.com/karpathy/nanochat/blob/main/nanochat/gpt.py#L64 """ def __init__(self, config: NanoChatConfig, layer_idx: int): super().__init__() self.config = config self.layer_idx = layer_idx - self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) - self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads - self.scaling = self.head_dim**-0.5 - self.attention_dropout = config.attention_dropout self.is_causal = True - # Override bias settings for NanoChat - self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.qkv_bias) - self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.qkv_bias) - self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.qkv_bias) - self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.qkv_bias) + + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.num_kv_heads = config.num_key_value_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_groups = self.num_heads // self.num_kv_heads + + self.attention_dropout = config.attention_dropout + self.scaling = self.head_dim**-0.5 + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.qkv_bias) + self.k_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=config.qkv_bias) + self.v_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=config.qkv_bias) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.qkv_bias) + self.query_norm = NanoChatRMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.key_norm = NanoChatRMSNorm(self.head_dim, eps=config.rms_norm_eps) def forward( self, @@ -206,28 +220,46 @@ class NanoChatAttention(nn.Module): cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + batch, seq_len, _ = hidden_states.shape input_shape = hidden_states.shape[:-1] - hidden_shape = (*input_shape, -1, self.head_dim) - query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) - key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + # Project the input to get queries, keys, and values [batch, num_heads, seq_len, head_dim] + query_states = ( + self.q_proj(hidden_states).view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + ) + key_states = ( + self.k_proj(hidden_states) + .view(batch, seq_len, self.num_kv_heads, self.head_dim) + .transpose(1, 2) + .contiguous() + ) + value_states = ( + self.v_proj(hidden_states) + .view(batch, seq_len, self.num_kv_heads, self.head_dim) + .transpose(1, 2) + .contiguous() + ) - cos, sin = position_embeddings + # Apply Rotary Embeddings to queries and keys to get relative positional encoding + cos, sin = position_embeddings # [batch, seq_len, 1, head_dim//2] + cos = cos.squeeze(2) + sin = sin.squeeze(2) + cos = torch.cat([cos, cos], dim=-1) + sin = torch.cat([-sin, -sin], dim=-1) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - # NanoChat-specific: Apply QK normalization after RoPE - query_states = F.rms_norm(query_states, (query_states.size(-1),), eps=self.config.rms_norm_eps) - key_states = F.rms_norm(key_states, (key_states.size(-1),), eps=self.config.rms_norm_eps) + # Apply QK normalization (RMSNorm) + query_states = self.query_norm(query_states) + key_states = self.key_norm(key_states) + # Apply KV cache: insert current k,v into cache, get the full view so far if past_key_values is not None: - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + cache_kwargs = {"cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface = eager_attention_forward + # Use attention interface pattern for vLLM compatibility + attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - from ...modeling_utils import ALL_ATTENTION_FUNCTIONS - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( @@ -241,6 +273,7 @@ class NanoChatAttention(nn.Module): **kwargs, ) + # Reshape and project output attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) return attn_output, attn_weights @@ -263,16 +296,13 @@ class NanoChatMLP(nn.Module): class NanoChatDecoderLayer(GradientCheckpointingLayer): - """ - NanoChat decoder layer with pre-norm architecture. - """ + """NanoChat decoder layer with pre-norm architecture.""" def __init__(self, config: NanoChatConfig, layer_idx: int): super().__init__() - self.hidden_size = config.hidden_size + self.config = config self.self_attn = NanoChatAttention(config, layer_idx) self.mlp = NanoChatMLP(config) - # Replace Llama's norm layers with NanoChat's weight-less norm self.input_layernorm = NanoChatRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = NanoChatRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -286,15 +316,13 @@ class NanoChatDecoderLayer(GradientCheckpointingLayer): cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, **kwargs: Unpack[TransformersKwargs], - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) - hidden_states, _ = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, - position_ids=position_ids, past_key_values=past_key_values, - use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, **kwargs, @@ -305,7 +333,7 @@ class NanoChatDecoderLayer(GradientCheckpointingLayer): hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states - return hidden_states + return hidden_states, self_attn_weights @auto_docstring @@ -343,18 +371,13 @@ class NanoChatPreTrainedModel(PreTrainedModel): @auto_docstring class NanoChatModel(NanoChatPreTrainedModel): - """ - NanoChat model that inherits from LlamaModel but uses NanoChat-specific layers - and RMSNorm without learnable weights. - """ - def __init__(self, config: NanoChatConfig): - # Call PreTrainedModel.__init__ directly to avoid LlamaModel's __init__ super().__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.initial_norm = NanoChatRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.layers = nn.ModuleList( [NanoChatDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) @@ -364,6 +387,12 @@ class NanoChatModel(NanoChatPreTrainedModel): self.post_init() + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, new_embeddings: nn.Embedding) -> None: + self.embed_tokens = new_embeddings + @check_model_inputs() @auto_docstring def forward( @@ -373,22 +402,25 @@ class NanoChatModel(NanoChatPreTrainedModel): position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, - cache_position: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[TransformersKwargs], ) -> BaseModelOutputWithPast: + output_attentions = kwargs.get("output_attentions", False) + output_hidden_states = kwargs.get("output_hidden_states", False) + if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") if inputs_embeds is None: - inputs_embeds: torch.Tensor = self.embed_tokens(input_ids) + inputs_embeds = self.embed_tokens(input_ids) if use_cache and past_key_values is None: past_key_values = DynamicCache(config=self.config) if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - cache_position: torch.Tensor = torch.arange( + cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) @@ -406,22 +438,38 @@ class NanoChatModel(NanoChatPreTrainedModel): hidden_states = inputs_embeds position_embeddings = self.rotary_emb(hidden_states, position_ids) + hidden_states = self.initial_norm(hidden_states) - for decoder_layer in self.layers[: self.config.num_hidden_layers]: - hidden_states = decoder_layer( + # Collect hidden states and attentions if requested + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + hidden_states, self_attn_weights = decoder_layer( hidden_states, attention_mask=causal_mask, position_ids=position_ids, past_key_values=past_key_values, + use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, **kwargs, ) + if output_attentions: + all_self_attns = all_self_attns + (self_attn_weights,) + hidden_states = self.norm(hidden_states) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + return BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=past_key_values, + past_key_values=past_key_values if use_cache else None, + hidden_states=all_hidden_states, + attentions=all_self_attns, ) @@ -429,7 +477,6 @@ class NanoChatModel(NanoChatPreTrainedModel): class NanoChatForCausalLM(NanoChatPreTrainedModel, GenerationMixin): """ The NanoChat Model transformer with a language modeling head on top. - Inherits from LlamaForCausalLM but uses NanoChatModel and supports logits soft capping. """ _tied_weights_keys = ["lm_head.weight"] @@ -437,14 +484,17 @@ class NanoChatForCausalLM(NanoChatPreTrainedModel, GenerationMixin): _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} def __init__(self, config: NanoChatConfig): - # Call PreTrainedModel.__init__ directly super().__init__(config) self.model = NanoChatModel(config) - self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - self.post_init() + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, new_embeddings: nn.Embedding) -> None: + self.model.set_input_embeddings(new_embeddings) + @can_return_tuple @auto_docstring def forward( @@ -477,7 +527,7 @@ class NanoChatForCausalLM(NanoChatPreTrainedModel, GenerationMixin): >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." ```""" - outputs: BaseModelOutputWithPast = self.model( + outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -491,8 +541,6 @@ class NanoChatForCausalLM(NanoChatPreTrainedModel, GenerationMixin): hidden_states = outputs.last_hidden_state slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.lm_head(hidden_states[:, slice_indices, :]) - - # NanoChat-specific: Apply logits soft capping if configured if self.config.logits_soft_cap is not None: cap = self.config.logits_soft_cap logits = cap * torch.tanh(logits / cap) diff --git a/src/transformers/models/nanochat/modular_nanochat.py b/src/transformers/models/nanochat/modular_nanochat.py index 62ff6d22304..a8dbce941eb 100644 --- a/src/transformers/models/nanochat/modular_nanochat.py +++ b/src/transformers/models/nanochat/modular_nanochat.py @@ -14,23 +14,24 @@ # limitations under the License. import math +from collections.abc import Callable from typing import Optional, Union import torch import torch.nn as nn -import torch.nn.functional as F from ...activations import ACT2FN -from ...cache_utils import Cache +from ...cache_utils import Cache, DynamicCache +from ...generation import GenerationMixin +from ...masking_utils import create_causal_mask +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast -from ...modeling_utils import PreTrainedModel +from ...modeling_rope_utils import dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...processing_utils import Unpack -from ...utils import TransformersKwargs, auto_docstring +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple +from ...utils.generic import check_model_inputs from ..llama.modeling_llama import ( - LlamaAttention, - LlamaDecoderLayer, - LlamaForCausalLM, - LlamaModel, LlamaPreTrainedModel, LlamaRotaryEmbedding, apply_rotary_pos_emb, @@ -40,23 +41,6 @@ from ..llama4.modeling_llama4 import Llama4TextL2Norm from .configuration_nanochat import NanoChatConfig -def rotate_half(x): - """Rotates half the hidden dims of the input. - - NanoChat uses a different rotation convention than standard Llama. - Llama uses: [-x2, x1], NanoChat uses: [x2, -x1] to match the original nanochat implementation. - This results in: [q1 * cos + q2 * sin, -(q1 * sin) + q2 * cos] instead of [q1 * cos - q2 * sin, q1 * sin + q2 * cos] - """ - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((x2, -x1), dim=-1) - - -class NanoChatRotaryEmbedding(LlamaRotaryEmbedding): - """Inherits from LlamaRotaryEmbedding but uses NanoChat's rotate_half.""" - pass - - class NanoChatRMSNorm(Llama4TextL2Norm): """ NanoChatRMSNorm inherits from Llama4TextL2Norm (weight-less RMS normalization). @@ -68,28 +52,64 @@ class NanoChatRMSNorm(Llama4TextL2Norm): self.hidden_size = hidden_size -class NanoChatAttention(LlamaAttention): +class NanoChatRotaryEmbedding(LlamaRotaryEmbedding): """ - Multi-headed attention from NanoChat with custom QK normalization. - Inherits from LlamaAttention but adds RMSNorm to queries and keys after RoPE. + NanoChat's Rotary Position Embedding. + Inherits from LlamaRotaryEmbedding but produces cos/sin tensors with shape + [batch, seq_len, 1, head_dim//2] instead of duplicating to full head_dim. + """ + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + """ + Returns cos and sin tensors for NanoChat's RoPE. + + Unlike LlamaRotaryEmbedding which duplicates freqs to full head_dim, + NanoChat keeps only head_dim//2 for memory efficiency. + """ + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + # NanoChat-specific: Don't duplicate freqs - keep as head_dim//2 + cos = freqs.cos() * self.attention_scaling + sin = freqs.sin() * self.attention_scaling + + # Add extra dimension for NanoChat's broadcasting: [batch, seq_len] -> [batch, seq_len, 1, head_dim//2] + return cos.unsqueeze(2).to(dtype=x.dtype), sin.unsqueeze(2).to(dtype=x.dtype) + + +class NanoChatAttention(nn.Module): + """ + Multi-headed attention from NanoChat with custom RoPE and QK normalization. + + Based on: https://github.com/karpathy/nanochat/blob/main/nanochat/gpt.py#L64 """ def __init__(self, config: NanoChatConfig, layer_idx: int): - super().__init__(config, layer_idx) + super().__init__() + self.config = config + self.layer_idx = layer_idx self.is_causal = True - # Override bias settings for NanoChat - self.q_proj = nn.Linear( - config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.qkv_bias - ) - self.k_proj = nn.Linear( - config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.qkv_bias - ) - self.v_proj = nn.Linear( - config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.qkv_bias - ) - self.o_proj = nn.Linear( - config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.qkv_bias - ) + + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.num_kv_heads = config.num_key_value_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_groups = self.num_heads // self.num_kv_heads + + self.attention_dropout = config.attention_dropout + self.scaling = self.head_dim**-0.5 + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.qkv_bias) + self.k_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=config.qkv_bias) + self.v_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=config.qkv_bias) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.qkv_bias) + self.query_norm = NanoChatRMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.key_norm = NanoChatRMSNorm(self.head_dim, eps=config.rms_norm_eps) def forward( self, @@ -100,27 +120,46 @@ class NanoChatAttention(LlamaAttention): cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + batch, seq_len, _ = hidden_states.shape input_shape = hidden_states.shape[:-1] - hidden_shape = (*input_shape, -1, self.head_dim) - query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) - key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + # Project the input to get queries, keys, and values [batch, num_heads, seq_len, head_dim] + query_states = ( + self.q_proj(hidden_states).view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + ) + key_states = ( + self.k_proj(hidden_states) + .view(batch, seq_len, self.num_kv_heads, self.head_dim) + .transpose(1, 2) + .contiguous() + ) + value_states = ( + self.v_proj(hidden_states) + .view(batch, seq_len, self.num_kv_heads, self.head_dim) + .transpose(1, 2) + .contiguous() + ) - cos, sin = position_embeddings + # Apply Rotary Embeddings to queries and keys to get relative positional encoding + cos, sin = position_embeddings # [batch, seq_len, 1, head_dim//2] + cos = cos.squeeze(2) + sin = sin.squeeze(2) + cos = torch.cat([cos, cos], dim=-1) + sin = torch.cat([-sin, -sin], dim=-1) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - # NanoChat-specific: Apply QK normalization after RoPE - query_states = F.rms_norm(query_states, (query_states.size(-1),), eps=self.config.rms_norm_eps) - key_states = F.rms_norm(key_states, (key_states.size(-1),), eps=self.config.rms_norm_eps) + # Apply QK normalization (RMSNorm) + query_states = self.query_norm(query_states) + key_states = self.key_norm(key_states) + # Apply KV cache: insert current k,v into cache, get the full view so far if past_key_values is not None: - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + cache_kwargs = {"cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface = eager_attention_forward + # Use attention interface pattern for vLLM compatibility + attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - from ...modeling_utils import ALL_ATTENTION_FUNCTIONS attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( @@ -134,6 +173,7 @@ class NanoChatAttention(LlamaAttention): **kwargs, ) + # Reshape and project output attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) return attn_output, attn_weights @@ -155,16 +195,14 @@ class NanoChatMLP(nn.Module): return hidden_states -class NanoChatDecoderLayer(LlamaDecoderLayer): - """ - NanoChat decoder layer with pre-norm architecture. - """ +class NanoChatDecoderLayer(GradientCheckpointingLayer): + """NanoChat decoder layer with pre-norm architecture.""" def __init__(self, config: NanoChatConfig, layer_idx: int): - super().__init__(config, layer_idx) + super().__init__() + self.config = config self.self_attn = NanoChatAttention(config, layer_idx) self.mlp = NanoChatMLP(config) - # Replace Llama's norm layers with NanoChat's weight-less norm self.input_layernorm = NanoChatRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = NanoChatRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -178,15 +216,13 @@ class NanoChatDecoderLayer(LlamaDecoderLayer): cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, **kwargs: Unpack[TransformersKwargs], - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) - hidden_states, _ = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, - position_ids=position_ids, past_key_values=past_key_values, - use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, **kwargs, @@ -197,13 +233,14 @@ class NanoChatDecoderLayer(LlamaDecoderLayer): hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states - return hidden_states + return hidden_states, self_attn_weights @auto_docstring class NanoChatPreTrainedModel(LlamaPreTrainedModel): config_class = NanoChatConfig base_model_prefix = "model" + supports_gradient_checkpointing = True _no_split_modules = ["NanoChatDecoderLayer"] _supports_attention_backend = True _skip_keys_device_placement = ["past_key_values"] @@ -212,16 +249,15 @@ class NanoChatPreTrainedModel(LlamaPreTrainedModel): _supports_flex_attn = True _can_compile_fullgraph = True - _supports_attention_backend = True - + _can_record_outputs = { "hidden_states": NanoChatDecoderLayer, "attentions": NanoChatAttention, } - + def _init_weights(self, module: nn.Module) -> None: super()._init_weights(module) - + # NanoChat-specific: scaled initialization for output projection for name, param in module.named_parameters(): if name == "o_proj.weight": @@ -233,19 +269,14 @@ class NanoChatPreTrainedModel(LlamaPreTrainedModel): @auto_docstring -class NanoChatModel(LlamaModel): - """ - NanoChat model that inherits from LlamaModel but uses NanoChat-specific layers - and RMSNorm without learnable weights. - """ - +class NanoChatModel(NanoChatPreTrainedModel): def __init__(self, config: NanoChatConfig): - # Call PreTrainedModel.__init__ directly to avoid LlamaModel's __init__ - PreTrainedModel.__init__(self, config) + super().__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.initial_norm = NanoChatRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.layers = nn.ModuleList( [NanoChatDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) @@ -255,6 +286,8 @@ class NanoChatModel(LlamaModel): self.post_init() + @check_model_inputs() + @auto_docstring def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -262,25 +295,25 @@ class NanoChatModel(LlamaModel): position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, - cache_position: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[TransformersKwargs], ) -> BaseModelOutputWithPast: - from ...cache_utils import DynamicCache - from ...masking_utils import create_causal_mask + output_attentions = kwargs.get("output_attentions", False) + output_hidden_states = kwargs.get("output_hidden_states", False) if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") if inputs_embeds is None: - inputs_embeds: torch.Tensor = self.embed_tokens(input_ids) + inputs_embeds = self.embed_tokens(input_ids) if use_cache and past_key_values is None: past_key_values = DynamicCache(config=self.config) if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - cache_position: torch.Tensor = torch.arange( + cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) @@ -298,41 +331,65 @@ class NanoChatModel(LlamaModel): hidden_states = inputs_embeds position_embeddings = self.rotary_emb(hidden_states, position_ids) + hidden_states = self.initial_norm(hidden_states) - for decoder_layer in self.layers[: self.config.num_hidden_layers]: - hidden_states = decoder_layer( + # Collect hidden states and attentions if requested + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + hidden_states, self_attn_weights = decoder_layer( hidden_states, attention_mask=causal_mask, position_ids=position_ids, past_key_values=past_key_values, + use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, **kwargs, ) + if output_attentions: + all_self_attns = all_self_attns + (self_attn_weights,) + hidden_states = self.norm(hidden_states) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + return BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=past_key_values, + past_key_values=past_key_values if use_cache else None, + hidden_states=all_hidden_states, + attentions=all_self_attns, ) @auto_docstring -class NanoChatForCausalLM(LlamaForCausalLM): +class NanoChatForCausalLM(NanoChatPreTrainedModel, GenerationMixin): """ The NanoChat Model transformer with a language modeling head on top. - Inherits from LlamaForCausalLM but uses NanoChatModel and supports logits soft capping. """ - def __init__(self, config: NanoChatConfig): - # Call PreTrainedModel.__init__ directly - PreTrainedModel.__init__(self, config) - self.model = NanoChatModel(config) - self.vocab_size = config.vocab_size - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + def __init__(self, config: NanoChatConfig): + super().__init__(config) + self.model = NanoChatModel(config) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.post_init() + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, new_embeddings: nn.Embedding) -> None: + self.model.set_input_embeddings(new_embeddings) + + @can_return_tuple + @auto_docstring def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -346,7 +403,31 @@ class NanoChatForCausalLM(LlamaForCausalLM): logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[TransformersKwargs], ) -> CausalLMOutputWithPast: - outputs: BaseModelOutputWithPast = self.model( + r""" + Example: + + ```python + >>> from transformers import AutoTokenizer, AutoModelForCausalLM + + >>> model = AutoModelForCausalLM.from_pretrained(model_id"karpathy/nanochat-d32") + + >>> tokenizer = AutoTokenizer.from_pretrained("karpathy/nanochat-d32") + + >>> conversation = [ + {"role": "user", "content": "What is the capital of France?"}, + ] + + >>> inputs = tokenizer.apply_chat_template( + conversation, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" + ).to(device) + + >>> with torch.no_grad(): + >>> outputs = model.generate(**inputs, max_new_tokens=64, do_sample=False) + + >>> generated_tokens = outputs[0, inputs["input_ids"].shape[1] :] + >>> output = tokenizer.decode(generated_tokens, skip_special_tokens=True) + ```""" + outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -360,8 +441,6 @@ class NanoChatForCausalLM(LlamaForCausalLM): hidden_states = outputs.last_hidden_state slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.lm_head(hidden_states[:, slice_indices, :]) - - # NanoChat-specific: Apply logits soft capping if configured if self.config.logits_soft_cap is not None: cap = self.config.logits_soft_cap logits = cap * torch.tanh(logits / cap) @@ -384,4 +463,3 @@ __all__ = [ "NanoChatModel", "NanoChatForCausalLM", ] -