Compare commits

..

8 Commits

Author SHA1 Message Date
a4a187e0e5 make script self contained 2025-11-01 19:28:54 +00:00
4ec11c20c8 changes from review 2025-11-01 19:14:55 +00:00
df8aaacbbb remove use_cache=False. 2025-10-25 14:25:52 +00:00
e18436c0d2 use DataCollatorForLanguageModeling 2025-10-25 14:20:22 +00:00
6d42d9a599 removed fsdp_transformer_layer_cls_to_wrap 2025-10-25 14:09:26 +00:00
0bcf34edad add test_cp_equivalence 2025-10-25 14:02:48 +00:00
1ff587b504 simplify tests 2025-10-25 12:54:58 +00:00
63efa18396 intial 2025-10-25 11:48:07 +00:00
32 changed files with 462 additions and 1683 deletions

View File

@ -14,7 +14,7 @@ This AGENTS.md file provides guidance for code agents working with this codebase
- PRs should be as brief as possible. Bugfix PRs in particular can often be only one or two lines long, and do not need large comments, docstrings or new functions in this case. Aim to minimize the size of the diff.
- When writing tests, they should be added to an existing file. The only exception is for PRs to add a new model, when a new test directory should be created for that model.
- Code style is enforced in the CI. You can install the style tools with `pip install -e ".[quality]"`. You can then run `make fixup` to apply style and consistency fixes to your code.
- Code style is enforced in the CI. You can install the style tools with `pip install -e .[quality]`. You can then run `make fixup` to apply style and consistency fixes to your code.
## Copying and inheritance
@ -36,4 +36,4 @@ After making changes, you should usually run `make fixup` to ensure any copies a
the model you made the changes in and any other models that were updated by `make fixup`. Tests can be run with `pytest tests/models/[name]/test_modeling_[name].py`
If your changes affect code in other classes like tokenizers or processors, you should run those tests instead, like `test_processing_[name].py` or `test_tokenization_[name].py`.
In order to run tests, you may need to install dependencies. You can do this with `pip install -e ".[testing]"`. You will probably also need to `pip install torch accelerate` if your environment does not already have them.
In order to run tests, you may need to install dependencies. You can do this with `pip install -e .[testing]`. You will probably also need to `pip install torch accelerate` if your environment does not already have them.

View File

