revert changes and use less llama modules

This commit is contained in:
Ben Burtenshaw
2025-10-18 18:31:36 +00:00
parent a572394fec
commit 99801d04cd
2 changed files with 312 additions and 186 deletions

View File

@ -20,11 +20,11 @@
# limitations under the License. # limitations under the License.
import math import math
from collections.abc import Callable
from typing import Optional, Union from typing import Optional, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from ...activations import ACT2FN from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache from ...cache_utils import Cache, DynamicCache
@ -33,15 +33,40 @@ from ...masking_utils import create_causal_mask
from ...modeling_layers import GradientCheckpointingLayer from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import PreTrainedModel from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
from ...utils.generic import check_model_inputs from ...utils.generic import check_model_inputs
from .configuration_nanochat import NanoChatConfig from .configuration_nanochat import NanoChatConfig
class NanoChatRMSNorm(torch.nn.Module):
"""
NanoChatRMSNorm inherits from Llama4TextL2Norm (weight-less RMS normalization).
Overrides __init__ to match NanoChat's API with hidden_size parameter.
"""
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.eps = eps
self.hidden_size = hidden_size
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
return self._norm(x.float()).type_as(x)
def extra_repr(self):
return f"eps={self.eps}"
class NanoChatRotaryEmbedding(nn.Module): class NanoChatRotaryEmbedding(nn.Module):
"""Inherits from LlamaRotaryEmbedding but uses NanoChat's rotate_half.""" """
NanoChat's Rotary Position Embedding.
Inherits from LlamaRotaryEmbedding but produces cos/sin tensors with shape
[batch, seq_len, 1, head_dim//2] instead of duplicating to full head_dim.
"""
inv_freq: torch.Tensor # fix linting for `register_buffer` inv_freq: torch.Tensor # fix linting for `register_buffer`
@ -65,38 +90,31 @@ class NanoChatRotaryEmbedding(nn.Module):
@torch.no_grad() @torch.no_grad()
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
def forward(self, x, position_ids): def forward(self, x, position_ids):
"""
Returns cos and sin tensors for NanoChat's RoPE.
Unlike LlamaRotaryEmbedding which duplicates freqs to full head_dim,
NanoChat keeps only head_dim//2 for memory efficiency.
"""
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
position_ids_expanded = position_ids[:, None, :].float() position_ids_expanded = position_ids[:, None, :].float()
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): # Force float32 with torch.autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1) # NanoChat-specific: Don't duplicate freqs - keep as head_dim//2
cos = emb.cos() * self.attention_scaling cos = freqs.cos() * self.attention_scaling
sin = emb.sin() * self.attention_scaling sin = freqs.sin() * self.attention_scaling
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) # Add extra dimension for NanoChat's broadcasting: [batch, seq_len] -> [batch, seq_len, 1, head_dim//2]
return cos.unsqueeze(2).to(dtype=x.dtype), sin.unsqueeze(2).to(dtype=x.dtype)
class NanoChatRMSNorm(torch.nn.Module): def rotate_half(x):
""" """Rotates half the hidden dims of the input."""
NanoChatRMSNorm inherits from Llama4TextL2Norm (weight-less RMS normalization). x1 = x[..., : x.shape[-1] // 2]
Overrides __init__ to match NanoChat's API with hidden_size parameter. x2 = x[..., x.shape[-1] // 2 :]
""" return torch.cat((-x2, x1), dim=-1)
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.eps = eps
self.hidden_size = hidden_size
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
return self._norm(x.float()).type_as(x)
def extra_repr(self):
return f"eps={self.eps}"
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
@ -164,38 +182,34 @@ def eager_attention_forward(
return attn_output, attn_weights return attn_output, attn_weights
def rotate_half(x):
"""Rotates half the hidden dims of the input.
NanoChat uses a different rotation convention than standard Llama.
Llama uses: [-x2, x1], NanoChat uses: [x2, -x1] to match the original nanochat implementation.
This results in: [q1 * cos + q2 * sin, -(q1 * sin) + q2 * cos] instead of [q1 * cos - q2 * sin, q1 * sin + q2 * cos]
"""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((x2, -x1), dim=-1)
class NanoChatAttention(nn.Module): class NanoChatAttention(nn.Module):
""" """
Multi-headed attention from NanoChat with custom QK normalization. Multi-headed attention from NanoChat with custom RoPE and QK normalization.
Inherits from LlamaAttention but adds RMSNorm to queries and keys after RoPE.
Based on: https://github.com/karpathy/nanochat/blob/main/nanochat/gpt.py#L64
""" """
def __init__(self, config: NanoChatConfig, layer_idx: int): def __init__(self, config: NanoChatConfig, layer_idx: int):
super().__init__() super().__init__()
self.config = config self.config = config
self.layer_idx = layer_idx self.layer_idx = layer_idx
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
self.scaling = self.head_dim**-0.5
self.attention_dropout = config.attention_dropout
self.is_causal = True self.is_causal = True
# Override bias settings for NanoChat
self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.qkv_bias) self.hidden_size = config.hidden_size
self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.qkv_bias) self.num_heads = config.num_attention_heads
self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.qkv_bias) self.num_kv_heads = config.num_key_value_heads
self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.qkv_bias) self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_groups = self.num_heads // self.num_kv_heads
self.attention_dropout = config.attention_dropout
self.scaling = self.head_dim**-0.5
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.qkv_bias)
self.k_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=config.qkv_bias)
self.v_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=config.qkv_bias)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.qkv_bias)
self.query_norm = NanoChatRMSNorm(self.head_dim, eps=config.rms_norm_eps)
self.key_norm = NanoChatRMSNorm(self.head_dim, eps=config.rms_norm_eps)
def forward( def forward(
self, self,
@ -206,28 +220,46 @@ class NanoChatAttention(nn.Module):
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[TransformersKwargs], **kwargs: Unpack[TransformersKwargs],
) -> tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
batch, seq_len, _ = hidden_states.shape
input_shape = hidden_states.shape[:-1] input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) # Project the input to get queries, keys, and values [batch, num_heads, seq_len, head_dim]
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) query_states = (
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) self.q_proj(hidden_states).view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
)
key_states = (
self.k_proj(hidden_states)
.view(batch, seq_len, self.num_kv_heads, self.head_dim)
.transpose(1, 2)
.contiguous()
)
value_states = (
self.v_proj(hidden_states)
.view(batch, seq_len, self.num_kv_heads, self.head_dim)
.transpose(1, 2)
.contiguous()
)
cos, sin = position_embeddings # Apply Rotary Embeddings to queries and keys to get relative positional encoding
cos, sin = position_embeddings # [batch, seq_len, 1, head_dim//2]
cos = cos.squeeze(2)
sin = sin.squeeze(2)
cos = torch.cat([cos, cos], dim=-1)
sin = torch.cat([-sin, -sin], dim=-1)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
# NanoChat-specific: Apply QK normalization after RoPE # Apply QK normalization (RMSNorm)
query_states = F.rms_norm(query_states, (query_states.size(-1),), eps=self.config.rms_norm_eps) query_states = self.query_norm(query_states)
key_states = F.rms_norm(key_states, (key_states.size(-1),), eps=self.config.rms_norm_eps) key_states = self.key_norm(key_states)
# Apply KV cache: insert current k,v into cache, get the full view so far
if past_key_values is not None: if past_key_values is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} cache_kwargs = {"cache_position": cache_position}
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
attention_interface = eager_attention_forward # Use attention interface pattern for vLLM compatibility
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager": if self.config._attn_implementation != "eager":
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, attn_weights = attention_interface( attn_output, attn_weights = attention_interface(
@ -241,6 +273,7 @@ class NanoChatAttention(nn.Module):
**kwargs, **kwargs,
) )
# Reshape and project output
attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output) attn_output = self.o_proj(attn_output)
return attn_output, attn_weights return attn_output, attn_weights
@ -263,16 +296,13 @@ class NanoChatMLP(nn.Module):
class NanoChatDecoderLayer(GradientCheckpointingLayer): class NanoChatDecoderLayer(GradientCheckpointingLayer):
""" """NanoChat decoder layer with pre-norm architecture."""
NanoChat decoder layer with pre-norm architecture.
"""
def __init__(self, config: NanoChatConfig, layer_idx: int): def __init__(self, config: NanoChatConfig, layer_idx: int):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.config = config
self.self_attn = NanoChatAttention(config, layer_idx) self.self_attn = NanoChatAttention(config, layer_idx)
self.mlp = NanoChatMLP(config) self.mlp = NanoChatMLP(config)
# Replace Llama's norm layers with NanoChat's weight-less norm
self.input_layernorm = NanoChatRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.input_layernorm = NanoChatRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = NanoChatRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = NanoChatRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@ -286,15 +316,13 @@ class NanoChatDecoderLayer(GradientCheckpointingLayer):
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs: Unpack[TransformersKwargs], **kwargs: Unpack[TransformersKwargs],
) -> torch.Tensor: ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
residual = hidden_states residual = hidden_states
hidden_states = self.input_layernorm(hidden_states) hidden_states = self.input_layernorm(hidden_states)
hidden_states, _ = self.self_attn( hidden_states, self_attn_weights = self.self_attn(
hidden_states=hidden_states, hidden_states=hidden_states,
attention_mask=attention_mask, attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values, past_key_values=past_key_values,
use_cache=use_cache,
cache_position=cache_position, cache_position=cache_position,
position_embeddings=position_embeddings, position_embeddings=position_embeddings,
**kwargs, **kwargs,
@ -305,7 +333,7 @@ class NanoChatDecoderLayer(GradientCheckpointingLayer):
hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states) hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
return hidden_states return hidden_states, self_attn_weights
@auto_docstring @auto_docstring
@ -343,18 +371,13 @@ class NanoChatPreTrainedModel(PreTrainedModel):
@auto_docstring @auto_docstring
class NanoChatModel(NanoChatPreTrainedModel): class NanoChatModel(NanoChatPreTrainedModel):
"""
NanoChat model that inherits from LlamaModel but uses NanoChat-specific layers
and RMSNorm without learnable weights.
"""
def __init__(self, config: NanoChatConfig): def __init__(self, config: NanoChatConfig):
# Call PreTrainedModel.__init__ directly to avoid LlamaModel's __init__
super().__init__(config) super().__init__(config)
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.initial_norm = NanoChatRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[NanoChatDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] [NanoChatDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
) )
@ -364,6 +387,12 @@ class NanoChatModel(NanoChatPreTrainedModel):
self.post_init() self.post_init()
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, new_embeddings: nn.Embedding) -> None:
self.embed_tokens = new_embeddings
@check_model_inputs() @check_model_inputs()
@auto_docstring @auto_docstring
def forward( def forward(
@ -373,22 +402,25 @@ class NanoChatModel(NanoChatPreTrainedModel):
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None, past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None, use_cache: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[TransformersKwargs], **kwargs: Unpack[TransformersKwargs],
) -> BaseModelOutputWithPast: ) -> BaseModelOutputWithPast:
output_attentions = kwargs.get("output_attentions", False)
output_hidden_states = kwargs.get("output_hidden_states", False)
if (input_ids is None) ^ (inputs_embeds is not None): if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds") raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds: torch.Tensor = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
if use_cache and past_key_values is None: if use_cache and past_key_values is None:
past_key_values = DynamicCache(config=self.config) past_key_values = DynamicCache(config=self.config)
if cache_position is None: if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position: torch.Tensor = torch.arange( cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
) )
@ -406,22 +438,38 @@ class NanoChatModel(NanoChatPreTrainedModel):
hidden_states = inputs_embeds hidden_states = inputs_embeds
position_embeddings = self.rotary_emb(hidden_states, position_ids) position_embeddings = self.rotary_emb(hidden_states, position_ids)
hidden_states = self.initial_norm(hidden_states)
for decoder_layer in self.layers[: self.config.num_hidden_layers]: # Collect hidden states and attentions if requested
hidden_states = decoder_layer( all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
hidden_states, self_attn_weights = decoder_layer(
hidden_states, hidden_states,
attention_mask=causal_mask, attention_mask=causal_mask,
position_ids=position_ids, position_ids=position_ids,
past_key_values=past_key_values, past_key_values=past_key_values,
use_cache=use_cache,
cache_position=cache_position, cache_position=cache_position,
position_embeddings=position_embeddings, position_embeddings=position_embeddings,
**kwargs, **kwargs,
) )
if output_attentions:
all_self_attns = all_self_attns + (self_attn_weights,)
hidden_states = self.norm(hidden_states) hidden_states = self.norm(hidden_states)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
return BaseModelOutputWithPast( return BaseModelOutputWithPast(
last_hidden_state=hidden_states, last_hidden_state=hidden_states,
past_key_values=past_key_values, past_key_values=past_key_values if use_cache else None,
hidden_states=all_hidden_states,
attentions=all_self_attns,
) )
@ -429,7 +477,6 @@ class NanoChatModel(NanoChatPreTrainedModel):
class NanoChatForCausalLM(NanoChatPreTrainedModel, GenerationMixin): class NanoChatForCausalLM(NanoChatPreTrainedModel, GenerationMixin):
""" """
The NanoChat Model transformer with a language modeling head on top. The NanoChat Model transformer with a language modeling head on top.
Inherits from LlamaForCausalLM but uses NanoChatModel and supports logits soft capping.
""" """
_tied_weights_keys = ["lm_head.weight"] _tied_weights_keys = ["lm_head.weight"]
@ -437,14 +484,17 @@ class NanoChatForCausalLM(NanoChatPreTrainedModel, GenerationMixin):
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
def __init__(self, config: NanoChatConfig): def __init__(self, config: NanoChatConfig):
# Call PreTrainedModel.__init__ directly
super().__init__(config) super().__init__(config)
self.model = NanoChatModel(config) self.model = NanoChatModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.post_init() self.post_init()
def get_input_embeddings(self):
return self.model.get_input_embeddings()
def set_input_embeddings(self, new_embeddings: nn.Embedding) -> None:
self.model.set_input_embeddings(new_embeddings)
@can_return_tuple @can_return_tuple
@auto_docstring @auto_docstring
def forward( def forward(
@ -477,7 +527,7 @@ class NanoChatForCausalLM(NanoChatPreTrainedModel, GenerationMixin):
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```""" ```"""
outputs: BaseModelOutputWithPast = self.model( outputs = self.model(
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
position_ids=position_ids, position_ids=position_ids,
@ -491,8 +541,6 @@ class NanoChatForCausalLM(NanoChatPreTrainedModel, GenerationMixin):
hidden_states = outputs.last_hidden_state hidden_states = outputs.last_hidden_state
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :]) logits = self.lm_head(hidden_states[:, slice_indices, :])
# NanoChat-specific: Apply logits soft capping if configured
if self.config.logits_soft_cap is not None: if self.config.logits_soft_cap is not None:
cap = self.config.logits_soft_cap cap = self.config.logits_soft_cap
logits = cap * torch.tanh(logits / cap) logits = cap * torch.tanh(logits / cap)

View File

@ -14,23 +14,24 @@
# limitations under the License. # limitations under the License.
import math import math
from collections.abc import Callable
from typing import Optional, Union from typing import Optional, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from ...activations import ACT2FN from ...activations import ACT2FN
from ...cache_utils import Cache from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
from ...masking_utils import create_causal_mask
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ...modeling_utils import PreTrainedModel from ...modeling_rope_utils import dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
from ...processing_utils import Unpack from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
from ...utils.generic import check_model_inputs
from ..llama.modeling_llama import ( from ..llama.modeling_llama import (
LlamaAttention,
LlamaDecoderLayer,
LlamaForCausalLM,
LlamaModel,
LlamaPreTrainedModel, LlamaPreTrainedModel,
LlamaRotaryEmbedding, LlamaRotaryEmbedding,
apply_rotary_pos_emb, apply_rotary_pos_emb,
@ -40,23 +41,6 @@ from ..llama4.modeling_llama4 import Llama4TextL2Norm
from .configuration_nanochat import NanoChatConfig from .configuration_nanochat import NanoChatConfig
def rotate_half(x):
"""Rotates half the hidden dims of the input.
NanoChat uses a different rotation convention than standard Llama.
Llama uses: [-x2, x1], NanoChat uses: [x2, -x1] to match the original nanochat implementation.
This results in: [q1 * cos + q2 * sin, -(q1 * sin) + q2 * cos] instead of [q1 * cos - q2 * sin, q1 * sin + q2 * cos]
"""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((x2, -x1), dim=-1)
class NanoChatRotaryEmbedding(LlamaRotaryEmbedding):
"""Inherits from LlamaRotaryEmbedding but uses NanoChat's rotate_half."""
pass
class NanoChatRMSNorm(Llama4TextL2Norm): class NanoChatRMSNorm(Llama4TextL2Norm):
""" """
NanoChatRMSNorm inherits from Llama4TextL2Norm (weight-less RMS normalization). NanoChatRMSNorm inherits from Llama4TextL2Norm (weight-less RMS normalization).
@ -68,28 +52,64 @@ class NanoChatRMSNorm(Llama4TextL2Norm):
self.hidden_size = hidden_size self.hidden_size = hidden_size
class NanoChatAttention(LlamaAttention): class NanoChatRotaryEmbedding(LlamaRotaryEmbedding):
""" """
Multi-headed attention from NanoChat with custom QK normalization. NanoChat's Rotary Position Embedding.
Inherits from LlamaAttention but adds RMSNorm to queries and keys after RoPE. Inherits from LlamaRotaryEmbedding but produces cos/sin tensors with shape
[batch, seq_len, 1, head_dim//2] instead of duplicating to full head_dim.
"""
@torch.no_grad()
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
def forward(self, x, position_ids):
"""
Returns cos and sin tensors for NanoChat's RoPE.
Unlike LlamaRotaryEmbedding which duplicates freqs to full head_dim,
NanoChat keeps only head_dim//2 for memory efficiency.
"""
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
position_ids_expanded = position_ids[:, None, :].float()
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
# NanoChat-specific: Don't duplicate freqs - keep as head_dim//2
cos = freqs.cos() * self.attention_scaling
sin = freqs.sin() * self.attention_scaling
# Add extra dimension for NanoChat's broadcasting: [batch, seq_len] -> [batch, seq_len, 1, head_dim//2]
return cos.unsqueeze(2).to(dtype=x.dtype), sin.unsqueeze(2).to(dtype=x.dtype)
class NanoChatAttention(nn.Module):
"""
Multi-headed attention from NanoChat with custom RoPE and QK normalization.
Based on: https://github.com/karpathy/nanochat/blob/main/nanochat/gpt.py#L64
""" """
def __init__(self, config: NanoChatConfig, layer_idx: int): def __init__(self, config: NanoChatConfig, layer_idx: int):
super().__init__(config, layer_idx) super().__init__()
self.config = config
self.layer_idx = layer_idx
self.is_causal = True self.is_causal = True
# Override bias settings for NanoChat
self.q_proj = nn.Linear( self.hidden_size = config.hidden_size
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.qkv_bias self.num_heads = config.num_attention_heads
) self.num_kv_heads = config.num_key_value_heads
self.k_proj = nn.Linear( self.head_dim = self.hidden_size // self.num_heads
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.qkv_bias self.num_key_value_groups = self.num_heads // self.num_kv_heads
)
self.v_proj = nn.Linear( self.attention_dropout = config.attention_dropout
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.qkv_bias self.scaling = self.head_dim**-0.5
)
self.o_proj = nn.Linear( self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.qkv_bias)
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.qkv_bias self.k_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=config.qkv_bias)
) self.v_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=config.qkv_bias)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.qkv_bias)
self.query_norm = NanoChatRMSNorm(self.head_dim, eps=config.rms_norm_eps)
self.key_norm = NanoChatRMSNorm(self.head_dim, eps=config.rms_norm_eps)
def forward( def forward(
self, self,
@ -100,27 +120,46 @@ class NanoChatAttention(LlamaAttention):
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[TransformersKwargs], **kwargs: Unpack[TransformersKwargs],
) -> tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
batch, seq_len, _ = hidden_states.shape
input_shape = hidden_states.shape[:-1] input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) # Project the input to get queries, keys, and values [batch, num_heads, seq_len, head_dim]
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) query_states = (
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) self.q_proj(hidden_states).view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
)
key_states = (
self.k_proj(hidden_states)
.view(batch, seq_len, self.num_kv_heads, self.head_dim)
.transpose(1, 2)
.contiguous()
)
value_states = (
self.v_proj(hidden_states)
.view(batch, seq_len, self.num_kv_heads, self.head_dim)
.transpose(1, 2)
.contiguous()
)
cos, sin = position_embeddings # Apply Rotary Embeddings to queries and keys to get relative positional encoding
cos, sin = position_embeddings # [batch, seq_len, 1, head_dim//2]
cos = cos.squeeze(2)
sin = sin.squeeze(2)
cos = torch.cat([cos, cos], dim=-1)
sin = torch.cat([-sin, -sin], dim=-1)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
# NanoChat-specific: Apply QK normalization after RoPE # Apply QK normalization (RMSNorm)
query_states = F.rms_norm(query_states, (query_states.size(-1),), eps=self.config.rms_norm_eps) query_states = self.query_norm(query_states)
key_states = F.rms_norm(key_states, (key_states.size(-1),), eps=self.config.rms_norm_eps) key_states = self.key_norm(key_states)
# Apply KV cache: insert current k,v into cache, get the full view so far
if past_key_values is not None: if past_key_values is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} cache_kwargs = {"cache_position": cache_position}
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
attention_interface = eager_attention_forward # Use attention interface pattern for vLLM compatibility
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager": if self.config._attn_implementation != "eager":
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, attn_weights = attention_interface( attn_output, attn_weights = attention_interface(
@ -134,6 +173,7 @@ class NanoChatAttention(LlamaAttention):
**kwargs, **kwargs,
) )
# Reshape and project output
attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output) attn_output = self.o_proj(attn_output)
return attn_output, attn_weights return attn_output, attn_weights
@ -155,16 +195,14 @@ class NanoChatMLP(nn.Module):
return hidden_states return hidden_states
class NanoChatDecoderLayer(LlamaDecoderLayer): class NanoChatDecoderLayer(GradientCheckpointingLayer):
""" """NanoChat decoder layer with pre-norm architecture."""
NanoChat decoder layer with pre-norm architecture.
"""
def __init__(self, config: NanoChatConfig, layer_idx: int): def __init__(self, config: NanoChatConfig, layer_idx: int):
super().__init__(config, layer_idx) super().__init__()
self.config = config
self.self_attn = NanoChatAttention(config, layer_idx) self.self_attn = NanoChatAttention(config, layer_idx)
self.mlp = NanoChatMLP(config) self.mlp = NanoChatMLP(config)
# Replace Llama's norm layers with NanoChat's weight-less norm
self.input_layernorm = NanoChatRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.input_layernorm = NanoChatRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = NanoChatRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = NanoChatRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@ -178,15 +216,13 @@ class NanoChatDecoderLayer(LlamaDecoderLayer):
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs: Unpack[TransformersKwargs], **kwargs: Unpack[TransformersKwargs],
) -> torch.Tensor: ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
residual = hidden_states residual = hidden_states
hidden_states = self.input_layernorm(hidden_states) hidden_states = self.input_layernorm(hidden_states)
hidden_states, _ = self.self_attn( hidden_states, self_attn_weights = self.self_attn(
hidden_states=hidden_states, hidden_states=hidden_states,
attention_mask=attention_mask, attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values, past_key_values=past_key_values,
use_cache=use_cache,
cache_position=cache_position, cache_position=cache_position,
position_embeddings=position_embeddings, position_embeddings=position_embeddings,
**kwargs, **kwargs,
@ -197,13 +233,14 @@ class NanoChatDecoderLayer(LlamaDecoderLayer):
hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states) hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
return hidden_states return hidden_states, self_attn_weights
@auto_docstring @auto_docstring
class NanoChatPreTrainedModel(LlamaPreTrainedModel): class NanoChatPreTrainedModel(LlamaPreTrainedModel):
config_class = NanoChatConfig config_class = NanoChatConfig
base_model_prefix = "model" base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["NanoChatDecoderLayer"] _no_split_modules = ["NanoChatDecoderLayer"]
_supports_attention_backend = True _supports_attention_backend = True
_skip_keys_device_placement = ["past_key_values"] _skip_keys_device_placement = ["past_key_values"]
@ -212,7 +249,6 @@ class NanoChatPreTrainedModel(LlamaPreTrainedModel):
_supports_flex_attn = True _supports_flex_attn = True
_can_compile_fullgraph = True _can_compile_fullgraph = True
_supports_attention_backend = True
_can_record_outputs = { _can_record_outputs = {
"hidden_states": NanoChatDecoderLayer, "hidden_states": NanoChatDecoderLayer,
@ -233,19 +269,14 @@ class NanoChatPreTrainedModel(LlamaPreTrainedModel):
@auto_docstring @auto_docstring
class NanoChatModel(LlamaModel): class NanoChatModel(NanoChatPreTrainedModel):
"""
NanoChat model that inherits from LlamaModel but uses NanoChat-specific layers
and RMSNorm without learnable weights.
"""
def __init__(self, config: NanoChatConfig): def __init__(self, config: NanoChatConfig):
# Call PreTrainedModel.__init__ directly to avoid LlamaModel's __init__ super().__init__(config)
PreTrainedModel.__init__(self, config)
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.initial_norm = NanoChatRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[NanoChatDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] [NanoChatDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
) )
@ -255,6 +286,8 @@ class NanoChatModel(LlamaModel):
self.post_init() self.post_init()
@check_model_inputs()
@auto_docstring
def forward( def forward(
self, self,
input_ids: Optional[torch.LongTensor] = None, input_ids: Optional[torch.LongTensor] = None,
@ -262,25 +295,25 @@ class NanoChatModel(LlamaModel):
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None, past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None, use_cache: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[TransformersKwargs], **kwargs: Unpack[TransformersKwargs],
) -> BaseModelOutputWithPast: ) -> BaseModelOutputWithPast:
from ...cache_utils import DynamicCache output_attentions = kwargs.get("output_attentions", False)
from ...masking_utils import create_causal_mask output_hidden_states = kwargs.get("output_hidden_states", False)
if (input_ids is None) ^ (inputs_embeds is not None): if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds") raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds: torch.Tensor = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
if use_cache and past_key_values is None: if use_cache and past_key_values is None:
past_key_values = DynamicCache(config=self.config) past_key_values = DynamicCache(config=self.config)
if cache_position is None: if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position: torch.Tensor = torch.arange( cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
) )
@ -298,41 +331,65 @@ class NanoChatModel(LlamaModel):
hidden_states = inputs_embeds hidden_states = inputs_embeds
position_embeddings = self.rotary_emb(hidden_states, position_ids) position_embeddings = self.rotary_emb(hidden_states, position_ids)
hidden_states = self.initial_norm(hidden_states)
for decoder_layer in self.layers[: self.config.num_hidden_layers]: # Collect hidden states and attentions if requested
hidden_states = decoder_layer( all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
hidden_states, self_attn_weights = decoder_layer(
hidden_states, hidden_states,
attention_mask=causal_mask, attention_mask=causal_mask,
position_ids=position_ids, position_ids=position_ids,
past_key_values=past_key_values, past_key_values=past_key_values,
use_cache=use_cache,
cache_position=cache_position, cache_position=cache_position,
position_embeddings=position_embeddings, position_embeddings=position_embeddings,
**kwargs, **kwargs,
) )
if output_attentions:
all_self_attns = all_self_attns + (self_attn_weights,)
hidden_states = self.norm(hidden_states) hidden_states = self.norm(hidden_states)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
return BaseModelOutputWithPast( return BaseModelOutputWithPast(
last_hidden_state=hidden_states, last_hidden_state=hidden_states,
past_key_values=past_key_values, past_key_values=past_key_values if use_cache else None,
hidden_states=all_hidden_states,
attentions=all_self_attns,
) )
@auto_docstring @auto_docstring
class NanoChatForCausalLM(LlamaForCausalLM): class NanoChatForCausalLM(NanoChatPreTrainedModel, GenerationMixin):
""" """
The NanoChat Model transformer with a language modeling head on top. The NanoChat Model transformer with a language modeling head on top.
Inherits from LlamaForCausalLM but uses NanoChatModel and supports logits soft capping.
""" """
def __init__(self, config: NanoChatConfig): _tied_weights_keys = ["lm_head.weight"]
# Call PreTrainedModel.__init__ directly _tp_plan = {"lm_head": "colwise_rep"}
PreTrainedModel.__init__(self, config) _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
self.model = NanoChatModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
def __init__(self, config: NanoChatConfig):
super().__init__(config)
self.model = NanoChatModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.post_init() self.post_init()
def get_input_embeddings(self):
return self.model.get_input_embeddings()
def set_input_embeddings(self, new_embeddings: nn.Embedding) -> None:
self.model.set_input_embeddings(new_embeddings)
@can_return_tuple
@auto_docstring
def forward( def forward(
self, self,
input_ids: Optional[torch.LongTensor] = None, input_ids: Optional[torch.LongTensor] = None,
@ -346,7 +403,31 @@ class NanoChatForCausalLM(LlamaForCausalLM):
logits_to_keep: Union[int, torch.Tensor] = 0, logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs: Unpack[TransformersKwargs], **kwargs: Unpack[TransformersKwargs],
) -> CausalLMOutputWithPast: ) -> CausalLMOutputWithPast:
outputs: BaseModelOutputWithPast = self.model( r"""
Example:
```python
>>> from transformers import AutoTokenizer, AutoModelForCausalLM
>>> model = AutoModelForCausalLM.from_pretrained(model_id"karpathy/nanochat-d32")
>>> tokenizer = AutoTokenizer.from_pretrained("karpathy/nanochat-d32")
>>> conversation = [
{"role": "user", "content": "What is the capital of France?"},
]
>>> inputs = tokenizer.apply_chat_template(
conversation, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt"
).to(device)
>>> with torch.no_grad():
>>> outputs = model.generate(**inputs, max_new_tokens=64, do_sample=False)
>>> generated_tokens = outputs[0, inputs["input_ids"].shape[1] :]
>>> output = tokenizer.decode(generated_tokens, skip_special_tokens=True)
```"""
outputs = self.model(
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
position_ids=position_ids, position_ids=position_ids,
@ -360,8 +441,6 @@ class NanoChatForCausalLM(LlamaForCausalLM):
hidden_states = outputs.last_hidden_state hidden_states = outputs.last_hidden_state
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :]) logits = self.lm_head(hidden_states[:, slice_indices, :])
# NanoChat-specific: Apply logits soft capping if configured
if self.config.logits_soft_cap is not None: if self.config.logits_soft_cap is not None:
cap = self.config.logits_soft_cap cap = self.config.logits_soft_cap
logits = cap * torch.tanh(logits / cap) logits = cap * torch.tanh(logits / cap)
@ -384,4 +463,3 @@ __all__ = [
"NanoChatModel", "NanoChatModel",
"NanoChatForCausalLM", "NanoChatForCausalLM",
] ]