mirror of
https://github.com/huggingface/transformers.git
synced 2025-11-03 03:14:36 +08:00
Compare commits
8 Commits
nanochat-i
...
cp-ci-test
| Author | SHA1 | Date | |
|---|---|---|---|
| a4a187e0e5 | |||
| 4ec11c20c8 | |||
| df8aaacbbb | |||
| e18436c0d2 | |||
| 6d42d9a599 | |||
| 0bcf34edad | |||
| 1ff587b504 | |||
| 63efa18396 |
@ -14,7 +14,7 @@ This AGENTS.md file provides guidance for code agents working with this codebase
|
||||
|
||||
- PRs should be as brief as possible. Bugfix PRs in particular can often be only one or two lines long, and do not need large comments, docstrings or new functions in this case. Aim to minimize the size of the diff.
|
||||
- When writing tests, they should be added to an existing file. The only exception is for PRs to add a new model, when a new test directory should be created for that model.
|
||||
- Code style is enforced in the CI. You can install the style tools with `pip install -e ".[quality]"`. You can then run `make fixup` to apply style and consistency fixes to your code.
|
||||
- Code style is enforced in the CI. You can install the style tools with `pip install -e .[quality]`. You can then run `make fixup` to apply style and consistency fixes to your code.
|
||||
|
||||
## Copying and inheritance
|
||||
|
||||
@ -36,4 +36,4 @@ After making changes, you should usually run `make fixup` to ensure any copies a
|
||||
the model you made the changes in and any other models that were updated by `make fixup`. Tests can be run with `pytest tests/models/[name]/test_modeling_[name].py`
|
||||
If your changes affect code in other classes like tokenizers or processors, you should run those tests instead, like `test_processing_[name].py` or `test_tokenization_[name].py`.
|
||||
|
||||
In order to run tests, you may need to install dependencies. You can do this with `pip install -e ".[testing]"`. You will probably also need to `pip install torch accelerate` if your environment does not already have them.
|
||||
In order to run tests, you may need to install dependencies. You can do this with `pip install -e .[testing]`. You will probably also need to `pip install torch accelerate` if your environment does not already have them.
|
||||
@ -9,12 +9,6 @@ In this list, we showcase incredibly impactful and novel projects that have push
|
||||
adding other projects to the list. If you believe a project should be here and it's not, then please, open a PR
|
||||
to add it.
|
||||
|
||||
## [◉ Universal Intelligence](https://github.com/blueraai/universal-intelligence)
|
||||
|
||||
[Universal Intelligence](https://github.com/blueraai/universal-intelligence) aims to standardize models, tools, and agents —transforming them into simple, composable, portable, interoperable, framework-agnostic, hardware-agnostic interfaces (through auto-negotiation and resource sharing); for fast and accessible development of AI applications.
|
||||
|
||||
Keywords: Protocol, Open-source, LLMs, Large Language Models, Agents, Low-code
|
||||
|
||||
## [gpt4all](https://github.com/nomic-ai/gpt4all)
|
||||
|
||||
[gpt4all](https://github.com/nomic-ai/gpt4all) is an ecosystem of open-source chatbots trained on massive collections of clean assistant data including code, stories and dialogue. It offers open-source, large language models such as LLaMA and GPT-J trained in an assistant-style.
|
||||
|
||||
@ -626,8 +626,6 @@
|
||||
title: MVP
|
||||
- local: model_doc/myt5
|
||||
title: myt5
|
||||
- local: model_doc/nanochat
|
||||
title: NanoChat
|
||||
- local: model_doc/nemotron
|
||||
title: Nemotron
|
||||
- local: model_doc/nezha
|
||||
|
||||
@ -1,119 +0,0 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# NanoChat
|
||||
|
||||
<div style="float: right;">
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
</div>
|
||||
|
||||
[NanoChat](https://huggingface.co/karpathy/nanochat-d32) is a compact decoder-only transformer model designed for educational purposes and efficient training. The model features several fundamental architectural innovations which are common in modern transformer models. Therefore, it is a good model to use as a starting point to understand the principles of modern transformer models. NanoChat is a variant of the [Llama](https://huggingface.co/docs/transformers/en/model_doc/llama) architecture, with simplified attention mechanism and normalization layers.
|
||||
|
||||
The architecture is based on [nanochat](https://github.com/karpathy/nanochat) by [Andrej Karpathy](https://huggingface.co/karpathy), adapted for the Hugging Face Transformers library by [Ben Burtenshaw](https://huggingface.co/burtenshaw).
|
||||
|
||||
> [!TIP]
|
||||
> This model was contributed by the Hugging Face team.
|
||||
|
||||
The example below demonstrates how to use NanoChat for text generation with chat templates.
|
||||
|
||||
<hfoptions id="usage">
|
||||
<hfoption id="Pipeline">
|
||||
|
||||
```py
|
||||
import torch
|
||||
from transformers import pipeline
|
||||
|
||||
chatbot = pipeline(
|
||||
task="text-generation",
|
||||
model="karpathy/nanochat-d32",
|
||||
dtype=torch.bfloat16,
|
||||
device=0
|
||||
)
|
||||
|
||||
conversation = [
|
||||
{"role": "user", "content": "What is the capital of France?"},
|
||||
]
|
||||
|
||||
outputs = chatbot(conversation, max_new_tokens=64)
|
||||
print(outputs[0]["generated_text"][-1]["content"])
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="AutoModel">
|
||||
|
||||
```py
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
model_id = "karpathy/nanochat-d32"
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id,
|
||||
dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
)
|
||||
model.eval()
|
||||
|
||||
conversation = [
|
||||
{"role": "user", "content": "What is the capital of France?"},
|
||||
]
|
||||
|
||||
inputs = tokenizer.apply_chat_template(
|
||||
conversation,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt"
|
||||
).to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=64,
|
||||
)
|
||||
|
||||
# Decode only the generated tokens (excluding the input prompt)
|
||||
generated_tokens = outputs[0, inputs["input_ids"].shape[1]:]
|
||||
print(tokenizer.decode(generated_tokens, skip_special_tokens=True))
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="transformers CLI">
|
||||
|
||||
```bash
|
||||
echo -e '{"role": "user", "content": "What is the capital of France?"}' | transformers run --task text-generation --model karpathy/nanochat-d32 --device 0
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
## NanoChatConfig
|
||||
|
||||
[[autodoc]] NanoChatConfig
|
||||
|
||||
## NanoChatModel
|
||||
|
||||
[[autodoc]] NanoChatModel
|
||||
- forward
|
||||
|
||||
## NanoChatForCausalLM
|
||||
|
||||
[[autodoc]] NanoChatForCausalLM
|
||||
- forward
|
||||
@ -38,7 +38,7 @@ pip install transformers[dev]
|
||||
or for an editable install:
|
||||
|
||||
```bash
|
||||
pip install -e ".[dev]"
|
||||
pip install -e .[dev]
|
||||
```
|
||||
|
||||
inside the Transformers repo. Since the number of optional dependencies of Transformers has grown a lot, it's possible you don't manage to get all of them. If the dev install fails, make sure to install PyTorch then do
|
||||
@ -50,7 +50,7 @@ pip install transformers[quality]
|
||||
or for an editable install:
|
||||
|
||||
```bash
|
||||
pip install -e ".[quality]"
|
||||
pip install -e .[quality]
|
||||
```
|
||||
|
||||
## Tests
|
||||
|
||||
@ -37,7 +37,7 @@ pip install transformers[dev]
|
||||
o una instalación editable:
|
||||
|
||||
```bash
|
||||
pip install -e ".[dev]"
|
||||
pip install -e .[dev]
|
||||
```
|
||||
|
||||
del repositorio de Transformers.
|
||||
|
||||
@ -37,7 +37,7 @@ pip install transformers[dev]
|
||||
o un'installazione modificabile:
|
||||
|
||||
```bash
|
||||
pip install -e ".[dev]"
|
||||
pip install -e .[dev]
|
||||
```
|
||||
|
||||
all'interno del repo Transformers.
|
||||
|
||||
@ -40,7 +40,7 @@ pip install transformers[dev]
|
||||
|
||||
|
||||
```bash
|
||||
pip install -e ".[dev]"
|
||||
pip install -e .[dev]
|
||||
```
|
||||
|
||||
トランスフォーマーズのリポジトリ内で作業しています。トランスフォーマーズのオプションの依存関係の数が増えたため、すべてを取得できない可能性があります。開発用インストールが失敗した場合、作業しているディープラーニングフレームワーク(PyTorch、TensorFlow、および/またはFlax)をインストールし、次の手順を実行してください。
|
||||
@ -53,7 +53,7 @@ pip install transformers[quality]
|
||||
または編集可能なインストールの場合:
|
||||
|
||||
```bash
|
||||
pip install -e ".[quality]"
|
||||
pip install -e .[quality]
|
||||
```
|
||||
|
||||
## Tests
|
||||
|
||||
@ -37,7 +37,7 @@ pip install transformers[dev]
|
||||
또는 Transformers 저장소 내에 편집 가능한 설치가 필요합니다:
|
||||
|
||||
```bash
|
||||
pip install -e ".[dev]"
|
||||
pip install -e .[dev]
|
||||
```
|
||||
|
||||
Transformers의 선택적 종속성 수가 많이 늘어났기 때문에 개발 설치를 실패할 수도 있습니다. 개발 설치가 실패하는 경우, 작업 중인 Deep Learning 프레임워크 (PyTorch, TensorFlow 및/또는 Flax)를 설치하고 다음 명령을 실행하세요.
|
||||
@ -49,7 +49,7 @@ pip install transformers[quality]
|
||||
편집 가능한 설치의 경우는 다음 명령을 실행하세요.
|
||||
|
||||
```bash
|
||||
pip install -e ".[quality]"
|
||||
pip install -e .[quality]
|
||||
```
|
||||
|
||||
|
||||
|
||||
@ -240,7 +240,6 @@ if TYPE_CHECKING:
|
||||
from .musicgen_melody import *
|
||||
from .mvp import *
|
||||
from .myt5 import *
|
||||
from .nanochat import *
|
||||
from .nemotron import *
|
||||
from .nllb import *
|
||||
from .nllb_moe import *
|
||||
|
||||
@ -281,7 +281,6 @@ CONFIG_MAPPING_NAMES = OrderedDict[str, str](
|
||||
("musicgen", "MusicgenConfig"),
|
||||
("musicgen_melody", "MusicgenMelodyConfig"),
|
||||
("mvp", "MvpConfig"),
|
||||
("nanochat", "NanoChatConfig"),
|
||||
("nat", "NatConfig"),
|
||||
("nemotron", "NemotronConfig"),
|
||||
("nezha", "NezhaConfig"),
|
||||
@ -738,8 +737,6 @@ MODEL_NAMES_MAPPING = OrderedDict[str, str](
|
||||
("musicgen_melody", "MusicGen Melody"),
|
||||
("mvp", "MVP"),
|
||||
("myt5", "myt5"),
|
||||
("nanochat", "NanoChat"),
|
||||
("nanochat", "NanoChatForCausalLM"),
|
||||
("nat", "NAT"),
|
||||
("nemotron", "Nemotron"),
|
||||
("nezha", "Nezha"),
|
||||
|
||||
@ -282,7 +282,6 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
||||
("musicgen", "MusicgenModel"),
|
||||
("musicgen_melody", "MusicgenMelodyModel"),
|
||||
("mvp", "MvpModel"),
|
||||
("nanochat", "NanoChatModel"),
|
||||
("nat", "NatModel"),
|
||||
("nemotron", "NemotronModel"),
|
||||
("nezha", "NezhaModel"),
|
||||
@ -499,7 +498,6 @@ MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
|
||||
("mpt", "MptForCausalLM"),
|
||||
("mra", "MraForMaskedLM"),
|
||||
("mvp", "MvpForConditionalGeneration"),
|
||||
("nanochat", "NanoChatForCausalLM"),
|
||||
("nezha", "NezhaForPreTraining"),
|
||||
("nllb-moe", "NllbMoeForConditionalGeneration"),
|
||||
("openai-gpt", "OpenAIGPTLMHeadModel"),
|
||||
@ -722,7 +720,6 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
||||
("musicgen", "MusicgenForCausalLM"),
|
||||
("musicgen_melody", "MusicgenMelodyForCausalLM"),
|
||||
("mvp", "MvpForCausalLM"),
|
||||
("nanochat", "NanoChatForCausalLM"),
|
||||
("nemotron", "NemotronForCausalLM"),
|
||||
("olmo", "OlmoForCausalLM"),
|
||||
("olmo2", "Olmo2ForCausalLM"),
|
||||
|
||||
@ -468,7 +468,6 @@ TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]](
|
||||
("musicgen_melody", ("T5Tokenizer", "T5TokenizerFast" if is_tokenizers_available() else None)),
|
||||
("mvp", ("MvpTokenizer", "MvpTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("myt5", ("MyT5Tokenizer", None)),
|
||||
("nanochat", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("nemotron", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("nezha", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
|
||||
(
|
||||
|
||||
@ -1392,7 +1392,7 @@ class Emu3Model(Emu3PreTrainedModel):
|
||||
image_features = torch.split(image_features, split_sizes)
|
||||
return image_features
|
||||
|
||||
@torch.no_grad()
|
||||
@torch.no_grad
|
||||
def decode_image_tokens(self, image_tokens: torch.LongTensor, height: int, width: int):
|
||||
"""
|
||||
Decodes generated image tokens from language model to continuous pixel values
|
||||
|
||||
@ -946,7 +946,7 @@ class Emu3Model(Emu3PreTrainedModel):
|
||||
image_features = torch.split(image_features, split_sizes)
|
||||
return image_features
|
||||
|
||||
@torch.no_grad()
|
||||
@torch.no_grad
|
||||
def decode_image_tokens(self, image_tokens: torch.LongTensor, height: int, width: int):
|
||||
"""
|
||||
Decodes generated image tokens from language model to continuous pixel values
|
||||
|
||||
@ -39,11 +39,14 @@ from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
|
||||
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
|
||||
from ...utils.generic import check_model_inputs
|
||||
from .configuration_gemma2 import Gemma2Config
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class Gemma2RMSNorm(nn.Module):
|
||||
def __init__(self, dim: int, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
@ -322,6 +325,8 @@ class Gemma2DecoderLayer(GradientCheckpointingLayer):
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Cache] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
@ -330,12 +335,14 @@ class Gemma2DecoderLayer(GradientCheckpointingLayer):
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
# Self Attention
|
||||
hidden_states, _ = self.self_attn(
|
||||
hidden_states, self_attn_weights = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
position_embeddings=position_embeddings,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
@ -348,7 +355,12 @@ class Gemma2DecoderLayer(GradientCheckpointingLayer):
|
||||
hidden_states = self.post_feedforward_layernorm(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
return hidden_states
|
||||
outputs = (hidden_states,)
|
||||
|
||||
if output_attentions:
|
||||
outputs += (self_attn_weights,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
@auto_docstring
|
||||
@ -405,16 +417,30 @@ class Gemma2Model(Gemma2PreTrainedModel):
|
||||
past_key_values: Optional[Cache] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> BaseModelOutputWithPast:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
|
||||
if self.gradient_checkpointing and self.training and use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
if use_cache and past_key_values is None:
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
if use_cache and past_key_values is None and not self.training:
|
||||
past_key_values = DynamicCache(config=self.config)
|
||||
|
||||
if cache_position is None:
|
||||
@ -453,22 +479,41 @@ class Gemma2Model(Gemma2PreTrainedModel):
|
||||
normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
|
||||
hidden_states = hidden_states * normalizer
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
|
||||
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
||||
hidden_states = decoder_layer(
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
|
||||
position_embeddings=position_embeddings,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=past_key_values,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
|
||||
@ -498,9 +543,11 @@ class Gemma2ForCausalLM(Gemma2PreTrainedModel, GenerationMixin):
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
**kwargs,
|
||||
) -> CausalLMOutputWithPast:
|
||||
r"""
|
||||
Example:
|
||||
@ -519,6 +566,11 @@ class Gemma2ForCausalLM(Gemma2PreTrainedModel, GenerationMixin):
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"What is your favorite condiment?"
|
||||
```"""
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs: BaseModelOutputWithPast = self.model(
|
||||
input_ids=input_ids,
|
||||
@ -527,6 +579,8 @@ class Gemma2ForCausalLM(Gemma2PreTrainedModel, GenerationMixin):
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@ -381,6 +381,8 @@ class Gemma2DecoderLayer(GradientCheckpointingLayer):
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Cache] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
@ -389,12 +391,14 @@ class Gemma2DecoderLayer(GradientCheckpointingLayer):
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
# Self Attention
|
||||
hidden_states, _ = self.self_attn(
|
||||
hidden_states, self_attn_weights = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
position_embeddings=position_embeddings,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
@ -407,7 +411,12 @@ class Gemma2DecoderLayer(GradientCheckpointingLayer):
|
||||
hidden_states = self.post_feedforward_layernorm(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
return hidden_states
|
||||
outputs = (hidden_states,)
|
||||
|
||||
if output_attentions:
|
||||
outputs += (self_attn_weights,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class Gemma2PreTrainedModel(GemmaPreTrainedModel):
|
||||
@ -430,16 +439,30 @@ class Gemma2Model(GemmaModel):
|
||||
past_key_values: Optional[Cache] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> BaseModelOutputWithPast:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
|
||||
if self.gradient_checkpointing and self.training and use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
if use_cache and past_key_values is None:
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
if use_cache and past_key_values is None and not self.training:
|
||||
past_key_values = DynamicCache(config=self.config)
|
||||
|
||||
if cache_position is None:
|
||||
@ -478,22 +501,41 @@ class Gemma2Model(GemmaModel):
|
||||
normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
|
||||
hidden_states = hidden_states * normalizer
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
|
||||
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
||||
hidden_states = decoder_layer(
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
|
||||
position_embeddings=position_embeddings,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=past_key_values,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
|
||||
@ -512,9 +554,11 @@ class Gemma2ForCausalLM(GemmaForCausalLM):
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
**kwargs,
|
||||
) -> CausalLMOutputWithPast:
|
||||
r"""
|
||||
Example:
|
||||
@ -533,6 +577,11 @@ class Gemma2ForCausalLM(GemmaForCausalLM):
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"What is your favorite condiment?"
|
||||
```"""
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs: BaseModelOutputWithPast = self.model(
|
||||
input_ids=input_ids,
|
||||
@ -541,6 +590,8 @@ class Gemma2ForCausalLM(GemmaForCausalLM):
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@ -652,9 +652,11 @@ class Gemma3ForCausalLM(Gemma3PreTrainedModel, GenerationMixin):
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
**kwargs,
|
||||
) -> CausalLMOutputWithPast:
|
||||
r"""
|
||||
Example:
|
||||
@ -673,6 +675,11 @@ class Gemma3ForCausalLM(Gemma3PreTrainedModel, GenerationMixin):
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"What is your favorite condiment?"
|
||||
```"""
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs: BaseModelOutputWithPast = self.model(
|
||||
input_ids=input_ids,
|
||||
@ -681,6 +688,8 @@ class Gemma3ForCausalLM(Gemma3PreTrainedModel, GenerationMixin):
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@ -1543,6 +1543,8 @@ class Gemma3nDecoderLayer(GradientCheckpointingLayer):
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Cache] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
@ -1551,12 +1553,14 @@ class Gemma3nDecoderLayer(GradientCheckpointingLayer):
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
# Self Attention
|
||||
hidden_states, _ = self.self_attn(
|
||||
hidden_states, self_attn_weights = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
position_embeddings=position_embeddings,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
@ -1569,7 +1573,12 @@ class Gemma3nDecoderLayer(GradientCheckpointingLayer):
|
||||
hidden_states = self.post_feedforward_layernorm(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
return hidden_states
|
||||
outputs = (hidden_states,)
|
||||
|
||||
if output_attentions:
|
||||
outputs += (self_attn_weights,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
@auto_docstring
|
||||
@ -1950,9 +1959,11 @@ class Gemma3nForCausalLM(Gemma3nPreTrainedModel, GenerationMixin):
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
**kwargs,
|
||||
) -> CausalLMOutputWithPast:
|
||||
r"""
|
||||
Example:
|
||||
@ -1971,6 +1982,11 @@ class Gemma3nForCausalLM(Gemma3nPreTrainedModel, GenerationMixin):
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"What is your favorite condiment?"
|
||||
```"""
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs: BaseModelOutputWithPast = self.model(
|
||||
input_ids=input_ids,
|
||||
@ -1979,6 +1995,8 @@ class Gemma3nForCausalLM(Gemma3nPreTrainedModel, GenerationMixin):
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@ -1283,7 +1283,7 @@ class JanusForConditionalGeneration(JanusPreTrainedModel, GenerationMixin):
|
||||
decoded_image = decoded_image.permute(0, 2, 3, 1)
|
||||
return decoded_image
|
||||
|
||||
@torch.no_grad()
|
||||
@torch.no_grad
|
||||
def generate(
|
||||
self,
|
||||
inputs: Optional[torch.Tensor] = None,
|
||||
|
||||
@ -1099,7 +1099,7 @@ class JanusForConditionalGeneration(JanusPreTrainedModel, GenerationMixin):
|
||||
decoded_image = decoded_image.permute(0, 2, 3, 1)
|
||||
return decoded_image
|
||||
|
||||
@torch.no_grad()
|
||||
@torch.no_grad
|
||||
def generate(
|
||||
self,
|
||||
inputs: Optional[torch.Tensor] = None,
|
||||
|
||||
@ -1,14 +0,0 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import _LazyModule
|
||||
from ...utils.import_utils import define_import_structure
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_nanochat import *
|
||||
from .modeling_nanochat import *
|
||||
else:
|
||||
import sys
|
||||
|
||||
_file = globals()["__file__"]
|
||||
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
||||
@ -1,164 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...modeling_rope_utils import RopeParameters, rope_config_validation, standardize_rope_params
|
||||
|
||||
|
||||
class NanoChatConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`NanoChatModel`]. It is used to instantiate a
|
||||
NanoChat model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
||||
with the defaults will yield a similar configuration to that of the [karpathy/nanochat-d32](https://huggingface.co/karpathy/nanochat-d32).
|
||||
|
||||
Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PreTrainedConfig`] for more information.
|
||||
|
||||
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 50304):
|
||||
Vocabulary size of the NanoChat model. Defines the number of different tokens that can be represented by the
|
||||
`inputs_ids` passed when calling [`NanoChatModel`].
|
||||
hidden_size (`int`, *optional*, defaults to 768):
|
||||
Dimension of the hidden representations.
|
||||
intermediate_size (`int`, *optional*, defaults to 8192):
|
||||
Dimension of the MLP representations. If `None`, it will be computed based on the model architecture.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 12):
|
||||
Number of hidden layers in the Transformer decoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 6):
|
||||
Number of attention heads for each attention layer in the Transformer decoder.
|
||||
num_key_value_heads (`int`, *optional*):
|
||||
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
||||
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
||||
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
||||
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
||||
by meanpooling all the original heads within that group. For more details, check out [this
|
||||
paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
|
||||
`num_attention_heads`.
|
||||
max_position_embeddings (`int`, *optional*, defaults to 2048):
|
||||
The maximum sequence length that this model might ever be used with.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"relu2"`):
|
||||
The non-linear activation function (function or string) in the decoder.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
||||
The epsilon used by the rms normalization layers.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
rope_parameters (`RopeParameters`, *optional*):
|
||||
Dictionary containing the configuration parameters for the RoPE embeddings. The dictionaty should contain
|
||||
a value for `rope_theta` and optionally parameters used for scaling in case you want to use RoPE
|
||||
with longer `max_position_embeddings`.
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||
relevant if `config.is_decoder=True`.
|
||||
final_logit_softcapping (`float`, *optional*, defaults to 15.0):
|
||||
scaling factor when applying tanh softcapping on the logits.
|
||||
attention_bias (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use a bias in the query, key, and value projection layers during self-attention.
|
||||
bos_token_id (`int`, *optional*, defaults to 0):
|
||||
Beginning of stream token id.
|
||||
eos_token_id (`int`, *optional*, defaults to 1):
|
||||
End of stream token id.
|
||||
pad_token_id (`int`, *optional*, defaults to 1):
|
||||
Padding token id.
|
||||
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
||||
Whether to tie weight embeddings
|
||||
|
||||
```python
|
||||
>>> from transformers import NanoChatModel, NanoChatConfig
|
||||
|
||||
>>> # Initializing a NanoChat style configuration
|
||||
>>> configuration = NanoChatConfig()
|
||||
|
||||
>>> # Initializing a model from the NanoChat style configuration
|
||||
>>> model = NanoChatModel(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "nanochat"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
||||
base_model_tp_plan = {
|
||||
"layers.*.self_attn.q_proj": "colwise",
|
||||
"layers.*.self_attn.k_proj": "colwise",
|
||||
"layers.*.self_attn.v_proj": "colwise",
|
||||
"layers.*.self_attn.o_proj": "rowwise",
|
||||
"layers.*.mlp.fc1": "colwise",
|
||||
"layers.*.mlp.fc2": "rowwise",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int = 50304,
|
||||
hidden_size: int = 768,
|
||||
intermediate_size: int | None = 8192,
|
||||
num_hidden_layers: int = 12,
|
||||
num_attention_heads: int = 6,
|
||||
num_key_value_heads: int | None = None,
|
||||
max_position_embeddings: int = 2048,
|
||||
hidden_act: str = "relu2",
|
||||
attention_dropout: float = 0.0,
|
||||
rms_norm_eps: float = 1e-6,
|
||||
initializer_range: float = 0.02,
|
||||
rope_parameters: RopeParameters | dict[RopeParameters] | None = None,
|
||||
use_cache: bool = True,
|
||||
final_logit_softcapping: float | None = 15.0,
|
||||
attention_bias: bool = False,
|
||||
bos_token_id: int = 0,
|
||||
eos_token_id: int = 1,
|
||||
pad_token_id: int = 1,
|
||||
tie_word_embeddings: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
|
||||
# for backward compatibility
|
||||
if num_key_value_heads is None:
|
||||
num_key_value_heads = num_attention_heads
|
||||
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_act = hidden_act
|
||||
self.attention_dropout = attention_dropout
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.initializer_range = initializer_range
|
||||
self.use_cache = use_cache
|
||||
self.final_logit_softcapping = final_logit_softcapping
|
||||
self.attention_bias = attention_bias
|
||||
|
||||
super().__init__(
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
pad_token_id=pad_token_id,
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Validate the correctness of rotary position embeddings parameters
|
||||
# Must be done after super().__init__() to avoid being overridden by kwargs
|
||||
self.rope_parameters = rope_parameters
|
||||
rope_theta = kwargs.get("rope_theta", 10000.0)
|
||||
standardize_rope_params(self, rope_theta=rope_theta)
|
||||
rope_config_validation(self)
|
||||
|
||||
|
||||
__all__ = ["NanoChatConfig"]
|
||||
@ -1,313 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import gc
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
from transformers import AutoTokenizer, NanoChatConfig, NanoChatForCausalLM
|
||||
|
||||
|
||||
def infer_kv_heads(config: NanoChatConfig, state_dict: dict[str, torch.Tensor]) -> int:
|
||||
key_weight = state_dict.get("transformer.h.0.attn.c_k.weight")
|
||||
if key_weight is None:
|
||||
return config.num_key_value_heads
|
||||
rows = key_weight.shape[0]
|
||||
head_dim = config.hidden_size // config.num_attention_heads
|
||||
if rows % head_dim != 0:
|
||||
return config.num_key_value_heads
|
||||
inferred = rows // head_dim
|
||||
print(f"Inferred {inferred} key_value heads from checkpoint")
|
||||
return max(inferred, 1)
|
||||
|
||||
|
||||
def convert_layer(old_prefix: str, new_prefix: str) -> dict[str, str]:
|
||||
return {
|
||||
f"{old_prefix}.attn.c_q.weight": f"{new_prefix}.self_attn.q_proj.weight",
|
||||
f"{old_prefix}.attn.c_k.weight": f"{new_prefix}.self_attn.k_proj.weight",
|
||||
f"{old_prefix}.attn.c_v.weight": f"{new_prefix}.self_attn.v_proj.weight",
|
||||
f"{old_prefix}.attn.c_proj.weight": f"{new_prefix}.self_attn.o_proj.weight",
|
||||
f"{old_prefix}.mlp.c_fc.weight": f"{new_prefix}.mlp.fc1.weight",
|
||||
f"{old_prefix}.mlp.c_proj.weight": f"{new_prefix}.mlp.fc2.weight",
|
||||
}
|
||||
|
||||
|
||||
def load_config_from_checkpoint(input_path: Path) -> NanoChatConfig:
|
||||
"""Load config from either meta_*.json or config.json in the checkpoint directory."""
|
||||
# Try to find meta_*.json first
|
||||
meta_files = list(input_path.glob("meta_*.json"))
|
||||
|
||||
if meta_files:
|
||||
meta_file = meta_files[0]
|
||||
print(f"Loading config from {meta_file.name}")
|
||||
with open(meta_file, "r") as f:
|
||||
meta_config = json.load(f)
|
||||
|
||||
# Extract model config from meta file
|
||||
if "model_config" in meta_config:
|
||||
model_config = meta_config["model_config"]
|
||||
else:
|
||||
model_config = meta_config
|
||||
|
||||
# Map to NanoChat config parameters
|
||||
config_kwargs = {
|
||||
"vocab_size": model_config.get("vocab_size", 50304),
|
||||
"hidden_size": model_config.get("n_embd", 768),
|
||||
"num_hidden_layers": model_config.get("n_layer", 12),
|
||||
"num_attention_heads": model_config.get("n_head", 6),
|
||||
"num_key_value_heads": model_config.get("n_kv_head"),
|
||||
"max_position_embeddings": model_config.get("sequence_len", 2048),
|
||||
}
|
||||
|
||||
# Try to load existing config.json for additional parameters
|
||||
config_file = input_path / "config.json"
|
||||
if config_file.exists():
|
||||
print("Loading additional config from config.json")
|
||||
with open(config_file, "r") as f:
|
||||
extra_config = json.load(f)
|
||||
|
||||
# Add additional parameters from config.json
|
||||
for key in [
|
||||
"hidden_act",
|
||||
"attention_dropout",
|
||||
"rms_norm_eps",
|
||||
"initializer_range",
|
||||
"logits_soft_cap",
|
||||
"attention_bias",
|
||||
"intermediate_size",
|
||||
"bos_token_id",
|
||||
"eos_token_id",
|
||||
"pad_token_id",
|
||||
]:
|
||||
if key in extra_config:
|
||||
config_kwargs[key] = extra_config[key]
|
||||
# Handle legacy qkv_bias -> attention_bias conversion
|
||||
elif key == "attention_bias" and "qkv_bias" in extra_config:
|
||||
config_kwargs[key] = extra_config["qkv_bias"]
|
||||
|
||||
# Handle rope_theta as a direct kwarg for the rope_parameters processing
|
||||
if "rope_theta" in extra_config:
|
||||
config_kwargs["rope_theta"] = extra_config["rope_theta"]
|
||||
|
||||
# Handle rope_parameters or rope_scaling if present
|
||||
if "rope_parameters" in extra_config:
|
||||
config_kwargs["rope_parameters"] = extra_config["rope_parameters"]
|
||||
elif "rope_scaling" in extra_config and extra_config["rope_scaling"] is not None:
|
||||
config_kwargs["rope_parameters"] = extra_config["rope_scaling"]
|
||||
|
||||
config = NanoChatConfig(**config_kwargs)
|
||||
else:
|
||||
# Fallback to loading from config.json if it exists
|
||||
config_file = input_path / "config.json"
|
||||
if config_file.exists():
|
||||
print("Loading config from config.json")
|
||||
config = NanoChatConfig.from_pretrained(input_path)
|
||||
# Handle legacy qkv_bias -> attention_bias conversion
|
||||
if hasattr(config, "qkv_bias") and not hasattr(config, "attention_bias"):
|
||||
config.attention_bias = config.qkv_bias
|
||||
else:
|
||||
raise ValueError(f"No config file found in {input_path}. Expected meta_*.json or config.json")
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def write_model(input_dir, output_dir, safe_serialization=True):
|
||||
"""Convert NanoChat model from original checkpoint format to HuggingFace format."""
|
||||
print("Converting the model.")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
input_path = Path(input_dir)
|
||||
|
||||
# Load config
|
||||
config = load_config_from_checkpoint(input_path)
|
||||
print(f"Loaded config hidden_size={config.hidden_size} num_layers={config.num_hidden_layers}")
|
||||
|
||||
# Load checkpoint - try model_*.pt first, then pytorch_model.bin
|
||||
checkpoint_files = list(input_path.glob("model_*.pt"))
|
||||
if checkpoint_files:
|
||||
checkpoint_path = checkpoint_files[0]
|
||||
else:
|
||||
checkpoint_path = input_path / "pytorch_model.bin"
|
||||
|
||||
print(f"Fetching all parameters from the checkpoint at {checkpoint_path}...")
|
||||
old_state = torch.load(checkpoint_path, map_location="cpu")
|
||||
|
||||
# Original nanochat weights are in bfloat16
|
||||
for key in old_state:
|
||||
if old_state[key].dtype == torch.float32:
|
||||
old_state[key] = old_state[key].to(torch.bfloat16)
|
||||
|
||||
# Infer key-value heads from checkpoint
|
||||
inferred_kv = infer_kv_heads(config, old_state)
|
||||
config.num_key_value_heads = inferred_kv
|
||||
if config.num_attention_heads % config.num_key_value_heads != 0:
|
||||
print(f"Adjusting num_attention_heads from {config.num_attention_heads} to {config.num_key_value_heads}")
|
||||
config.num_attention_heads = config.num_key_value_heads
|
||||
|
||||
print("Converting model...")
|
||||
state_dict = {}
|
||||
rename_map = {}
|
||||
|
||||
def assign(
|
||||
old_key: str,
|
||||
new_key: str,
|
||||
old_state: dict[str, torch.Tensor],
|
||||
state_dict: dict[str, torch.Tensor],
|
||||
rename_map: dict[str, str],
|
||||
) -> None:
|
||||
tensor = old_state.get(old_key)
|
||||
if tensor is None:
|
||||
return
|
||||
state_dict[new_key] = tensor.clone()
|
||||
rename_map[old_key] = new_key
|
||||
|
||||
# Convert embeddings and head
|
||||
assign("transformer.wte.weight", "model.embed_tokens.weight", old_state, state_dict, rename_map)
|
||||
assign("lm_head.weight", "lm_head.weight", old_state, state_dict, rename_map)
|
||||
|
||||
# Convert layers
|
||||
for layer_idx in range(config.num_hidden_layers):
|
||||
old_prefix = f"transformer.h.{layer_idx}"
|
||||
new_prefix = f"model.layers.{layer_idx}"
|
||||
mapping = convert_layer(old_prefix, new_prefix)
|
||||
for old_key, new_key in mapping.items():
|
||||
assign(old_key, new_key, old_state, state_dict, rename_map)
|
||||
|
||||
missing = [key for key in old_state.keys() if key not in rename_map]
|
||||
if missing:
|
||||
print(f"Skipped {len(missing)} legacy entries that have no equivalent in the shared implementation")
|
||||
|
||||
del old_state
|
||||
gc.collect()
|
||||
|
||||
# Update config
|
||||
config.torch_dtype = torch.bfloat16
|
||||
config.tie_word_embeddings = False
|
||||
|
||||
# Load the checkpoint into the model
|
||||
print("Loading the checkpoint in a NanoChat model.")
|
||||
with torch.device("meta"):
|
||||
model = NanoChatForCausalLM(config)
|
||||
model.load_state_dict(state_dict, strict=True, assign=True)
|
||||
print("Checkpoint loaded successfully.")
|
||||
|
||||
if hasattr(model.config, "_name_or_path"):
|
||||
del model.config._name_or_path
|
||||
|
||||
print("Saving the model.")
|
||||
model.save_pretrained(output_dir, safe_serialization=safe_serialization)
|
||||
del state_dict, model
|
||||
|
||||
# Safety check: reload the converted model
|
||||
gc.collect()
|
||||
print("Reloading the model to check if it's saved correctly.")
|
||||
NanoChatForCausalLM.from_pretrained(output_dir, torch_dtype=torch.bfloat16, device_map="auto")
|
||||
print("Model reloaded successfully.")
|
||||
|
||||
|
||||
def write_tokenizer(input_dir, output_dir):
|
||||
"""Convert and save the tokenizer."""
|
||||
input_path = Path(input_dir)
|
||||
|
||||
# Convert the pickle tokenizer to HF format
|
||||
tokenizer_pkl = input_path / "tokenizer.pkl"
|
||||
if tokenizer_pkl.exists():
|
||||
try:
|
||||
import pickle
|
||||
|
||||
from transformers.integrations.tiktoken import convert_tiktoken_to_fast
|
||||
|
||||
with open(tokenizer_pkl, "rb") as f:
|
||||
tok_pkl = pickle.load(f)
|
||||
convert_tiktoken_to_fast(tok_pkl, output_dir)
|
||||
print("Converted tokenizer.pkl to HuggingFace format")
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to convert tokenizer.pkl: {e}")
|
||||
# Fallback: copy tokenizer files if they exist
|
||||
for filename in ("tokenizer.json", "tokenizer_config.json"):
|
||||
src = input_path / filename
|
||||
if src.exists():
|
||||
(Path(output_dir) / filename).write_bytes(src.read_bytes())
|
||||
else:
|
||||
# No pickle tokenizer, copy JSON files
|
||||
for filename in ("tokenizer.json", "tokenizer_config.json", "special_tokens_map.json"):
|
||||
src = input_path / filename
|
||||
if src.exists():
|
||||
(Path(output_dir) / filename).write_bytes(src.read_bytes())
|
||||
|
||||
print("Tokenizer saved successfully.")
|
||||
|
||||
|
||||
def run_test(output_dir: str, prompt: str, max_new_tokens: int = 64) -> None:
|
||||
"""Run a quick generation test to verify the converted model works correctly."""
|
||||
print(f"Running quick generation test with prompt: {prompt}")
|
||||
tokenizer = AutoTokenizer.from_pretrained(output_dir)
|
||||
model = NanoChatForCausalLM.from_pretrained(output_dir, torch_dtype=torch.bfloat16)
|
||||
model.eval()
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
model = model.to(device)
|
||||
inputs = tokenizer(prompt, return_tensors="pt").to(device)
|
||||
with torch.no_grad():
|
||||
output = model.generate(**inputs, max_new_tokens=max_new_tokens)
|
||||
generated = tokenizer.decode(output[0, inputs.input_ids.shape[1] :], skip_special_tokens=True)
|
||||
print(f"Generated text: {generated}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Convert NanoChat checkpoints to HuggingFace format")
|
||||
parser.add_argument(
|
||||
"--input_dir",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the original checkpoint directory",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Location to write HF model and tokenizer",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--safe_serialization",
|
||||
action="store_true",
|
||||
default=True,
|
||||
help="Whether or not to save using `safetensors`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--test_prompt",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Optional prompt for a quick generation test",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
write_model(
|
||||
args.input_dir,
|
||||
args.output_dir,
|
||||
safe_serialization=args.safe_serialization,
|
||||
)
|
||||
|
||||
write_tokenizer(args.input_dir, args.output_dir)
|
||||
|
||||
if args.test_prompt:
|
||||
run_test(args.output_dir, args.test_prompt)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -1,533 +0,0 @@
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# This file was automatically generated from src/transformers/models/nanochat/modular_nanochat.py.
|
||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||
# the file from the modular. If any change should be done, please apply the change to the
|
||||
# modular_nanochat.py file directly. One of our CI enforces this.
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from collections.abc import Callable
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...masking_utils import create_causal_mask
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
|
||||
from ...utils.generic import check_model_inputs
|
||||
from .configuration_nanochat import NanoChatConfig
|
||||
|
||||
|
||||
class NanoChatRMSNorm(torch.nn.Module):
|
||||
def __init__(self, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
|
||||
def _norm(self, x):
|
||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
|
||||
def forward(self, x):
|
||||
return self._norm(x.float()).type_as(x)
|
||||
|
||||
def extra_repr(self):
|
||||
return f"eps={self.eps}"
|
||||
|
||||
|
||||
class NanoChatRotaryEmbedding(nn.Module):
|
||||
inv_freq: torch.Tensor # fix linting for `register_buffer`
|
||||
|
||||
def __init__(self, config: NanoChatConfig, device=None):
|
||||
super().__init__()
|
||||
self.max_seq_len_cached = config.max_position_embeddings
|
||||
self.original_max_seq_len = config.max_position_embeddings
|
||||
|
||||
self.config = config
|
||||
|
||||
self.rope_type = self.config.rope_parameters["rope_type"]
|
||||
rope_init_fn: Callable = self.compute_default_rope_parameters
|
||||
if self.rope_type != "default":
|
||||
rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
||||
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
||||
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
self.original_inv_freq = inv_freq
|
||||
|
||||
@staticmethod
|
||||
def compute_default_rope_parameters(
|
||||
config: Optional[NanoChatConfig] = None,
|
||||
device: Optional["torch.device"] = None,
|
||||
seq_len: Optional[int] = None,
|
||||
) -> tuple["torch.Tensor", float]:
|
||||
"""
|
||||
Computes the inverse frequencies according to the original RoPE implementation
|
||||
Args:
|
||||
config ([`~transformers.PreTrainedConfig`]):
|
||||
The model configuration.
|
||||
device (`torch.device`):
|
||||
The device to use for initialization of the inverse frequencies.
|
||||
seq_len (`int`, *optional*):
|
||||
The current sequence length. Unused for this type of RoPE.
|
||||
Returns:
|
||||
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
|
||||
post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
|
||||
"""
|
||||
base = config.rope_parameters["rope_theta"]
|
||||
dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
|
||||
|
||||
attention_factor = 1.0 # Unused in this type of RoPE
|
||||
|
||||
# Compute the inverse frequencies
|
||||
inv_freq = 1.0 / (
|
||||
base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
|
||||
)
|
||||
return inv_freq, attention_factor
|
||||
|
||||
@torch.no_grad()
|
||||
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
|
||||
def forward(self, x, position_ids):
|
||||
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
|
||||
position_ids_expanded = position_ids[:, None, :].float()
|
||||
|
||||
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False): # Force float32
|
||||
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos() * self.attention_scaling
|
||||
sin = emb.sin() * self.attention_scaling
|
||||
|
||||
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
||||
|
||||
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||
"""Applies Rotary Position Embedding to the query and key tensors.
|
||||
|
||||
Args:
|
||||
q (`torch.Tensor`): The query tensor.
|
||||
k (`torch.Tensor`): The key tensor.
|
||||
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
||||
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
||||
position_ids (`torch.Tensor`, *optional*):
|
||||
Deprecated and unused.
|
||||
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
||||
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
||||
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
||||
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
||||
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
||||
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
||||
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
||||
Returns:
|
||||
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
||||
"""
|
||||
cos = cos.unsqueeze(unsqueeze_dim)
|
||||
sin = sin.unsqueeze(unsqueeze_dim)
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
"""
|
||||
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
||||
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
||||
"""
|
||||
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
||||
if n_rep == 1:
|
||||
return hidden_states
|
||||
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
||||
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||
|
||||
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
scaling: float,
|
||||
dropout: float = 0.0,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
):
|
||||
key_states = repeat_kv(key, module.num_key_value_groups)
|
||||
value_states = repeat_kv(value, module.num_key_value_groups)
|
||||
|
||||
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input with flipped signs for NanoChat."""
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((x2, -x1), dim=-1)
|
||||
|
||||
|
||||
class NanoChatAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
def __init__(self, config: NanoChatConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer_idx = layer_idx
|
||||
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
||||
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.attention_dropout = config.attention_dropout
|
||||
self.is_causal = True
|
||||
|
||||
self.q_proj = nn.Linear(
|
||||
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
|
||||
)
|
||||
self.k_proj = nn.Linear(
|
||||
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
|
||||
)
|
||||
self.v_proj = nn.Linear(
|
||||
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
|
||||
)
|
||||
self.o_proj = nn.Linear(
|
||||
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
|
||||
)
|
||||
|
||||
self.q_norm = NanoChatRMSNorm(eps=config.rms_norm_eps)
|
||||
self.k_norm = NanoChatRMSNorm(eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
past_key_values: Optional[Cache] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, self.head_dim)
|
||||
|
||||
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
# RoPE -> Norm (instead of usual Norm -> RoPE)
|
||||
query_states = self.q_norm(query_states)
|
||||
key_states = self.k_norm(key_states)
|
||||
|
||||
if past_key_values is not None:
|
||||
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
scaling=self.scaling,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
||||
attn_output = self.o_proj(attn_output)
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class NanoChatMLP(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.activation_fn = ACT2FN[config.hidden_act]
|
||||
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
|
||||
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.fc1(hidden_states)
|
||||
hidden_states = self.activation_fn(hidden_states)
|
||||
hidden_states = self.fc2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class NanoChatDecoderLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config: NanoChatConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
|
||||
self.self_attn = NanoChatAttention(config=config, layer_idx=layer_idx)
|
||||
|
||||
self.mlp = NanoChatMLP(config)
|
||||
|
||||
self.input_layernorm = NanoChatRMSNorm(eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = NanoChatRMSNorm(eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Cache] = None,
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
# Self Attention
|
||||
hidden_states, _ = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
return hidden_states
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class NanoChatPreTrainedModel(PreTrainedModel):
|
||||
config: NanoChatConfig
|
||||
base_model_prefix = "model"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["NanoChatDecoderLayer"]
|
||||
_skip_keys_device_placement = ["past_key_values"]
|
||||
_supports_flash_attn = True
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
|
||||
_can_compile_fullgraph = True
|
||||
_supports_attention_backend = True
|
||||
_can_record_outputs = {
|
||||
"hidden_states": NanoChatDecoderLayer,
|
||||
"attentions": NanoChatAttention,
|
||||
}
|
||||
|
||||
def _init_weights(self, module: nn.Module) -> None:
|
||||
super()._init_weights(module)
|
||||
|
||||
for name, param in module.named_parameters():
|
||||
if name == "o_proj.weight":
|
||||
nn.init.normal_(
|
||||
param,
|
||||
mean=0.0,
|
||||
std=self.config.initializer_range / math.sqrt(2 * self.config.num_hidden_layers),
|
||||
)
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class NanoChatModel(NanoChatPreTrainedModel):
|
||||
def __init__(self, config: NanoChatConfig):
|
||||
super().__init__(config)
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
||||
self.layers = nn.ModuleList(
|
||||
[NanoChatDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
||||
)
|
||||
self.norm = NanoChatRMSNorm(eps=config.rms_norm_eps)
|
||||
self.rotary_emb = NanoChatRotaryEmbedding(config=config)
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
self.initial_norm = NanoChatRMSNorm(eps=config.rms_norm_eps)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
@check_model_inputs()
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Cache] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> BaseModelOutputWithPast:
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
|
||||
|
||||
if use_cache and past_key_values is None:
|
||||
past_key_values = DynamicCache(config=self.config)
|
||||
|
||||
if cache_position is None:
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
cache_position: torch.Tensor = torch.arange(
|
||||
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
||||
)
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
causal_mask = create_causal_mask(
|
||||
config=self.config,
|
||||
input_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
cache_position=cache_position,
|
||||
past_key_values=past_key_values,
|
||||
position_ids=position_ids,
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
|
||||
|
||||
hidden_states = self.initial_norm(hidden_states) # Additional norm before the layers
|
||||
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
||||
hidden_states = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=causal_mask,
|
||||
position_embeddings=position_embeddings,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class NanoChatForCausalLM(NanoChatPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tp_plan = {"lm_head": "colwise_rep"}
|
||||
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.model = NanoChatModel(config)
|
||||
self.vocab_size = config.vocab_size
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Cache] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> CausalLMOutputWithPast:
|
||||
r"""
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
|
||||
>>> model = AutoModelForCausalLM.from_pretrained("karpathy/nanochat-d32")
|
||||
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("karpathy/nanochat-d32")
|
||||
|
||||
>>> conversation = [
|
||||
{"role": "user", "content": "What is the capital of France?"},
|
||||
]
|
||||
|
||||
>>> inputs = tokenizer.apply_chat_template(
|
||||
conversation, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt"
|
||||
).to(device)
|
||||
|
||||
>>> with torch.no_grad():
|
||||
>>> outputs = model.generate(**inputs, max_new_tokens=64, do_sample=False)
|
||||
|
||||
>>> generated_tokens = outputs[0, inputs["input_ids"].shape[1] :]
|
||||
>>> output = tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
||||
```"""
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs: BaseModelOutputWithPast = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs.last_hidden_state
|
||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
||||
if self.config.final_logit_softcapping is not None:
|
||||
logits = logits / self.config.final_logit_softcapping
|
||||
logits = torch.tanh(logits)
|
||||
logits = logits * self.config.final_logit_softcapping
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["NanoChatPreTrainedModel", "NanoChatModel", "NanoChatForCausalLM"]
|
||||
@ -1,247 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from collections.abc import Callable
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ...cache_utils import Cache, DynamicCache
|
||||
from ...masking_utils import create_causal_mask
|
||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import TransformersKwargs, auto_docstring
|
||||
from ..clip.modeling_clip import CLIPMLP
|
||||
from ..gemma2.modeling_gemma2 import Gemma2ForCausalLM
|
||||
from ..llama.modeling_llama import (
|
||||
LlamaDecoderLayer,
|
||||
LlamaModel,
|
||||
LlamaPreTrainedModel,
|
||||
LlamaRotaryEmbedding,
|
||||
apply_rotary_pos_emb,
|
||||
eager_attention_forward,
|
||||
)
|
||||
from ..llama4.modeling_llama4 import Llama4TextL2Norm
|
||||
from ..qwen3.modeling_qwen3 import Qwen3Attention
|
||||
from .configuration_nanochat import NanoChatConfig
|
||||
|
||||
|
||||
class NanoChatRMSNorm(Llama4TextL2Norm):
|
||||
pass
|
||||
|
||||
|
||||
class NanoChatRotaryEmbedding(LlamaRotaryEmbedding):
|
||||
pass
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input with flipped signs for NanoChat."""
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((x2, -x1), dim=-1)
|
||||
|
||||
|
||||
class NanoChatAttention(Qwen3Attention):
|
||||
def __init__(self, config: NanoChatConfig, layer_idx: int):
|
||||
super().__init__(config, layer_idx)
|
||||
del self.sliding_window
|
||||
del self.layer_type
|
||||
|
||||
self.q_norm = NanoChatRMSNorm(eps=config.rms_norm_eps)
|
||||
self.k_norm = NanoChatRMSNorm(eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
past_key_values: Optional[Cache] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, self.head_dim)
|
||||
|
||||
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
# RoPE -> Norm (instead of usual Norm -> RoPE)
|
||||
query_states = self.q_norm(query_states)
|
||||
key_states = self.k_norm(key_states)
|
||||
|
||||
if past_key_values is not None:
|
||||
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
scaling=self.scaling,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
||||
attn_output = self.o_proj(attn_output)
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class NanoChatMLP(CLIPMLP):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
|
||||
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
|
||||
|
||||
|
||||
class NanoChatDecoderLayer(LlamaDecoderLayer):
|
||||
def __init__(self, config: NanoChatConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
|
||||
self.input_layernorm = NanoChatRMSNorm(eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = NanoChatRMSNorm(eps=config.rms_norm_eps)
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class NanoChatPreTrainedModel(LlamaPreTrainedModel):
|
||||
def _init_weights(self, module: nn.Module) -> None:
|
||||
PreTrainedModel._init_weights(module)
|
||||
|
||||
for name, param in module.named_parameters():
|
||||
if name == "o_proj.weight":
|
||||
nn.init.normal_(
|
||||
param,
|
||||
mean=0.0,
|
||||
std=self.config.initializer_range / math.sqrt(2 * self.config.num_hidden_layers),
|
||||
)
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class NanoChatModel(LlamaModel):
|
||||
def __init__(self, config: NanoChatConfig):
|
||||
super().__init__(config)
|
||||
|
||||
self.initial_norm = NanoChatRMSNorm(eps=config.rms_norm_eps)
|
||||
self.norm = NanoChatRMSNorm(eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Cache] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> BaseModelOutputWithPast:
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
|
||||
|
||||
if use_cache and past_key_values is None:
|
||||
past_key_values = DynamicCache(config=self.config)
|
||||
|
||||
if cache_position is None:
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
cache_position: torch.Tensor = torch.arange(
|
||||
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
||||
)
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
causal_mask = create_causal_mask(
|
||||
config=self.config,
|
||||
input_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
cache_position=cache_position,
|
||||
past_key_values=past_key_values,
|
||||
position_ids=position_ids,
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
|
||||
|
||||
hidden_states = self.initial_norm(hidden_states) # Additional norm before the layers
|
||||
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
||||
hidden_states = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=causal_mask,
|
||||
position_embeddings=position_embeddings,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class NanoChatForCausalLM(Gemma2ForCausalLM):
|
||||
def forward(self, **super_kwargs) -> CausalLMOutputWithPast:
|
||||
r"""
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
|
||||
>>> model = AutoModelForCausalLM.from_pretrained("karpathy/nanochat-d32")
|
||||
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("karpathy/nanochat-d32")
|
||||
|
||||
>>> conversation = [
|
||||
{"role": "user", "content": "What is the capital of France?"},
|
||||
]
|
||||
|
||||
>>> inputs = tokenizer.apply_chat_template(
|
||||
conversation, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt"
|
||||
).to(device)
|
||||
|
||||
>>> with torch.no_grad():
|
||||
>>> outputs = model.generate(**inputs, max_new_tokens=64, do_sample=False)
|
||||
|
||||
>>> generated_tokens = outputs[0, inputs["input_ids"].shape[1] :]
|
||||
>>> output = tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
||||
```"""
|
||||
super().forward(**super_kwargs)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"NanoChatPreTrainedModel",
|
||||
"NanoChatModel",
|
||||
"NanoChatForCausalLM",
|
||||
]
|
||||
@ -35,11 +35,14 @@ from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
|
||||
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
|
||||
from ...utils.generic import check_model_inputs
|
||||
from .configuration_vaultgemma import VaultGemmaConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class VaultGemmaRMSNorm(nn.Module):
|
||||
def __init__(self, dim: int, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
@ -251,18 +254,22 @@ class VaultGemmaDecoderLayer(GradientCheckpointingLayer):
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Cache] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
# Self Attention
|
||||
hidden_states, _ = self.self_attn(
|
||||
hidden_states, self_attn_weights = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
position_embeddings=position_embeddings,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
@ -273,7 +280,11 @@ class VaultGemmaDecoderLayer(GradientCheckpointingLayer):
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
return hidden_states
|
||||
outputs = (hidden_states,)
|
||||
if output_attentions:
|
||||
outputs += (self_attn_weights,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class VaultGemmaRotaryEmbedding(nn.Module):
|
||||
@ -395,16 +406,30 @@ class VaultGemmaModel(VaultGemmaPreTrainedModel):
|
||||
past_key_values: Optional[Cache] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> BaseModelOutputWithPast:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
|
||||
if self.gradient_checkpointing and self.training and use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
if use_cache and past_key_values is None:
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
if use_cache and past_key_values is None and not self.training:
|
||||
past_key_values = DynamicCache(config=self.config)
|
||||
|
||||
if cache_position is None:
|
||||
@ -443,22 +468,41 @@ class VaultGemmaModel(VaultGemmaPreTrainedModel):
|
||||
normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
|
||||
hidden_states = hidden_states * normalizer
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
|
||||
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
||||
hidden_states = decoder_layer(
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
|
||||
position_embeddings=position_embeddings,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=past_key_values,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
|
||||
@ -488,9 +532,11 @@ class VaultGemmaForCausalLM(VaultGemmaPreTrainedModel, GenerationMixin):
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
**kwargs,
|
||||
) -> CausalLMOutputWithPast:
|
||||
r"""
|
||||
Example:
|
||||
@ -509,6 +555,11 @@ class VaultGemmaForCausalLM(VaultGemmaPreTrainedModel, GenerationMixin):
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"What is your favorite condiment?"
|
||||
```"""
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs: BaseModelOutputWithPast = self.model(
|
||||
input_ids=input_ids,
|
||||
@ -517,6 +568,8 @@ class VaultGemmaForCausalLM(VaultGemmaPreTrainedModel, GenerationMixin):
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@ -191,18 +191,22 @@ class VaultGemmaDecoderLayer(Gemma2DecoderLayer):
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Cache] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
# Self Attention
|
||||
hidden_states, _ = self.self_attn(
|
||||
hidden_states, self_attn_weights = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
position_embeddings=position_embeddings,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
@ -213,7 +217,11 @@ class VaultGemmaDecoderLayer(Gemma2DecoderLayer):
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
return hidden_states
|
||||
outputs = (hidden_states,)
|
||||
if output_attentions:
|
||||
outputs += (self_attn_weights,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class VaultGemmaForCausalLM(Gemma2ForCausalLM):
|
||||
|
||||
@ -47,7 +47,8 @@ PACKAGE_DISTRIBUTION_MAPPING = importlib.metadata.packages_distributions()
|
||||
def _is_package_available(pkg_name: str, return_version: bool = False) -> tuple[bool, str] | bool:
|
||||
"""Check if `pkg_name` exist, and optionally try to get its version"""
|
||||
spec = importlib.util.find_spec(pkg_name)
|
||||
package_exists = spec is not None
|
||||
# the spec might be not None but not importable
|
||||
package_exists = spec is not None and spec.loader is not None
|
||||
package_version = "N/A"
|
||||
if package_exists and return_version:
|
||||
try:
|
||||
|
||||
@ -1,233 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Testing suite for the PyTorch NanoChat model."""
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import AutoTokenizer, NanoChatConfig, is_torch_available
|
||||
from transformers.testing_utils import (
|
||||
cleanup,
|
||||
require_torch,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import (
|
||||
NanoChatForCausalLM,
|
||||
NanoChatModel,
|
||||
)
|
||||
|
||||
|
||||
from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester
|
||||
|
||||
|
||||
class NanoChatModelTester(CausalLMModelTester):
|
||||
config_class = NanoChatConfig
|
||||
if is_torch_available():
|
||||
base_model_class = NanoChatModel
|
||||
causal_lm_class = NanoChatForCausalLM
|
||||
|
||||
|
||||
@require_torch
|
||||
class NanoChatModelTest(CausalLMModelTest, unittest.TestCase):
|
||||
model_tester_class = NanoChatModelTester
|
||||
|
||||
|
||||
@require_torch
|
||||
class NanoChatIntegrationTest(unittest.TestCase):
|
||||
"""Integration tests for NanoChat models using real checkpoints."""
|
||||
|
||||
def setUp(self):
|
||||
cleanup(torch_device, gc_collect=True)
|
||||
|
||||
def tearDown(self):
|
||||
cleanup(torch_device, gc_collect=True)
|
||||
|
||||
@slow
|
||||
def test_model_d20_logits(self):
|
||||
"""Test that d20 model logits are computed correctly."""
|
||||
model_id = "nanochat-students/nanochat-d20"
|
||||
model = NanoChatForCausalLM.from_pretrained(model_id, device_map="auto", torch_dtype=torch.bfloat16)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
|
||||
# Simple test input - "Hello world"
|
||||
test_text = "Hello world"
|
||||
input_ids = tokenizer.encode(test_text, return_tensors="pt").to(model.device)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(input_ids)
|
||||
logits = outputs.logits.float().cpu()
|
||||
|
||||
# Basic shape checks
|
||||
self.assertEqual(logits.shape[0], 1) # batch size
|
||||
self.assertEqual(logits.shape[1], input_ids.shape[1]) # sequence length
|
||||
self.assertEqual(logits.shape[2], model.config.vocab_size) # vocab size 65536
|
||||
|
||||
# Check logits are not NaN or Inf
|
||||
self.assertFalse(torch.isnan(logits).any())
|
||||
self.assertFalse(torch.isinf(logits).any())
|
||||
|
||||
# Check expected mean logits (with tolerance for numerical variation)
|
||||
EXPECTED_MEAN = torch.tensor([[-6.6607, -7.8095]])
|
||||
|
||||
# Check first 10 logits at position [0,0,:10]
|
||||
EXPECTED_SLICE = torch.tensor(
|
||||
[-12.8750, -13.0625, -13.1875, -13.1875, -13.1875, -13.1875, -13.1875, -13.1875, -12.6250, -4.4062]
|
||||
)
|
||||
|
||||
torch.testing.assert_close(logits.mean(-1), EXPECTED_MEAN, rtol=1e-3, atol=1e-3)
|
||||
torch.testing.assert_close(logits[0, 0, :10], EXPECTED_SLICE, rtol=1e-3, atol=1e-3)
|
||||
|
||||
@slow
|
||||
def test_model_d20_generation(self):
|
||||
"""Test that d20 model generates text correctly."""
|
||||
model_id = "nanochat-students/nanochat-d20"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
model = NanoChatForCausalLM.from_pretrained(model_id, device_map="auto", torch_dtype=torch.bfloat16)
|
||||
|
||||
# Test generation with chat template
|
||||
conversation = [
|
||||
[
|
||||
{"role": "user", "content": "What is the capital of France?"},
|
||||
],
|
||||
[
|
||||
{"role": "user", "content": "Tell me something."},
|
||||
],
|
||||
]
|
||||
|
||||
inputs = tokenizer.apply_chat_template(
|
||||
conversation,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
padding=True,
|
||||
tokenizer_kwargs={"padding_side": "left"},
|
||||
return_tensors="pt",
|
||||
).to(model.device)
|
||||
|
||||
# Generate with greedy decoding for reproducibility
|
||||
with torch.no_grad():
|
||||
generated_ids = model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=32,
|
||||
do_sample=False,
|
||||
)
|
||||
|
||||
# Decode only the generated tokens
|
||||
generated_text = [
|
||||
tokenizer.decode(generated_ids[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True),
|
||||
tokenizer.decode(generated_ids[1, inputs["input_ids"].shape[1] :], skip_special_tokens=True),
|
||||
]
|
||||
|
||||
EXPECTED_TEXT_COMPLETION = [
|
||||
"The capital of France is Paris.",
|
||||
"I'm ready to help. What's the first thing you'd like to know or discuss?",
|
||||
]
|
||||
|
||||
self.assertEqual(EXPECTED_TEXT_COMPLETION[0], generated_text[0])
|
||||
self.assertEqual(EXPECTED_TEXT_COMPLETION[1], generated_text[1])
|
||||
|
||||
@slow
|
||||
def test_model_d32_logits(self):
|
||||
"""Test that d32 model logits are computed correctly."""
|
||||
model_id = "karpathy/nanochat-d32"
|
||||
revision = "refs/pr/1" # TODO: update when merged to hub
|
||||
model = NanoChatForCausalLM.from_pretrained(
|
||||
model_id, device_map="auto", torch_dtype=torch.bfloat16, revision=revision
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision)
|
||||
|
||||
# Simple test input - "Hello world"
|
||||
test_text = "Hello world"
|
||||
input_ids = tokenizer.encode(test_text, return_tensors="pt").to(model.device)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(input_ids)
|
||||
logits = outputs.logits.float().cpu()
|
||||
|
||||
# Basic shape checks
|
||||
self.assertEqual(logits.shape[0], 1) # batch size
|
||||
self.assertEqual(logits.shape[1], input_ids.shape[1]) # sequence length
|
||||
self.assertEqual(logits.shape[2], model.config.vocab_size) # vocab size 65536
|
||||
|
||||
# Check logits are not NaN or Inf
|
||||
self.assertFalse(torch.isnan(logits).any())
|
||||
self.assertFalse(torch.isinf(logits).any())
|
||||
|
||||
# Check expected mean logits (with tolerance for numerical variation)
|
||||
EXPECTED_MEAN = torch.tensor([[-5.5791, -8.3456]])
|
||||
|
||||
# Check first 10 logits at position [0,0,:10]
|
||||
EXPECTED_SLICE = torch.tensor(
|
||||
[-12.3125, -13.1250, -12.8125, -13.1250, -13.1250, -13.1250, -13.1250, -13.1250, -11.8125, -1.4688]
|
||||
)
|
||||
|
||||
torch.testing.assert_close(logits.mean(-1), EXPECTED_MEAN, rtol=1e-3, atol=1e-3)
|
||||
torch.testing.assert_close(logits[0, 0, :10], EXPECTED_SLICE, rtol=1e-3, atol=1e-3)
|
||||
|
||||
@slow
|
||||
def test_model_d32_generation(self):
|
||||
"""Test that d32 model generates text correctly."""
|
||||
model_id = "karpathy/nanochat-d32"
|
||||
revision = "refs/pr/1" # TODO: update when merged to hub
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision)
|
||||
model = NanoChatForCausalLM.from_pretrained(
|
||||
model_id, device_map="auto", torch_dtype=torch.bfloat16, revision=revision
|
||||
)
|
||||
|
||||
# Test generation with chat template
|
||||
conversation = [
|
||||
[
|
||||
{"role": "user", "content": "What is the capital of France?"},
|
||||
],
|
||||
[
|
||||
{"role": "user", "content": "Tell me something."},
|
||||
],
|
||||
]
|
||||
|
||||
inputs = tokenizer.apply_chat_template(
|
||||
conversation,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
padding=True,
|
||||
tokenizer_kwargs={"padding_side": "left"},
|
||||
return_tensors="pt",
|
||||
).to(model.device)
|
||||
|
||||
# Generate with greedy decoding for reproducibility
|
||||
with torch.no_grad():
|
||||
generated_ids = model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=32,
|
||||
do_sample=False,
|
||||
)
|
||||
|
||||
# Decode only the generated tokens
|
||||
generated_text = [
|
||||
tokenizer.decode(generated_ids[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True),
|
||||
tokenizer.decode(generated_ids[1, inputs["input_ids"].shape[1] :], skip_special_tokens=True),
|
||||
]
|
||||
|
||||
EXPECTED_TEXT_COMPLETION = [
|
||||
"The capital of France is Paris.",
|
||||
"I'm here to help you explore your creative writing endeavors. What's been on your mind lately? Do you have a story idea you'd like to develop,",
|
||||
]
|
||||
|
||||
self.assertEqual(EXPECTED_TEXT_COMPLETION[0], generated_text[0])
|
||||
self.assertEqual(EXPECTED_TEXT_COMPLETION[1], generated_text[1])
|
||||
224
tests/trainer/test_trainer_context_parallel_torch.py
Normal file
224
tests/trainer/test_trainer_context_parallel_torch.py
Normal file
@ -0,0 +1,224 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from transformers import is_torch_available
|
||||
from transformers.testing_utils import (
|
||||
TestCasePlus,
|
||||
execute_subprocess_async,
|
||||
require_accelerate,
|
||||
require_torch_multi_accelerator,
|
||||
run_first,
|
||||
slow,
|
||||
)
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
DataCollatorForLanguageModeling,
|
||||
HfArgumentParser,
|
||||
Trainer,
|
||||
TrainingArguments,
|
||||
)
|
||||
|
||||
|
||||
class TestTrainerContextParallelTorch(TestCasePlus):
|
||||
"""Test Trainer with Torch context parallelism enabled via accelerate's ParallelismConfig."""
|
||||
|
||||
@require_torch_multi_accelerator
|
||||
@require_accelerate
|
||||
@slow
|
||||
@run_first
|
||||
def test_cp_equivalence(self):
|
||||
"""Test that CP produces the same losses as without CP."""
|
||||
|
||||
# Shared setup
|
||||
world_size = 2
|
||||
script_path = __file__
|
||||
|
||||
# Step 1: Run with CP enabled (cp_size=world_size)
|
||||
cp_yes_output_dir = Path(self.get_auto_remove_tmp_dir()).resolve()
|
||||
cp_yes_config_path = cp_yes_output_dir / "context_parallel_config.yaml"
|
||||
cp_yes_losses_path = cp_yes_output_dir / "cp_yes_losses.json"
|
||||
|
||||
# Write config file inline (self-contained test)
|
||||
with open(cp_yes_config_path, "w") as f:
|
||||
f.write(
|
||||
f"""distributed_type: FSDP
|
||||
fsdp_config:
|
||||
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
fsdp_state_dict_type: SHARDED_STATE_DICT
|
||||
fsdp_version: 2
|
||||
mixed_precision: bf16
|
||||
num_processes: {world_size}
|
||||
parallelism_config:
|
||||
parallelism_config_dp_replicate_size: 1
|
||||
parallelism_config_dp_shard_size: 1
|
||||
parallelism_config_tp_size: 1
|
||||
parallelism_config_cp_size: {world_size}
|
||||
parallelism_config_cp_comm_strategy: alltoall
|
||||
"""
|
||||
)
|
||||
|
||||
cmd_cp_yes = f"""
|
||||
accelerate launch
|
||||
--config_file {cp_yes_config_path}
|
||||
{script_path}
|
||||
--output_dir {cp_yes_output_dir}
|
||||
--report_to none
|
||||
--max_steps 10
|
||||
--per_device_train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--logging_steps 1
|
||||
--remove_unused_columns False
|
||||
--seed 42
|
||||
--loss_output_file {cp_yes_losses_path}
|
||||
""".split()
|
||||
|
||||
execute_subprocess_async(cmd_cp_yes, env=self.get_env())
|
||||
|
||||
# Step 2: Run without CP (FSDP with num_processes=1, no parallelism_config)
|
||||
cp_no_output_dir = Path(self.get_auto_remove_tmp_dir()).resolve()
|
||||
cp_no_config_path = cp_no_output_dir / "context_parallel_config.yaml"
|
||||
cp_no_losses_path = cp_no_output_dir / "cp_no_losses.json"
|
||||
|
||||
# Write config file inline (self-contained test)
|
||||
with open(cp_no_config_path, "w") as f:
|
||||
f.write(
|
||||
"""distributed_type: FSDP
|
||||
fsdp_config:
|
||||
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
fsdp_state_dict_type: SHARDED_STATE_DICT
|
||||
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
||||
fsdp_version: 2
|
||||
mixed_precision: bf16
|
||||
num_processes: 1
|
||||
"""
|
||||
)
|
||||
|
||||
cmd_cp_no = f"""
|
||||
accelerate launch
|
||||
--config_file {cp_no_config_path}
|
||||
{script_path}
|
||||
--output_dir {cp_no_output_dir}
|
||||
--report_to none
|
||||
--max_steps 10
|
||||
--per_device_train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--logging_steps 1
|
||||
--remove_unused_columns False
|
||||
--seed 42
|
||||
--loss_output_file {cp_no_losses_path}
|
||||
""".split()
|
||||
|
||||
execute_subprocess_async(cmd_cp_no, env=self.get_env())
|
||||
|
||||
# Compare losses - should be very close since CP just splits sequence computation
|
||||
with open(cp_yes_losses_path) as f:
|
||||
cp_yes_losses = json.load(f)
|
||||
with open(cp_no_losses_path) as f:
|
||||
cp_no_losses = json.load(f)
|
||||
|
||||
assert len(cp_yes_losses) == len(cp_no_losses), (
|
||||
f"Different number of losses: CP has {len(cp_yes_losses)}, no-CP has {len(cp_no_losses)}"
|
||||
)
|
||||
|
||||
# CP should produce very similar results (small numerical differences expected)
|
||||
# The differences come from:
|
||||
# - Different gradient reduction patterns in distributed training
|
||||
# - BF16 mixed precision accumulated differences
|
||||
# - Sequence splitting and gathering in CP mode
|
||||
cp_yes_losses_tensor = torch.tensor(cp_yes_losses)
|
||||
cp_no_losses_tensor = torch.tensor(cp_no_losses)
|
||||
|
||||
# Use torch.testing.assert_close with rtol=2% and atol=0.02
|
||||
# Testing shows actual differences are typically <1.5%
|
||||
torch.testing.assert_close(
|
||||
cp_yes_losses_tensor,
|
||||
cp_no_losses_tensor,
|
||||
rtol=2e-2, # 2% relative tolerance
|
||||
atol=2e-2, # 0.02 absolute tolerance
|
||||
msg=f"CP losses {cp_yes_losses} do not match non-CP losses {cp_no_losses}",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Parse custom arguments (not TrainingArguments parameters)
|
||||
loss_output_file = None
|
||||
|
||||
if "--loss_output_file" in sys.argv:
|
||||
idx = sys.argv.index("--loss_output_file")
|
||||
loss_output_file = sys.argv[idx + 1]
|
||||
sys.argv.pop(idx)
|
||||
sys.argv.pop(idx)
|
||||
|
||||
parser = HfArgumentParser((TrainingArguments,))
|
||||
training_args = parser.parse_args_into_dataclasses()[0]
|
||||
|
||||
# Use SmolLM (small Llama-based model that works with CP)
|
||||
model_name = "HuggingFaceTB/SmolLM-135M"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
attn_implementation="sdpa", # CP requires SDPA
|
||||
)
|
||||
|
||||
# Create simple dataset: just tokenize some text
|
||||
texts = [
|
||||
"The quick brown fox jumps over the lazy dog. " * 10,
|
||||
"Hello world, this is a test sentence for training. " * 10,
|
||||
] * 4 # 8 samples total
|
||||
|
||||
def tokenize_function(examples):
|
||||
return tokenizer(examples, max_length=128, truncation=True, padding="max_length")
|
||||
|
||||
train_dataset = [tokenize_function(text) for text in texts]
|
||||
|
||||
# Use standard DataCollatorForLanguageModeling for causal LM
|
||||
# pad_to_multiple_of=4 ensures sequences are divisible by cp_size * 2 (for cp_size=2)
|
||||
# Trainer will automatically generate position_ids and shift_labels as needed
|
||||
data_collator = DataCollatorForLanguageModeling(
|
||||
tokenizer=tokenizer,
|
||||
mlm=False, # Causal language modeling
|
||||
pad_to_multiple_of=4,
|
||||
)
|
||||
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset,
|
||||
data_collator=data_collator,
|
||||
)
|
||||
|
||||
# Train for a few steps
|
||||
trainer.train()
|
||||
|
||||
# Verify training completed
|
||||
assert trainer.state.global_step > 0, "Training should have completed at least one step"
|
||||
|
||||
# Save losses to file if requested (for equivalence testing)
|
||||
if loss_output_file and training_args.process_index == 0:
|
||||
losses = [log["loss"] for log in trainer.state.log_history if "loss" in log]
|
||||
with open(loss_output_file, "w") as f:
|
||||
json.dump(losses, f)
|
||||
Reference in New Issue
Block a user