mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-21 01:23:56 +08:00
revert changes and use less llama modules
This commit is contained in:
@ -20,11 +20,11 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import math
|
import math
|
||||||
|
from collections.abc import Callable
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...cache_utils import Cache, DynamicCache
|
from ...cache_utils import Cache, DynamicCache
|
||||||
@ -33,15 +33,40 @@ from ...masking_utils import create_causal_mask
|
|||||||
from ...modeling_layers import GradientCheckpointingLayer
|
from ...modeling_layers import GradientCheckpointingLayer
|
||||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||||
from ...processing_utils import Unpack
|
from ...processing_utils import Unpack
|
||||||
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
|
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
|
||||||
from ...utils.generic import check_model_inputs
|
from ...utils.generic import check_model_inputs
|
||||||
from .configuration_nanochat import NanoChatConfig
|
from .configuration_nanochat import NanoChatConfig
|
||||||
|
|
||||||
|
|
||||||
|
class NanoChatRMSNorm(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
NanoChatRMSNorm inherits from Llama4TextL2Norm (weight-less RMS normalization).
|
||||||
|
Overrides __init__ to match NanoChat's API with hidden_size parameter.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, hidden_size, eps=1e-6):
|
||||||
|
super().__init__()
|
||||||
|
self.eps = eps
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
|
||||||
|
def _norm(self, x):
|
||||||
|
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self._norm(x.float()).type_as(x)
|
||||||
|
|
||||||
|
def extra_repr(self):
|
||||||
|
return f"eps={self.eps}"
|
||||||
|
|
||||||
|
|
||||||
class NanoChatRotaryEmbedding(nn.Module):
|
class NanoChatRotaryEmbedding(nn.Module):
|
||||||
"""Inherits from LlamaRotaryEmbedding but uses NanoChat's rotate_half."""
|
"""
|
||||||
|
NanoChat's Rotary Position Embedding.
|
||||||
|
Inherits from LlamaRotaryEmbedding but produces cos/sin tensors with shape
|
||||||
|
[batch, seq_len, 1, head_dim//2] instead of duplicating to full head_dim.
|
||||||
|
"""
|
||||||
|
|
||||||
inv_freq: torch.Tensor # fix linting for `register_buffer`
|
inv_freq: torch.Tensor # fix linting for `register_buffer`
|
||||||
|
|
||||||
@ -65,38 +90,31 @@ class NanoChatRotaryEmbedding(nn.Module):
|
|||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
|
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
|
||||||
def forward(self, x, position_ids):
|
def forward(self, x, position_ids):
|
||||||
|
"""
|
||||||
|
Returns cos and sin tensors for NanoChat's RoPE.
|
||||||
|
|
||||||
|
Unlike LlamaRotaryEmbedding which duplicates freqs to full head_dim,
|
||||||
|
NanoChat keeps only head_dim//2 for memory efficiency.
|
||||||
|
"""
|
||||||
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
|
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
|
||||||
position_ids_expanded = position_ids[:, None, :].float()
|
position_ids_expanded = position_ids[:, None, :].float()
|
||||||
|
|
||||||
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
||||||
with torch.autocast(device_type=device_type, enabled=False): # Force float32
|
with torch.autocast(device_type=device_type, enabled=False): # Force float32
|
||||||
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
||||||
emb = torch.cat((freqs, freqs), dim=-1)
|
# NanoChat-specific: Don't duplicate freqs - keep as head_dim//2
|
||||||
cos = emb.cos() * self.attention_scaling
|
cos = freqs.cos() * self.attention_scaling
|
||||||
sin = emb.sin() * self.attention_scaling
|
sin = freqs.sin() * self.attention_scaling
|
||||||
|
|
||||||
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
# Add extra dimension for NanoChat's broadcasting: [batch, seq_len] -> [batch, seq_len, 1, head_dim//2]
|
||||||
|
return cos.unsqueeze(2).to(dtype=x.dtype), sin.unsqueeze(2).to(dtype=x.dtype)
|
||||||
|
|
||||||
|
|
||||||
class NanoChatRMSNorm(torch.nn.Module):
|
def rotate_half(x):
|
||||||
"""
|
"""Rotates half the hidden dims of the input."""
|
||||||
NanoChatRMSNorm inherits from Llama4TextL2Norm (weight-less RMS normalization).
|
x1 = x[..., : x.shape[-1] // 2]
|
||||||
Overrides __init__ to match NanoChat's API with hidden_size parameter.
|
x2 = x[..., x.shape[-1] // 2 :]
|
||||||
"""
|
return torch.cat((-x2, x1), dim=-1)
|
||||||
|
|
||||||
def __init__(self, hidden_size, eps=1e-6):
|
|
||||||
super().__init__()
|
|
||||||
self.eps = eps
|
|
||||||
self.hidden_size = hidden_size
|
|
||||||
|
|
||||||
def _norm(self, x):
|
|
||||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return self._norm(x.float()).type_as(x)
|
|
||||||
|
|
||||||
def extra_repr(self):
|
|
||||||
return f"eps={self.eps}"
|
|
||||||
|
|
||||||
|
|
||||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||||
@ -164,38 +182,34 @@ def eager_attention_forward(
|
|||||||
return attn_output, attn_weights
|
return attn_output, attn_weights
|
||||||
|
|
||||||
|
|
||||||
def rotate_half(x):
|
|
||||||
"""Rotates half the hidden dims of the input.
|
|
||||||
|
|
||||||
NanoChat uses a different rotation convention than standard Llama.
|
|
||||||
Llama uses: [-x2, x1], NanoChat uses: [x2, -x1] to match the original nanochat implementation.
|
|
||||||
This results in: [q1 * cos + q2 * sin, -(q1 * sin) + q2 * cos] instead of [q1 * cos - q2 * sin, q1 * sin + q2 * cos]
|
|
||||||
"""
|
|
||||||
x1 = x[..., : x.shape[-1] // 2]
|
|
||||||
x2 = x[..., x.shape[-1] // 2 :]
|
|
||||||
return torch.cat((x2, -x1), dim=-1)
|
|
||||||
|
|
||||||
|
|
||||||
class NanoChatAttention(nn.Module):
|
class NanoChatAttention(nn.Module):
|
||||||
"""
|
"""
|
||||||
Multi-headed attention from NanoChat with custom QK normalization.
|
Multi-headed attention from NanoChat with custom RoPE and QK normalization.
|
||||||
Inherits from LlamaAttention but adds RMSNorm to queries and keys after RoPE.
|
|
||||||
|
Based on: https://github.com/karpathy/nanochat/blob/main/nanochat/gpt.py#L64
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config: NanoChatConfig, layer_idx: int):
|
def __init__(self, config: NanoChatConfig, layer_idx: int):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.layer_idx = layer_idx
|
self.layer_idx = layer_idx
|
||||||
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
|
||||||
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
|
|
||||||
self.scaling = self.head_dim**-0.5
|
|
||||||
self.attention_dropout = config.attention_dropout
|
|
||||||
self.is_causal = True
|
self.is_causal = True
|
||||||
# Override bias settings for NanoChat
|
|
||||||
self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.qkv_bias)
|
self.hidden_size = config.hidden_size
|
||||||
self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.qkv_bias)
|
self.num_heads = config.num_attention_heads
|
||||||
self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.qkv_bias)
|
self.num_kv_heads = config.num_key_value_heads
|
||||||
self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.qkv_bias)
|
self.head_dim = self.hidden_size // self.num_heads
|
||||||
|
self.num_key_value_groups = self.num_heads // self.num_kv_heads
|
||||||
|
|
||||||
|
self.attention_dropout = config.attention_dropout
|
||||||
|
self.scaling = self.head_dim**-0.5
|
||||||
|
|
||||||
|
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.qkv_bias)
|
||||||
|
self.k_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=config.qkv_bias)
|
||||||
|
self.v_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=config.qkv_bias)
|
||||||
|
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.qkv_bias)
|
||||||
|
self.query_norm = NanoChatRMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
||||||
|
self.key_norm = NanoChatRMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -206,28 +220,46 @@ class NanoChatAttention(nn.Module):
|
|||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
**kwargs: Unpack[TransformersKwargs],
|
**kwargs: Unpack[TransformersKwargs],
|
||||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
batch, seq_len, _ = hidden_states.shape
|
||||||
input_shape = hidden_states.shape[:-1]
|
input_shape = hidden_states.shape[:-1]
|
||||||
hidden_shape = (*input_shape, -1, self.head_dim)
|
|
||||||
|
|
||||||
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
# Project the input to get queries, keys, and values [batch, num_heads, seq_len, head_dim]
|
||||||
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
query_states = (
|
||||||
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
self.q_proj(hidden_states).view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
||||||
|
)
|
||||||
|
key_states = (
|
||||||
|
self.k_proj(hidden_states)
|
||||||
|
.view(batch, seq_len, self.num_kv_heads, self.head_dim)
|
||||||
|
.transpose(1, 2)
|
||||||
|
.contiguous()
|
||||||
|
)
|
||||||
|
value_states = (
|
||||||
|
self.v_proj(hidden_states)
|
||||||
|
.view(batch, seq_len, self.num_kv_heads, self.head_dim)
|
||||||
|
.transpose(1, 2)
|
||||||
|
.contiguous()
|
||||||
|
)
|
||||||
|
|
||||||
cos, sin = position_embeddings
|
# Apply Rotary Embeddings to queries and keys to get relative positional encoding
|
||||||
|
cos, sin = position_embeddings # [batch, seq_len, 1, head_dim//2]
|
||||||
|
cos = cos.squeeze(2)
|
||||||
|
sin = sin.squeeze(2)
|
||||||
|
cos = torch.cat([cos, cos], dim=-1)
|
||||||
|
sin = torch.cat([-sin, -sin], dim=-1)
|
||||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||||
|
|
||||||
# NanoChat-specific: Apply QK normalization after RoPE
|
# Apply QK normalization (RMSNorm)
|
||||||
query_states = F.rms_norm(query_states, (query_states.size(-1),), eps=self.config.rms_norm_eps)
|
query_states = self.query_norm(query_states)
|
||||||
key_states = F.rms_norm(key_states, (key_states.size(-1),), eps=self.config.rms_norm_eps)
|
key_states = self.key_norm(key_states)
|
||||||
|
|
||||||
|
# Apply KV cache: insert current k,v into cache, get the full view so far
|
||||||
if past_key_values is not None:
|
if past_key_values is not None:
|
||||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
cache_kwargs = {"cache_position": cache_position}
|
||||||
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||||
|
|
||||||
attention_interface = eager_attention_forward
|
# Use attention interface pattern for vLLM compatibility
|
||||||
|
attention_interface: Callable = eager_attention_forward
|
||||||
if self.config._attn_implementation != "eager":
|
if self.config._attn_implementation != "eager":
|
||||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
|
|
||||||
|
|
||||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||||
|
|
||||||
attn_output, attn_weights = attention_interface(
|
attn_output, attn_weights = attention_interface(
|
||||||
@ -241,6 +273,7 @@ class NanoChatAttention(nn.Module):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Reshape and project output
|
||||||
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
||||||
attn_output = self.o_proj(attn_output)
|
attn_output = self.o_proj(attn_output)
|
||||||
return attn_output, attn_weights
|
return attn_output, attn_weights
|
||||||
@ -263,16 +296,13 @@ class NanoChatMLP(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class NanoChatDecoderLayer(GradientCheckpointingLayer):
|
class NanoChatDecoderLayer(GradientCheckpointingLayer):
|
||||||
"""
|
"""NanoChat decoder layer with pre-norm architecture."""
|
||||||
NanoChat decoder layer with pre-norm architecture.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, config: NanoChatConfig, layer_idx: int):
|
def __init__(self, config: NanoChatConfig, layer_idx: int):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = config.hidden_size
|
self.config = config
|
||||||
self.self_attn = NanoChatAttention(config, layer_idx)
|
self.self_attn = NanoChatAttention(config, layer_idx)
|
||||||
self.mlp = NanoChatMLP(config)
|
self.mlp = NanoChatMLP(config)
|
||||||
# Replace Llama's norm layers with NanoChat's weight-less norm
|
|
||||||
self.input_layernorm = NanoChatRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.input_layernorm = NanoChatRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
self.post_attention_layernorm = NanoChatRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.post_attention_layernorm = NanoChatRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
|
||||||
@ -286,15 +316,13 @@ class NanoChatDecoderLayer(GradientCheckpointingLayer):
|
|||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
||||||
**kwargs: Unpack[TransformersKwargs],
|
**kwargs: Unpack[TransformersKwargs],
|
||||||
) -> torch.Tensor:
|
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
hidden_states = self.input_layernorm(hidden_states)
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
hidden_states, _ = self.self_attn(
|
hidden_states, self_attn_weights = self.self_attn(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
use_cache=use_cache,
|
|
||||||
cache_position=cache_position,
|
cache_position=cache_position,
|
||||||
position_embeddings=position_embeddings,
|
position_embeddings=position_embeddings,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@ -305,7 +333,7 @@ class NanoChatDecoderLayer(GradientCheckpointingLayer):
|
|||||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||||
hidden_states = self.mlp(hidden_states)
|
hidden_states = self.mlp(hidden_states)
|
||||||
hidden_states = residual + hidden_states
|
hidden_states = residual + hidden_states
|
||||||
return hidden_states
|
return hidden_states, self_attn_weights
|
||||||
|
|
||||||
|
|
||||||
@auto_docstring
|
@auto_docstring
|
||||||
@ -343,18 +371,13 @@ class NanoChatPreTrainedModel(PreTrainedModel):
|
|||||||
|
|
||||||
@auto_docstring
|
@auto_docstring
|
||||||
class NanoChatModel(NanoChatPreTrainedModel):
|
class NanoChatModel(NanoChatPreTrainedModel):
|
||||||
"""
|
|
||||||
NanoChat model that inherits from LlamaModel but uses NanoChat-specific layers
|
|
||||||
and RMSNorm without learnable weights.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, config: NanoChatConfig):
|
def __init__(self, config: NanoChatConfig):
|
||||||
# Call PreTrainedModel.__init__ directly to avoid LlamaModel's __init__
|
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.padding_idx = config.pad_token_id
|
self.padding_idx = config.pad_token_id
|
||||||
self.vocab_size = config.vocab_size
|
self.vocab_size = config.vocab_size
|
||||||
|
|
||||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
||||||
|
self.initial_norm = NanoChatRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
self.layers = nn.ModuleList(
|
self.layers = nn.ModuleList(
|
||||||
[NanoChatDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
[NanoChatDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
||||||
)
|
)
|
||||||
@ -364,6 +387,12 @@ class NanoChatModel(NanoChatPreTrainedModel):
|
|||||||
|
|
||||||
self.post_init()
|
self.post_init()
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
return self.embed_tokens
|
||||||
|
|
||||||
|
def set_input_embeddings(self, new_embeddings: nn.Embedding) -> None:
|
||||||
|
self.embed_tokens = new_embeddings
|
||||||
|
|
||||||
@check_model_inputs()
|
@check_model_inputs()
|
||||||
@auto_docstring
|
@auto_docstring
|
||||||
def forward(
|
def forward(
|
||||||
@ -373,22 +402,25 @@ class NanoChatModel(NanoChatPreTrainedModel):
|
|||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
past_key_values: Optional[Cache] = None,
|
past_key_values: Optional[Cache] = None,
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
|
||||||
use_cache: Optional[bool] = None,
|
use_cache: Optional[bool] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
**kwargs: Unpack[TransformersKwargs],
|
**kwargs: Unpack[TransformersKwargs],
|
||||||
) -> BaseModelOutputWithPast:
|
) -> BaseModelOutputWithPast:
|
||||||
|
output_attentions = kwargs.get("output_attentions", False)
|
||||||
|
output_hidden_states = kwargs.get("output_hidden_states", False)
|
||||||
|
|
||||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||||
|
|
||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
if use_cache and past_key_values is None:
|
if use_cache and past_key_values is None:
|
||||||
past_key_values = DynamicCache(config=self.config)
|
past_key_values = DynamicCache(config=self.config)
|
||||||
|
|
||||||
if cache_position is None:
|
if cache_position is None:
|
||||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||||
cache_position: torch.Tensor = torch.arange(
|
cache_position = torch.arange(
|
||||||
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -406,22 +438,38 @@ class NanoChatModel(NanoChatPreTrainedModel):
|
|||||||
|
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||||
|
hidden_states = self.initial_norm(hidden_states)
|
||||||
|
|
||||||
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
# Collect hidden states and attentions if requested
|
||||||
hidden_states = decoder_layer(
|
all_hidden_states = () if output_hidden_states else None
|
||||||
|
all_self_attns = () if output_attentions else None
|
||||||
|
for decoder_layer in self.layers:
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
|
hidden_states, self_attn_weights = decoder_layer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask=causal_mask,
|
attention_mask=causal_mask,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
|
use_cache=use_cache,
|
||||||
cache_position=cache_position,
|
cache_position=cache_position,
|
||||||
position_embeddings=position_embeddings,
|
position_embeddings=position_embeddings,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
all_self_attns = all_self_attns + (self_attn_weights,)
|
||||||
|
|
||||||
hidden_states = self.norm(hidden_states)
|
hidden_states = self.norm(hidden_states)
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
return BaseModelOutputWithPast(
|
return BaseModelOutputWithPast(
|
||||||
last_hidden_state=hidden_states,
|
last_hidden_state=hidden_states,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values if use_cache else None,
|
||||||
|
hidden_states=all_hidden_states,
|
||||||
|
attentions=all_self_attns,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -429,7 +477,6 @@ class NanoChatModel(NanoChatPreTrainedModel):
|
|||||||
class NanoChatForCausalLM(NanoChatPreTrainedModel, GenerationMixin):
|
class NanoChatForCausalLM(NanoChatPreTrainedModel, GenerationMixin):
|
||||||
"""
|
"""
|
||||||
The NanoChat Model transformer with a language modeling head on top.
|
The NanoChat Model transformer with a language modeling head on top.
|
||||||
Inherits from LlamaForCausalLM but uses NanoChatModel and supports logits soft capping.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_tied_weights_keys = ["lm_head.weight"]
|
_tied_weights_keys = ["lm_head.weight"]
|
||||||
@ -437,14 +484,17 @@ class NanoChatForCausalLM(NanoChatPreTrainedModel, GenerationMixin):
|
|||||||
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
||||||
|
|
||||||
def __init__(self, config: NanoChatConfig):
|
def __init__(self, config: NanoChatConfig):
|
||||||
# Call PreTrainedModel.__init__ directly
|
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.model = NanoChatModel(config)
|
self.model = NanoChatModel(config)
|
||||||
self.vocab_size = config.vocab_size
|
|
||||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||||
|
|
||||||
self.post_init()
|
self.post_init()
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
return self.model.get_input_embeddings()
|
||||||
|
|
||||||
|
def set_input_embeddings(self, new_embeddings: nn.Embedding) -> None:
|
||||||
|
self.model.set_input_embeddings(new_embeddings)
|
||||||
|
|
||||||
@can_return_tuple
|
@can_return_tuple
|
||||||
@auto_docstring
|
@auto_docstring
|
||||||
def forward(
|
def forward(
|
||||||
@ -477,7 +527,7 @@ class NanoChatForCausalLM(NanoChatPreTrainedModel, GenerationMixin):
|
|||||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||||
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||||||
```"""
|
```"""
|
||||||
outputs: BaseModelOutputWithPast = self.model(
|
outputs = self.model(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
@ -491,8 +541,6 @@ class NanoChatForCausalLM(NanoChatPreTrainedModel, GenerationMixin):
|
|||||||
hidden_states = outputs.last_hidden_state
|
hidden_states = outputs.last_hidden_state
|
||||||
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
||||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
||||||
|
|
||||||
# NanoChat-specific: Apply logits soft capping if configured
|
|
||||||
if self.config.logits_soft_cap is not None:
|
if self.config.logits_soft_cap is not None:
|
||||||
cap = self.config.logits_soft_cap
|
cap = self.config.logits_soft_cap
|
||||||
logits = cap * torch.tanh(logits / cap)
|
logits = cap * torch.tanh(logits / cap)
|
||||||
|
|||||||
@ -14,23 +14,24 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import math
|
import math
|
||||||
|
from collections.abc import Callable
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...cache_utils import Cache
|
from ...cache_utils import Cache, DynamicCache
|
||||||
|
from ...generation import GenerationMixin
|
||||||
|
from ...masking_utils import create_causal_mask
|
||||||
|
from ...modeling_layers import GradientCheckpointingLayer
|
||||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_rope_utils import dynamic_rope_update
|
||||||
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||||
from ...processing_utils import Unpack
|
from ...processing_utils import Unpack
|
||||||
from ...utils import TransformersKwargs, auto_docstring
|
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
|
||||||
|
from ...utils.generic import check_model_inputs
|
||||||
from ..llama.modeling_llama import (
|
from ..llama.modeling_llama import (
|
||||||
LlamaAttention,
|
|
||||||
LlamaDecoderLayer,
|
|
||||||
LlamaForCausalLM,
|
|
||||||
LlamaModel,
|
|
||||||
LlamaPreTrainedModel,
|
LlamaPreTrainedModel,
|
||||||
LlamaRotaryEmbedding,
|
LlamaRotaryEmbedding,
|
||||||
apply_rotary_pos_emb,
|
apply_rotary_pos_emb,
|
||||||
@ -40,23 +41,6 @@ from ..llama4.modeling_llama4 import Llama4TextL2Norm
|
|||||||
from .configuration_nanochat import NanoChatConfig
|
from .configuration_nanochat import NanoChatConfig
|
||||||
|
|
||||||
|
|
||||||
def rotate_half(x):
|
|
||||||
"""Rotates half the hidden dims of the input.
|
|
||||||
|
|
||||||
NanoChat uses a different rotation convention than standard Llama.
|
|
||||||
Llama uses: [-x2, x1], NanoChat uses: [x2, -x1] to match the original nanochat implementation.
|
|
||||||
This results in: [q1 * cos + q2 * sin, -(q1 * sin) + q2 * cos] instead of [q1 * cos - q2 * sin, q1 * sin + q2 * cos]
|
|
||||||
"""
|
|
||||||
x1 = x[..., : x.shape[-1] // 2]
|
|
||||||
x2 = x[..., x.shape[-1] // 2 :]
|
|
||||||
return torch.cat((x2, -x1), dim=-1)
|
|
||||||
|
|
||||||
|
|
||||||
class NanoChatRotaryEmbedding(LlamaRotaryEmbedding):
|
|
||||||
"""Inherits from LlamaRotaryEmbedding but uses NanoChat's rotate_half."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class NanoChatRMSNorm(Llama4TextL2Norm):
|
class NanoChatRMSNorm(Llama4TextL2Norm):
|
||||||
"""
|
"""
|
||||||
NanoChatRMSNorm inherits from Llama4TextL2Norm (weight-less RMS normalization).
|
NanoChatRMSNorm inherits from Llama4TextL2Norm (weight-less RMS normalization).
|
||||||
@ -68,28 +52,64 @@ class NanoChatRMSNorm(Llama4TextL2Norm):
|
|||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
|
|
||||||
|
|
||||||
class NanoChatAttention(LlamaAttention):
|
class NanoChatRotaryEmbedding(LlamaRotaryEmbedding):
|
||||||
"""
|
"""
|
||||||
Multi-headed attention from NanoChat with custom QK normalization.
|
NanoChat's Rotary Position Embedding.
|
||||||
Inherits from LlamaAttention but adds RMSNorm to queries and keys after RoPE.
|
Inherits from LlamaRotaryEmbedding but produces cos/sin tensors with shape
|
||||||
|
[batch, seq_len, 1, head_dim//2] instead of duplicating to full head_dim.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
|
||||||
|
def forward(self, x, position_ids):
|
||||||
|
"""
|
||||||
|
Returns cos and sin tensors for NanoChat's RoPE.
|
||||||
|
|
||||||
|
Unlike LlamaRotaryEmbedding which duplicates freqs to full head_dim,
|
||||||
|
NanoChat keeps only head_dim//2 for memory efficiency.
|
||||||
|
"""
|
||||||
|
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
|
||||||
|
position_ids_expanded = position_ids[:, None, :].float()
|
||||||
|
|
||||||
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
||||||
|
with torch.autocast(device_type=device_type, enabled=False): # Force float32
|
||||||
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
||||||
|
# NanoChat-specific: Don't duplicate freqs - keep as head_dim//2
|
||||||
|
cos = freqs.cos() * self.attention_scaling
|
||||||
|
sin = freqs.sin() * self.attention_scaling
|
||||||
|
|
||||||
|
# Add extra dimension for NanoChat's broadcasting: [batch, seq_len] -> [batch, seq_len, 1, head_dim//2]
|
||||||
|
return cos.unsqueeze(2).to(dtype=x.dtype), sin.unsqueeze(2).to(dtype=x.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class NanoChatAttention(nn.Module):
|
||||||
|
"""
|
||||||
|
Multi-headed attention from NanoChat with custom RoPE and QK normalization.
|
||||||
|
|
||||||
|
Based on: https://github.com/karpathy/nanochat/blob/main/nanochat/gpt.py#L64
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config: NanoChatConfig, layer_idx: int):
|
def __init__(self, config: NanoChatConfig, layer_idx: int):
|
||||||
super().__init__(config, layer_idx)
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.layer_idx = layer_idx
|
||||||
self.is_causal = True
|
self.is_causal = True
|
||||||
# Override bias settings for NanoChat
|
|
||||||
self.q_proj = nn.Linear(
|
self.hidden_size = config.hidden_size
|
||||||
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.qkv_bias
|
self.num_heads = config.num_attention_heads
|
||||||
)
|
self.num_kv_heads = config.num_key_value_heads
|
||||||
self.k_proj = nn.Linear(
|
self.head_dim = self.hidden_size // self.num_heads
|
||||||
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.qkv_bias
|
self.num_key_value_groups = self.num_heads // self.num_kv_heads
|
||||||
)
|
|
||||||
self.v_proj = nn.Linear(
|
self.attention_dropout = config.attention_dropout
|
||||||
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.qkv_bias
|
self.scaling = self.head_dim**-0.5
|
||||||
)
|
|
||||||
self.o_proj = nn.Linear(
|
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.qkv_bias)
|
||||||
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.qkv_bias
|
self.k_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=config.qkv_bias)
|
||||||
)
|
self.v_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=config.qkv_bias)
|
||||||
|
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.qkv_bias)
|
||||||
|
self.query_norm = NanoChatRMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
||||||
|
self.key_norm = NanoChatRMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -100,27 +120,46 @@ class NanoChatAttention(LlamaAttention):
|
|||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
**kwargs: Unpack[TransformersKwargs],
|
**kwargs: Unpack[TransformersKwargs],
|
||||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
batch, seq_len, _ = hidden_states.shape
|
||||||
input_shape = hidden_states.shape[:-1]
|
input_shape = hidden_states.shape[:-1]
|
||||||
hidden_shape = (*input_shape, -1, self.head_dim)
|
|
||||||
|
|
||||||
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
# Project the input to get queries, keys, and values [batch, num_heads, seq_len, head_dim]
|
||||||
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
query_states = (
|
||||||
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
self.q_proj(hidden_states).view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
||||||
|
)
|
||||||
|
key_states = (
|
||||||
|
self.k_proj(hidden_states)
|
||||||
|
.view(batch, seq_len, self.num_kv_heads, self.head_dim)
|
||||||
|
.transpose(1, 2)
|
||||||
|
.contiguous()
|
||||||
|
)
|
||||||
|
value_states = (
|
||||||
|
self.v_proj(hidden_states)
|
||||||
|
.view(batch, seq_len, self.num_kv_heads, self.head_dim)
|
||||||
|
.transpose(1, 2)
|
||||||
|
.contiguous()
|
||||||
|
)
|
||||||
|
|
||||||
cos, sin = position_embeddings
|
# Apply Rotary Embeddings to queries and keys to get relative positional encoding
|
||||||
|
cos, sin = position_embeddings # [batch, seq_len, 1, head_dim//2]
|
||||||
|
cos = cos.squeeze(2)
|
||||||
|
sin = sin.squeeze(2)
|
||||||
|
cos = torch.cat([cos, cos], dim=-1)
|
||||||
|
sin = torch.cat([-sin, -sin], dim=-1)
|
||||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||||
|
|
||||||
# NanoChat-specific: Apply QK normalization after RoPE
|
# Apply QK normalization (RMSNorm)
|
||||||
query_states = F.rms_norm(query_states, (query_states.size(-1),), eps=self.config.rms_norm_eps)
|
query_states = self.query_norm(query_states)
|
||||||
key_states = F.rms_norm(key_states, (key_states.size(-1),), eps=self.config.rms_norm_eps)
|
key_states = self.key_norm(key_states)
|
||||||
|
|
||||||
|
# Apply KV cache: insert current k,v into cache, get the full view so far
|
||||||
if past_key_values is not None:
|
if past_key_values is not None:
|
||||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
cache_kwargs = {"cache_position": cache_position}
|
||||||
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||||
|
|
||||||
attention_interface = eager_attention_forward
|
# Use attention interface pattern for vLLM compatibility
|
||||||
|
attention_interface: Callable = eager_attention_forward
|
||||||
if self.config._attn_implementation != "eager":
|
if self.config._attn_implementation != "eager":
|
||||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
|
|
||||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||||
|
|
||||||
attn_output, attn_weights = attention_interface(
|
attn_output, attn_weights = attention_interface(
|
||||||
@ -134,6 +173,7 @@ class NanoChatAttention(LlamaAttention):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Reshape and project output
|
||||||
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
||||||
attn_output = self.o_proj(attn_output)
|
attn_output = self.o_proj(attn_output)
|
||||||
return attn_output, attn_weights
|
return attn_output, attn_weights
|
||||||
@ -155,16 +195,14 @@ class NanoChatMLP(nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class NanoChatDecoderLayer(LlamaDecoderLayer):
|
class NanoChatDecoderLayer(GradientCheckpointingLayer):
|
||||||
"""
|
"""NanoChat decoder layer with pre-norm architecture."""
|
||||||
NanoChat decoder layer with pre-norm architecture.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, config: NanoChatConfig, layer_idx: int):
|
def __init__(self, config: NanoChatConfig, layer_idx: int):
|
||||||
super().__init__(config, layer_idx)
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
self.self_attn = NanoChatAttention(config, layer_idx)
|
self.self_attn = NanoChatAttention(config, layer_idx)
|
||||||
self.mlp = NanoChatMLP(config)
|
self.mlp = NanoChatMLP(config)
|
||||||
# Replace Llama's norm layers with NanoChat's weight-less norm
|
|
||||||
self.input_layernorm = NanoChatRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.input_layernorm = NanoChatRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
self.post_attention_layernorm = NanoChatRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.post_attention_layernorm = NanoChatRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
|
||||||
@ -178,15 +216,13 @@ class NanoChatDecoderLayer(LlamaDecoderLayer):
|
|||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
||||||
**kwargs: Unpack[TransformersKwargs],
|
**kwargs: Unpack[TransformersKwargs],
|
||||||
) -> torch.Tensor:
|
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
hidden_states = self.input_layernorm(hidden_states)
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
hidden_states, _ = self.self_attn(
|
hidden_states, self_attn_weights = self.self_attn(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
use_cache=use_cache,
|
|
||||||
cache_position=cache_position,
|
cache_position=cache_position,
|
||||||
position_embeddings=position_embeddings,
|
position_embeddings=position_embeddings,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@ -197,13 +233,14 @@ class NanoChatDecoderLayer(LlamaDecoderLayer):
|
|||||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||||
hidden_states = self.mlp(hidden_states)
|
hidden_states = self.mlp(hidden_states)
|
||||||
hidden_states = residual + hidden_states
|
hidden_states = residual + hidden_states
|
||||||
return hidden_states
|
return hidden_states, self_attn_weights
|
||||||
|
|
||||||
|
|
||||||
@auto_docstring
|
@auto_docstring
|
||||||
class NanoChatPreTrainedModel(LlamaPreTrainedModel):
|
class NanoChatPreTrainedModel(LlamaPreTrainedModel):
|
||||||
config_class = NanoChatConfig
|
config_class = NanoChatConfig
|
||||||
base_model_prefix = "model"
|
base_model_prefix = "model"
|
||||||
|
supports_gradient_checkpointing = True
|
||||||
_no_split_modules = ["NanoChatDecoderLayer"]
|
_no_split_modules = ["NanoChatDecoderLayer"]
|
||||||
_supports_attention_backend = True
|
_supports_attention_backend = True
|
||||||
_skip_keys_device_placement = ["past_key_values"]
|
_skip_keys_device_placement = ["past_key_values"]
|
||||||
@ -212,7 +249,6 @@ class NanoChatPreTrainedModel(LlamaPreTrainedModel):
|
|||||||
_supports_flex_attn = True
|
_supports_flex_attn = True
|
||||||
|
|
||||||
_can_compile_fullgraph = True
|
_can_compile_fullgraph = True
|
||||||
_supports_attention_backend = True
|
|
||||||
|
|
||||||
_can_record_outputs = {
|
_can_record_outputs = {
|
||||||
"hidden_states": NanoChatDecoderLayer,
|
"hidden_states": NanoChatDecoderLayer,
|
||||||
@ -233,19 +269,14 @@ class NanoChatPreTrainedModel(LlamaPreTrainedModel):
|
|||||||
|
|
||||||
|
|
||||||
@auto_docstring
|
@auto_docstring
|
||||||
class NanoChatModel(LlamaModel):
|
class NanoChatModel(NanoChatPreTrainedModel):
|
||||||
"""
|
|
||||||
NanoChat model that inherits from LlamaModel but uses NanoChat-specific layers
|
|
||||||
and RMSNorm without learnable weights.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, config: NanoChatConfig):
|
def __init__(self, config: NanoChatConfig):
|
||||||
# Call PreTrainedModel.__init__ directly to avoid LlamaModel's __init__
|
super().__init__(config)
|
||||||
PreTrainedModel.__init__(self, config)
|
|
||||||
self.padding_idx = config.pad_token_id
|
self.padding_idx = config.pad_token_id
|
||||||
self.vocab_size = config.vocab_size
|
self.vocab_size = config.vocab_size
|
||||||
|
|
||||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
||||||
|
self.initial_norm = NanoChatRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
self.layers = nn.ModuleList(
|
self.layers = nn.ModuleList(
|
||||||
[NanoChatDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
[NanoChatDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
||||||
)
|
)
|
||||||
@ -255,6 +286,8 @@ class NanoChatModel(LlamaModel):
|
|||||||
|
|
||||||
self.post_init()
|
self.post_init()
|
||||||
|
|
||||||
|
@check_model_inputs()
|
||||||
|
@auto_docstring
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
@ -262,25 +295,25 @@ class NanoChatModel(LlamaModel):
|
|||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
past_key_values: Optional[Cache] = None,
|
past_key_values: Optional[Cache] = None,
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
|
||||||
use_cache: Optional[bool] = None,
|
use_cache: Optional[bool] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
**kwargs: Unpack[TransformersKwargs],
|
**kwargs: Unpack[TransformersKwargs],
|
||||||
) -> BaseModelOutputWithPast:
|
) -> BaseModelOutputWithPast:
|
||||||
from ...cache_utils import DynamicCache
|
output_attentions = kwargs.get("output_attentions", False)
|
||||||
from ...masking_utils import create_causal_mask
|
output_hidden_states = kwargs.get("output_hidden_states", False)
|
||||||
|
|
||||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||||
|
|
||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
if use_cache and past_key_values is None:
|
if use_cache and past_key_values is None:
|
||||||
past_key_values = DynamicCache(config=self.config)
|
past_key_values = DynamicCache(config=self.config)
|
||||||
|
|
||||||
if cache_position is None:
|
if cache_position is None:
|
||||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||||
cache_position: torch.Tensor = torch.arange(
|
cache_position = torch.arange(
|
||||||
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -298,41 +331,65 @@ class NanoChatModel(LlamaModel):
|
|||||||
|
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||||
|
hidden_states = self.initial_norm(hidden_states)
|
||||||
|
|
||||||
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
# Collect hidden states and attentions if requested
|
||||||
hidden_states = decoder_layer(
|
all_hidden_states = () if output_hidden_states else None
|
||||||
|
all_self_attns = () if output_attentions else None
|
||||||
|
for decoder_layer in self.layers:
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
|
hidden_states, self_attn_weights = decoder_layer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask=causal_mask,
|
attention_mask=causal_mask,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
|
use_cache=use_cache,
|
||||||
cache_position=cache_position,
|
cache_position=cache_position,
|
||||||
position_embeddings=position_embeddings,
|
position_embeddings=position_embeddings,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
all_self_attns = all_self_attns + (self_attn_weights,)
|
||||||
|
|
||||||
hidden_states = self.norm(hidden_states)
|
hidden_states = self.norm(hidden_states)
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
return BaseModelOutputWithPast(
|
return BaseModelOutputWithPast(
|
||||||
last_hidden_state=hidden_states,
|
last_hidden_state=hidden_states,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values if use_cache else None,
|
||||||
|
hidden_states=all_hidden_states,
|
||||||
|
attentions=all_self_attns,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@auto_docstring
|
@auto_docstring
|
||||||
class NanoChatForCausalLM(LlamaForCausalLM):
|
class NanoChatForCausalLM(NanoChatPreTrainedModel, GenerationMixin):
|
||||||
"""
|
"""
|
||||||
The NanoChat Model transformer with a language modeling head on top.
|
The NanoChat Model transformer with a language modeling head on top.
|
||||||
Inherits from LlamaForCausalLM but uses NanoChatModel and supports logits soft capping.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config: NanoChatConfig):
|
_tied_weights_keys = ["lm_head.weight"]
|
||||||
# Call PreTrainedModel.__init__ directly
|
_tp_plan = {"lm_head": "colwise_rep"}
|
||||||
PreTrainedModel.__init__(self, config)
|
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
||||||
self.model = NanoChatModel(config)
|
|
||||||
self.vocab_size = config.vocab_size
|
|
||||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
|
||||||
|
|
||||||
|
def __init__(self, config: NanoChatConfig):
|
||||||
|
super().__init__(config)
|
||||||
|
self.model = NanoChatModel(config)
|
||||||
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||||
self.post_init()
|
self.post_init()
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
return self.model.get_input_embeddings()
|
||||||
|
|
||||||
|
def set_input_embeddings(self, new_embeddings: nn.Embedding) -> None:
|
||||||
|
self.model.set_input_embeddings(new_embeddings)
|
||||||
|
|
||||||
|
@can_return_tuple
|
||||||
|
@auto_docstring
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
@ -346,7 +403,31 @@ class NanoChatForCausalLM(LlamaForCausalLM):
|
|||||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||||
**kwargs: Unpack[TransformersKwargs],
|
**kwargs: Unpack[TransformersKwargs],
|
||||||
) -> CausalLMOutputWithPast:
|
) -> CausalLMOutputWithPast:
|
||||||
outputs: BaseModelOutputWithPast = self.model(
|
r"""
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||||
|
|
||||||
|
>>> model = AutoModelForCausalLM.from_pretrained(model_id"karpathy/nanochat-d32")
|
||||||
|
|
||||||
|
>>> tokenizer = AutoTokenizer.from_pretrained("karpathy/nanochat-d32")
|
||||||
|
|
||||||
|
>>> conversation = [
|
||||||
|
{"role": "user", "content": "What is the capital of France?"},
|
||||||
|
]
|
||||||
|
|
||||||
|
>>> inputs = tokenizer.apply_chat_template(
|
||||||
|
conversation, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt"
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
>>> with torch.no_grad():
|
||||||
|
>>> outputs = model.generate(**inputs, max_new_tokens=64, do_sample=False)
|
||||||
|
|
||||||
|
>>> generated_tokens = outputs[0, inputs["input_ids"].shape[1] :]
|
||||||
|
>>> output = tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
||||||
|
```"""
|
||||||
|
outputs = self.model(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
@ -360,8 +441,6 @@ class NanoChatForCausalLM(LlamaForCausalLM):
|
|||||||
hidden_states = outputs.last_hidden_state
|
hidden_states = outputs.last_hidden_state
|
||||||
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
||||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
||||||
|
|
||||||
# NanoChat-specific: Apply logits soft capping if configured
|
|
||||||
if self.config.logits_soft_cap is not None:
|
if self.config.logits_soft_cap is not None:
|
||||||
cap = self.config.logits_soft_cap
|
cap = self.config.logits_soft_cap
|
||||||
logits = cap * torch.tanh(logits / cap)
|
logits = cap * torch.tanh(logits / cap)
|
||||||
@ -384,4 +463,3 @@ __all__ = [
|
|||||||
"NanoChatModel",
|
"NanoChatModel",
|
||||||
"NanoChatForCausalLM",
|
"NanoChatForCausalLM",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user