revert changes and use less llama modules

This commit is contained in:
Ben Burtenshaw
2025-10-18 18:31:36 +00:00
parent a572394fec
commit 99801d04cd
2 changed files with 312 additions and 186 deletions

View File

@ -20,11 +20,11 @@
# limitations under the License.
import math
from collections.abc import Callable
from typing import Optional, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
@ -33,15 +33,40 @@ from ...masking_utils import create_causal_mask
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import PreTrainedModel
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
from ...utils.generic import check_model_inputs
from .configuration_nanochat import NanoChatConfig
class NanoChatRMSNorm(torch.nn.Module):
"""
NanoChatRMSNorm inherits from Llama4TextL2Norm (weight-less RMS normalization).
Overrides __init__ to match NanoChat's API with hidden_size parameter.
"""
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.eps = eps
self.hidden_size = hidden_size
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
return self._norm(x.float()).type_as(x)
def extra_repr(self):
return f"eps={self.eps}"
class NanoChatRotaryEmbedding(nn.Module):
"""Inherits from LlamaRotaryEmbedding but uses NanoChat's rotate_half."""
"""
NanoChat's Rotary Position Embedding.
Inherits from LlamaRotaryEmbedding but produces cos/sin tensors with shape
[batch, seq_len, 1, head_dim//2] instead of duplicating to full head_dim.
"""
inv_freq: torch.Tensor # fix linting for `register_buffer`
@ -65,38 +90,31 @@ class NanoChatRotaryEmbedding(nn.Module):
@torch.no_grad()
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
def forward(self, x, position_ids):
"""
Returns cos and sin tensors for NanoChat's RoPE.
Unlike LlamaRotaryEmbedding which duplicates freqs to full head_dim,
NanoChat keeps only head_dim//2 for memory efficiency.
"""
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
position_ids_expanded = position_ids[:, None, :].float()
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() * self.attention_scaling
sin = emb.sin() * self.attention_scaling
# NanoChat-specific: Don't duplicate freqs - keep as head_dim//2
cos = freqs.cos() * self.attention_scaling
sin = freqs.sin() * self.attention_scaling
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
# Add extra dimension for NanoChat's broadcasting: [batch, seq_len] -> [batch, seq_len, 1, head_dim//2]
return cos.unsqueeze(2).to(dtype=x.dtype), sin.unsqueeze(2).to(dtype=x.dtype)
class NanoChatRMSNorm(torch.nn.Module):
"""
NanoChatRMSNorm inherits from Llama4TextL2Norm (weight-less RMS normalization).
Overrides __init__ to match NanoChat's API with hidden_size parameter.
"""
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.eps = eps
self.hidden_size = hidden_size
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
return self._norm(x.float()).type_as(x)
def extra_repr(self):
return f"eps={self.eps}"
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
@ -164,38 +182,34 @@ def eager_attention_forward(
return attn_output, attn_weights
def rotate_half(x):
"""Rotates half the hidden dims of the input.
NanoChat uses a different rotation convention than standard Llama.
Llama uses: [-x2, x1], NanoChat uses: [x2, -x1] to match the original nanochat implementation.
This results in: [q1 * cos + q2 * sin, -(q1 * sin) + q2 * cos] instead of [q1 * cos - q2 * sin, q1 * sin + q2 * cos]
"""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((x2, -x1), dim=-1)
class NanoChatAttention(nn.Module):
"""
Multi-headed attention from NanoChat with custom QK normalization.
Inherits from LlamaAttention but adds RMSNorm to queries and keys after RoPE.
Multi-headed attention from NanoChat with custom RoPE and QK normalization.
Based on: https://github.com/karpathy/nanochat/blob/main/nanochat/gpt.py#L64
"""
def __init__(self, config: NanoChatConfig, layer_idx: int):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
self.scaling = self.head_dim**-0.5
self.attention_dropout = config.attention_dropout
self.is_causal = True
# Override bias settings for NanoChat
self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.qkv_bias)
self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.qkv_bias)
self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.qkv_bias)
self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.qkv_bias)
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.num_kv_heads = config.num_key_value_heads
self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_groups = self.num_heads // self.num_kv_heads
self.attention_dropout = config.attention_dropout
self.scaling = self.head_dim**-0.5
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.qkv_bias)
self.k_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=config.qkv_bias)
self.v_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=config.qkv_bias)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.qkv_bias)
self.query_norm = NanoChatRMSNorm(self.head_dim, eps=config.rms_norm_eps)
self.key_norm = NanoChatRMSNorm(self.head_dim, eps=config.rms_norm_eps)
def forward(
self,
@ -206,28 +220,46 @@ class NanoChatAttention(nn.Module):
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
batch, seq_len, _ = hidden_states.shape
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
# Project the input to get queries, keys, and values [batch, num_heads, seq_len, head_dim]
query_states = (
self.q_proj(hidden_states).view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
)
key_states = (
self.k_proj(hidden_states)
.view(batch, seq_len, self.num_kv_heads, self.head_dim)
.transpose(1, 2)
.contiguous()
)
value_states = (
self.v_proj(hidden_states)
.view(batch, seq_len, self.num_kv_heads, self.head_dim)
.transpose(1, 2)
.contiguous()
)
cos, sin = position_embeddings
# Apply Rotary Embeddings to queries and keys to get relative positional encoding
cos, sin = position_embeddings # [batch, seq_len, 1, head_dim//2]
cos = cos.squeeze(2)
sin = sin.squeeze(2)
cos = torch.cat([cos, cos], dim=-1)
sin = torch.cat([-sin, -sin], dim=-1)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
# NanoChat-specific: Apply QK normalization after RoPE
query_states = F.rms_norm(query_states, (query_states.size(-1),), eps=self.config.rms_norm_eps)
key_states = F.rms_norm(key_states, (key_states.size(-1),), eps=self.config.rms_norm_eps)
# Apply QK normalization (RMSNorm)
query_states = self.query_norm(query_states)
key_states = self.key_norm(key_states)
# Apply KV cache: insert current k,v into cache, get the full view so far
if past_key_values is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
cache_kwargs = {"cache_position": cache_position}
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
attention_interface = eager_attention_forward
# Use attention interface pattern for vLLM compatibility
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, attn_weights = attention_interface(
@ -241,6 +273,7 @@ class NanoChatAttention(nn.Module):
**kwargs,
)
# Reshape and project output
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
@ -263,16 +296,13 @@ class NanoChatMLP(nn.Module):
class NanoChatDecoderLayer(GradientCheckpointingLayer):
"""
NanoChat decoder layer with pre-norm architecture.
"""
"""NanoChat decoder layer with pre-norm architecture."""
def __init__(self, config: NanoChatConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.config = config
self.self_attn = NanoChatAttention(config, layer_idx)
self.mlp = NanoChatMLP(config)
# Replace Llama's norm layers with NanoChat's weight-less norm
self.input_layernorm = NanoChatRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = NanoChatRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@ -286,15 +316,13 @@ class NanoChatDecoderLayer(GradientCheckpointingLayer):
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs: Unpack[TransformersKwargs],
) -> torch.Tensor:
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states, _ = self.self_attn(
hidden_states, self_attn_weights = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
@ -305,7 +333,7 @@ class NanoChatDecoderLayer(GradientCheckpointingLayer):
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
return hidden_states, self_attn_weights
@auto_docstring
@ -343,18 +371,13 @@ class NanoChatPreTrainedModel(PreTrainedModel):
@auto_docstring
class NanoChatModel(NanoChatPreTrainedModel):
"""
NanoChat model that inherits from LlamaModel but uses NanoChat-specific layers
and RMSNorm without learnable weights.
"""
def __init__(self, config: NanoChatConfig):
# Call PreTrainedModel.__init__ directly to avoid LlamaModel's __init__
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.initial_norm = NanoChatRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.layers = nn.ModuleList(
[NanoChatDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
@ -364,6 +387,12 @@ class NanoChatModel(NanoChatPreTrainedModel):
self.post_init()
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, new_embeddings: nn.Embedding) -> None:
self.embed_tokens = new_embeddings
@check_model_inputs()
@auto_docstring
def forward(
@ -373,22 +402,25 @@ class NanoChatModel(NanoChatPreTrainedModel):
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> BaseModelOutputWithPast:
output_attentions = kwargs.get("output_attentions", False)
output_hidden_states = kwargs.get("output_hidden_states", False)
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if inputs_embeds is None:
inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
inputs_embeds = self.embed_tokens(input_ids)
if use_cache and past_key_values is None:
past_key_values = DynamicCache(config=self.config)
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position: torch.Tensor = torch.arange(
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
@ -406,22 +438,38 @@ class NanoChatModel(NanoChatPreTrainedModel):
hidden_states = inputs_embeds
position_embeddings = self.rotary_emb(hidden_states, position_ids)
hidden_states = self.initial_norm(hidden_states)
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
hidden_states = decoder_layer(
# Collect hidden states and attentions if requested
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
hidden_states, self_attn_weights = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
if output_attentions:
all_self_attns = all_self_attns + (self_attn_weights,)
hidden_states = self.norm(hidden_states)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values,
past_key_values=past_key_values if use_cache else None,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
@ -429,7 +477,6 @@ class NanoChatModel(NanoChatPreTrainedModel):
class NanoChatForCausalLM(NanoChatPreTrainedModel, GenerationMixin):
"""
The NanoChat Model transformer with a language modeling head on top.
Inherits from LlamaForCausalLM but uses NanoChatModel and supports logits soft capping.
"""
_tied_weights_keys = ["lm_head.weight"]
@ -437,14 +484,17 @@ class NanoChatForCausalLM(NanoChatPreTrainedModel, GenerationMixin):
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
def __init__(self, config: NanoChatConfig):
# Call PreTrainedModel.__init__ directly
super().__init__(config)
self.model = NanoChatModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.post_init()
def get_input_embeddings(self):
return self.model.get_input_embeddings()
def set_input_embeddings(self, new_embeddings: nn.Embedding) -> None:
self.model.set_input_embeddings(new_embeddings)
@can_return_tuple
@auto_docstring
def forward(
@ -477,7 +527,7 @@ class NanoChatForCausalLM(NanoChatPreTrainedModel, GenerationMixin):
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
outputs: BaseModelOutputWithPast = self.model(
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
@ -491,8 +541,6 @@ class NanoChatForCausalLM(NanoChatPreTrainedModel, GenerationMixin):
hidden_states = outputs.last_hidden_state
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
# NanoChat-specific: Apply logits soft capping if configured
if self.config.logits_soft_cap is not None:
cap = self.config.logits_soft_cap
logits = cap * torch.tanh(logits / cap)

View File

@ -14,23 +14,24 @@
# limitations under the License.
import math
from collections.abc import Callable
from typing import Optional, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...activations import ACT2FN
from ...cache_utils import Cache
from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
from ...masking_utils import create_causal_mask
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ...modeling_utils import PreTrainedModel
from ...modeling_rope_utils import dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
from ...utils.generic import check_model_inputs
from ..llama.modeling_llama import (
LlamaAttention,
LlamaDecoderLayer,
LlamaForCausalLM,
LlamaModel,
LlamaPreTrainedModel,
LlamaRotaryEmbedding,
apply_rotary_pos_emb,
@ -40,23 +41,6 @@ from ..llama4.modeling_llama4 import Llama4TextL2Norm
from .configuration_nanochat import NanoChatConfig
def rotate_half(x):
"""Rotates half the hidden dims of the input.
NanoChat uses a different rotation convention than standard Llama.
Llama uses: [-x2, x1], NanoChat uses: [x2, -x1] to match the original nanochat implementation.
This results in: [q1 * cos + q2 * sin, -(q1 * sin) + q2 * cos] instead of [q1 * cos - q2 * sin, q1 * sin + q2 * cos]
"""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((x2, -x1), dim=-1)
class NanoChatRotaryEmbedding(LlamaRotaryEmbedding):
"""Inherits from LlamaRotaryEmbedding but uses NanoChat's rotate_half."""
pass
class NanoChatRMSNorm(Llama4TextL2Norm):
"""
NanoChatRMSNorm inherits from Llama4TextL2Norm (weight-less RMS normalization).
@ -68,28 +52,64 @@ class NanoChatRMSNorm(Llama4TextL2Norm):
self.hidden_size = hidden_size
class NanoChatAttention(LlamaAttention):
class NanoChatRotaryEmbedding(LlamaRotaryEmbedding):
"""
Multi-headed attention from NanoChat with custom QK normalization.
Inherits from LlamaAttention but adds RMSNorm to queries and keys after RoPE.
NanoChat's Rotary Position Embedding.
Inherits from LlamaRotaryEmbedding but produces cos/sin tensors with shape
[batch, seq_len, 1, head_dim//2] instead of duplicating to full head_dim.
"""
@torch.no_grad()
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
def forward(self, x, position_ids):
"""
Returns cos and sin tensors for NanoChat's RoPE.
Unlike LlamaRotaryEmbedding which duplicates freqs to full head_dim,
NanoChat keeps only head_dim//2 for memory efficiency.
"""
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
position_ids_expanded = position_ids[:, None, :].float()
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
# NanoChat-specific: Don't duplicate freqs - keep as head_dim//2
cos = freqs.cos() * self.attention_scaling
sin = freqs.sin() * self.attention_scaling
# Add extra dimension for NanoChat's broadcasting: [batch, seq_len] -> [batch, seq_len, 1, head_dim//2]
return cos.unsqueeze(2).to(dtype=x.dtype), sin.unsqueeze(2).to(dtype=x.dtype)
class NanoChatAttention(nn.Module):
"""
Multi-headed attention from NanoChat with custom RoPE and QK normalization.
Based on: https://github.com/karpathy/nanochat/blob/main/nanochat/gpt.py#L64
"""
def __init__(self, config: NanoChatConfig, layer_idx: int):
super().__init__(config, layer_idx)
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.is_causal = True
# Override bias settings for NanoChat
self.q_proj = nn.Linear(
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.qkv_bias
)
self.k_proj = nn.Linear(
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.qkv_bias
)
self.v_proj = nn.Linear(
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.qkv_bias
)
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.qkv_bias
)
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.num_kv_heads = config.num_key_value_heads
self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_groups = self.num_heads // self.num_kv_heads
self.attention_dropout = config.attention_dropout
self.scaling = self.head_dim**-0.5
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.qkv_bias)
self.k_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=config.qkv_bias)
self.v_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=config.qkv_bias)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.qkv_bias)
self.query_norm = NanoChatRMSNorm(self.head_dim, eps=config.rms_norm_eps)
self.key_norm = NanoChatRMSNorm(self.head_dim, eps=config.rms_norm_eps)
def forward(
self,
@ -100,27 +120,46 @@ class NanoChatAttention(LlamaAttention):
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
batch, seq_len, _ = hidden_states.shape
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
# Project the input to get queries, keys, and values [batch, num_heads, seq_len, head_dim]
query_states = (
self.q_proj(hidden_states).view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
)
key_states = (
self.k_proj(hidden_states)
.view(batch, seq_len, self.num_kv_heads, self.head_dim)
.transpose(1, 2)
.contiguous()
)
value_states = (
self.v_proj(hidden_states)
.view(batch, seq_len, self.num_kv_heads, self.head_dim)
.transpose(1, 2)
.contiguous()
)
cos, sin = position_embeddings
# Apply Rotary Embeddings to queries and keys to get relative positional encoding
cos, sin = position_embeddings # [batch, seq_len, 1, head_dim//2]
cos = cos.squeeze(2)
sin = sin.squeeze(2)
cos = torch.cat([cos, cos], dim=-1)
sin = torch.cat([-sin, -sin], dim=-1)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
# NanoChat-specific: Apply QK normalization after RoPE
query_states = F.rms_norm(query_states, (query_states.size(-1),), eps=self.config.rms_norm_eps)
key_states = F.rms_norm(key_states, (key_states.size(-1),), eps=self.config.rms_norm_eps)
# Apply QK normalization (RMSNorm)
query_states = self.query_norm(query_states)
key_states = self.key_norm(key_states)
# Apply KV cache: insert current k,v into cache, get the full view so far
if past_key_values is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
cache_kwargs = {"cache_position": cache_position}
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
attention_interface = eager_attention_forward
# Use attention interface pattern for vLLM compatibility
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, attn_weights = attention_interface(
@ -134,6 +173,7 @@ class NanoChatAttention(LlamaAttention):
**kwargs,
)
# Reshape and project output
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
@ -155,16 +195,14 @@ class NanoChatMLP(nn.Module):
return hidden_states
class NanoChatDecoderLayer(LlamaDecoderLayer):
"""
NanoChat decoder layer with pre-norm architecture.
"""
class NanoChatDecoderLayer(GradientCheckpointingLayer):
"""NanoChat decoder layer with pre-norm architecture."""
def __init__(self, config: NanoChatConfig, layer_idx: int):
super().__init__(config, layer_idx)
super().__init__()
self.config = config
self.self_attn = NanoChatAttention(config, layer_idx)
self.mlp = NanoChatMLP(config)
# Replace Llama's norm layers with NanoChat's weight-less norm
self.input_layernorm = NanoChatRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = NanoChatRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@ -178,15 +216,13 @@ class NanoChatDecoderLayer(LlamaDecoderLayer):
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs: Unpack[TransformersKwargs],
) -> torch.Tensor:
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states, _ = self.self_attn(
hidden_states, self_attn_weights = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
@ -197,13 +233,14 @@ class NanoChatDecoderLayer(LlamaDecoderLayer):
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
return hidden_states, self_attn_weights
@auto_docstring
class NanoChatPreTrainedModel(LlamaPreTrainedModel):
config_class = NanoChatConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["NanoChatDecoderLayer"]
_supports_attention_backend = True
_skip_keys_device_placement = ["past_key_values"]
@ -212,16 +249,15 @@ class NanoChatPreTrainedModel(LlamaPreTrainedModel):
_supports_flex_attn = True
_can_compile_fullgraph = True
_supports_attention_backend = True
_can_record_outputs = {
"hidden_states": NanoChatDecoderLayer,
"attentions": NanoChatAttention,
}
def _init_weights(self, module: nn.Module) -> None:
super()._init_weights(module)
# NanoChat-specific: scaled initialization for output projection
for name, param in module.named_parameters():
if name == "o_proj.weight":
@ -233,19 +269,14 @@ class NanoChatPreTrainedModel(LlamaPreTrainedModel):
@auto_docstring
class NanoChatModel(LlamaModel):
"""
NanoChat model that inherits from LlamaModel but uses NanoChat-specific layers
and RMSNorm without learnable weights.
"""
class NanoChatModel(NanoChatPreTrainedModel):
def __init__(self, config: NanoChatConfig):
# Call PreTrainedModel.__init__ directly to avoid LlamaModel's __init__
PreTrainedModel.__init__(self, config)
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.initial_norm = NanoChatRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.layers = nn.ModuleList(
[NanoChatDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
@ -255,6 +286,8 @@ class NanoChatModel(LlamaModel):
self.post_init()
@check_model_inputs()
@auto_docstring
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
@ -262,25 +295,25 @@ class NanoChatModel(LlamaModel):
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> BaseModelOutputWithPast:
from ...cache_utils import DynamicCache
from ...masking_utils import create_causal_mask
output_attentions = kwargs.get("output_attentions", False)
output_hidden_states = kwargs.get("output_hidden_states", False)
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if inputs_embeds is None:
inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
inputs_embeds = self.embed_tokens(input_ids)
if use_cache and past_key_values is None:
past_key_values = DynamicCache(config=self.config)
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position: torch.Tensor = torch.arange(
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
@ -298,41 +331,65 @@ class NanoChatModel(LlamaModel):
hidden_states = inputs_embeds
position_embeddings = self.rotary_emb(hidden_states, position_ids)
hidden_states = self.initial_norm(hidden_states)
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
hidden_states = decoder_layer(
# Collect hidden states and attentions if requested
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
hidden_states, self_attn_weights = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
if output_attentions:
all_self_attns = all_self_attns + (self_attn_weights,)
hidden_states = self.norm(hidden_states)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values,
past_key_values=past_key_values if use_cache else None,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
@auto_docstring
class NanoChatForCausalLM(LlamaForCausalLM):
class NanoChatForCausalLM(NanoChatPreTrainedModel, GenerationMixin):
"""
The NanoChat Model transformer with a language modeling head on top.
Inherits from LlamaForCausalLM but uses NanoChatModel and supports logits soft capping.
"""
def __init__(self, config: NanoChatConfig):
# Call PreTrainedModel.__init__ directly
PreTrainedModel.__init__(self, config)
self.model = NanoChatModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
def __init__(self, config: NanoChatConfig):
super().__init__(config)
self.model = NanoChatModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.post_init()
def get_input_embeddings(self):
return self.model.get_input_embeddings()
def set_input_embeddings(self, new_embeddings: nn.Embedding) -> None:
self.model.set_input_embeddings(new_embeddings)
@can_return_tuple
@auto_docstring
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
@ -346,7 +403,31 @@ class NanoChatForCausalLM(LlamaForCausalLM):
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs: Unpack[TransformersKwargs],
) -> CausalLMOutputWithPast:
outputs: BaseModelOutputWithPast = self.model(
r"""
Example:
```python
>>> from transformers import AutoTokenizer, AutoModelForCausalLM
>>> model = AutoModelForCausalLM.from_pretrained(model_id"karpathy/nanochat-d32")
>>> tokenizer = AutoTokenizer.from_pretrained("karpathy/nanochat-d32")
>>> conversation = [
{"role": "user", "content": "What is the capital of France?"},
]
>>> inputs = tokenizer.apply_chat_template(
conversation, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt"
).to(device)
>>> with torch.no_grad():
>>> outputs = model.generate(**inputs, max_new_tokens=64, do_sample=False)
>>> generated_tokens = outputs[0, inputs["input_ids"].shape[1] :]
>>> output = tokenizer.decode(generated_tokens, skip_special_tokens=True)
```"""
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
@ -360,8 +441,6 @@ class NanoChatForCausalLM(LlamaForCausalLM):
hidden_states = outputs.last_hidden_state
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
# NanoChat-specific: Apply logits soft capping if configured
if self.config.logits_soft_cap is not None:
cap = self.config.logits_soft_cap
logits = cap * torch.tanh(logits / cap)
@ -384,4 +463,3 @@ __all__ = [
"NanoChatModel",
"NanoChatForCausalLM",
]