@ -9,12 +9,6 @@ In this list, we showcase incredibly impactful and novel projects that have push
adding other projects to the list. If you believe a project should be here and it's not, then please, open a PR
to add it.
## [◉ Universal Intelligence](https://github.com/blueraai/universal-intelligence)
[Universal Intelligence](https://github.com/blueraai/universal-intelligence) aims to standardize models, tools, and agents —transforming them into simple, composable, portable, interoperable, framework-agnostic, hardware-agnostic interfaces (through auto-negotiation and resource sharing); for fast and accessible development of AI applications.
Keywords: Protocol, Open-source, LLMs, Large Language Models, Agents, Low-code
## [gpt4all](https://github.com/nomic-ai/gpt4all)
[gpt4all](https://github.com/nomic-ai/gpt4all) is an ecosystem of open-source chatbots trained on massive collections of clean assistant data including code, stories and dialogue. It offers open-source, large language models such as LLaMA and GPT-J trained in an assistant-style.

View File

@ -626,8 +626,6 @@
title: MVP
- local: model_doc/myt5
title: myt5
- local: model_doc/nanochat
title: NanoChat
- local: model_doc/nemotron
title: Nemotron
- local: model_doc/nezha

View File

@ -1,119 +0,0 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
# NanoChat
<div style="float: right;">
<div class="flex flex-wrap space-x-1">
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
</div>
</div>
[NanoChat](https://huggingface.co/karpathy/nanochat-d32) is a compact decoder-only transformer model designed for educational purposes and efficient training. The model features several fundamental architectural innovations which are common in modern transformer models. Therefore, it is a good model to use as a starting point to understand the principles of modern transformer models. NanoChat is a variant of the [Llama](https://huggingface.co/docs/transformers/en/model_doc/llama) architecture, with simplified attention mechanism and normalization layers.
The architecture is based on [nanochat](https://github.com/karpathy/nanochat) by [Andrej Karpathy](https://huggingface.co/karpathy), adapted for the Hugging Face Transformers library by [Ben Burtenshaw](https://huggingface.co/burtenshaw).
> [!TIP]
> This model was contributed by the Hugging Face team.
The example below demonstrates how to use NanoChat for text generation with chat templates.
<hfoptions id="usage">
<hfoption id="Pipeline">
```py
import torch
from transformers import pipeline
chatbot = pipeline(
task="text-generation",
model="karpathy/nanochat-d32",
dtype=torch.bfloat16,
device=0
)
conversation = [
{"role": "user", "content": "What is the capital of France?"},
]
outputs = chatbot(conversation, max_new_tokens=64)
print(outputs[0]["generated_text"][-1]["content"])
```
</hfoption>
<hfoption id="AutoModel">
```py
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
model_id = "karpathy/nanochat-d32"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
dtype=torch.bfloat16,
device_map="auto",
)
model.eval()
conversation = [
{"role": "user", "content": "What is the capital of France?"},
]
inputs = tokenizer.apply_chat_template(
conversation,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt"
).to(device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=64,
)
# Decode only the generated tokens (excluding the input prompt)
generated_tokens = outputs[0, inputs["input_ids"].shape[1]:]
print(tokenizer.decode(generated_tokens, skip_special_tokens=True))
```
</hfoption>
<hfoption id="transformers CLI">
```bash
echo -e '{"role": "user", "content": "What is the capital of France?"}' | transformers run --task text-generation --model karpathy/nanochat-d32 --device 0
```
</hfoption>
</hfoptions>
## NanoChatConfig
[[autodoc]] NanoChatConfig
## NanoChatModel
[[autodoc]] NanoChatModel
- forward
## NanoChatForCausalLM
[[autodoc]] NanoChatForCausalLM
- forward

View File

@ -38,7 +38,7 @@ pip install transformers[dev]
or for an editable install:
```bash
pip install -e ".[dev]"
pip install -e .[dev]
```
inside the Transformers repo. Since the number of optional dependencies of Transformers has grown a lot, it's possible you don't manage to get all of them. If the dev install fails, make sure to install PyTorch then do
@ -50,7 +50,7 @@ pip install transformers[quality]
or for an editable install:
```bash
pip install -e ".[quality]"
pip install -e .[quality]
```
## Tests

View File

@ -37,7 +37,7 @@ pip install transformers[dev]
o una instalación editable:
```bash
pip install -e ".[dev]"
pip install -e .[dev]
```
del repositorio de Transformers.

View File

@ -37,7 +37,7 @@ pip install transformers[dev]
o un'installazione modificabile:
```bash
pip install -e ".[dev]"
pip install -e .[dev]
```
all'interno del repo Transformers.

View File

@ -40,7 +40,7 @@ pip install transformers[dev]
```bash
pip install -e ".[dev]"
pip install -e .[dev]
```
トランスフォーマーズのリポジトリ内で作業しています。トランスフォーマーズのオプションの依存関係の数が増えたため、すべてを取得できない可能性があります。開発用インストールが失敗した場合、作業しているディープラーニングフレームワークPyTorch、TensorFlow、および/またはFlaxをインストールし、次の手順を実行してください。
@ -53,7 +53,7 @@ pip install transformers[quality]
または編集可能なインストールの場合:
```bash
pip install -e ".[quality]"
pip install -e .[quality]
```
## Tests

View File

@ -37,7 +37,7 @@ pip install transformers[dev]
또는 Transformers 저장소 내에 편집 가능한 설치가 필요합니다:
```bash
pip install -e ".[dev]"
pip install -e .[dev]
```
Transformers의 선택적 종속성 수가 많이 늘어났기 때문에 개발 설치를 실패할 수도 있습니다. 개발 설치가 실패하는 경우, 작업 중인 Deep Learning 프레임워크 (PyTorch, TensorFlow 및/또는 Flax)를 설치하고 다음 명령을 실행하세요.
@ -49,7 +49,7 @@ pip install transformers[quality]
편집 가능한 설치의 경우는 다음 명령을 실행하세요.
```bash
pip install -e ".[quality]"
pip install -e .[quality]
```

View File

@ -240,7 +240,6 @@ if TYPE_CHECKING:
from .musicgen_melody import *
from .mvp import *
from .myt5 import *
from .nanochat import *
from .nemotron import *
from .nllb import *
from .nllb_moe import *

View File

@ -281,7 +281,6 @@ CONFIG_MAPPING_NAMES = OrderedDict[str, str](
("musicgen", "MusicgenConfig"),
("musicgen_melody", "MusicgenMelodyConfig"),
("mvp", "MvpConfig"),
("nanochat", "NanoChatConfig"),
("nat", "NatConfig"),
("nemotron", "NemotronConfig"),
("nezha", "NezhaConfig"),
@ -738,8 +737,6 @@ MODEL_NAMES_MAPPING = OrderedDict[str, str](
("musicgen_melody", "MusicGen Melody"),
("mvp", "MVP"),
("myt5", "myt5"),
("nanochat", "NanoChat"),
("nanochat", "NanoChatForCausalLM"),
("nat", "NAT"),
("nemotron", "Nemotron"),
("nezha", "Nezha"),

View File

@ -282,7 +282,6 @@ MODEL_MAPPING_NAMES = OrderedDict(
("musicgen", "MusicgenModel"),
("musicgen_melody", "MusicgenMelodyModel"),
("mvp", "MvpModel"),
("nanochat", "NanoChatModel"),
("nat", "NatModel"),
("nemotron", "NemotronModel"),
("nezha", "NezhaModel"),
@ -499,7 +498,6 @@ MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
("mpt", "MptForCausalLM"),
("mra", "MraForMaskedLM"),
("mvp", "MvpForConditionalGeneration"),
("nanochat", "NanoChatForCausalLM"),
("nezha", "NezhaForPreTraining"),
("nllb-moe", "NllbMoeForConditionalGeneration"),
("openai-gpt", "OpenAIGPTLMHeadModel"),
@ -722,7 +720,6 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
("musicgen", "MusicgenForCausalLM"),
("musicgen_melody", "MusicgenMelodyForCausalLM"),
("mvp", "MvpForCausalLM"),
("nanochat", "NanoChatForCausalLM"),
("nemotron", "NemotronForCausalLM"),
("olmo", "OlmoForCausalLM"),
("olmo2", "Olmo2ForCausalLM"),

View File

@ -468,7 +468,6 @@ TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]](
("musicgen_melody", ("T5Tokenizer", "T5TokenizerFast" if is_tokenizers_available() else None)),
("mvp", ("MvpTokenizer", "MvpTokenizerFast" if is_tokenizers_available() else None)),
("myt5", ("MyT5Tokenizer", None)),
("nanochat", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
("nemotron", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
("nezha", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
(

View File

@ -1392,7 +1392,7 @@ class Emu3Model(Emu3PreTrainedModel):
image_features = torch.split(image_features, split_sizes)
return image_features
@torch.no_grad()
@torch.no_grad
def decode_image_tokens(self, image_tokens: torch.LongTensor, height: int, width: int):
"""
Decodes generated image tokens from language model to continuous pixel values

View File

@ -946,7 +946,7 @@ class Emu3Model(Emu3PreTrainedModel):
image_features = torch.split(image_features, split_sizes)
return image_features
@torch.no_grad()
@torch.no_grad
def decode_image_tokens(self, image_tokens: torch.LongTensor, height: int, width: int):
"""
Decodes generated image tokens from language model to continuous pixel values

View File

@ -39,11 +39,14 @@ from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
from ...utils.generic import check_model_inputs
from .configuration_gemma2 import Gemma2Config
logger = logging.get_logger(__name__)
class Gemma2RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
@ -322,6 +325,8 @@ class Gemma2DecoderLayer(GradientCheckpointingLayer):
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
@ -330,12 +335,14 @@ class Gemma2DecoderLayer(GradientCheckpointingLayer):
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, _ = self.self_attn(
hidden_states, self_attn_weights = self.self_attn(
hidden_states=hidden_states,
position_embeddings=position_embeddings,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
@ -348,7 +355,12 @@ class Gemma2DecoderLayer(GradientCheckpointingLayer):
hidden_states = self.post_feedforward_layernorm(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
return outputs
@auto_docstring
@ -405,16 +417,30 @@ class Gemma2Model(Gemma2PreTrainedModel):
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> BaseModelOutputWithPast:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if inputs_embeds is None:
inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
if self.gradient_checkpointing and self.training and use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
)
use_cache = False
if use_cache and past_key_values is None:
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if use_cache and past_key_values is None and not self.training:
past_key_values = DynamicCache(config=self.config)
if cache_position is None:
@ -453,22 +479,41 @@ class Gemma2Model(Gemma2PreTrainedModel):
normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
hidden_states = hidden_states * normalizer
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
hidden_states = decoder_layer(
if output_hidden_states:
all_hidden_states += (hidden_states,)
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
position_embeddings=position_embeddings,
position_ids=position_ids,
past_key_values=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
if output_hidden_states:
all_hidden_states += (hidden_states,)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
@ -498,9 +543,11 @@ class Gemma2ForCausalLM(Gemma2PreTrainedModel, GenerationMixin):
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs: Unpack[TransformersKwargs],
**kwargs,
) -> CausalLMOutputWithPast:
r"""
Example:
@ -519,6 +566,11 @@ class Gemma2ForCausalLM(Gemma2PreTrainedModel, GenerationMixin):
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"What is your favorite condiment?"
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs: BaseModelOutputWithPast = self.model(
input_ids=input_ids,
@ -527,6 +579,8 @@ class Gemma2ForCausalLM(Gemma2PreTrainedModel, GenerationMixin):
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
cache_position=cache_position,
**kwargs,
)

View File

@ -381,6 +381,8 @@ class Gemma2DecoderLayer(GradientCheckpointingLayer):
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
@ -389,12 +391,14 @@ class Gemma2DecoderLayer(GradientCheckpointingLayer):
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, _ = self.self_attn(
hidden_states, self_attn_weights = self.self_attn(
hidden_states=hidden_states,
position_embeddings=position_embeddings,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
@ -407,7 +411,12 @@ class Gemma2DecoderLayer(GradientCheckpointingLayer):
hidden_states = self.post_feedforward_layernorm(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
return outputs
class Gemma2PreTrainedModel(GemmaPreTrainedModel):
@ -430,16 +439,30 @@ class Gemma2Model(GemmaModel):
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> BaseModelOutputWithPast:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if inputs_embeds is None:
inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
if self.gradient_checkpointing and self.training and use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
)
use_cache = False
if use_cache and past_key_values is None:
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if use_cache and past_key_values is None and not self.training:
past_key_values = DynamicCache(config=self.config)
if cache_position is None:
@ -478,22 +501,41 @@ class Gemma2Model(GemmaModel):
normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
hidden_states = hidden_states * normalizer
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
hidden_states = decoder_layer(
if output_hidden_states:
all_hidden_states += (hidden_states,)
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
position_embeddings=position_embeddings,
position_ids=position_ids,
past_key_values=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
if output_hidden_states:
all_hidden_states += (hidden_states,)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
@ -512,9 +554,11 @@ class Gemma2ForCausalLM(GemmaForCausalLM):
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs: Unpack[TransformersKwargs],
**kwargs,
) -> CausalLMOutputWithPast:
r"""
Example:
@ -533,6 +577,11 @@ class Gemma2ForCausalLM(GemmaForCausalLM):
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"What is your favorite condiment?"
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs: BaseModelOutputWithPast = self.model(
input_ids=input_ids,
@ -541,6 +590,8 @@ class Gemma2ForCausalLM(GemmaForCausalLM):
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
cache_position=cache_position,
**kwargs,
)

View File

@ -652,9 +652,11 @@ class Gemma3ForCausalLM(Gemma3PreTrainedModel, GenerationMixin):
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs: Unpack[TransformersKwargs],
**kwargs,
) -> CausalLMOutputWithPast:
r"""
Example:
@ -673,6 +675,11 @@ class Gemma3ForCausalLM(Gemma3PreTrainedModel, GenerationMixin):
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"What is your favorite condiment?"
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs: BaseModelOutputWithPast = self.model(
input_ids=input_ids,
@ -681,6 +688,8 @@ class Gemma3ForCausalLM(Gemma3PreTrainedModel, GenerationMixin):
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
cache_position=cache_position,
**kwargs,
)

View File

@ -1543,6 +1543,8 @@ class Gemma3nDecoderLayer(GradientCheckpointingLayer):
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
@ -1551,12 +1553,14 @@ class Gemma3nDecoderLayer(GradientCheckpointingLayer):
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, _ = self.self_attn(
hidden_states, self_attn_weights = self.self_attn(
hidden_states=hidden_states,
position_embeddings=position_embeddings,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
@ -1569,7 +1573,12 @@ class Gemma3nDecoderLayer(GradientCheckpointingLayer):
hidden_states = self.post_feedforward_layernorm(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
return outputs
@auto_docstring
@ -1950,9 +1959,11 @@ class Gemma3nForCausalLM(Gemma3nPreTrainedModel, GenerationMixin):
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs: Unpack[TransformersKwargs],
**kwargs,
) -> CausalLMOutputWithPast:
r"""
Example:
@ -1971,6 +1982,11 @@ class Gemma3nForCausalLM(Gemma3nPreTrainedModel, GenerationMixin):
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"What is your favorite condiment?"
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs: BaseModelOutputWithPast = self.model(
input_ids=input_ids,
@ -1979,6 +1995,8 @@ class Gemma3nForCausalLM(Gemma3nPreTrainedModel, GenerationMixin):
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
cache_position=cache_position,
**kwargs,
)

View File

@ -1283,7 +1283,7 @@ class JanusForConditionalGeneration(JanusPreTrainedModel, GenerationMixin):
decoded_image = decoded_image.permute(0, 2, 3, 1)
return decoded_image
@torch.no_grad()
@torch.no_grad
def generate(
self,
inputs: Optional[torch.Tensor] = None,

View File

@ -1099,7 +1099,7 @@ class JanusForConditionalGeneration(JanusPreTrainedModel, GenerationMixin):
decoded_image = decoded_image.permute(0, 2, 3, 1)
return decoded_image
@torch.no_grad()
@torch.no_grad
def generate(
self,
inputs: Optional[torch.Tensor] = None,

View File

@ -1,14 +0,0 @@
from typing import TYPE_CHECKING
from ...utils import _LazyModule
from ...utils.import_utils import define_import_structure
if TYPE_CHECKING:
from .configuration_nanochat import *
from .modeling_nanochat import *
else:
import sys
_file = globals()["__file__"]
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)

View File

@ -1,164 +0,0 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ...configuration_utils import PretrainedConfig
from ...modeling_rope_utils import RopeParameters, rope_config_validation, standardize_rope_params
class NanoChatConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`NanoChatModel`]. It is used to instantiate a
NanoChat model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of the [karpathy/nanochat-d32](https://huggingface.co/karpathy/nanochat-d32).
Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PreTrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 50304):
Vocabulary size of the NanoChat model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`NanoChatModel`].
hidden_size (`int`, *optional*, defaults to 768):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 8192):
Dimension of the MLP representations. If `None`, it will be computed based on the model architecture.
num_hidden_layers (`int`, *optional*, defaults to 12):
Number of hidden layers in the Transformer decoder.
num_attention_heads (`int`, *optional*, defaults to 6):
Number of attention heads for each attention layer in the Transformer decoder.
num_key_value_heads (`int`, *optional*):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details, check out [this
paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
`num_attention_heads`.
max_position_embeddings (`int`, *optional*, defaults to 2048):
The maximum sequence length that this model might ever be used with.
hidden_act (`str` or `function`, *optional*, defaults to `"relu2"`):
The non-linear activation function (function or string) in the decoder.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the rms normalization layers.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rope_parameters (`RopeParameters`, *optional*):
Dictionary containing the configuration parameters for the RoPE embeddings. The dictionaty should contain
a value for `rope_theta` and optionally parameters used for scaling in case you want to use RoPE
with longer `max_position_embeddings`.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
final_logit_softcapping (`float`, *optional*, defaults to 15.0):
scaling factor when applying tanh softcapping on the logits.
attention_bias (`bool`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, and value projection layers during self-attention.
bos_token_id (`int`, *optional*, defaults to 0):
Beginning of stream token id.
eos_token_id (`int`, *optional*, defaults to 1):
End of stream token id.
pad_token_id (`int`, *optional*, defaults to 1):
Padding token id.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether to tie weight embeddings
```python
>>> from transformers import NanoChatModel, NanoChatConfig
>>> # Initializing a NanoChat style configuration
>>> configuration = NanoChatConfig()
>>> # Initializing a model from the NanoChat style configuration
>>> model = NanoChatModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "nanochat"
keys_to_ignore_at_inference = ["past_key_values"]
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.fc1": "colwise",
"layers.*.mlp.fc2": "rowwise",
}
def __init__(
self,
vocab_size: int = 50304,
hidden_size: int = 768,
intermediate_size: int | None = 8192,
num_hidden_layers: int = 12,
num_attention_heads: int = 6,
num_key_value_heads: int | None = None,
max_position_embeddings: int = 2048,
hidden_act: str = "relu2",
attention_dropout: float = 0.0,
rms_norm_eps: float = 1e-6,
initializer_range: float = 0.02,
rope_parameters: RopeParameters | dict[RopeParameters] | None = None,
use_cache: bool = True,
final_logit_softcapping: float | None = 15.0,
attention_bias: bool = False,
bos_token_id: int = 0,
eos_token_id: int = 1,
pad_token_id: int = 1,
tie_word_embeddings: bool = False,
**kwargs,
):
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.max_position_embeddings = max_position_embeddings
self.hidden_act = hidden_act
self.attention_dropout = attention_dropout
self.rms_norm_eps = rms_norm_eps
self.initializer_range = initializer_range
self.use_cache = use_cache
self.final_logit_softcapping = final_logit_softcapping
self.attention_bias = attention_bias
super().__init__(
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
pad_token_id=pad_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
# Validate the correctness of rotary position embeddings parameters
# Must be done after super().__init__() to avoid being overridden by kwargs
self.rope_parameters = rope_parameters
rope_theta = kwargs.get("rope_theta", 10000.0)
standardize_rope_params(self, rope_theta=rope_theta)
rope_config_validation(self)
__all__ = ["NanoChatConfig"]

View File

@ -1,313 +0,0 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import gc
import json
import os
from pathlib import Path
import torch
from transformers import AutoTokenizer, NanoChatConfig, NanoChatForCausalLM
def infer_kv_heads(config: NanoChatConfig, state_dict: dict[str, torch.Tensor]) -> int:
key_weight = state_dict.get("transformer.h.0.attn.c_k.weight")
if key_weight is None:
return config.num_key_value_heads
rows = key_weight.shape[0]
head_dim = config.hidden_size // config.num_attention_heads
if rows % head_dim != 0:
return config.num_key_value_heads
inferred = rows // head_dim
print(f"Inferred {inferred} key_value heads from checkpoint")
return max(inferred, 1)
def convert_layer(old_prefix: str, new_prefix: str) -> dict[str, str]:
return {
f"{old_prefix}.attn.c_q.weight": f"{new_prefix}.self_attn.q_proj.weight",
f"{old_prefix}.attn.c_k.weight": f"{new_prefix}.self_attn.k_proj.weight",
f"{old_prefix}.attn.c_v.weight": f"{new_prefix}.self_attn.v_proj.weight",
f"{old_prefix}.attn.c_proj.weight": f"{new_prefix}.self_attn.o_proj.weight",
f"{old_prefix}.mlp.c_fc.weight": f"{new_prefix}.mlp.fc1.weight",
f"{old_prefix}.mlp.c_proj.weight": f"{new_prefix}.mlp.fc2.weight",
}
def load_config_from_checkpoint(input_path: Path) -> NanoChatConfig:
"""Load config from either meta_*.json or config.json in the checkpoint directory."""
# Try to find meta_*.json first
meta_files = list(input_path.glob("meta_*.json"))
if meta_files:
meta_file = meta_files[0]
print(f"Loading config from {meta_file.name}")
with open(meta_file, "r") as f:
meta_config = json.load(f)
# Extract model config from meta file
if "model_config" in meta_config:
model_config = meta_config["model_config"]
else:
model_config = meta_config
# Map to NanoChat config parameters
config_kwargs = {
"vocab_size": model_config.get("vocab_size", 50304),
"hidden_size": model_config.get("n_embd", 768),
"num_hidden_layers": model_config.get("n_layer", 12),
"num_attention_heads": model_config.get("n_head", 6),
"num_key_value_heads": model_config.get("n_kv_head"),
"max_position_embeddings": model_config.get("sequence_len", 2048),
}
# Try to load existing config.json for additional parameters
config_file = input_path / "config.json"
if config_file.exists():
print("Loading additional config from config.json")
with open(config_file, "r") as f:
extra_config = json.load(f)
# Add additional parameters from config.json
for key in [
"hidden_act",
"attention_dropout",
"rms_norm_eps",
"initializer_range",
"logits_soft_cap",
"attention_bias",
"intermediate_size",
"bos_token_id",
"eos_token_id",
"pad_token_id",
]:
if key in extra_config:
config_kwargs[key] = extra_config[key]
# Handle legacy qkv_bias -> attention_bias conversion
elif key == "attention_bias" and "qkv_bias" in extra_config:
config_kwargs[key] = extra_config["qkv_bias"]
# Handle rope_theta as a direct kwarg for the rope_parameters processing
if "rope_theta" in extra_config:
config_kwargs["rope_theta"] = extra_config["rope_theta"]
# Handle rope_parameters or rope_scaling if present
if "rope_parameters" in extra_config:
config_kwargs["rope_parameters"] = extra_config["rope_parameters"]
elif "rope_scaling" in extra_config and extra_config["rope_scaling"] is not None:
config_kwargs["rope_parameters"] = extra_config["rope_scaling"]
config = NanoChatConfig(**config_kwargs)
else:
# Fallback to loading from config.json if it exists
config_file = input_path / "config.json"
if config_file.exists():
print("Loading config from config.json")
config = NanoChatConfig.from_pretrained(input_path)
# Handle legacy qkv_bias -> attention_bias conversion
if hasattr(config, "qkv_bias") and not hasattr(config, "attention_bias"):
config.attention_bias = config.qkv_bias
else:
raise ValueError(f"No config file found in {input_path}. Expected meta_*.json or config.json")
return config
def write_model(input_dir, output_dir, safe_serialization=True):
"""Convert NanoChat model from original checkpoint format to HuggingFace format."""
print("Converting the model.")
os.makedirs(output_dir, exist_ok=True)
input_path = Path(input_dir)
# Load config
config = load_config_from_checkpoint(input_path)
print(f"Loaded config hidden_size={config.hidden_size} num_layers={config.num_hidden_layers}")
# Load checkpoint - try model_*.pt first, then pytorch_model.bin
checkpoint_files = list(input_path.glob("model_*.pt"))
if checkpoint_files:
checkpoint_path = checkpoint_files[0]
else:
checkpoint_path = input_path / "pytorch_model.bin"
print(f"Fetching all parameters from the checkpoint at {checkpoint_path}...")
old_state = torch.load(checkpoint_path, map_location="cpu")
# Original nanochat weights are in bfloat16
for key in old_state:
if old_state[key].dtype == torch.float32:
old_state[key] = old_state[key].to(torch.bfloat16)
# Infer key-value heads from checkpoint
inferred_kv = infer_kv_heads(config, old_state)
config.num_key_value_heads = inferred_kv
if config.num_attention_heads % config.num_key_value_heads != 0:
print(f"Adjusting num_attention_heads from {config.num_attention_heads} to {config.num_key_value_heads}")
config.num_attention_heads = config.num_key_value_heads
print("Converting model...")
state_dict = {}
rename_map = {}
def assign(
old_key: str,
new_key: str,
old_state: dict[str, torch.Tensor],
state_dict: dict[str, torch.Tensor],
rename_map: dict[str, str],
) -> None:
tensor = old_state.get(old_key)
if tensor is None:
return
state_dict[new_key] = tensor.clone()
rename_map[old_key] = new_key
# Convert embeddings and head
assign("transformer.wte.weight", "model.embed_tokens.weight", old_state, state_dict, rename_map)
assign("lm_head.weight", "lm_head.weight", old_state, state_dict, rename_map)
# Convert layers
for layer_idx in range(config.num_hidden_layers):
old_prefix = f"transformer.h.{layer_idx}"
new_prefix = f"model.layers.{layer_idx}"
mapping = convert_layer(old_prefix, new_prefix)
for old_key, new_key in mapping.items():
assign(old_key, new_key, old_state, state_dict, rename_map)
missing = [key for key in old_state.keys() if key not in rename_map]
if missing:
print(f"Skipped {len(missing)} legacy entries that have no equivalent in the shared implementation")
del old_state
gc.collect()
# Update config
config.torch_dtype = torch.bfloat16
config.tie_word_embeddings = False
# Load the checkpoint into the model
print("Loading the checkpoint in a NanoChat model.")
with torch.device("meta"):
model = NanoChatForCausalLM(config)
model.load_state_dict(state_dict, strict=True, assign=True)
print("Checkpoint loaded successfully.")
if hasattr(model.config, "_name_or_path"):
del model.config._name_or_path
print("Saving the model.")
model.save_pretrained(output_dir, safe_serialization=safe_serialization)
del state_dict, model
# Safety check: reload the converted model
gc.collect()
print("Reloading the model to check if it's saved correctly.")
NanoChatForCausalLM.from_pretrained(output_dir, torch_dtype=torch.bfloat16, device_map="auto")
print("Model reloaded successfully.")
def write_tokenizer(input_dir, output_dir):
"""Convert and save the tokenizer."""
input_path = Path(input_dir)
# Convert the pickle tokenizer to HF format
tokenizer_pkl = input_path / "tokenizer.pkl"
if tokenizer_pkl.exists():
try:
import pickle
from transformers.integrations.tiktoken import convert_tiktoken_to_fast
with open(tokenizer_pkl, "rb") as f:
tok_pkl = pickle.load(f)
convert_tiktoken_to_fast(tok_pkl, output_dir)
print("Converted tokenizer.pkl to HuggingFace format")
except Exception as e:
print(f"Warning: Failed to convert tokenizer.pkl: {e}")
# Fallback: copy tokenizer files if they exist
for filename in ("tokenizer.json", "tokenizer_config.json"):
src = input_path / filename
if src.exists():
(Path(output_dir) / filename).write_bytes(src.read_bytes())
else:
# No pickle tokenizer, copy JSON files
for filename in ("tokenizer.json", "tokenizer_config.json", "special_tokens_map.json"):
src = input_path / filename
if src.exists():
(Path(output_dir) / filename).write_bytes(src.read_bytes())
print("Tokenizer saved successfully.")
def run_test(output_dir: str, prompt: str, max_new_tokens: int = 64) -> None:
"""Run a quick generation test to verify the converted model works correctly."""
print(f"Running quick generation test with prompt: {prompt}")
tokenizer = AutoTokenizer.from_pretrained(output_dir)
model = NanoChatForCausalLM.from_pretrained(output_dir, torch_dtype=torch.bfloat16)
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
inputs = tokenizer(prompt, return_tensors="pt").to(device)
with torch.no_grad():
output = model.generate(**inputs, max_new_tokens=max_new_tokens)
generated = tokenizer.decode(output[0, inputs.input_ids.shape[1] :], skip_special_tokens=True)
print(f"Generated text: {generated}")
def main():
parser = argparse.ArgumentParser(description="Convert NanoChat checkpoints to HuggingFace format")
parser.add_argument(
"--input_dir",
type=str,
required=True,
help="Path to the original checkpoint directory",
)
parser.add_argument(
"--output_dir",
type=str,
required=True,
help="Location to write HF model and tokenizer",
)
parser.add_argument(
"--safe_serialization",
action="store_true",
default=True,
help="Whether or not to save using `safetensors`.",
)
parser.add_argument(
"--test_prompt",
type=str,
default=None,
help="Optional prompt for a quick generation test",
)
args = parser.parse_args()
write_model(
args.input_dir,
args.output_dir,
safe_serialization=args.safe_serialization,
)
write_tokenizer(args.input_dir, args.output_dir)
if args.test_prompt:
run_test(args.output_dir, args.test_prompt)
if __name__ == "__main__":
main()

View File

@ -1,533 +0,0 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/nanochat/modular_nanochat.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_nanochat.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from collections.abc import Callable
from typing import Optional, Union
import torch
import torch.nn as nn
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
from ...masking_utils import create_causal_mask
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
from ...utils.generic import check_model_inputs
from .configuration_nanochat import NanoChatConfig
class NanoChatRMSNorm(torch.nn.Module):
def __init__(self, eps: float = 1e-6):
super().__init__()
self.eps = eps
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
return self._norm(x.float()).type_as(x)
def extra_repr(self):
return f"eps={self.eps}"
class NanoChatRotaryEmbedding(nn.Module):
inv_freq: torch.Tensor # fix linting for `register_buffer`
def __init__(self, config: NanoChatConfig, device=None):
super().__init__()
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
self.config = config
self.rope_type = self.config.rope_parameters["rope_type"]
rope_init_fn: Callable = self.compute_default_rope_parameters
if self.rope_type != "default":
rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = inv_freq
@staticmethod
def compute_default_rope_parameters(
config: Optional[NanoChatConfig] = None,
device: Optional["torch.device"] = None,
seq_len: Optional[int] = None,
) -> tuple["torch.Tensor", float]:
"""
Computes the inverse frequencies according to the original RoPE implementation
Args:
config ([`~transformers.PreTrainedConfig`]):
The model configuration.
device (`torch.device`):
The device to use for initialization of the inverse frequencies.
seq_len (`int`, *optional*):
The current sequence length. Unused for this type of RoPE.
Returns:
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
"""
base = config.rope_parameters["rope_theta"]
dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
attention_factor = 1.0 # Unused in this type of RoPE
# Compute the inverse frequencies
inv_freq = 1.0 / (
base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
)
return inv_freq, attention_factor
@torch.no_grad()
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
def forward(self, x, position_ids):
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
position_ids_expanded = position_ids[:, None, :].float()
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() * self.attention_scaling
sin = emb.sin() * self.attention_scaling
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs: Unpack[TransformersKwargs],
):
key_states = repeat_kv(key, module.num_key_value_groups)
value_states = repeat_kv(value, module.num_key_value_groups)
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
def rotate_half(x):
"""Rotates half the hidden dims of the input with flipped signs for NanoChat."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((x2, -x1), dim=-1)
class NanoChatAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: NanoChatConfig, layer_idx: int):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
self.scaling = self.head_dim**-0.5
self.attention_dropout = config.attention_dropout
self.is_causal = True
self.q_proj = nn.Linear(
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
)
self.k_proj = nn.Linear(
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
)
self.v_proj = nn.Linear(
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
)
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
)
self.q_norm = NanoChatRMSNorm(eps=config.rms_norm_eps)
self.k_norm = NanoChatRMSNorm(eps=config.rms_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
# RoPE -> Norm (instead of usual Norm -> RoPE)
query_states = self.q_norm(query_states)
key_states = self.k_norm(key_states)
if past_key_values is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
**kwargs,
)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
class NanoChatMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.activation_fn = ACT2FN[config.hidden_act]
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.fc2(hidden_states)
return hidden_states
class NanoChatDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: NanoChatConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = NanoChatAttention(config=config, layer_idx=layer_idx)
self.mlp = NanoChatMLP(config)
self.input_layernorm = NanoChatRMSNorm(eps=config.rms_norm_eps)
self.post_attention_layernorm = NanoChatRMSNorm(eps=config.rms_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs: Unpack[TransformersKwargs],
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, _ = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
@auto_docstring
class NanoChatPreTrainedModel(PreTrainedModel):
config: NanoChatConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["NanoChatDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn = True
_supports_sdpa = True
_supports_flex_attn = True
_can_compile_fullgraph = True
_supports_attention_backend = True
_can_record_outputs = {
"hidden_states": NanoChatDecoderLayer,
"attentions": NanoChatAttention,
}
def _init_weights(self, module: nn.Module) -> None:
super()._init_weights(module)
for name, param in module.named_parameters():
if name == "o_proj.weight":
nn.init.normal_(
param,
mean=0.0,
std=self.config.initializer_range / math.sqrt(2 * self.config.num_hidden_layers),
)
@auto_docstring
class NanoChatModel(NanoChatPreTrainedModel):
def __init__(self, config: NanoChatConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList(
[NanoChatDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.norm = NanoChatRMSNorm(eps=config.rms_norm_eps)
self.rotary_emb = NanoChatRotaryEmbedding(config=config)
self.gradient_checkpointing = False
self.initial_norm = NanoChatRMSNorm(eps=config.rms_norm_eps)
# Initialize weights and apply final processing
self.post_init()
@check_model_inputs()
@auto_docstring
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
**kwargs: Unpack[TransformersKwargs],
) -> BaseModelOutputWithPast:
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if inputs_embeds is None:
inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
if use_cache and past_key_values is None:
past_key_values = DynamicCache(config=self.config)
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position: torch.Tensor = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
causal_mask = create_causal_mask(
config=self.config,
input_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
position_ids=position_ids,
)
hidden_states = inputs_embeds
position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
hidden_states = self.initial_norm(hidden_states) # Additional norm before the layers
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
hidden_states = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_embeddings=position_embeddings,
position_ids=position_ids,
past_key_values=past_key_values,
cache_position=cache_position,
**kwargs,
)
hidden_states = self.norm(hidden_states)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values,
)
@auto_docstring
class NanoChatForCausalLM(NanoChatPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
def __init__(self, config):
super().__init__(config)
self.model = NanoChatModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
@can_return_tuple
@auto_docstring
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs: Unpack[TransformersKwargs],
) -> CausalLMOutputWithPast:
r"""
Example:
```python
>>> from transformers import AutoTokenizer, AutoModelForCausalLM
>>> model = AutoModelForCausalLM.from_pretrained("karpathy/nanochat-d32")
>>> tokenizer = AutoTokenizer.from_pretrained("karpathy/nanochat-d32")
>>> conversation = [
{"role": "user", "content": "What is the capital of France?"},
]
>>> inputs = tokenizer.apply_chat_template(
conversation, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt"
).to(device)
>>> with torch.no_grad():
>>> outputs = model.generate(**inputs, max_new_tokens=64, do_sample=False)
>>> generated_tokens = outputs[0, inputs["input_ids"].shape[1] :]
>>> output = tokenizer.decode(generated_tokens, skip_special_tokens=True)
```"""
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs: BaseModelOutputWithPast = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs.last_hidden_state
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
if self.config.final_logit_softcapping is not None:
logits = logits / self.config.final_logit_softcapping
logits = torch.tanh(logits)
logits = logits * self.config.final_logit_softcapping
loss = None
if labels is not None:
loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
__all__ = ["NanoChatPreTrainedModel", "NanoChatModel", "NanoChatForCausalLM"]

View File

@ -1,247 +0,0 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from collections.abc import Callable
from typing import Optional
import torch
import torch.nn as nn
from ...cache_utils import Cache, DynamicCache
from ...masking_utils import create_causal_mask
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring
from ..clip.modeling_clip import CLIPMLP
from ..gemma2.modeling_gemma2 import Gemma2ForCausalLM
from ..llama.modeling_llama import (
LlamaDecoderLayer,
LlamaModel,
LlamaPreTrainedModel,
LlamaRotaryEmbedding,
apply_rotary_pos_emb,
eager_attention_forward,
)
from ..llama4.modeling_llama4 import Llama4TextL2Norm
from ..qwen3.modeling_qwen3 import Qwen3Attention
from .configuration_nanochat import NanoChatConfig
class NanoChatRMSNorm(Llama4TextL2Norm):
pass
class NanoChatRotaryEmbedding(LlamaRotaryEmbedding):
pass
def rotate_half(x):
"""Rotates half the hidden dims of the input with flipped signs for NanoChat."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((x2, -x1), dim=-1)
class NanoChatAttention(Qwen3Attention):
def __init__(self, config: NanoChatConfig, layer_idx: int):
super().__init__(config, layer_idx)
del self.sliding_window
del self.layer_type
self.q_norm = NanoChatRMSNorm(eps=config.rms_norm_eps)
self.k_norm = NanoChatRMSNorm(eps=config.rms_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
# RoPE -> Norm (instead of usual Norm -> RoPE)
query_states = self.q_norm(query_states)
key_states = self.k_norm(key_states)
if past_key_values is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
**kwargs,
)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
class NanoChatMLP(CLIPMLP):
def __init__(self, config):
super().__init__(config)
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
class NanoChatDecoderLayer(LlamaDecoderLayer):
def __init__(self, config: NanoChatConfig, layer_idx: int):
super().__init__()
self.input_layernorm = NanoChatRMSNorm(eps=config.rms_norm_eps)
self.post_attention_layernorm = NanoChatRMSNorm(eps=config.rms_norm_eps)
@auto_docstring
class NanoChatPreTrainedModel(LlamaPreTrainedModel):
def _init_weights(self, module: nn.Module) -> None:
PreTrainedModel._init_weights(module)
for name, param in module.named_parameters():
if name == "o_proj.weight":
nn.init.normal_(
param,
mean=0.0,
std=self.config.initializer_range / math.sqrt(2 * self.config.num_hidden_layers),
)
@auto_docstring
class NanoChatModel(LlamaModel):
def __init__(self, config: NanoChatConfig):
super().__init__(config)
self.initial_norm = NanoChatRMSNorm(eps=config.rms_norm_eps)
self.norm = NanoChatRMSNorm(eps=config.rms_norm_eps)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
**kwargs: Unpack[TransformersKwargs],
) -> BaseModelOutputWithPast:
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if inputs_embeds is None:
inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
if use_cache and past_key_values is None:
past_key_values = DynamicCache(config=self.config)
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position: torch.Tensor = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
causal_mask = create_causal_mask(
config=self.config,
input_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
position_ids=position_ids,
)
hidden_states = inputs_embeds
position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
hidden_states = self.initial_norm(hidden_states) # Additional norm before the layers
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
hidden_states = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_embeddings=position_embeddings,
position_ids=position_ids,
past_key_values=past_key_values,
cache_position=cache_position,
**kwargs,
)
hidden_states = self.norm(hidden_states)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values,
)
@auto_docstring
class NanoChatForCausalLM(Gemma2ForCausalLM):
def forward(self, **super_kwargs) -> CausalLMOutputWithPast:
r"""
Example:
```python
>>> from transformers import AutoTokenizer, AutoModelForCausalLM
>>> model = AutoModelForCausalLM.from_pretrained("karpathy/nanochat-d32")
>>> tokenizer = AutoTokenizer.from_pretrained("karpathy/nanochat-d32")
>>> conversation = [
{"role": "user", "content": "What is the capital of France?"},
]
>>> inputs = tokenizer.apply_chat_template(
conversation, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt"
).to(device)
>>> with torch.no_grad():
>>> outputs = model.generate(**inputs, max_new_tokens=64, do_sample=False)
>>> generated_tokens = outputs[0, inputs["input_ids"].shape[1] :]
>>> output = tokenizer.decode(generated_tokens, skip_special_tokens=True)
```"""
super().forward(**super_kwargs)
__all__ = [
"NanoChatPreTrainedModel",
"NanoChatModel",
"NanoChatForCausalLM",
]

View File

@ -35,11 +35,14 @@ from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
from ...utils.generic import check_model_inputs
from .configuration_vaultgemma import VaultGemmaConfig
logger = logging.get_logger(__name__)
class VaultGemmaRMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
@ -251,18 +254,22 @@ class VaultGemmaDecoderLayer(GradientCheckpointingLayer):
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, _ = self.self_attn(
hidden_states, self_attn_weights = self.self_attn(
hidden_states=hidden_states,
position_embeddings=position_embeddings,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
@ -273,7 +280,11 @@ class VaultGemmaDecoderLayer(GradientCheckpointingLayer):
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
return outputs
class VaultGemmaRotaryEmbedding(nn.Module):
@ -395,16 +406,30 @@ class VaultGemmaModel(VaultGemmaPreTrainedModel):
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> BaseModelOutputWithPast:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if inputs_embeds is None:
inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
if self.gradient_checkpointing and self.training and use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
)
use_cache = False
if use_cache and past_key_values is None:
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if use_cache and past_key_values is None and not self.training:
past_key_values = DynamicCache(config=self.config)
if cache_position is None:
@ -443,22 +468,41 @@ class VaultGemmaModel(VaultGemmaPreTrainedModel):
normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
hidden_states = hidden_states * normalizer
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
hidden_states = decoder_layer(
if output_hidden_states:
all_hidden_states += (hidden_states,)
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
position_embeddings=position_embeddings,
position_ids=position_ids,
past_key_values=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
if output_hidden_states:
all_hidden_states += (hidden_states,)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
@ -488,9 +532,11 @@ class VaultGemmaForCausalLM(VaultGemmaPreTrainedModel, GenerationMixin):
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs: Unpack[TransformersKwargs],
**kwargs,
) -> CausalLMOutputWithPast:
r"""
Example:
@ -509,6 +555,11 @@ class VaultGemmaForCausalLM(VaultGemmaPreTrainedModel, GenerationMixin):
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"What is your favorite condiment?"
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs: BaseModelOutputWithPast = self.model(
input_ids=input_ids,
@ -517,6 +568,8 @@ class VaultGemmaForCausalLM(VaultGemmaPreTrainedModel, GenerationMixin):
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
cache_position=cache_position,
**kwargs,
)

View File

@ -191,18 +191,22 @@ class VaultGemmaDecoderLayer(Gemma2DecoderLayer):
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, _ = self.self_attn(
hidden_states, self_attn_weights = self.self_attn(
hidden_states=hidden_states,
position_embeddings=position_embeddings,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
@ -213,7 +217,11 @@ class VaultGemmaDecoderLayer(Gemma2DecoderLayer):
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
return outputs
class VaultGemmaForCausalLM(Gemma2ForCausalLM):

View File

@ -47,7 +47,8 @@ PACKAGE_DISTRIBUTION_MAPPING = importlib.metadata.packages_distributions()
def _is_package_available(pkg_name: str, return_version: bool = False) -> tuple[bool, str] | bool:
"""Check if `pkg_name` exist, and optionally try to get its version"""
spec = importlib.util.find_spec(pkg_name)
package_exists = spec is not None
# the spec might be not None but not importable
package_exists = spec is not None and spec.loader is not None
package_version = "N/A"
if package_exists and return_version:
try:

View File

@ -1,233 +0,0 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Testing suite for the PyTorch NanoChat model."""
import unittest
from transformers import AutoTokenizer, NanoChatConfig, is_torch_available
from transformers.testing_utils import (
cleanup,
require_torch,
slow,
torch_device,
)
if is_torch_available():
import torch
from transformers import (
NanoChatForCausalLM,
NanoChatModel,
)
from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester
class NanoChatModelTester(CausalLMModelTester):
config_class = NanoChatConfig
if is_torch_available():
base_model_class = NanoChatModel
causal_lm_class = NanoChatForCausalLM
@require_torch
class NanoChatModelTest(CausalLMModelTest, unittest.TestCase):
model_tester_class = NanoChatModelTester
@require_torch
class NanoChatIntegrationTest(unittest.TestCase):
"""Integration tests for NanoChat models using real checkpoints."""
def setUp(self):
cleanup(torch_device, gc_collect=True)
def tearDown(self):
cleanup(torch_device, gc_collect=True)
@slow
def test_model_d20_logits(self):
"""Test that d20 model logits are computed correctly."""
model_id = "nanochat-students/nanochat-d20"
model = NanoChatForCausalLM.from_pretrained(model_id, device_map="auto", torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(model_id)
# Simple test input - "Hello world"
test_text = "Hello world"
input_ids = tokenizer.encode(test_text, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model(input_ids)
logits = outputs.logits.float().cpu()
# Basic shape checks
self.assertEqual(logits.shape[0], 1) # batch size
self.assertEqual(logits.shape[1], input_ids.shape[1]) # sequence length
self.assertEqual(logits.shape[2], model.config.vocab_size) # vocab size 65536
# Check logits are not NaN or Inf
self.assertFalse(torch.isnan(logits).any())
self.assertFalse(torch.isinf(logits).any())
# Check expected mean logits (with tolerance for numerical variation)
EXPECTED_MEAN = torch.tensor([[-6.6607, -7.8095]])
# Check first 10 logits at position [0,0,:10]
EXPECTED_SLICE = torch.tensor(
[-12.8750, -13.0625, -13.1875, -13.1875, -13.1875, -13.1875, -13.1875, -13.1875, -12.6250, -4.4062]
)
torch.testing.assert_close(logits.mean(-1), EXPECTED_MEAN, rtol=1e-3, atol=1e-3)
torch.testing.assert_close(logits[0, 0, :10], EXPECTED_SLICE, rtol=1e-3, atol=1e-3)
@slow
def test_model_d20_generation(self):
"""Test that d20 model generates text correctly."""
model_id = "nanochat-students/nanochat-d20"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = NanoChatForCausalLM.from_pretrained(model_id, device_map="auto", torch_dtype=torch.bfloat16)
# Test generation with chat template
conversation = [
[
{"role": "user", "content": "What is the capital of France?"},
],
[
{"role": "user", "content": "Tell me something."},
],
]
inputs = tokenizer.apply_chat_template(
conversation,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
padding=True,
tokenizer_kwargs={"padding_side": "left"},
return_tensors="pt",
).to(model.device)
# Generate with greedy decoding for reproducibility
with torch.no_grad():
generated_ids = model.generate(
**inputs,
max_new_tokens=32,
do_sample=False,
)
# Decode only the generated tokens
generated_text = [
tokenizer.decode(generated_ids[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True),
tokenizer.decode(generated_ids[1, inputs["input_ids"].shape[1] :], skip_special_tokens=True),
]
EXPECTED_TEXT_COMPLETION = [
"The capital of France is Paris.",
"I'm ready to help. What's the first thing you'd like to know or discuss?",
]
self.assertEqual(EXPECTED_TEXT_COMPLETION[0], generated_text[0])
self.assertEqual(EXPECTED_TEXT_COMPLETION[1], generated_text[1])
@slow
def test_model_d32_logits(self):
"""Test that d32 model logits are computed correctly."""
model_id = "karpathy/nanochat-d32"
revision = "refs/pr/1" # TODO: update when merged to hub
model = NanoChatForCausalLM.from_pretrained(
model_id, device_map="auto", torch_dtype=torch.bfloat16, revision=revision
)
tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision)
# Simple test input - "Hello world"
test_text = "Hello world"
input_ids = tokenizer.encode(test_text, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model(input_ids)
logits = outputs.logits.float().cpu()
# Basic shape checks
self.assertEqual(logits.shape[0], 1) # batch size
self.assertEqual(logits.shape[1], input_ids.shape[1]) # sequence length
self.assertEqual(logits.shape[2], model.config.vocab_size) # vocab size 65536
# Check logits are not NaN or Inf
self.assertFalse(torch.isnan(logits).any())
self.assertFalse(torch.isinf(logits).any())
# Check expected mean logits (with tolerance for numerical variation)
EXPECTED_MEAN = torch.tensor([[-5.5791, -8.3456]])
# Check first 10 logits at position [0,0,:10]
EXPECTED_SLICE = torch.tensor(
[-12.3125, -13.1250, -12.8125, -13.1250, -13.1250, -13.1250, -13.1250, -13.1250, -11.8125, -1.4688]
)
torch.testing.assert_close(logits.mean(-1), EXPECTED_MEAN, rtol=1e-3, atol=1e-3)
torch.testing.assert_close(logits[0, 0, :10], EXPECTED_SLICE, rtol=1e-3, atol=1e-3)
@slow
def test_model_d32_generation(self):
"""Test that d32 model generates text correctly."""
model_id = "karpathy/nanochat-d32"
revision = "refs/pr/1" # TODO: update when merged to hub
tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision)
model = NanoChatForCausalLM.from_pretrained(
model_id, device_map="auto", torch_dtype=torch.bfloat16, revision=revision
)
# Test generation with chat template
conversation = [
[
{"role": "user", "content": "What is the capital of France?"},
],
[
{"role": "user", "content": "Tell me something."},
],
]
inputs = tokenizer.apply_chat_template(
conversation,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
padding=True,
tokenizer_kwargs={"padding_side": "left"},
return_tensors="pt",
).to(model.device)
# Generate with greedy decoding for reproducibility
with torch.no_grad():
generated_ids = model.generate(
**inputs,
max_new_tokens=32,
do_sample=False,
)
# Decode only the generated tokens
generated_text = [
tokenizer.decode(generated_ids[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True),
tokenizer.decode(generated_ids[1, inputs["input_ids"].shape[1] :], skip_special_tokens=True),
]
EXPECTED_TEXT_COMPLETION = [
"The capital of France is Paris.",
"I'm here to help you explore your creative writing endeavors. What's been on your mind lately? Do you have a story idea you'd like to develop,",
]
self.assertEqual(EXPECTED_TEXT_COMPLETION[0], generated_text[0])
self.assertEqual(EXPECTED_TEXT_COMPLETION[1], generated_text[1])

View File

@ -0,0 +1,224 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import sys
from pathlib import Path
from transformers import is_torch_available
from transformers.testing_utils import (
TestCasePlus,
execute_subprocess_async,
require_accelerate,
require_torch_multi_accelerator,
run_first,
slow,
)
if is_torch_available():
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
DataCollatorForLanguageModeling,
HfArgumentParser,
Trainer,
TrainingArguments,
)
class TestTrainerContextParallelTorch(TestCasePlus):
"""Test Trainer with Torch context parallelism enabled via accelerate's ParallelismConfig."""
@require_torch_multi_accelerator
@require_accelerate
@slow
@run_first
def test_cp_equivalence(self):
"""Test that CP produces the same losses as without CP."""
# Shared setup
world_size = 2
script_path = __file__
# Step 1: Run with CP enabled (cp_size=world_size)
cp_yes_output_dir = Path(self.get_auto_remove_tmp_dir()).resolve()
cp_yes_config_path = cp_yes_output_dir / "context_parallel_config.yaml"
cp_yes_losses_path = cp_yes_output_dir / "cp_yes_losses.json"
# Write config file inline (self-contained test)
with open(cp_yes_config_path, "w") as f:
f.write(
f"""distributed_type: FSDP
fsdp_config:
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_state_dict_type: SHARDED_STATE_DICT
fsdp_version: 2
mixed_precision: bf16
num_processes: {world_size}
parallelism_config:
parallelism_config_dp_replicate_size: 1
parallelism_config_dp_shard_size: 1
parallelism_config_tp_size: 1
parallelism_config_cp_size: {world_size}
parallelism_config_cp_comm_strategy: alltoall
"""
)
cmd_cp_yes = f"""
accelerate launch
--config_file {cp_yes_config_path}
{script_path}
--output_dir {cp_yes_output_dir}
--report_to none
--max_steps 10
--per_device_train_batch_size 1
--gradient_accumulation_steps 1
--logging_steps 1
--remove_unused_columns False
--seed 42
--loss_output_file {cp_yes_losses_path}
""".split()
execute_subprocess_async(cmd_cp_yes, env=self.get_env())
# Step 2: Run without CP (FSDP with num_processes=1, no parallelism_config)
cp_no_output_dir = Path(self.get_auto_remove_tmp_dir()).resolve()
cp_no_config_path = cp_no_output_dir / "context_parallel_config.yaml"
cp_no_losses_path = cp_no_output_dir / "cp_no_losses.json"
# Write config file inline (self-contained test)
with open(cp_no_config_path, "w") as f:
f.write(
"""distributed_type: FSDP
fsdp_config:
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_state_dict_type: SHARDED_STATE_DICT
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
fsdp_version: 2
mixed_precision: bf16
num_processes: 1
"""
)
cmd_cp_no = f"""
accelerate launch
--config_file {cp_no_config_path}
{script_path}
--output_dir {cp_no_output_dir}
--report_to none
--max_steps 10
--per_device_train_batch_size 1
--gradient_accumulation_steps 1
--logging_steps 1
--remove_unused_columns False
--seed 42
--loss_output_file {cp_no_losses_path}
""".split()
execute_subprocess_async(cmd_cp_no, env=self.get_env())
# Compare losses - should be very close since CP just splits sequence computation
with open(cp_yes_losses_path) as f:
cp_yes_losses = json.load(f)
with open(cp_no_losses_path) as f:
cp_no_losses = json.load(f)
assert len(cp_yes_losses) == len(cp_no_losses), (
f"Different number of losses: CP has {len(cp_yes_losses)}, no-CP has {len(cp_no_losses)}"
)
# CP should produce very similar results (small numerical differences expected)
# The differences come from:
# - Different gradient reduction patterns in distributed training
# - BF16 mixed precision accumulated differences
# - Sequence splitting and gathering in CP mode
cp_yes_losses_tensor = torch.tensor(cp_yes_losses)
cp_no_losses_tensor = torch.tensor(cp_no_losses)
# Use torch.testing.assert_close with rtol=2% and atol=0.02
# Testing shows actual differences are typically <1.5%
torch.testing.assert_close(
cp_yes_losses_tensor,
cp_no_losses_tensor,
rtol=2e-2, # 2% relative tolerance
atol=2e-2, # 0.02 absolute tolerance
msg=f"CP losses {cp_yes_losses} do not match non-CP losses {cp_no_losses}",
)
if __name__ == "__main__":
# Parse custom arguments (not TrainingArguments parameters)
loss_output_file = None
if "--loss_output_file" in sys.argv:
idx = sys.argv.index("--loss_output_file")
loss_output_file = sys.argv[idx + 1]
sys.argv.pop(idx)
sys.argv.pop(idx)
parser = HfArgumentParser((TrainingArguments,))
training_args = parser.parse_args_into_dataclasses()[0]
# Use SmolLM (small Llama-based model that works with CP)
model_name = "HuggingFaceTB/SmolLM-135M"
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
model_name,
attn_implementation="sdpa", # CP requires SDPA
)
# Create simple dataset: just tokenize some text
texts = [
"The quick brown fox jumps over the lazy dog. " * 10,
"Hello world, this is a test sentence for training. " * 10,
] * 4 # 8 samples total
def tokenize_function(examples):
return tokenizer(examples, max_length=128, truncation=True, padding="max_length")
train_dataset = [tokenize_function(text) for text in texts]
# Use standard DataCollatorForLanguageModeling for causal LM
# pad_to_multiple_of=4 ensures sequences are divisible by cp_size * 2 (for cp_size=2)
# Trainer will automatically generate position_ids and shift_labels as needed
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm=False, # Causal language modeling
pad_to_multiple_of=4,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
data_collator=data_collator,
)
# Train for a few steps
trainer.train()
# Verify training completed
assert trainer.state.global_step > 0, "Training should have completed at least one step"
# Save losses to file if requested (for equivalence testing)
if loss_output_file and training_args.process_index == 0:
losses = [log["loss"] for log in trainer.state.log_history if "loss" in log]
with open(loss_output_file, "w") as f:
json.dump(losses, f)