Add CSM model (#36719)

* draft structure

* depth decoder with forward pre hook

* full model forward draft

* draft update

* depth decoder update

* ConversationalSpeechModelForCausalLM udpates

* add generate

* max length criteria small fix

* udpate

* updates

* generation update

* update in loss compute

* conversion script

* update for correct input embeddings

* handle interleaved rope

* update

* update

* update

* support compile

* update training

* add doc

* update doc

* correct inits

* ConversationalSpeechModel -> Csm

* conf update

* name update

* tests CsmForCausalLMTest

* convert use cached_file

* conf + modeling updates

* generate utils handle third dim shape

* integration test

* modeling + conf updates

* common test handle more than 2 dims

* add nested audio list utils

* processing handle nested audio list

* csm processing draft

* mimi util

* init updates

* modular update

* convert modular

* processing update

* csm tests update

* generate tests handle third dim

* generate utils handle third dim

* propagate _get_initial_cache_position update

* tied_weight_keys update + convert correctly

* fix inputs_embeds

* revert audio nested list

* batch inference update + return audio

* audio_utils update

* processor update

* some more integration tests

* remove old test

* porcessing output labels

* improve

* fix

* update rope values with equivalent ones

* conversion update

* udpate tests

* handle depth decoder generation config

* remove default eos_token_id

* make style

* revert modeling_mimi

* add default generation_config

* remove sdpa since handled by default

* make

* fix conflict

* fix conflicts

* correct naming

* correct imports

* make

* causal -> conditional naming

* causal -> conditional naming

* auto update

* make

* make

* add doc

* test update

* fix weight init

* audio tokens offsets as buffer

* 4d mask in conditional class

* make

* doc update

* fix causal mask

* fix causal mask

* doc update

* doc update

* add processor doc

* update doc

* fix 4d causal mask

* update make_list_of_audio

* do not default to mutable

* remove duplicates

* remove useless reset_parameters

* use GradientCheckpointingLayer

* use can_return_tuple

* formatting

* prepend placeholder in _sample

* torch compile fix

* some more fixies

* convert modular

* fix

* default max_length in convert

* handle depth decoder generation config correctly

* clearer formulation

* handle output_loading_info

* handle softmax warning

* add doc

* propagate _get_initial_cache_position changes

* generation in its own module

* add processor tests

* fix compile witu cuda graphs

* fix compile with cuda graphs

* add csm.md

* include CSM loss

* doc nit

* doc nit

* doc nit

* Update docs/source/en/model_doc/csm.md

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* add save_audio to processor

* Update src/transformers/models/csm/modular_csm.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* doc update

* simplify audio_codes_mask computation

* doc update

* simplify loss computation

* fix static cache test

* fix

* remove comment

* simplify encoded length computation

* use hf-internal-testing

* doc update

* cast to float before numpy

* nit

* mem efficient codebook head

* nit

* cat input values with cutoffs

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
eustlb
2025-05-07 10:20:13 -04:00
committed by GitHub
parent c8607a17cb
commit 798f948e88
29 changed files with 5827 additions and 86 deletions

View File

@ -825,6 +825,8 @@
title: Bark
- local: model_doc/clap
title: CLAP
- local: model_doc/csm
title: CSM
- local: model_doc/dac
title: dac
- local: model_doc/encodec

View File

@ -0,0 +1,377 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
# Csm
## Overview
The Conversational Speech Model (CSM) is the first open-source contextual text-to-speech model [released by Sesame](https://www.sesame.com/research/crossing_the_uncanny_valley_of_voice). It is designed to generate natural-sounding speech with or without conversational context. This context typically consists of multi-turn dialogue between speakers, represented as sequences of text and corresponding spoken audio.
**Model Architecture:**
CSM is composed of two LLaMA-style auto-regressive transformer decoders: a backbone decoder that predicts the first codebook token and a depth decoder that generates the remaining tokens. It uses the pretrained codec model [Mimi](./mimi.md), introduced by Kyutai, to encode speech into discrete codebook tokens and decode them back into audio.
The original csm-1b checkpoint is available under the [Sesame](https://huggingface.co/sesame/csm-1b) organization on Hugging Face.
<div class="flex justify-center">
<img src="https://huggingface.co/datasets/eustlb/documentation-images/resolve/main/csm_architecture.png"/>
</div>
## Usage Tips
### Without Conversational Context
CSM can be used to simply generate speech from a text prompt:
```python
import torch
from transformers import CsmForConditionalGeneration, AutoProcessor
model_id = "eustlb/csm-1b"
device = "cuda" if torch.cuda.is_available() else "cpu"
# load the model and the processor
processor = AutoProcessor.from_pretrained(model_id)
model = CsmForConditionalGeneration.from_pretrained(model_id, device_map=device)
# prepare the inputs
text = "[0]The past is just a story we tell ourselves." # `[0]` for speaker id 0
inputs = processor(text, add_special_tokens=True).to(device)
# another equivalent way to prepare the inputs
conversation = [
{"role": "0", "content": [{"type": "text", "text": "The past is just a story we tell ourselves."}]},
]
inputs = processor.apply_chat_template(
conversation,
tokenize=True,
return_dict=True,
).to(device)
# infer the model
audio = model.generate(**inputs, output_audio=True)
processor.save_audio(audio, "example_without_context.wav")
```
### With Conversational Context
CSM can be used to generate speech given a conversation, allowing consistency in the voices and content-aware generation:
```python
import torch
from transformers import CsmForConditionalGeneration, AutoProcessor
from datasets import load_dataset, Audio
model_id = "eustlb/csm-1b"
device = "cuda" if torch.cuda.is_available() else "cpu"
# load the model and the processor
processor = AutoProcessor.from_pretrained(model_id)
model = CsmForConditionalGeneration.from_pretrained(model_id, device_map=device)
# prepare the inputs
ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train")
# ensure the audio is 24kHz
ds = ds.cast_column("audio", Audio(sampling_rate=24000))
conversation = []
# 1. context
for text, audio, speaker_id in zip(ds[:4]["text"], ds[:4]["audio"], ds[:4]["speaker_id"]):
conversation.append(
{
"role": f"{speaker_id}",
"content": [{"type": "text", "text": text}, {"type": "audio", "path": audio["array"]}],
}
)
# 2. text prompt
conversation.append({"role": f"{ds[4]['speaker_id']}", "content": [{"type": "text", "text": ds[4]["text"]}]})
inputs = processor.apply_chat_template(
conversation,
tokenize=True,
return_dict=True,
).to(device)
# infer the model
audio = model.generate(**inputs, output_audio=True)
processor.save_audio(audio, "example_with_context.wav")
```
### Batched Inference
CSM supports batched inference!
```python
import torch
from transformers import CsmForConditionalGeneration, AutoProcessor
from datasets import load_dataset, Audio
model_id = "eustlb/csm-1b"
device = "cuda" if torch.cuda.is_available() else "cpu"
# load the model and the processor
processor = AutoProcessor.from_pretrained(model_id)
model = CsmForConditionalGeneration.from_pretrained(model_id, device_map=device)
# prepare the inputs
ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train")
# ensure the audio is 24kHz
ds = ds.cast_column("audio", Audio(sampling_rate=24000))
# here a batch with two prompts
conversation = [
[
{
"role": f"{ds[0]['speaker_id']}",
"content": [
{"type": "text", "text": ds[0]["text"]},
{"type": "audio", "path": ds[0]["audio"]["array"]},
],
},
{
"role": f"{ds[1]['speaker_id']}",
"content": [
{"type": "text", "text": ds[1]["text"]},
],
},
],
[
{
"role": f"{ds[0]['speaker_id']}",
"content": [
{"type": "text", "text": ds[0]["text"]},
],
}
],
]
inputs = processor.apply_chat_template(
conversation,
tokenize=True,
return_dict=True,
).to(device)
audio = model.generate(**inputs, output_audio=True)
processor.save_audio(audio, [f"speech_batch_idx_{i}.wav" for i in range(len(audio))])
```
### Making The Model Go Brrr
CSM supports full-graph compilation with CUDA graphs!
```python
import torch
import copy
from transformers import CsmForConditionalGeneration, AutoProcessor
from datasets import load_dataset
model_id = "eustlb/csm-1b"
device = "cuda"
# set logs to ensure no recompilation and graph breaks
torch._logging.set_logs(graph_breaks=True, recompiles=True, cudagraphs=True)
# load the model and the processor
processor = AutoProcessor.from_pretrained(model_id)
model = CsmForConditionalGeneration.from_pretrained(model_id, device_map=device)
# use static cache, enabling automatically torch compile with fullgraph and reduce-overhead
model.generation_config.max_length = 250 # big enough to avoid recompilation
model.generation_config.max_new_tokens = None # would take precedence over max_length
model.generation_config.cache_implementation = "static"
model.depth_decoder.generation_config.cache_implementation = "static"
# generation kwargs
gen_kwargs = {
"do_sample": False,
"depth_decoder_do_sample": False,
"temperature": 1.0,
"depth_decoder_temperature": 1.0,
}
# Define a timing decorator
class TimerContext:
def __init__(self, name="Execution"):
self.name = name
self.start_event = None
self.end_event = None
def __enter__(self):
# Use CUDA events for more accurate GPU timing
self.start_event = torch.cuda.Event(enable_timing=True)
self.end_event = torch.cuda.Event(enable_timing=True)
self.start_event.record()
return self
def __exit__(self, *args):
self.end_event.record()
torch.cuda.synchronize()
elapsed_time = self.start_event.elapsed_time(self.end_event) / 1000.0
print(f"{self.name} time: {elapsed_time:.4f} seconds")
# prepare the inputs
ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train")
conversation = [
{
"role": f"{ds[0]['speaker_id']}",
"content": [
{"type": "text", "text": ds[0]["text"]},
{"type": "audio", "path": ds[0]["audio"]["array"]},
],
},
{
"role": f"{ds[1]['speaker_id']}",
"content": [
{"type": "text", "text": ds[1]["text"]},
{"type": "audio", "path": ds[1]["audio"]["array"]},
],
},
{
"role": f"{ds[2]['speaker_id']}",
"content": [
{"type": "text", "text": ds[2]["text"]},
],
},
]
padded_inputs_1 = processor.apply_chat_template(
conversation,
tokenize=True,
return_dict=True,
).to(device)
print("\n" + "="*50)
print("First generation - compiling and recording CUDA graphs...")
with TimerContext("First generation"):
_ = model.generate(**padded_inputs_1, **gen_kwargs)
print("="*50)
print("\n" + "="*50)
print("Second generation - fast !!!")
with TimerContext("Second generation"):
_ = model.generate(**padded_inputs_1, **gen_kwargs)
print("="*50)
# now with different inputs
conversation = [
{
"role": f"{ds[0]['speaker_id']}",
"content": [
{"type": "text", "text": ds[2]["text"]},
{"type": "audio", "path": ds[2]["audio"]["array"]},
],
},
{
"role": f"{ds[1]['speaker_id']}",
"content": [
{"type": "text", "text": ds[3]["text"]},
{"type": "audio", "path": ds[3]["audio"]["array"]},
],
},
{
"role": f"{ds[2]['speaker_id']}",
"content": [
{"type": "text", "text": ds[4]["text"]},
],
},
]
padded_inputs_2 = processor.apply_chat_template(
conversation,
tokenize=True,
return_dict=True,
).to(device)
print("\n" + "="*50)
print("Generation with other inputs!")
with TimerContext("Generation with different inputs"):
_ = model.generate(**padded_inputs_2, **gen_kwargs)
print("="*50)
```
### Training
CSM Transformers integration supports training!
```python
from transformers import CsmForConditionalGeneration, AutoProcessor
from datasets import load_dataset, Audio
model_id = "eustlb/csm-1b"
device = "cuda"
# load the model and the processor
processor = AutoProcessor.from_pretrained(model_id)
model = CsmForConditionalGeneration.from_pretrained(model_id, device_map=device)
model.train()
ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train")
# ensure the audio is 24kHz
ds = ds.cast_column("audio", Audio(sampling_rate=24000))
conversation = []
# context
for text, audio, speaker_id in zip(ds[:4]["text"], ds[:4]["audio"], ds[:4]["speaker_id"]):
conversation.append(
{
"role": f"{speaker_id}",
"content": [{"type": "text", "text": text}, {"type": "audio", "path": audio["array"]}],
}
)
inputs = processor.apply_chat_template(
conversation,
tokenize=True,
return_dict=True,
output_labels=True,
).to(device)
out = model(**inputs)
out.loss.backward()
```
This model was contributed by [Eustache Le Bihan](https://huggingface.co/eustlb).
The original code can be found [here](https://github.com/SesameAILabs/csm).
## CsmConfig
[[autodoc]] CsmConfig
## CsmDepthDecoderConfig
[[autodoc]] CsmDepthDecoderConfig
## CsmProcessor
[[autodoc]] CsmProcessor
- __call__
## CsmForConditionalGeneration
[[autodoc]] CsmForConditionalGeneration
- forward
- generate
## CsmDepthDecoderForCausalLM
[[autodoc]] CsmDepthDecoderForCausalLM
## CsmDepthDecoderModel
[[autodoc]] CsmDepthDecoderModel
## CsmBackboneModel
[[autodoc]] CsmBackboneModel

View File

@ -24,7 +24,12 @@ from typing import List, Optional, Tuple, Union
import numpy as np
import requests
from .utils import is_librosa_available, requires_backends
from .utils import (
is_librosa_available,
is_numpy_array,
is_torch_tensor,
requires_backends,
)
if is_librosa_available():
@ -69,6 +74,36 @@ AudioInput = Union[
]
def is_valid_audio(audio):
return is_numpy_array(audio) or is_torch_tensor(audio)
def is_valid_list_of_audio(audio):
return audio and all(is_valid_audio(audio_i) for audio_i in audio)
def make_list_of_audio(
audio: Union[list[AudioInput], AudioInput],
) -> AudioInput:
"""
Ensure that the output is a list of audio.
Args:
audio (`Union[List[AudioInput], AudioInput]`):
The input audio.
Returns:
list: A list of audio.
"""
# If it's a list of audios, it's already in the right format
if isinstance(audio, (list, tuple)) and is_valid_list_of_audio(audio):
return audio
# If it's a single audio, convert it to a list of
if is_valid_audio(audio):
return [audio]
raise ValueError("Invalid input type. Must be a single audio or a list of audio")
def hertz_to_mel(freq: Union[float, np.ndarray], mel_scale: str = "htk") -> Union[float, np.ndarray]:
"""
Convert frequency from hertz to mels.

View File

@ -73,7 +73,7 @@ class MaxLengthCriteria(StoppingCriteria):
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
cur_len = input_ids.shape[-1]
cur_len = input_ids.shape[1]
is_done = cur_len >= self.max_length
if self.max_position_embeddings is not None and not is_done and cur_len >= self.max_position_embeddings:
logger.warning_once(

View File

@ -563,7 +563,7 @@ class GenerationMixin:
if model_inputs["inputs_embeds"] is not None:
batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
else:
batch_size, sequence_length = model_inputs[input_ids_key].shape
batch_size, sequence_length = model_inputs[input_ids_key].shape[:2]
# Create the causal mask with fixed shape in advance, to reduce recompilations. If the function to create
# the 4D causal mask exists, it should be present in the base model (XXXModel class) or in its decoder.
@ -1708,7 +1708,7 @@ class GenerationMixin:
return generation_config, model_kwargs
def _get_initial_cache_position(self, input_ids, model_kwargs):
def _get_initial_cache_position(self, seq_length, device, model_kwargs):
"""Calculates `cache_position` for the pre-fill stage based on `input_ids` and optionally past length"""
# `torch.compile`-friendly `torch.arange` from a shape -- the lines below are equivalent to `torch.arange`
if "inputs_embeds" in model_kwargs and not self.config.is_encoder_decoder:
@ -1718,7 +1718,7 @@ class GenerationMixin:
torch.ones_like(model_kwargs["decoder_inputs_embeds"][0, :, 0], dtype=torch.int64).cumsum(0) - 1
)
else:
cache_position = torch.ones_like(input_ids[0, :], dtype=torch.int64).cumsum(0) - 1
cache_position = torch.ones(seq_length, dtype=torch.int64, device=device).cumsum(0) - 1
past_length = 0
if model_kwargs.get("past_key_values") is not None:
@ -2332,7 +2332,7 @@ class GenerationMixin:
streamer.put(input_ids.cpu())
# 6. Prepare `max_length` depending on other stopping criteria.
input_ids_length = input_ids.shape[-1]
input_ids_length = input_ids.shape[1]
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None
generation_config = self._prepare_generated_length(
@ -2805,9 +2805,9 @@ class GenerationMixin:
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
# keep track of which sequences are already finished
batch_size = input_ids.shape[0]
batch_size, cur_length = input_ids.shape[:2]
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
model_kwargs = self._get_initial_cache_position(cur_length, input_ids.device, model_kwargs)
this_peer_finished = False
@ -3016,9 +3016,9 @@ class GenerationMixin:
)
# keep track of which sequences are already finished
batch_size = input_ids.shape[0]
batch_size, cur_len = input_ids.shape[:2]
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs)
# Create cosine_matrix_mask based on the attention_mask
cosine_matrix_mask = torch.ones_like(input_ids, dtype=torch.long)
@ -3428,10 +3428,10 @@ class GenerationMixin:
)
# keep track of which sequences are already finished
batch_size, cur_len = input_ids.shape
batch_size, cur_len = input_ids.shape[:2]
this_peer_finished = False
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs)
model_forward = self.__call__
compile_forward = self._valid_auto_compile_criteria(model_kwargs, generation_config)
@ -3834,7 +3834,7 @@ class GenerationMixin:
num_beams = generation_config.num_beams
num_return_sequences = generation_config.num_return_sequences
batch_size_unflattened, cur_len = input_ids.shape
batch_size_unflattened, cur_len = input_ids.shape[:2]
batch_size = batch_size_unflattened // num_beams
# TODO (joao): standardize special cases
if self.__class__.__name__ == "MoshiDepthDecoder":
@ -3857,7 +3857,7 @@ class GenerationMixin:
dim=0,
).to(input_ids.device)
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs)
# (joao) feature lost in the refactor. Probably won't implement, hurts readability with minimal gains (there
# are newer low-memory alternatives like the offloaded cache)
@ -4156,7 +4156,7 @@ class GenerationMixin:
device = input_ids.device
batch_beam_size, cur_len = input_ids.shape
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs)
if return_dict_in_generate and output_scores:
beam_indices = [tuple(() for _ in range(num_sub_beams * batch_size)) for _ in range(num_beam_groups)]
@ -4190,7 +4190,7 @@ class GenerationMixin:
this_peer_finished = False
decoder_prompt_len = input_ids.shape[-1] # record the prompt length of decoder
decoder_prompt_len = input_ids.shape[1] # record the prompt length of decoder
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
# predicted tokens in cur_len step
current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device)
@ -4444,8 +4444,8 @@ class GenerationMixin:
batch_size = len(constrained_beam_scorer._beam_hyps)
num_beams = constrained_beam_scorer.num_beams
batch_beam_size, cur_len = input_ids.shape
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
batch_beam_size, cur_len = input_ids.shape[:2]
model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs)
if num_beams * batch_size != batch_beam_size:
raise ValueError(
@ -4477,7 +4477,7 @@ class GenerationMixin:
this_peer_finished = False
decoder_prompt_len = input_ids.shape[-1] # record the prompt length of decoder
decoder_prompt_len = input_ids.shape[1] # record the prompt length of decoder
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
@ -4698,14 +4698,14 @@ class GenerationMixin:
)
# keep track of which sequences are already finished
batch_size = input_ids.shape[0]
batch_size, cur_len = input_ids.shape[:2]
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs)
this_peer_finished = False
is_first_iteration = True # to preserve the same API in the output as other generation methods
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
cur_len = input_ids.shape[-1]
cur_len = input_ids.shape[1]
# 1. Fetch candidate sequences from a `CandidateGenerator` and move to the correct device
candidate_input_ids, candidate_logits = candidate_generator.get_candidates(input_ids)
@ -4795,7 +4795,7 @@ class GenerationMixin:
input_ids = torch.cat((input_ids, valid_tokens), dim=-1)
if streamer is not None:
streamer.put(valid_tokens.cpu())
new_cur_len = input_ids.shape[-1]
new_cur_len = input_ids.shape[1]
# 4.2. Discard past key values relative to unused assistant tokens
new_cache_size = new_cur_len - 1

View File

@ -158,4 +158,5 @@ LOSS_MAPPING = {
"RTDetrForObjectDetection": RTDetrForObjectDetectionLoss,
"RTDetrV2ForObjectDetection": RTDetrForObjectDetectionLoss,
"DFineForObjectDetection": DFineForObjectDetectionLoss,
"CsmForConditionalGeneration": ForCausalLMLoss,
}

View File

@ -68,6 +68,7 @@ if TYPE_CHECKING:
from .convnextv2 import *
from .cpm import *
from .cpmant import *
from .csm import *
from .ctrl import *
from .cvt import *
from .d_fine import *

View File

@ -80,6 +80,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
("convnext", "ConvNextConfig"),
("convnextv2", "ConvNextV2Config"),
("cpmant", "CpmAntConfig"),
("csm", "CsmConfig"),
("ctrl", "CTRLConfig"),
("cvt", "CvtConfig"),
("d_fine", "DFineConfig"),
@ -437,6 +438,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
("convnextv2", "ConvNeXTV2"),
("cpm", "CPM"),
("cpmant", "CPM-Ant"),
("csm", "CSM"),
("ctrl", "CTRL"),
("cvt", "CvT"),
("d_fine", "D-FINE"),

View File

@ -78,6 +78,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
("convnext", "ConvNextModel"),
("convnextv2", "ConvNextV2Model"),
("cpmant", "CpmAntModel"),
("csm", "CsmForConditionalGeneration"),
("ctrl", "CTRLModel"),
("cvt", "CvtModel"),
("d_fine", "DFineModel"),
@ -1446,6 +1447,7 @@ MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING_NAMES = OrderedDict(
[
# Model for Text-To-Waveform mapping
("bark", "BarkModel"),
("csm", "CsmForConditionalGeneration"),
("fastspeech2_conformer", "FastSpeech2ConformerWithHifiGan"),
("musicgen", "MusicgenForConditionalGeneration"),
("musicgen_melody", "MusicgenMelodyForConditionalGeneration"),

View File

@ -0,0 +1,28 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...utils import _LazyModule
from ...utils.import_utils import define_import_structure
if TYPE_CHECKING:
from .configuration_csm import *
from .modeling_csm import *
from .processing_csm import *
else:
import sys
_file = globals()["__file__"]
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)

View File

@ -0,0 +1,440 @@
# coding=utf-8
# Copyright 2025 Sesame and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ...configuration_utils import PretrainedConfig
from ...modeling_rope_utils import rope_config_validation
from ...utils import logging
from ..auto.configuration_auto import AutoConfig
logger = logging.get_logger(__name__)
class CsmDepthDecoderConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`CsmDepthDecoderModel`]. It is used to instantiate an CSM depth decoder
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield
a similar configuration to that of the csm-1b.
e.g. [eustlb/csm-1b](https://huggingface.co/eustlb/csm-1b)
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
num_codebooks (`int`, *optional*, defaults to 32):
Number of codebooks used in the underlying codec model responsible for tokenizing the audio.
backbone_hidden_size (`int`, *optional*, defaults to 2048):
Dimension of the hidden representations of the backbone model used with this depth decoder.
vocab_size (`int`, *optional*, defaults to 2051):
Vocabulary size of the CsmDepthDecoder model. Defines the number of different audio tokens that can be represented by each codebook.
hidden_size (`int`, *optional*, defaults to 1024):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 8192):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 4):
Number of hidden layers in the Transformer decoder.
num_attention_heads (`int`, *optional*, defaults to 8):
Number of attention heads for each attention layer in the Transformer decoder.
num_key_value_heads (`int`, *optional*, defaults to 2):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
`num_attention_heads`.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 33):
The maximum sequence length that this model might ever be used with.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
pad_token_id (`int`, *optional*, defaults to 2050):
Padding token id.
bos_token_id (`int`, *optional*):
Beginning of stream token id.
eos_token_id (`int`, *optional*):
End of stream token id.
rope_theta (`float`, *optional*, defaults to 500000):
The base period of the RoPE embeddings.
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
accordingly.
Expected contents:
`rope_type` (`str`):
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
'llama3'], with 'default' being the original RoPE implementation.
`factor` (`float`, *optional*):
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
original maximum pre-trained length.
`original_max_position_embeddings` (`int`, *optional*):
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
pretraining.
`attention_factor` (`float`, *optional*):
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
computation. If unspecified, it defaults to value recommended by the implementation, using the
`factor` field to infer the suggested value.
`beta_fast` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
ramp function. If unspecified, it defaults to 32.
`beta_slow` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
ramp function. If unspecified, it defaults to 1.
`short_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`long_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`low_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
`high_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
attention_bias (`bool`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
mlp_bias (`bool`, *optional*, defaults to `False`):
Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers.
head_dim (`int`, *optional*):
The attention head dimension. If None, it will default to hidden_size // num_attention_heads
```python
>>> from transformers import CsmDepthDecoder, CsmDepthDecoderConfig
>>> # Initializing a CsmDepthDecoder
>>> configuration = CsmDepthDecoderConfig()
>>> model = CsmDepthDecoderModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "csm_depth_decoder_model"
base_config_key = "depth_decoder_config"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
num_codebooks=32,
backbone_hidden_size=2048,
vocab_size=2051,
hidden_size=1024,
intermediate_size=8192,
num_hidden_layers=4,
num_attention_heads=8,
num_key_value_heads=2,
hidden_act="silu",
max_position_embeddings=33,
initializer_range=0.02,
rms_norm_eps=1e-5,
use_cache=True,
pad_token_id=None,
bos_token_id=None,
eos_token_id=None,
rope_theta=500000,
rope_scaling=None,
attention_bias=False,
attention_dropout=0.0,
mlp_bias=False,
head_dim=None,
**kwargs,
):
if kwargs.pop("tie_word_embeddings", False):
raise ValueError("`tie_word_embeddings=True` is not supported for CsmDepthDecoderConfig")
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=False,
**kwargs,
)
self.num_codebooks = num_codebooks
self.vocab_size = vocab_size
self.backbone_hidden_size = backbone_hidden_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.mlp_bias = mlp_bias
self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads
# Validate the correctness of rotary position embeddings parameters
# BC: if there is a 'type' field, copy it it to 'rope_type'.
if self.rope_scaling is not None and "type" in self.rope_scaling:
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
rope_config_validation(self)
class CsmConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`CsmForConditionalGeneration`]. It is used to instantiate an CSM
model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of the csm-1b.
e.g. [eustlb/csm-1b](https://huggingface.co/eustlb/csm-1b)
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
num_codebooks (`int`, *optional*, defaults to 32):
Number of codebooks used in the underlying codec model responsible for tokenizing the audio.
vocab_size (`int`, *optional*, defaults to 2051):
Vocabulary size of the Csm model. Defines the number of different audio tokens that can be represented by each codebook.
text_vocab_size (`int`, *optional*, defaults to 128256):
Vocabulary size of the text input for the Csm model. Defines the number of different text tokens that can be represented.
hidden_size (`int`, *optional*, defaults to 2048):
Dimension of the hidden representations of the backbone model.
intermediate_size (`int`, *optional*, defaults to 8192):
Dimension of the MLP representations of the backbone model.
num_hidden_layers (`int`, *optional*, defaults to 16):
Number of hidden layers in the backbone model Transformer decoder.
num_attention_heads (`int`, *optional*, defaults to 32):
Number of attention heads for each attention layer in the backbone model Transformer decoder.
num_key_value_heads (`int`, *optional*, defaults to 8):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf).
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the backbone model Transformer decoder.
max_position_embeddings (`int`, *optional*, defaults to 2048):
The maximum sequence length that this model might ever be used with.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
pad_token_id (`int`, *optional*, defaults to 128002):
Padding token id.
codebook_pad_token_id (`int`, *optional*, defaults to 2050):
Padding token id for codebook tokens.
codebook_eos_token_id (`int`, *optional*, defaults to 0):
End of stream token id for codebook tokens.
bos_token_id (`int`, *optional*, defaults to 128000):
Beginning of stream token id.
eos_token_id (`int`, *optional*):
End of stream token id.
audio_token_id (`int`, *optional*, defaults to 128002):
Audio token id in the text input.
audio_eos_token_id (`int`, *optional*, defaults to 128003):
End of stream token id for audio in the text input.
rope_theta (`float`, *optional*, defaults to 500000):
The base period of the RoPE embeddings.
rope_scaling (`Dict`, *optional*, defaults to `{'factor': 32.0, 'high_freq_factor': 0.5, 'low_freq_factor': 0.125, 'original_max_position_embeddings': 1024, 'rope_type': 'llama3'}`):
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
accordingly.
Expected contents:
`rope_type` (`str`):
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
'llama3'], with 'default' being the original RoPE implementation.
`factor` (`float`, *optional*):
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
original maximum pre-trained length.
`original_max_position_embeddings` (`int`, *optional*):
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
pretraining.
`attention_factor` (`float`, *optional*):
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
computation. If unspecified, it defaults to value recommended by the implementation, using the
`factor` field to infer the suggested value.
`beta_fast` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
ramp function. If unspecified, it defaults to 32.
`beta_slow` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
ramp function. If unspecified, it defaults to 1.
`short_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`long_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`low_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
`high_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
attention_bias (`bool`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
mlp_bias (`bool`, *optional*, defaults to `False`):
Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers.
head_dim (`int`, *optional*):
The attention head dimension. If None, it will default to hidden_size // num_attention_heads
tie_codebooks_embeddings (`bool`, *optional*, defaults to `True`):
Whether to tie the codebook tokens embeddings of the backbone model to the codebook tokens embeddings of the depth decoder.
depth_decoder_config (`CsmDepthDecoderConfig`, *optional*):
Configuration for the depth decoder.
codec_config (`PretrainedConfig`, *optional*):
Configuration for the codec.
```python
>>> from transformers import CsmForConditionalGeneration, CsmConfig
>>> # Initializing a CsmConfig
>>> configuration = CsmConfig()
>>> # Initializing a model
>>> model = CsmForConditionalGeneration(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "csm"
base_config_key = "csm_config"
keys_to_ignore_at_inference = ["past_key_values"]
sub_configs = {
"codec_config": AutoConfig,
"depth_decoder_config": CsmDepthDecoderConfig,
}
def __init__(
self,
num_codebooks=32,
vocab_size=2051,
text_vocab_size=128256,
hidden_size=2048,
intermediate_size=8192,
num_hidden_layers=16,
num_attention_heads=32,
num_key_value_heads=8,
hidden_act="silu",
max_position_embeddings=2048,
initializer_range=0.02,
rms_norm_eps=1e-5,
use_cache=True,
pad_token_id=128002,
codebook_pad_token_id=2050,
codebook_eos_token_id=0,
bos_token_id=128000,
eos_token_id=None,
audio_token_id=128002,
audio_eos_token_id=128003,
rope_theta=500000,
rope_scaling=None,
attention_bias=False,
attention_dropout=0.0,
mlp_bias=False,
head_dim=None,
tie_codebooks_embeddings=True,
depth_decoder_config=None,
codec_config=None,
**kwargs,
):
if kwargs.pop("tie_word_embeddings", False):
raise ValueError("`tie_word_embeddings=True` is not supported for CsmConfig")
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=False,
**kwargs,
)
if depth_decoder_config is None:
self.depth_decoder_config = CsmDepthDecoderConfig()
logger.info("depth_decoder_config is None, using default depth decoder config.")
elif isinstance(depth_decoder_config, dict):
self.depth_decoder_config = CsmDepthDecoderConfig(**depth_decoder_config)
elif isinstance(depth_decoder_config, CsmDepthDecoderConfig):
self.depth_decoder_config = depth_decoder_config
if codec_config is None:
self.codec_config = AutoConfig.for_model("mimi")
logger.info("codec_config is None, using default audio encoder config.")
elif isinstance(codec_config, dict):
self.codec_config = AutoConfig.for_model(**codec_config)
elif isinstance(codec_config, PretrainedConfig):
self.codec_config = codec_config
self.text_vocab_size = text_vocab_size
self.num_codebooks = num_codebooks
self.audio_token_id = audio_token_id
self.audio_eos_token_id = audio_eos_token_id
self.codebook_pad_token_id = codebook_pad_token_id
self.codebook_eos_token_id = codebook_eos_token_id
self.tie_codebooks_embeddings = tie_codebooks_embeddings
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.mlp_bias = mlp_bias
self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads
# Validate the correctness of rotary position embeddings parameters
# BC: if there is a 'type' field, copy it it to 'rope_type'.
if self.rope_scaling is not None and "type" in self.rope_scaling:
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
rope_config_validation(self)
__all__ = [
"CsmDepthDecoderConfig",
"CsmConfig",
]

View File

@ -0,0 +1,339 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import gc
import os
import re
import torch
from tokenizers.processors import TemplateProcessing
from transformers import (
AutoFeatureExtractor,
AutoTokenizer,
CsmConfig,
CsmDepthDecoderConfig,
CsmForConditionalGeneration,
CsmProcessor,
MimiModel,
)
from transformers.utils.hub import cached_file
# fmt: off
ORIGINAL_TO_CONVERTED_KEY_MAPPING = {
r"backbone\.layers\.(\d+)": r"backbone_model.layers.\1",
r"decoder\.layers\.(\d+)": r"depth_decoder.model.layers.\1",
r"attn": r"self_attn",
r"output_proj": r"o_proj",
r"w1": r"gate_proj",
r"w2": r"down_proj",
r"w3": r"up_proj",
r"text_embeddings": r"embed_text_tokens",
r"audio_embeddings": r"backbone_model.embed_tokens.embed_audio_tokens",
r"codebook0_head": r"lm_head",
r"audio_head": r"depth_decoder.codebooks_head.weight",
r"projection": r"depth_decoder.model.inputs_embeds_projector",
r"sa_norm.scale": r"input_layernorm.weight",
r"mlp_norm.scale": r"post_attention_layernorm.weight",
r"decoder.norm.scale": r"depth_decoder.model.norm.weight",
r"backbone.norm.scale": r"backbone_model.norm.weight",
}
# fmt: on
def permute_for_rope(input_tensor, n_heads, dim1, dim2):
"""
When you go from the complex ROPE formulation to sin and cos one, you need
to permute the query and key weights (to avoid doing it on the fly)
"""
input_tensor = input_tensor.reshape(dim1, dim2)
input_tensor = input_tensor.view(n_heads, dim1 // n_heads // 2, 2, dim2)
input_tensor = input_tensor.transpose(1, 2).reshape(dim1, dim2)
return input_tensor
def convert_key(key, mapping):
for pattern, replacement in mapping.items():
key = re.sub(pattern, replacement, key)
return key
def write_model(
input_path_or_repo,
model_name,
codec_model_path_or_repo,
output_dir,
safe_serialization=True,
):
print("Converting the model.")
os.makedirs(output_dir, exist_ok=True)
codec_model = MimiModel.from_pretrained(codec_model_path_or_repo)
codec_model.config._attn_implementation_autoset = False
# prepare rope scaling args: the model uses originally
# 1 - for the depth decoder
# rope_theta=500000,
# rope_scaling={
# "factor": 32.0,
# "high_freq_factor": 4.0,
# "low_freq_factor": 1.0,
# "original_max_position_embeddings": 8192,
# "rope_type": "llama3",
# },
# 2 - for the backbone
# rope_theta=500000,
# rope_scaling={
# "factor": 32.0,
# "high_freq_factor": 4.0,
# "low_freq_factor": 1.0,
# "original_max_position_embeddings": 8192,
# "rope_type": "llama3",
# },
#
# Yet we want to use max_position_embeddings=32, resp. 2048
# This will throw warning as we would have original_max_position_embeddings >= max_position_embeddings
# Therefore, we convert values to equivalent ones
depth_decoder_config = CsmDepthDecoderConfig(
rope_scaling={
"factor": 32.0,
"high_freq_factor": 0.0078125,
"low_freq_factor": 0.001953125,
"original_max_position_embeddings": 16,
"rope_type": "llama3",
},
)
config = CsmConfig(
codec_config=codec_model.config,
depth_decoder_config=depth_decoder_config,
rope_scaling={
"factor": 32.0,
"high_freq_factor": 0.5,
"low_freq_factor": 0.125,
"original_max_position_embeddings": 1024,
"rope_type": "llama3",
},
)
params = {
"backbone": {
"num_attention_heads": config.num_attention_heads,
"num_key_value_heads": config.num_key_value_heads,
"dim_per_head": config.head_dim,
"key_value_dim": config.head_dim * config.num_key_value_heads,
"dim": config.hidden_size,
},
"depth_decoder": {
"num_attention_heads": config.depth_decoder_config.num_attention_heads,
"num_key_value_heads": config.depth_decoder_config.num_key_value_heads,
"dim_per_head": config.depth_decoder_config.head_dim,
"key_value_dim": config.depth_decoder_config.head_dim * config.depth_decoder_config.num_key_value_heads,
"dim": config.depth_decoder_config.hidden_size,
},
}
model_path = cached_file(
input_path_or_repo,
model_name,
)
print(f"Fetching all parameters from the checkpoint at {model_path}...")
loaded = torch.load(model_path, map_location="cpu")
print("Converting model...")
state_dict = {}
# -----------------------
# convert parameter names
# -----------------------
# Add codec_model. prefix to every key in the codec model state dict
codec_state_dict = {f"codec_model.{k}": v for k, v in codec_model.state_dict().items()}
state_dict.update(codec_state_dict)
for key, value in loaded.items():
new_key = convert_key(key, ORIGINAL_TO_CONVERTED_KEY_MAPPING)
current_parameter = value
# Post-process the current_parameter.
if re.search("(k|q)_proj.weight", new_key):
params_keys = "backbone" if "backbone" in new_key else "depth_decoder"
if "q_proj" in new_key:
num_heads = params[params_keys]["num_attention_heads"]
dim_per_head = params[params_keys]["dim_per_head"]
param_dim = params[params_keys]["dim"]
dim = params[params_keys]["dim"]
else:
num_heads = params[params_keys]["num_key_value_heads"]
dim_per_head = params[params_keys]["dim_per_head"]
param_dim = params[params_keys]["key_value_dim"]
dim = params[params_keys]["dim"]
current_parameter = permute_for_rope(value, num_heads, param_dim, dim)
state_dict[new_key] = current_parameter.reshape(num_heads * dim_per_head, dim)
state_dict[new_key] = current_parameter
# add the depth decoder embed audio tokens weights, latter tied to the backbone embed audio tokens weights
state_dict["depth_decoder.model.embed_tokens.weight"] = state_dict[
"backbone_model.embed_tokens.embed_audio_tokens.weight"
].clone()
del loaded
gc.collect()
# -------------------------
# load the weights and save
# -------------------------
print("Loading the checkpoint in a Csm model.")
with torch.device("meta"):
model = CsmForConditionalGeneration(config)
model.load_state_dict(state_dict, strict=True, assign=True)
print("Checkpoint loaded successfully.")
del model.config._name_or_path
# default generation config
model.generation_config._from_model_config = False
model.generation_config.max_new_tokens = 125
model.generation_config.do_sample = True
model.generation_config.top_k = 50
model.generation_config.temperature = 0.9
model.generation_config.depth_decoder_do_sample = True
model.generation_config.depth_decoder_top_k = 50
model.generation_config.depth_decoder_temperature = 0.9
print("Saving the model.")
model.save_pretrained(output_dir, safe_serialization=safe_serialization)
del state_dict, model
# Safety check: reload the converted model
gc.collect()
print("Reloading the model to check if it's saved correctly.")
CsmForConditionalGeneration.from_pretrained(output_dir, torch_dtype=torch.bfloat16, device_map="auto")
print("Model reloaded successfully.")
def write_tokenizer(output_dir):
# from https://github.com/SesameAILabs/csm/blob/2d720827843b653c4d67bb4445b1c0a4f59e646f/generator.py#L22-L36
def load_llama3_tokenizer():
"""
https://github.com/huggingface/transformers/issues/22794#issuecomment-2092623992
"""
tokenizer_name = "meta-llama/Llama-3.2-1B"
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
bos = tokenizer.bos_token
eos = tokenizer.eos_token
tokenizer._tokenizer.post_processor = TemplateProcessing(
single=f"{bos}:0 $A:0 {eos}:0",
pair=f"{bos}:0 $A:0 {eos}:0 {bos}:1 $B:1 {eos}:1",
special_tokens=[(f"{bos}", tokenizer.bos_token_id), (f"{eos}", tokenizer.eos_token_id)],
)
return tokenizer
tokenizer = load_llama3_tokenizer()
tokenizer.pad_token = tokenizer.eos_token
tokenizer.save_pretrained(output_dir)
# manually modify in tokenizer_config.json
# "128002": {
# "content": "<|AUDIO|>",
# ...
# }
# "128003": {
# "content": "<|audio_eos|>",
# ...
# }
print(
"Tokenizer saved successfully. Please manually modify in tokenizer_config.json AND tokenizer.json as follows: "
)
print("""
# "128002": {
# "content": "<|AUDIO|>",
# ...
# }
# "128003": {
# "content": "<|audio_eos|>",
# ...
# }
""")
def write_processor(output_dir, codec_model_path_or_repo):
chat_template = "\n{%- for message in messages %}\n {#-- Validate role is a stringified integer --#}\n {%- if not message['role'] is string or not message['role'].isdigit() %}\n {{- raise_exception(\"The role must be an integer or a stringified integer (e.g. '0') designating the speaker id\") }}\n {%- endif %}\n\n {#-- Validate content is a list --#}\n {%- set content = message['content'] %}\n {%- if content is not iterable or content is string %}\n {{- raise_exception(\"The content must be a list\") }}\n {%- endif %}\n\n {#-- Collect content types --#}\n {%- set content_types = content | map(attribute='type') | list %}\n {%- set is_last = loop.last %}\n\n {#-- Last message validation --#}\n {%- if is_last %}\n {%- if 'text' not in content_types %}\n {{- raise_exception(\"The last message must include one item of type 'text'\") }}\n {%- elif (content_types | select('equalto', 'text') | list | length > 1) or (content_types | select('equalto', 'audio') | list | length > 1) %}\n {{- raise_exception(\"At most two items are allowed in the last message: one 'text' and one 'audio'\") }}\n {%- endif %}\n\n {#-- All other messages validation --#}\n {%- else %}\n {%- if content_types | select('equalto', 'text') | list | length != 1\n or content_types | select('equalto', 'audio') | list | length != 1 %}\n {{- raise_exception(\"Each message (except the last) must contain exactly one 'text' and one 'audio' item\") }}\n {%- elif content_types | reject('in', ['text', 'audio']) | list | length > 0 %}\n {{- raise_exception(\"Only 'text' and 'audio' types are allowed in content\") }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n\n{%- for message in messages %}\n {{- bos_token }}\n {{- '[' + message['role'] + ']' }}\n {{- message['content'][0]['text'] }}\n {{- eos_token }}\n {%- if message['content']|length > 1 %}\n {{- '<|AUDIO|><|audio_eos|>' }}\n {%- endif %}\n{%- endfor %}\n"
tokenizer = AutoTokenizer.from_pretrained(output_dir)
feature_extractor = AutoFeatureExtractor.from_pretrained(codec_model_path_or_repo)
processor = CsmProcessor(
tokenizer=tokenizer,
feature_extractor=feature_extractor,
chat_template=chat_template,
)
processor.save_pretrained(output_dir)
print("Processor saved successfully.")
def main():
parser = argparse.ArgumentParser(description="Convert Csm weights to HuggingFace format")
parser.add_argument(
"--input_path_or_repo",
type=str,
required=True,
help="Path or repo containing Csm weights",
)
parser.add_argument(
"--model_name",
type=str,
required=True,
help="Name of the model in input_path_or_repo",
)
parser.add_argument(
"--codec_model_path_or_repo",
type=str,
required=True,
help="Path or repo containing the codec model",
)
parser.add_argument(
"--output_dir",
help="Location to write HF model and tokenizer",
)
parser.add_argument(
"--safe_serialization", action="store_true", default=True, help="Whether or not to save using `safetensors`."
)
args = parser.parse_args()
write_model(
args.input_path_or_repo,
args.model_name,
args.codec_model_path_or_repo,
output_dir=args.output_dir,
safe_serialization=args.safe_serialization,
)
write_tokenizer(args.output_dir)
write_processor(args.output_dir, args.codec_model_path_or_repo)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,491 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from ...generation import (
GenerateDecoderOnlyOutput,
GenerationConfig,
GenerationMixin,
GenerationMode,
)
from ...generation.logits_process import LogitsProcessorList
from ...generation.stopping_criteria import MaxLengthCriteria, StoppingCriteriaList
from ...generation.utils import GenerateNonBeamOutput
from ...utils import logging
if TYPE_CHECKING:
from ...generation.streamers import BaseStreamer
logger = logging.get_logger(__name__)
@dataclass
class CsmGenerateOutput(GenerateDecoderOnlyOutput):
"""
Outputs of CsmForConditionalGeneration.generate.
Args:
sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
if all batches finished early due to the `eos_token_id`.
scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True`):
Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True`):
Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
`torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
`torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`.
past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True`):
Returns the model cache, used to speed up decoding. Different models have a different cache format, check
audio (`list(torch.FloatTensor)` of length `batch_size`):
The generated audio.
"""
audio: Optional[List[torch.Tensor]] = None
class CsmGenerationMixin(GenerationMixin):
def _get_stopping_criteria(
self,
*args,
**kwargs,
) -> StoppingCriteriaList:
criteria = super()._get_stopping_criteria(*args, **kwargs)
kept_criteria = StoppingCriteriaList()
for criterion in criteria:
if not isinstance(criterion, MaxLengthCriteria):
logger.warning(
f"Csm does not support {criterion.__class__.__name__} stopping criteria, it will be ignored."
)
else:
kept_criteria.append(criterion)
return kept_criteria
def _prepare_generation_config(
self, generation_config: Optional[GenerationConfig], use_model_defaults: Optional[bool] = None, **kwargs: Dict
) -> Tuple[GenerationConfig, Dict]:
"""
This method overrides [~generation.utils.GenerationMixin._prepare_generation_config].
It ensures that the depth decoder generation config is initialized and that passed args as depth_decoder_* are properly handled.
"""
# extract depth decoder kwargs and remove them from the main kwargs
depth_decoder_kwargs = {
k[len("depth_decoder_") :]: v for k, v in kwargs.items() if k.startswith("depth_decoder_")
}
# remove the depth decoder keys from the original kwargs
kwargs = {k: v for k, v in kwargs.items() if not k.startswith("depth_decoder_")}
# initialize the generation config
generation_config, model_kwargs = super()._prepare_generation_config(
generation_config, use_model_defaults, **kwargs
)
self.depth_decoder.generation_config.update(**depth_decoder_kwargs)
# ensure the depth decoder generation config is valid
depth_decoder_min_new_tokens = getattr(self.depth_decoder.generation_config, "min_new_tokens") or (
self.config.num_codebooks - 1
)
depth_decoder_max_new_tokens = getattr(self.depth_decoder.generation_config, "max_new_tokens") or (
self.config.num_codebooks - 1
)
if {depth_decoder_min_new_tokens, depth_decoder_max_new_tokens} != {self.config.num_codebooks - 1}:
raise ValueError(
f"depth_decoder_generation_config's min_new_tokens ({depth_decoder_min_new_tokens}) and max_new_tokens ({depth_decoder_max_new_tokens}) must be equal to self.config.num_codebooks - 1 ({self.config.num_codebooks - 1})"
)
elif self.depth_decoder.generation_config.return_dict_in_generate:
logger.warning(
"depth_decoder_generation_config.return_dict_in_generate is set to True, but this will be ignored as the depth decoder model does not return a dictionary in generate"
)
self.depth_decoder.generation_config.return_dict_in_generate = False
self.depth_decoder.generation_config.min_new_tokens = depth_decoder_min_new_tokens
self.depth_decoder.generation_config.max_new_tokens = depth_decoder_max_new_tokens
# Monkey patch the get_generation_mode method to support CSM model
original_get_generation_mode = generation_config.get_generation_mode
def patched_get_generation_mode(assistant_model=None):
generation_mode = original_get_generation_mode(assistant_model)
if generation_mode not in [GenerationMode.GREEDY_SEARCH, GenerationMode.SAMPLE]:
raise ValueError(
f"Generation mode {generation_mode} is not supported for CSM model. Please set generation parameters to use greedy or sampling generation."
)
return generation_mode
generation_config.get_generation_mode = patched_get_generation_mode
return generation_config, model_kwargs
def _sample(
self,
input_ids: torch.LongTensor,
logits_processor: LogitsProcessorList,
stopping_criteria: StoppingCriteriaList,
generation_config: GenerationConfig,
synced_gpus: bool,
streamer: Optional["BaseStreamer"],
**model_kwargs,
) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
"""
This method overrides [~generation.utils.GenerationMixin._sample].
To ease maintenance, modifications are marked with the comment "Csm specific".
Indeed, Csm model requires a custom generation sampling step:
1. Infer the backbone model to sample the first codebook token
2. Call generate on the depth decoder with the first codebook token as input_ids to sample the next codebook tokens
3. Use these generated codebook tokens as input_ids to sample the next first codebook token using the backbone model
4. Repeat until stopping criteria is met
Csm supports two stopping criterias:
- stop when the generated sequence is at max_length
- stop when all the generated codebook tokens are the codebook_eos_token_id
"""
# init values
# *************** Csm specific ***************
pad_token_id = self.config.codebook_pad_token_id
has_eos_stopping_criteria = generation_config._eos_token_tensor is not None
# ============================================
output_attentions = generation_config.output_attentions
output_hidden_states = generation_config.output_hidden_states
output_scores = generation_config.output_scores
output_logits = generation_config.output_logits
return_dict_in_generate = generation_config.return_dict_in_generate
do_sample = generation_config.do_sample
# init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and output_scores) else None
raw_logits = () if (return_dict_in_generate and output_logits) else None
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
# keep track of which sequences are already finished
batch_size, cur_len = input_ids.shape[:2]
this_peer_finished = False
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs)
# *************** Csm specific ***************
if input_ids.ndim == 2 and model_kwargs.get("inputs_embeds") is None:
# in the case where the passed input_ids correspond to text tokens, i.e. don't have a third dimension for codebook ids,
# we need to remove the input length to the MaxLengthCriteria stopping criteria has such input are not returned
for criterion in stopping_criteria:
if isinstance(criterion, MaxLengthCriteria):
criterion.max_length -= cur_len
# ============================================
model_forward = self.__call__
compile_forward = self._valid_auto_compile_criteria(model_kwargs, generation_config)
if compile_forward:
os.environ["TOKENIZERS_PARALLELISM"] = "0"
model_forward = self.get_compiled_call(generation_config.compile_config)
is_prefill = True
while self._has_unfinished_sequences(
this_peer_finished,
synced_gpus,
device=input_ids.device,
):
# prepare model inputs
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
# prepare variable output controls (note: some models won't accept all output controls)
model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
# *************** Csm specific ***************
model_inputs.update({"output_hidden_states": True})
# ============================================
if is_prefill:
outputs = self(**model_inputs, return_dict=True)
is_prefill = False
else:
outputs = model_forward(**model_inputs, return_dict=True)
# synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
model_kwargs = self._update_model_kwargs_for_generation(
outputs,
model_kwargs,
)
if synced_gpus and this_peer_finished:
continue
# Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
# (the clone itself is always small)
next_token_logits = outputs.logits[:, -1, :].clone().float()
next_token_logits = next_token_logits.to(input_ids.device)
# pre-process distribution
next_token_scores = logits_processor(input_ids, next_token_logits)
# Store scores, attentions and hidden_states when required
if return_dict_in_generate:
if output_scores:
scores += (next_token_scores,)
if output_logits:
raw_logits += (next_token_logits,)
if output_attentions:
decoder_attentions += (outputs.attentions,)
if output_hidden_states:
decoder_hidden_states += (outputs.hidden_states,)
# token selection
if do_sample:
probs = nn.functional.softmax(next_token_scores, dim=-1)
# TODO (joao): this OP throws "skipping cudagraphs due to ['incompatible ops']", find solution
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
else:
next_tokens = torch.argmax(next_token_scores, dim=-1)
# *************** Csm specific ***************
# infer the depth decoder
first_codebook_ids = next_tokens[:, None]
# adds place holder in position 0 that will be replaced by the backbone_last_hidden_state
depth_decoder_input_ids = nn.functional.pad(first_codebook_ids, (1, 0), value=0)
backbone_last_hidden_state = outputs.hidden_states[-1][:, -1, :]
depth_decoder_outputs = self.depth_decoder.generate(
input_ids=depth_decoder_input_ids, backbone_last_hidden_state=backbone_last_hidden_state.clone()
)
codebook_ids = (
depth_decoder_outputs
if isinstance(depth_decoder_outputs, torch.Tensor)
else depth_decoder_outputs.sequences
)
# remove the place holder in position 0
codebook_ids = codebook_ids[:, 1:]
next_tokens = codebook_ids
# finished sentences should have their next token be a padding token
if has_eos_stopping_criteria:
next_tokens = next_tokens * unfinished_sequences.unsqueeze(-1) + pad_token_id * (
1 - unfinished_sequences.unsqueeze(-1)
)
# update generated ids, model inputs, and length for next step
if input_ids.ndim == 2:
input_ids = next_tokens[:, None, :]
else:
input_ids = torch.cat([input_ids, next_tokens[:, None, :]], dim=1)
# ============================================
if streamer is not None:
streamer.put(next_tokens.cpu())
# *************** Csm specific ***************
# for the eos stopping criteria, is it expected that the eos token is the same for each codebook !!!!
unfinished_sequences = unfinished_sequences & ~(
input_ids[:, -1, :-1] == self.config.codebook_eos_token_id
).all(-1)
# ============================================
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
this_peer_finished = unfinished_sequences.max() == 0
cur_len += 1
# This is needed to properly delete outputs.logits which may be very large for first iteration
# Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
del outputs
# *************** Csm specific ***************
del depth_decoder_outputs
# ============================================
if streamer is not None:
streamer.end()
if return_dict_in_generate:
return GenerateDecoderOnlyOutput(
sequences=input_ids,
scores=scores,
logits=raw_logits,
attentions=decoder_attentions,
hidden_states=decoder_hidden_states,
past_key_values=model_kwargs.get("past_key_values"),
)
else:
return input_ids
def generate(
self,
input_ids: Optional[torch.Tensor] = None,
input_values: Optional[torch.Tensor] = None,
input_values_cutoffs: Optional[torch.Tensor] = None,
generation_config: Optional[GenerationConfig] = None,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
synced_gpus: Optional[bool] = None,
streamer: Optional["BaseStreamer"] = None,
output_audio: Optional[bool] = False,
**kwargs,
) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
r"""
This method overrides [`~generation.utils.GenerationMixin.generate`] to match the specifics of the Csm model.
Indeed, Csm model requires a custom generation sampling step:
1. Infer the backbone model to sample the first codebook token
2. Call generate on the depth decoder with the first codebook token as `input_ids` to sample the next codebook tokens
3. Use these generated codebook tokens as `input_ids` to sample the next first codebook token using the backbone model
4. Repeat until stopping criteria is met
<Tip warning={true}>
Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the
model's default generation configuration. You can override any `generation_config` by passing the corresponding
parameters to generate(), e.g. `.generate(inputs, do_sample=True)`.
</Tip>
Parameters:
inputs_ids (`torch.Tensor` of shape (batch_size, seq_length), *optional*):
The sequence used as a prompt for the backbone model.
input_values (`torch.Tensor` of shape (batch_size, channels, max_concatenated_audio_length), *optional*):
The batched audio input values, where each batch entry contains the concatenation of all audio segments for that entry.
These values will be encoded into codebook tokens using the codec model and merged with the text input ids provided in `input_ids`.
input_values_cutoffs (`torch.Tensor` of shape (batch_size, max_num_audio), *optional*):
Specify the end positions of audio segments within each batch entry, relative to the concatenated audio input.
If a batch entry has fewer segments than the maximum, it is padded with -1. For example, in a batch of 2 sequences
where the first contains 2 audio segments of length l1, and the second contains 1 audio segment of length l2,
the input_values_cutoffs would be: [[l1, 2 * l1], [l2, -1]].
generation_config ([`~generation.GenerationConfig`], *optional*):
The generation configuration to be used as base parametrization for the generation call. `**kwargs`
passed to generate matching the attributes of `generation_config` will override them. If
`generation_config` is not provided, the default will be used, which has the following loading
priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
default values, whose documentation should be checked to parameterize generation.
logits_processor (`LogitsProcessorList`, *optional*):
Custom logits processors that complement the default logits processors built from arguments and
generation config. If a logit processor is passed that is already created with the arguments or a
generation config an error is thrown. This feature is intended for advanced users.
stopping_criteria (`StoppingCriteriaList`, *optional*):
Custom stopping criteria that complements the default stopping criteria built from arguments and a
generation config. If a stopping criteria is passed that is already created with the arguments or a
generation config an error is thrown. If your stopping criteria depends on the `scores` input, make
sure you pass `return_dict_in_generate=True, output_scores=True` to `generate`. This feature is
intended for advanced users.
synced_gpus (`bool`, *optional*):
Whether to continue running the while loop until max_length. Unless overridden, this flag will be set
to `True` if using `FullyShardedDataParallel` or DeepSpeed ZeRO Stage 3 with multiple GPUs to avoid
deadlocking if one GPU finishes generating before other GPUs. Otherwise, defaults to `False`.
streamer (`BaseStreamer`, *optional*):
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
output_audio (`bool`, *optional*):
Whether to return the generated audio.
kwargs (`Dict[str, Any]`, *optional*):
Ad hoc parametrization of `generation_config` and/or additional model-specific kwargs that will be
forwarded to the `forward` function of the model. Depth decoder specific kwargs should be prefixed with *depth_decoder_*.
Return:
[`CsmGenerateOutput`] or `torch.LongTensor` or `List[torch.FloatTensor]`: A [`CsmGenerateOutput`]
(if `return_dict_in_generate=True` or when `config.return_dict_in_generate=True`) or a `torch.LongTensor` when `output_audio=False`
or a `List[torch.FloatTensor]` otherwise.
Example:
```python
>>> from transformers import CsmProcessor, CsmForConditionalGeneration
>>> from datasets import load_dataset, Audio
>>> model_id = "eustlb/csm-1b"
>>> torch_device = "cuda" if torch.cuda.is_available() else "cpu"
>>> processor = AutoProcessor.from_pretrained(model_id)
>>> ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train")
>>> # ensure the audio is 24kHz
>>> ds = ds.cast_column("audio", Audio(sampling_rate=24000))
>>> conversation = []
>>> # prepare a conversation with text and corresponding audio
>>> for text, audio, speaker_id in zip(ds[:4]["text"], ds[:4]["audio"], ds[:4]["speaker_id"]):
... conversation.append(
... {
... "role": f"{speaker_id}",
... "content": [{"type": "text", "text": text}, {"type": "audio", "path": audio["array"]}],
... }
... )
>>> # text prompt
>>> conversation.append({"role": f"{ds[4]['speaker_id']}", "content": [{"type": "text", "text": ds[4]["text"]}]})
>>> inputs = processor.apply_chat_template(
... conversation,
... tokenize=True,
... return_dict=True,
... ).to(torch_device)
>>> model = CsmForConditionalGeneration.from_pretrained(model_id, device_map=torch_device)
>>> audio = model.generate(**inputs, output_audio=True)
>>> processor.save_audio(audio, "output.wav")
```
"""
generate_output = super().generate(
input_ids=input_ids,
input_values=input_values,
input_values_cutoffs=input_values_cutoffs,
generation_config=generation_config,
logits_processor=logits_processor,
stopping_criteria=stopping_criteria,
synced_gpus=synced_gpus,
streamer=streamer,
**kwargs,
)
generate_returned_dict = not isinstance(generate_output, torch.Tensor)
audio = None
if output_audio:
generated_audio_codes = generate_output.sequences if generate_returned_dict else generate_output
# infer the codec model
audio = []
with torch.no_grad():
# =======================================
# TODO: @eustlb, this should be batched !!!
# but requires making sure batched inference of the codec model works as intended
for audio_codes_batch in generated_audio_codes:
eos_idxs = (audio_codes_batch == self.config.codebook_eos_token_id).all(dim=-1).nonzero()
if eos_idxs.numel() != 0:
cutoff_idx = eos_idxs.min()
else:
cutoff_idx = audio_codes_batch.shape[1]
audio_codes_batch = audio_codes_batch[:cutoff_idx]
codec_decode_output = self.codec_model.decode(audio_codes_batch.transpose(0, 1).unsqueeze(0))
audio.append(codec_decode_output.audio_values[0, 0])
# =======================================
if generate_returned_dict:
return CsmGenerateOutput(audio=audio, **generate_output)
elif output_audio:
return audio
else:
return generate_output

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,364 @@
# coding=utf-8
# Copyright 2025 Sesame and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
import numpy as np
from ...utils import is_soundfile_available, is_torch_available
if is_torch_available():
import torch
if is_soundfile_available():
import soundfile as sf
from ...audio_utils import AudioInput, make_list_of_audio
from ...feature_extraction_utils import BatchFeature
from ...processing_utils import AudioKwargs, ProcessingKwargs, ProcessorMixin, Unpack
from ...tokenization_utils_base import (
PreTokenizedInput,
TextInput,
)
class CsmAudioKwargs(AudioKwargs, total=False):
encoded_length_kwargs: Optional[Dict[str, Any]]
class CsmProcessorKwargs(ProcessingKwargs, total=False):
audio_kwargs: CsmAudioKwargs
_defaults = {
"text_kwargs": {
"padding": True,
"padding_side": "left",
"add_special_tokens": False,
},
"audio_kwargs": {
"encoded_length_kwargs": {
"kernel_sizes": [7, 3, 1, 8, 3, 1, 10, 3, 1, 12, 3, 1, 16, 3, 4],
"strides": [1, 1, 1, 4, 1, 1, 5, 1, 1, 6, 1, 1, 8, 1, 2],
"dilations": [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
"use_causal_conv": True,
},
"sampling_rate": 24000,
},
"common_kwargs": {"return_tensors": "pt"},
}
class CsmProcessor(ProcessorMixin):
r"""
Constructs a Csm processor which wraps [`EncodecFeatureExtractor`] and
[`PretrainedTokenizerFast`] into a single processor that inherits both the audio feature extraction and
tokenizer functionalities. See the [`~CsmProcessor.__call__`] for more
information.
The preferred way of passing kwargs is as a dictionary per modality, see usage example below.
```python
from transformers import CsmProcessor
from datasets import load_dataset
ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train")
audio = ds[0]["audio"]["array"]
processor = CsmProcessor.from_pretrained("eustlb/csm-1b")
processor(
text=["<|begin_of_text|>[0]What are you working on?<|end_of_text|><|AUDIO|><|audio_eos|><|begin_of_text|>[1]I'm figuring out my budget.<|end_of_text|>"],
audio=audio,
text_kwargs = {"padding": False},
audio_kwargs = {"sampling_rate": 16000},
common_kwargs = {"return_tensors": "pt"},
)
# this should error out because EncodecFeatureExtractor expects a 24kHz audio :)
```
Args:
feature_extractor ([`EncodecFeatureExtractor`]):
The feature extractor is a required input.
tokenizer ([`PreTrainedTokenizer`, `PreTrainedTokenizerFast`]):
The tokenizer is a required input.
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
in a chat into a tokenizable string.
"""
attributes = ["feature_extractor", "tokenizer"]
valid_kwargs = ["chat_template"]
feature_extractor_class = "EncodecFeatureExtractor"
tokenizer_class = "PreTrainedTokenizerFast"
def __init__(
self,
feature_extractor,
tokenizer,
chat_template=None,
):
if not hasattr(tokenizer, "audio_token"):
self.audio_token = "<|AUDIO|>"
self.audio_token_id = tokenizer.convert_tokens_to_ids(self.audio_token)
else:
self.audio_token = tokenizer.audio_token
self.audio_token_id = tokenizer.audio_token_id
if not hasattr(tokenizer, "audio_eos_token"):
self.audio_eos_token = "<|audio_eos|>"
self.audio_eos_token_id = tokenizer.convert_tokens_to_ids(self.audio_eos_token)
else:
self.audio_eos_token = tokenizer.audio_eos_token
self.audio_eos_token_id = tokenizer.audio_eos_token_id
super().__init__(feature_extractor, tokenizer, chat_template=chat_template)
@staticmethod
def _get_encoded_length(audio_length, kernel_sizes=None, strides=None, dilations=None, use_causal_conv=None):
"""
Compute the length of the encoded audio sequence.
Args:
audio_length (int): The length of the audio sequence.
kernel_sizes (List[int]): The kernel sizes for the convolutional layers.
strides (List[int]): The strides for the convolutional layers.
use_causal_conv (bool): Whether to use causal convolutions.
"""
cur_length = audio_length
if kernel_sizes is None or strides is None or dilations is None or use_causal_conv is None:
return cur_length
for kernel_size, stride, dilation in zip(kernel_sizes, strides, dilations):
effective_kernel_size = (kernel_size - 1) * dilation + 1
padding_total = kernel_size - stride
padding_right = padding_total // 2
padding_left = padding_total - padding_right
n_frames = (cur_length - effective_kernel_size + padding_total) / stride + 1
n_frames = math.ceil(n_frames) - 1
ideal_length = n_frames * stride + kernel_size - padding_total
extra_padding = ideal_length - cur_length
if use_causal_conv:
padding_left = padding_total
padding_right = extra_padding
else:
padding_left = padding_left
padding_right = padding_right + extra_padding
cur_length = cur_length + padding_left + padding_right
cur_length = (cur_length - dilation * (kernel_size - 1) - 1) // stride + 1
return cur_length
def save_audio(
self,
audio: AudioInput,
saving_path: Union[str, Path, List[Union[str, Path]]],
**kwargs: Unpack[CsmProcessorKwargs],
):
# TODO: @eustlb, this should be in AudioProcessor
if not is_soundfile_available():
raise ImportError("Please install `soundfile` to save audio files.")
# ensure correct audio input
audio = make_list_of_audio(audio)
# ensure correct saving path
if isinstance(saving_path, (str, Path)):
saving_path = [saving_path]
elif not (isinstance(saving_path, (list, tuple)) and all(isinstance(p, (str, Path)) for p in saving_path)):
raise ValueError("Invalid input path. Please provide a string, or a list of strings")
if len(audio) != len(saving_path):
raise ValueError("The number of audio and saving paths must be the same")
output_kwargs = self._merge_kwargs(
CsmProcessorKwargs,
**kwargs,
)
audio_kwargs = output_kwargs["audio_kwargs"]
sampling_rate = audio_kwargs["sampling_rate"]
for audio_value, p in zip(audio, saving_path):
if isinstance(audio_value, torch.Tensor):
audio_value = audio_value.cpu().float().numpy()
sf.write(p, audio_value, sampling_rate)
def __call__(
self,
text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]],
audio: Optional[AudioInput] = None,
output_labels: Optional[bool] = False,
depth_decoder_labels_ratio: Optional[float] = 1.0,
**kwargs: Unpack[CsmProcessorKwargs],
):
r"""
Main method to prepare text(s) and audio to be fed as input to the model. This method forwards the `text`
arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizerFast.__call__`] to encode
the text. To prepare the audio, this method forwards the `audio` arguments to
EncodecFeatureExtractor's [`~EncodecFeatureExtractor.__call__`]. Please refer
to the docstring of the above two methods for more information.
Args:
audio (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`):
The audio or batch of audio to be prepared. Each audio can be a NumPy array or PyTorch
tensor.
text (`str`, `List[str]`, `List[List[str]]`):
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
output_labels (bool, *optional*, default=False):
Whether to return labels for training. Indices will be in `[config.audio_token_id, -100, -101]`.
- `config.audio_token_id` indicates an audio frame (considering sequence length elements as frames)
- `-100` will be ignored in the loss computation
- `-101` indicates the audio frame will be used only for the backbone model (using the first codebook token as labels)
depth_decoder_labels_ratio (float, *optional*, default=1.0):
The ratio of audio frames to keep for the depth decoder labels.
return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors of a particular framework. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
- **input_values** -- List of audio values to be fed to a model. Returned when `audio` is not `None`.
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
`None`).
- **labels** -- List of labels for the audio frames. Returned when `output_labels=True`.
"""
output_kwargs = self._merge_kwargs(
CsmProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
text_kwargs = output_kwargs["text_kwargs"]
audio_kwargs = output_kwargs["audio_kwargs"]
common_kwargs = output_kwargs["common_kwargs"]
return_tensors = common_kwargs.pop("return_tensors", None)
if return_tensors != "pt":
raise ValueError(f"{self.__class__.__name__} only supports `return_tensors='pt'`.")
if isinstance(text, str):
text = [text]
elif not (isinstance(text, (list, tuple)) and all(isinstance(t, str) for t in text)):
raise ValueError("Invalid input text. Please provide a string, or a list of strings")
n_audio_in_text = [t.count(self.audio_token) for t in text]
n_audio = 0
if audio is not None:
audio = make_list_of_audio(audio)
n_audio = len(audio)
if sum(n_audio_in_text) > 0 and n_audio != sum(n_audio_in_text):
if audio is None:
raise ValueError("No audio were provided, but there are audio tokens in the prompt")
else:
raise ValueError(
f"The number of audio tokens in each text ({n_audio_in_text}) should be the same as the "
f"number of provided audios ({n_audio})."
)
if audio is not None:
encoded_length_kwargs = audio_kwargs.pop("encoded_length_kwargs", {})
num_audio_tokens_list = [
self._get_encoded_length(audio_array.shape[-1], **encoded_length_kwargs) for audio_array in audio
]
num_audio_tokens_list_copy = num_audio_tokens_list.copy()
# expand the text to repeat the audio token for the corresponding number of frames
expanded_text = []
for sample in text:
replace_str = []
while self.audio_token in sample:
num_audio_tokens = num_audio_tokens_list_copy.pop(0)
expanded_audio_token = self.audio_token * num_audio_tokens
replace_str.append(expanded_audio_token)
sample = sample.replace(self.audio_token, "<placeholder>", 1)
while "<placeholder>" in sample:
sample = sample.replace("<placeholder>", replace_str.pop(0), 1)
expanded_text.append(sample)
text = expanded_text
encoding = self.tokenizer(text, **text_kwargs)
data = {}
data.update(encoding)
if audio is not None:
audio_kwargs.pop("return_attention_mask", None) # not supported by the feature extractor
concatenated_audio, input_values_cutoffs = [], []
offset = 0
for n_audio in n_audio_in_text:
if n_audio == 0:
concatenated_audio.append(np.zeros(0))
input_values_cutoffs.append(torch.tensor([-1]))
else:
concatenated_audio.append(
np.concatenate(
[
el.cpu().numpy() if isinstance(el, torch.Tensor) else el
for el in audio[offset : offset + n_audio]
],
axis=-1,
)
)
input_values_cutoffs.append(
torch.tensor([el.shape[-1] for el in audio[offset : offset + n_audio]]).cumsum(dim=-1)
)
offset += n_audio
audio_inputs = self.feature_extractor(concatenated_audio, **audio_kwargs)
audio_inputs.pop("padding_mask", None) # not applicable here
data.update(audio_inputs)
# pad and stack the audio cut idxs
max_len = max(cut_idxs.shape[-1] for cut_idxs in input_values_cutoffs)
input_values_cutoffs = [
torch.nn.functional.pad(cut_idxs, (0, max_len - cut_idxs.shape[-1]), value=-1)
for cut_idxs in input_values_cutoffs
]
data["input_values_cutoffs"] = torch.stack(input_values_cutoffs, dim=0)
if output_labels:
audio_frame_idxs = (data["input_ids"] == self.audio_token_id).nonzero()
n_audio_frames = audio_frame_idxs.shape[0]
if depth_decoder_labels_ratio <= 1.0:
rand_idxs = torch.randperm(n_audio_frames)[: int(n_audio_frames * (1 - depth_decoder_labels_ratio))]
skip_frames_idxs = audio_frame_idxs[rand_idxs]
else:
skip_frames_idxs = audio_frame_idxs
labels = torch.where(data["input_ids"] == self.audio_token_id, data["input_ids"], -100)
labels[skip_frames_idxs[:, 0], skip_frames_idxs[:, 1]] = -101
data["labels"] = labels
return BatchFeature(data=data, tensor_type=return_tensors)
__all__ = ["CsmProcessor"]

View File

@ -1111,7 +1111,7 @@ class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel, GenerationMixin):
)
return model_inputs
def _get_initial_cache_position(self, input_ids, model_kwargs):
def _get_initial_cache_position(self, seq_length, device, model_kwargs):
"""
Calculates `cache_position` for the pre-fill stage based on `input_ids` and optionally past length.
Since gpt bigcode is special, the method is overridden here, other models use it from `generation.utils.py`.
@ -1125,8 +1125,8 @@ class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel, GenerationMixin):
if "inputs_embeds" in model_kwargs:
cur_len = model_kwargs["inputs_embeds"].shape[1]
else:
cur_len = input_ids.shape[-1]
model_kwargs["cache_position"] = torch.arange(past_length, cur_len, device=input_ids.device)
cur_len = seq_length
model_kwargs["cache_position"] = torch.arange(past_length, cur_len, device=device)
return model_kwargs
@add_start_docstrings_to_model_forward(GPT_BIGCODE_INPUTS_DOCSTRING)

View File

@ -1563,7 +1563,7 @@ class JanusForConditionalGeneration(JanusPreTrainedModel, GenerationMixin):
inputs_embeds = self.get_input_embeddings()(input_tokens)
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
model_kwargs = self._get_initial_cache_position(seq_len, device, model_kwargs)
if model_kwargs.get("past_key_values", None) is None:
# Prepare cache if not provided.

View File

@ -1378,7 +1378,7 @@ class JanusForConditionalGeneration(JanusPreTrainedModel, GenerationMixin):
inputs_embeds = self.get_input_embeddings()(input_tokens)
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
model_kwargs = self._get_initial_cache_position(seq_len, device, model_kwargs)
if model_kwargs.get("past_key_values", None) is None:
# Prepare cache if not provided.

View File

@ -216,6 +216,32 @@ class MimiConv1d(nn.Module):
end = padded.shape[-1] - extra_pad
return padded[..., :end]
def _get_output_length(self, input_length: torch.LongTensor) -> torch.LongTensor:
"""
Return the length of the output of the MimiConv1d.
"""
# padding size
n_frames = (input_length - self.kernel_size + self.padding_total) / self.stride + 1
n_frames = torch.ceil(n_frames).to(torch.int64) - 1
ideal_length = n_frames * self.stride + self.kernel_size - self.padding_total
extra_padding = ideal_length - input_length
if self.causal:
padding_left = self.padding_total
padding_right = extra_padding
else:
padding_left = self.padding_left
padding_right = self.padding_right + extra_padding
# padding
input_length = input_length + padding_left + padding_right
# conv
output_lenght = (
input_length + 2 * self.conv.padding[0] - self.conv.dilation[0] * (self.conv.kernel_size[0] - 1) - 1
) // self.conv.stride[0] + 1
return output_lenght
def forward(self, hidden_states):
extra_padding = self._get_extra_padding_for_conv1d(hidden_states)
@ -331,21 +357,28 @@ class MimiEncoder(nn.Module):
model = [MimiConv1d(config, config.audio_channels, config.num_filters, config.kernel_size)]
scaling = 1
# keep track of MimiConv1d submodule layer names for easy encoded length computation
mimiconv1d_layer_names = ["layers.0"]
# Downsample to raw audio scale
for ratio in reversed(config.upsampling_ratios):
current_scale = scaling * config.num_filters
# Add residual layers
for j in range(config.num_residual_layers):
mimiconv1d_layer_names.extend([f"layers.{len(model)}.block.1", f"layers.{len(model)}.block.3"])
model += [MimiResnetBlock(config, current_scale, [config.dilation_growth_rate**j, 1])]
# Add downsampling layers
model += [nn.ELU()]
mimiconv1d_layer_names.append(f"layers.{len(model)}")
model += [MimiConv1d(config, current_scale, current_scale * 2, kernel_size=ratio * 2, stride=ratio)]
scaling *= 2
model += [nn.ELU()]
mimiconv1d_layer_names.append(f"layers.{len(model)}")
model += [MimiConv1d(config, scaling * config.num_filters, config.hidden_size, config.last_kernel_size)]
self.layers = nn.ModuleList(model)
self._mimiconv1d_layer_names = mimiconv1d_layer_names
# Copied from transformers.models.encodec.modeling_encodec.EncodecEncoder.forward
def forward(self, hidden_states):
@ -1567,6 +1600,38 @@ class MimiModel(MimiPreTrainedModel):
codes = codes.transpose(0, 1)
return codes, past_key_values
def get_encoded_length(self, input_length: torch.LongTensor) -> torch.LongTensor:
"""
Return the number of frames of the encoded audio waveform.
"""
output_length = input_length
# encoder
for layer_name in self.encoder._mimiconv1d_layer_names:
output_length = self.encoder.get_submodule(layer_name)._get_output_length(output_length)
# downsample
output_length = self.downsample._get_output_length(output_length)
return output_length
def get_audio_codes_mask(self, padding_mask: torch.Tensor, padding_side: str = "right"):
"""
Get the mask for the audio codes from the original padding mask.
"""
encoded_lengths = self.get_encoded_length(padding_mask.sum(dim=-1))
audio_codes_mask = torch.arange(encoded_lengths.max(), device=encoded_lengths.device).expand(
len(encoded_lengths), -1
)
audio_codes_mask = audio_codes_mask < encoded_lengths.unsqueeze(1)
audio_codes_mask = audio_codes_mask.to(padding_mask.device)
if padding_side == "right":
return audio_codes_mask
else:
return audio_codes_mask.flip(dims=[-1])
def encode(
self,
input_values: torch.Tensor,

View File

@ -3084,10 +3084,10 @@ class Qwen2_5OmniTalkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForCon
thinker_reply_part=thinker_reply_part,
)
def _get_initial_cache_position(self, input_ids, model_kwargs):
def _get_initial_cache_position(self, seq_length, device, model_kwargs):
# Talker needs to calculate cache_position with input_ids, so pop inputs_embeds temporarily
inputs_embeds = model_kwargs.pop("inputs_embeds")
model_kwargs = super()._get_initial_cache_position(input_ids, model_kwargs)
model_kwargs = super()._get_initial_cache_position(seq_length, device, model_kwargs)
model_kwargs["inputs_embeds"] = inputs_embeds
return model_kwargs

View File

@ -2771,10 +2771,10 @@ class Qwen2_5OmniTalkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForCon
thinker_reply_part=thinker_reply_part,
)
def _get_initial_cache_position(self, input_ids, model_kwargs):
def _get_initial_cache_position(self, seq_length, device, model_kwargs):
# Talker needs to calculate cache_position with input_ids, so pop inputs_embeds temporarily
inputs_embeds = model_kwargs.pop("inputs_embeds")
model_kwargs = super()._get_initial_cache_position(input_ids, model_kwargs)
model_kwargs = super()._get_initial_cache_position(seq_length, device, model_kwargs)
model_kwargs["inputs_embeds"] = inputs_embeds
return model_kwargs

View File

@ -1058,7 +1058,7 @@ class ProcessorMixin(PushToHubMixin):
# update defaults with arguments from tokenizer init
for modality_key in ModelProcessorKwargs.__annotations__[modality].__annotations__.keys():
# init with tokenizer init kwargs if necessary
if modality_key in tokenizer_init_kwargs:
if tokenizer_init_kwargs is not None and modality_key in tokenizer_init_kwargs:
value = (
getattr(self.tokenizer, modality_key)
if hasattr(self.tokenizer, modality_key)

View File

@ -501,9 +501,9 @@ class GenerationTesterMixin:
output_generate = self._greedy_generate(model=model, inputs_dict=inputs_dict)
if model.config.is_encoder_decoder:
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + 1)
else:
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1])
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1])
@pytest.mark.generate
def test_greedy_generate_dict_outputs(self):
@ -525,13 +525,13 @@ class GenerationTesterMixin:
)
if model.config.is_encoder_decoder:
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
self.assertTrue(output_generate.sequences.shape[1] == self.max_new_tokens + 1)
self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
# Retrocompatibility check
self.assertIsInstance(output_generate, GreedySearchEncoderDecoderOutput)
else:
self.assertTrue(
output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]
output_generate.sequences.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1]
)
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
# Retrocompatibility check
@ -565,10 +565,10 @@ class GenerationTesterMixin:
)
if model.config.is_encoder_decoder:
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
self.assertTrue(output_generate.sequences.shape[1] == self.max_new_tokens + 1)
else:
self.assertTrue(
output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]
output_generate.sequences.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1]
)
self._check_generate_outputs(output_generate, model.config, use_cache=True)
@ -582,9 +582,9 @@ class GenerationTesterMixin:
output_generate = self._sample_generate(model=model, inputs_dict=inputs_dict, num_return_sequences=1)
if model.config.is_encoder_decoder:
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + 1)
else:
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1])
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1])
@pytest.mark.generate
def test_sample_generate_dict_output(self):
@ -607,13 +607,13 @@ class GenerationTesterMixin:
)
if model.config.is_encoder_decoder:
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
self.assertTrue(output_generate.sequences.shape[1] == self.max_new_tokens + 1)
self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
# Retrocompatibility check
self.assertIsInstance(output_generate, SampleEncoderDecoderOutput)
else:
self.assertTrue(
output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]
output_generate.sequences.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1]
)
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
# Retrocompatibility check
@ -632,9 +632,9 @@ class GenerationTesterMixin:
output_generate = self._beam_search_generate(model=model, inputs_dict=inputs_dict, beam_kwargs=beam_kwargs)
if model.config.is_encoder_decoder:
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + 1)
else:
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1])
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1])
@pytest.mark.generate
def test_beam_search_generate_dict_output(self):
@ -657,13 +657,13 @@ class GenerationTesterMixin:
use_cache=False,
)
if model.config.is_encoder_decoder:
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
self.assertTrue(output_generate.sequences.shape[1] == self.max_new_tokens + 1)
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
# Retrocompatibility check
self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput)
else:
self.assertTrue(
output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]
output_generate.sequences.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1]
)
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
# Retrocompatibility check
@ -706,10 +706,10 @@ class GenerationTesterMixin:
)
if model.config.is_encoder_decoder:
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
self.assertTrue(output_generate.sequences.shape[1] == self.max_new_tokens + 1)
else:
self.assertTrue(
output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]
output_generate.sequences.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1]
)
self._check_generate_outputs(
@ -759,9 +759,9 @@ class GenerationTesterMixin:
)
if model.config.is_encoder_decoder:
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + 1)
else:
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1])
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1])
@pytest.mark.generate
def test_beam_sample_generate_dict_output(self):
@ -786,13 +786,13 @@ class GenerationTesterMixin:
)
if model.config.is_encoder_decoder:
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
self.assertTrue(output_generate.sequences.shape[1] == self.max_new_tokens + 1)
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
# Retrocompatibility check
self.assertIsInstance(output_generate, BeamSampleEncoderDecoderOutput)
else:
self.assertTrue(
output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]
output_generate.sequences.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1]
)
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
# Retrocompatibility check
@ -840,9 +840,9 @@ class GenerationTesterMixin:
beam_kwargs=beam_kwargs,
)
if model.config.is_encoder_decoder:
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + 1)
else:
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1])
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1])
# check `group_beam_search` for higher than 1 `num_return_sequences`
num_return_sequences = 2
@ -853,9 +853,9 @@ class GenerationTesterMixin:
beam_kwargs=beam_kwargs,
)
if model.config.is_encoder_decoder:
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + 1)
else:
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1])
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1])
@pytest.mark.generate
def test_group_beam_search_generate_dict_output(self):
@ -878,13 +878,13 @@ class GenerationTesterMixin:
use_cache=False,
)
if model.config.is_encoder_decoder:
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
self.assertTrue(output_generate.sequences.shape[1] == self.max_new_tokens + 1)
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
# Retrocompatibility check
self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput)
else:
self.assertTrue(
output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]
output_generate.sequences.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1]
)
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
# Retrocompatibility check
@ -923,9 +923,9 @@ class GenerationTesterMixin:
)
if model.config.is_encoder_decoder:
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + 1)
else:
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1])
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1])
for generation_output in output_generate:
self._check_sequence_inside_sequence(force_tokens, generation_output)
@ -947,9 +947,9 @@ class GenerationTesterMixin:
)
if model.config.is_encoder_decoder:
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + 1)
else:
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1])
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1])
for generation_output in output_generate:
self._check_sequence_inside_sequence(force_tokens, generation_output)
@ -987,13 +987,13 @@ class GenerationTesterMixin:
)
if model.config.is_encoder_decoder:
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
self.assertTrue(output_generate.sequences.shape[1] == self.max_new_tokens + 1)
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
# Retrocompatibility check
self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput)
else:
self.assertTrue(
output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]
output_generate.sequences.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1]
)
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
# Retrocompatibility check
@ -1031,9 +1031,9 @@ class GenerationTesterMixin:
use_cache=True, # Enable cache
)
if model.config.is_encoder_decoder:
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + 1)
else:
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1])
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1])
@pytest.mark.generate
def test_contrastive_generate_dict_outputs_use_cache(self):
@ -1067,10 +1067,10 @@ class GenerationTesterMixin:
)
if model.config.is_encoder_decoder:
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
self.assertTrue(output_generate.sequences.shape[1] == self.max_new_tokens + 1)
else:
self.assertTrue(
output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]
output_generate.sequences.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1]
)
self._check_generate_outputs(output_generate, model.config, use_cache=True)
@ -1499,7 +1499,7 @@ class GenerationTesterMixin:
position_ids.masked_fill_(attention_mask == 0, 1)
model_kwargs["position_ids"] = position_ids
if "cache_position" in signature:
cache_position = torch.arange(input_ids.shape[-1], device=torch_device)
cache_position = torch.arange(input_ids.shape[1], device=torch_device)
model_kwargs["cache_position"] = cache_position
return model_kwargs
@ -1525,10 +1525,12 @@ class GenerationTesterMixin:
pad_token_id = (
config.get_text_config().pad_token_id if config.get_text_config().pad_token_id is not None else 0
)
pad_size = (input_ids.shape[0], 32)
pad_size = (input_ids.shape[0], 32, *input_ids.shape[2:])
padding = torch.ones(pad_size, dtype=input_ids.dtype, device=torch_device) * pad_token_id
padded_input_ids = torch.cat((padding, input_ids), dim=1)
padded_attention_mask = torch.cat((torch.zeros_like(padding), attention_mask), dim=1)
padded_attention_mask = torch.cat(
(torch.zeros(pad_size[:2], dtype=input_ids.dtype, device=torch_device), attention_mask), dim=1
)
model_kwargs = _prepare_model_kwargs(padded_input_ids, padded_attention_mask, signature)
next_logits_with_padding = model(**model_kwargs).logits[:, -1, :]
@ -1587,7 +1589,7 @@ class GenerationTesterMixin:
else text_config.num_attention_heads
)
encoder_per_head_embed_dim = embed_dim // encoder_num_attention_heads
batch_size, seq_length = inputs["decoder_input_ids"].shape
batch_size, seq_length = inputs["decoder_input_ids"].shape[:2]
# The sequence length for the encoder K V depends on the model. Since it is not manipulated in
# autoregressive generation, we're keeping the test general and not checking the 3rd dim
default_cross_attention_shape = (
@ -1606,7 +1608,7 @@ class GenerationTesterMixin:
for _ in range(num_decoder_layers)
]
else:
batch_size, seq_length = inputs["input_ids"].shape
batch_size, seq_length = inputs["input_ids"].shape[:2]
default_self_attention_shape = (batch_size, num_key_value_heads, seq_length, per_head_embed_dim)
all_cache_shapes = [
[default_self_attention_shape, default_self_attention_shape] for _ in range(num_decoder_layers)
@ -1727,7 +1729,7 @@ class GenerationTesterMixin:
"min_new_tokens": 5, # generate exactly 5 tokens
}
outputs_from_ids = model.generate(input_ids, **generation_kwargs, **inputs_dict)
self.assertEqual(outputs_from_ids.sequences.shape, (input_ids.shape[0], input_ids.shape[1] + 5))
self.assertEqual(outputs_from_ids.sequences.shape[:2], (input_ids.shape[0], input_ids.shape[1] + 5))
# Same thing, but from input embeddings (`input_ids` is passed so the prompt is present in the output).
# The output of the two calls should be the same.
@ -2262,11 +2264,11 @@ class GenerationTesterMixin:
self.assertTrue(hasattr(model, "_compiled_call")) # our auto compile should have been called
if model.config.is_encoder_decoder:
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
self.assertTrue(output_generate.sequences.shape[1] == self.max_new_tokens + 1)
self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
else:
self.assertTrue(
output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]
output_generate.sequences.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1]
)
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
@ -2408,7 +2410,7 @@ class GenerationTesterMixin:
config = config.text_config if hasattr(config, "text_config") else config
generated_length = (
output.sequences.shape[-1] - 1 if config.is_encoder_decoder else output.sequences.shape[-1] - prompt_length
output.sequences.shape[1] - 1 if config.is_encoder_decoder else output.sequences.shape[1] - prompt_length
)
decoder_past_key_values = getattr(output, "past_key_values", None)
if config.is_encoder_decoder and isinstance(decoder_past_key_values, EncoderDecoderCache):
@ -2441,7 +2443,7 @@ class GenerationTesterMixin:
batch_size=internal_batch_size,
attentions=output.decoder_attentions,
prompt_length=1, # the BOS token
output_length=output.sequences.shape[-1],
output_length=output.sequences.shape[1],
config=config,
decoder_past_key_values=decoder_past_key_values,
)
@ -2450,7 +2452,7 @@ class GenerationTesterMixin:
batch_size=internal_batch_size,
attentions=output.attentions,
prompt_length=prompt_length,
output_length=output.sequences.shape[-1],
output_length=output.sequences.shape[1],
config=config,
decoder_past_key_values=decoder_past_key_values,
)
@ -2469,7 +2471,7 @@ class GenerationTesterMixin:
batch_size=internal_batch_size,
hidden_states=output.decoder_hidden_states,
prompt_length=1, # the BOS token
output_length=output.sequences.shape[-1],
output_length=output.sequences.shape[1],
config=config,
use_cache=use_cache,
)
@ -2478,7 +2480,7 @@ class GenerationTesterMixin:
batch_size=internal_batch_size,
hidden_states=output.hidden_states,
prompt_length=prompt_length,
output_length=output.sequences.shape[-1],
output_length=output.sequences.shape[1],
config=config,
use_cache=use_cache,
)
@ -2506,7 +2508,7 @@ class GenerationTesterMixin:
)
if has_standard_cache:
if use_cache:
cache_length = output.sequences.shape[-1] - 1
cache_length = output.sequences.shape[1] - 1
self._check_past_key_values_for_generate(
batch_size=internal_batch_size,
decoder_past_key_values=decoder_past_key_values,

View File

View File

@ -0,0 +1,693 @@
# coding=utf-8
# Copyright 2024, The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Testing suite for the PyTorch ConversationalSpeechModel model."""
import collections
import copy
import re
import unittest
import pytest
from parameterized import parameterized
from transformers import (
AutoProcessor,
CsmConfig,
CsmForConditionalGeneration,
is_torch_available,
)
from transformers.testing_utils import (
cleanup,
require_torch_gpu,
slow,
torch_device,
)
from transformers.utils.import_utils import is_datasets_available
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import (
ModelTesterMixin,
_config_zero_init,
ids_tensor,
)
if is_datasets_available():
from datasets import load_dataset
if is_torch_available():
import torch
from transformers.pytorch_utils import id_tensor_storage
class CsmModelTester:
def __init__(
self,
parent,
ignore_index=-100,
batch_size=3,
seq_length=7,
is_training=True,
depth_decoder_config={
"num_codebooks": 10,
"backbone_hidden_size": 64,
"vocab_size": 6,
"hidden_size": 64,
"intermediate_size": 128,
"num_hidden_layers": 2,
"num_attention_heads": 4,
"num_key_value_heads": 2,
"hidden_act": "silu",
"max_position_embeddings": 10,
},
codec_config={
"model_type": "mimi",
"audio_channels": 1,
"chunk_in_sec": None,
"hidden_size": 32,
"num_filters": 8,
"num_residual_layers": 1,
"upsampling_ratios": [8, 4],
"codebook_size": 64,
"vector_quantization_hidden_dimension": 64,
"upsample_groups": 32,
"num_hidden_layers": 2,
"num_attention_heads": 2,
"num_key_value_heads": 2,
"sliding_window": 4,
"codebook_dim": 64,
"use_cache": False,
},
config={
"num_codebooks": 10,
"vocab_size": 6,
"text_vocab_size": 99,
"hidden_size": 64,
"intermediate_size": 64,
"num_hidden_layers": 2,
"num_attention_heads": 4,
"num_key_value_heads": 2,
"hidden_act": "silu",
"max_position_embeddings": 10,
"bos_token_id": 1,
"pad_token_id": 2,
"eos_token_id": 3,
"codebook_pad_token_id": 2,
"codebook_eos_token_id": 3,
},
):
self.parent = parent
self.is_training = is_training
self.ignore_index = ignore_index
self.depth_decoder_config = depth_decoder_config
self.codec_config = codec_config
self.config = config
self.seq_length = seq_length
self.batch_size = batch_size
self.num_hidden_layers = config["num_hidden_layers"]
self.vocab_size = config["vocab_size"]
self.hidden_size = config["hidden_size"]
self.num_attention_heads = config["num_attention_heads"]
self.pad_token_id = config["pad_token_id"]
def get_config(self):
return CsmConfig(
depth_decoder_config=self.depth_decoder_config,
codec_config=self.codec_config,
**self.config,
)
def prepare_config_and_inputs(self):
config = self.get_config()
input_ids = ids_tensor([self.batch_size, self.seq_length, config.num_codebooks], config.vocab_size - 1) + 1
attention_mask = input_ids[..., -1].ne(1).to(torch_device)
return config, input_ids, attention_mask
def prepare_config_and_inputs_for_common(self):
config, input_ids, attention_mask = self.prepare_config_and_inputs()
inputs_dict = {"input_ids": input_ids, "attention_mask": attention_mask}
return config, inputs_dict
class CsmForConditionalGenerationTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (CsmForConditionalGeneration,) if is_torch_available() else ()
test_pruning = False
test_headmasking = False
test_resize_embeddings = False
test_resize_embeddings_untied = False
test_torch_exportable = True
def setUp(self):
self.model_tester = CsmModelTester(self)
self.config_tester = ConfigTester(self, config_class=CsmConfig)
def test_config(self):
self.config_tester.run_common_tests()
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
"""
Overrides [ModelTesterMixin._prepare_for_class] to handle third input_ids dimension.
"""
inputs_dict = copy.deepcopy(inputs_dict)
if return_labels:
inputs_dict["labels"] = torch.zeros(
(
self.model_tester.batch_size,
self.model_tester.seq_length,
self.model_tester.config["num_codebooks"],
),
dtype=torch.long,
device=torch_device,
)
return inputs_dict
def _get_logits_processor_kwargs(self, do_sample=False, config=None):
"""
Overrides [GenerationTesterMixin._get_logits_processor_kwargs] to restrict to top_k, top_p, and temperature sampling.
"""
logits_processor_kwargs = {}
if do_sample:
logits_processor_kwargs.update(
{
"top_k": 10,
"top_p": 0.7,
"temperature": 0.7,
}
)
return logits_processor_kwargs
def test_initialization(self):
"""
Overrides [ModelTesterMixin.test_initialization] because of specificities of Mimi codec model.
See https://github.com/huggingface/transformers/blob/1077603410cd73ba71d64a522033574d66d64b55/tests/models/mimi/test_modeling_mimi.py#L384-L397
"""
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
configs_no_init = _config_zero_init(config)
for model_class in self.all_model_classes:
model = model_class(config=configs_no_init)
for name, param in model.named_parameters():
uniform_init_parms = ["conv", "input_proj", "output_proj"]
if param.requires_grad:
if any(x in name for x in uniform_init_parms):
self.assertTrue(
-1.0 <= ((param.data.mean() * 1e9).round() / 1e9).item() <= 1.0,
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
)
def _check_similar_generate_outputs(self, output_1, output_2, atol=1e-5, rtol=1e-5):
"""
Overrides [GenerationTesterMixin._check_similar_generate_outputs] to handle third input_ids dimension.
Here we only look a the first codebook (index 0 on last dimension of the generated sequences) since returned scores
are for this token.
"""
# scores doesn't include data regarding decoder input tokens
decoder_input_length = output_1.sequences.shape[1] - len(output_1.scores)
output_matches = output_1.sequences[..., 0] == output_2.sequences[..., 0]
has_matching_outputs = output_matches.all()
has_matching_scores = None
if not has_matching_outputs:
for batch_idx in range(output_1.sequences.shape[0]):
batch_matches = output_matches[batch_idx]
if batch_matches.all():
continue
first_mismatch_idx = batch_matches.int().argmin() # gets the index of the first False
first_mismatch_idx -= decoder_input_length
output_1_first_mismatch_scores = output_1.scores[first_mismatch_idx][batch_idx]
output_2_first_mismatch_scores = output_2.scores[first_mismatch_idx][batch_idx]
has_matching_scores = torch.allclose(
output_1_first_mismatch_scores, output_2_first_mismatch_scores, rtol=atol, atol=rtol
)
if not has_matching_scores:
break
self.assertTrue(has_matching_outputs or has_matching_scores)
@parameterized.expand([("random",), ("same",)])
@pytest.mark.generate
@unittest.skip(reason="CSM does not support assisted decoding.")
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
pass
@pytest.mark.generate
@unittest.skip(reason="CSM does not support assisted decoding.")
def test_assisted_decoding_sample(self):
pass
@pytest.mark.generate
@unittest.skip(reason="CSM does not support Dola decoding.")
def test_dola_decoding_sample(self):
pass
@pytest.mark.generate
@unittest.skip(reason="CSM does not support beam search.")
def test_beam_sample_generate(self):
pass
@pytest.mark.generate
@unittest.skip(reason="CSM does not support beam search.")
def test_beam_search_generate(self):
pass
@pytest.mark.generate
@unittest.skip(reason="CSM does not support beam search.")
def test_beam_search_generate_dict_output(self):
pass
@pytest.mark.generate
@unittest.skip(reason="CSM does not support beam search.")
def test_beam_search_generate_dict_outputs_use_cache(self):
pass
@pytest.mark.generate
@unittest.skip(reason="CSM does not support beam search.")
def test_beam_sample_generate_dict_output(self):
pass
@pytest.mark.generate
@unittest.skip(reason="CSM does not support group beam search.")
def test_group_beam_search_generate(self):
pass
@pytest.mark.generate
@unittest.skip(reason="CSM does not support group beam search.")
def test_group_beam_search_generate_dict_output(self):
pass
@pytest.mark.generate
@unittest.skip(reason="CSM does not support constrained beam search.")
def test_constrained_beam_search_generate(self):
pass
@pytest.mark.generate
@unittest.skip(reason="CSM does not support constrained beam search.")
def test_constrained_beam_search_generate_dict_output(self):
pass
@pytest.mark.generate
@unittest.skip(reason="CSM does not support contrastive search.")
def test_contrastive_generate(self):
pass
@pytest.mark.generate
@unittest.skip(reason="CSM does not support contrastive search.")
def test_contrastive_generate_dict_outputs_use_cache(self):
pass
@pytest.mark.generate
@unittest.skip(reason="CSM does not support contrastive search.")
def test_contrastive_generate_low_memory(self):
pass
@pytest.mark.generate
@unittest.skip(reason="CSM does not support prompt lookup decoding.")
def test_prompt_lookup_decoding_matches_greedy_search(self):
pass
@pytest.mark.generate
@unittest.skip(reason="CSM does not support prompt lookup decoding.")
def test_prompt_lookup_decoding_stops_at_eos(self):
pass
@pytest.mark.skip(reason="CSM has custom embedding approach (text and audio embeddings).")
def test_model_get_set_embeddings(self):
pass
@pytest.mark.skip(reason="CSM has custom embedding approach (text and audio embeddings).")
def test_tie_model_weights(self):
pass
@pytest.mark.generate
@unittest.skip(reason="CSM does not support beam search.")
def test_generate_from_inputs_embeds_1_beam_search(self, _, num_beams):
pass
@pytest.mark.generate
@unittest.skip(reason="CSM does not support beam search.")
def test_model_parallel_beam_search(self):
pass
def test_tied_weights_keys(self):
"""
Overrides [ModelTesterMixin.test_tied_weights_keys] to not test for text config (not applicable to CSM).
"""
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model_tied = model_class(config)
ptrs = collections.defaultdict(list)
for name, tensor in model_tied.state_dict().items():
ptrs[id_tensor_storage(tensor)].append(name)
# These are all the pointers of shared tensors.
tied_params = [names for _, names in ptrs.items() if len(names) > 1]
tied_weight_keys = model_tied._tied_weights_keys if model_tied._tied_weights_keys is not None else []
# Detect we get a hit for each key
for key in tied_weight_keys:
is_tied_key = any(re.search(key, p) for group in tied_params for p in group)
self.assertTrue(is_tied_key, f"{key} is not a tied weight key for {model_class}.")
# Removed tied weights found from tied params -> there should only be one left after
for key in tied_weight_keys:
for i in range(len(tied_params)):
tied_params[i] = [p for p in tied_params[i] if re.search(key, p) is None]
tied_params = [group for group in tied_params if len(group) > 1]
self.assertListEqual(
tied_params,
[],
f"Missing `_tied_weights_keys` for {model_class}: add all of {tied_params} except one.",
)
def _get_custom_4d_mask_test_data(self):
"""
Overrides [ModelTesterMixin._get_custom_4d_mask_test_data] to handle third input_ids dimension.
"""
# Sequence in which all but the last token is the same
input_ids = torch.tensor([[0, 1, 2, 3], [0, 1, 2, 4], [0, 1, 2, 5]], device=torch_device, dtype=torch.int64)
input_ids = input_ids.unsqueeze(-1).expand(-1, -1, self.model_tester.config["num_codebooks"])
position_ids = torch.tensor([[0, 1, 2, 3]] * 3, device=torch_device, dtype=torch.int64)
# Combining common prefix with the unique ending tokens:
input_ids_shared_prefix = torch.cat([input_ids[0][:-1], input_ids[:, -1]]).unsqueeze(0)
# Creating a 4D mask where each of the last 3 tokens do not attend to each other.
mask_shared_prefix = torch.tensor(
[
[
[
[1, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[1, 1, 1, 1, 0, 0],
[1, 1, 1, 0, 1, 0],
[1, 1, 1, 0, 0, 1],
]
]
],
)
# inverting the attention mask
mask_dtype = torch.float32
min_dtype = torch.finfo(mask_dtype).min
mask_shared_prefix = (mask_shared_prefix.eq(0.0)).to(dtype=mask_dtype, device=torch_device) * min_dtype
# Creating a position_ids tensor. note the repeating figures in the end.
position_ids_shared_prefix = torch.tensor([[0, 1, 2, 3, 3, 3]], device=torch_device, dtype=torch.int64)
return input_ids, position_ids, input_ids_shared_prefix, mask_shared_prefix, position_ids_shared_prefix
class CsmForConditionalGenerationIntegrationTest(unittest.TestCase):
def setUp(self):
# TODO: @eustlb, update with correct sesame's repo
self.model_checkpoint = "eustlb/csm-1b"
def tearDown(self):
cleanup(torch_device, gc_collect=True)
def _load_conversation(self):
ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train")
ds = ds.filter(lambda x: x["conversation_id"] == 0)
ds = ds.sort("turn_id")
return ds[0]
@slow
@require_torch_gpu
def test_1b_model_integration_generate(self):
"""
Tests the generated tokens match the ones from the original model implementation.
Such tokens are to be retreived using https://gist.github.com/eustlb/d25577a357ddcf8f4a8cd0d00baca551, which is a script that infers the original model.
"""
processor = AutoProcessor.from_pretrained(self.model_checkpoint)
prompt = "<|begin_of_text|>[0]What are you working on?<|end_of_text|><|AUDIO|><|audio_eos|><|begin_of_text|>[1]I'm figuring out my budget.<|end_of_text|>"
ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train")
audio = ds[0]["audio"]["array"]
inputs = processor(text=prompt, audio=audio, return_tensors="pt").to(torch_device)
model = CsmForConditionalGeneration.from_pretrained(self.model_checkpoint, device_map=torch_device)
output_tokens = model.generate(**inputs, do_sample=False, depth_decoder_do_sample=False)
# fmt: off
EXPECTED_OUTPUT_TOKENS = torch.tensor([[
[1140, 10, 37, 1180, 1100, 1319, 601, 1482, 1918, 1739, 372, 856, 674, 1, 854, 459, 1843, 1191, 347, 349, 1087, 846, 759, 1690, 947, 1280, 580, 1909, 1192, 487, 1302, 1601],
[1494, 1412, 1824, 1852, 150, 928, 91, 326, 623, 1632, 1163, 1221, 1949, 999, 1779, 248, 693, 1149, 1423, 1503, 598, 80, 223, 1798, 251, 385, 1391, 1692, 1228, 1631, 1101, 866],
[778, 645, 830, 1812, 524, 1704, 1805, 1289, 74, 1069, 243, 1622, 1755, 1281, 1397, 620, 1962, 1995, 253, 1124, 1007, 518, 89, 559, 1304, 1482, 523, 1747, 1979, 1003, 1707, 1578],
[1356, 481, 642, 989, 287, 1819, 171, 1115, 824, 1253, 1488, 1074, 1019, 342, 279, 513, 1275, 1364, 893, 2007, 553, 407, 882, 1170, 1586, 485, 762, 559, 100, 542, 911, 1460],
[1860, 593, 1944, 404, 575, 545, 862, 830, 1002, 125, 2010, 268, 1779, 804, 811, 809, 255, 373, 387, 1756, 259, 822, 1191, 700, 1686, 390, 1676, 844, 2006, 286, 1376, 719],
[1165, 1047, 848, 212, 1018, 1470, 93, 1709, 1487, 1691, 1190, 275, 1278, 2018, 121, 1023, 485, 463, 39, 1825, 1936, 1817, 569, 209, 1553, 1599, 1137, 769, 968, 558, 1957, 265],
[902, 1608, 719, 850, 371, 1920, 75, 1917, 2005, 1238, 562, 1743, 713, 95, 1107, 1463, 696, 840, 8, 487, 1950, 1171, 1004, 1516, 1130, 303, 1866, 1728, 2046, 238, 265, 153],
[1932, 839, 334, 1167, 134, 2025, 40, 505, 1244, 1238, 1840, 800, 697, 72, 216, 486, 940, 1312, 510, 361, 549, 583, 1364, 844, 397, 1181, 1779, 962, 457, 1782, 1316, 465],
[31, 1558, 1048, 404, 354, 7, 827, 414, 1082, 807, 243, 1517, 801, 1364, 99, 1276, 1655, 1488, 1313, 464, 828, 1612, 774, 1558, 745, 1496, 960, 1874, 995, 1943, 255, 213],
[355, 1270, 413, 1519, 1659, 1904, 690, 552, 1279, 1821, 2022, 458, 1779, 2003, 604, 832, 661, 1295, 305, 1701, 173, 869, 230, 539, 1188, 669, 117, 692, 250, 388, 1995, 294],
[629, 199, 1899, 1123, 1070, 344, 578, 1795, 1451, 1257, 168, 1410, 1120, 1270, 316, 983, 1245, 1870, 165, 471, 966, 1337, 308, 1118, 746, 67, 1767, 1480, 1517, 1585, 871, 1110],
[1281, 1173, 784, 404, 368, 403, 580, 526, 853, 1692, 792, 895, 1286, 573, 1368, 896, 931, 1958, 1912, 644, 583, 1706, 1176, 1262, 1637, 315, 524, 1629, 795, 1211, 915, 533],
[9, 1783, 621, 1954, 1212, 993, 197, 977, 1662, 1340, 618, 1997, 1689, 1001, 74, 1765, 1865, 797, 1219, 1609, 671, 1491, 950, 1849, 1301, 2031, 875, 323, 203, 1063, 1490, 1538],
[1944, 1578, 1256, 1169, 790, 1444, 1382, 1616, 1100, 1264, 214, 1646, 488, 573, 1333, 285, 1954, 74, 1333, 674, 1303, 266, 622, 1290, 402, 109, 1331, 1666, 1347, 780, 106, 605],
[221, 161, 1322, 1, 565, 1507, 1403, 1091, 1557, 932, 1664, 1165, 1828, 1647, 2008, 1616, 648, 1113, 1870, 22, 734, 1458, 1940, 1756, 1689, 925, 1318, 1095, 985, 473, 604, 1974],
[1178, 597, 1804, 747, 1383, 360, 1497, 406, 1053, 1023, 1901, 56, 1221, 628, 75, 1729, 575, 1681, 840, 410, 650, 794, 1171, 1889, 187, 54, 1364, 1390, 505, 1285, 1814, 90],
[1432, 1221, 1800, 1873, 1255, 627, 41, 9, 630, 896, 1469, 1195, 1098, 145, 442, 1460, 13, 57, 2039, 1015, 149, 461, 1084, 1288, 1099, 910, 63, 157, 906, 111, 1394, 460],
[1352, 593, 307, 780, 1614, 1675, 1491, 1253, 723, 1793, 1032, 1486, 1805, 1904, 777, 398, 1791, 951, 770, 499, 1858, 244, 1372, 1514, 1858, 1200, 69, 181, 673, 1144, 1938, 1191],
[905, 403, 1626, 1529, 581, 1443, 976, 754, 1561, 1370, 1048, 253, 194, 1271, 853, 959, 1532, 30, 286, 1594, 1255, 1135, 1410, 1699, 1423, 2002, 260, 69, 941, 1640, 895, 722],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
]])
# fmt: on
torch.testing.assert_close(output_tokens.cpu(), EXPECTED_OUTPUT_TOKENS)
@slow
@require_torch_gpu
def test_1b_model_integration_generate_no_audio(self):
"""
Tests the generated tokens match the ones from the original model implementation.
Such tokens are to be retreived using https://gist.github.com/eustlb/aed822f765e928b9612e01b0d8836d69, which is a script that infers the original model.
"""
processor = AutoProcessor.from_pretrained(self.model_checkpoint)
conversation = [
{"role": "0", "content": [{"type": "text", "text": "The past is just a story we tell ourselves."}]},
]
inputs = processor.apply_chat_template(conversation, tokenize=True, return_dict=True).to(torch_device)
model = CsmForConditionalGeneration.from_pretrained(self.model_checkpoint, device_map=torch_device)
output_tokens = model.generate(**inputs, do_sample=False, depth_decoder_do_sample=False)
print(output_tokens)
# fmt: off
EXPECTED_OUTPUT_TOKENS = torch.tensor([[
[1656, 629, 723, 1785, 206, 1873, 1059, 1190, 1833, 240, 618, 350, 156, 109, 2010, 452, 435, 1764, 77, 654, 1133, 908, 1095, 74, 804, 494, 1760, 1343, 1312, 1464, 1657, 324],
[366, 1532, 1945, 21, 145, 1428, 1417, 1987, 1793, 1444, 356, 1491, 849, 333, 788, 426, 1423, 1004, 414, 1823, 1169, 257, 1892, 696, 1572, 998, 1098, 523, 390, 1977, 546, 1692],
[1343, 1382, 1288, 1744, 1685, 1154, 1837, 1156, 1680, 1641, 1479, 1548, 632, 824, 694, 2010, 671, 1251, 1822, 343, 638, 1372, 696, 1272, 144, 125, 1332, 579, 936, 77, 159, 357],
[456, 1534, 349, 274, 1956, 1502, 1268, 1038, 1911, 523, 1360, 1159, 761, 293, 718, 1143, 63, 705, 168, 550, 413, 1372, 1771, 787, 631, 693, 784, 1789, 2039, 1131, 1601, 918],
[456, 829, 2026, 1108, 1649, 207, 1308, 1440, 1192, 1394, 426, 546, 590, 36, 1682, 1827, 1387, 1425, 1909, 1500, 1438, 1297, 5, 888, 948, 1745, 1304, 1364, 1692, 131, 300, 1908],
[2027, 1431, 1037, 1789, 1296, 1264, 1331, 1787, 1235, 1902, 1161, 1591, 590, 561, 1633, 1218, 510, 148, 1962, 118, 212, 608, 565, 1869, 583, 598, 532, 658, 1416, 9, 1172, 493],
[1215, 460, 1722, 317, 1423, 716, 1589, 1177, 1927, 1860, 1756, 1552, 1674, 643, 74, 1256, 587, 1742, 771, 2028, 469, 1070, 1683, 1614, 699, 494, 2020, 139, 1365, 1171, 171, 904],
[1615, 339, 323, 317, 469, 714, 104, 2015, 1407, 278, 468, 77, 2007, 650, 1630, 269, 168, 934, 1544, 58, 1487, 1373, 705, 874, 1252, 2031, 1995, 254, 1334, 1171, 1911, 1607],
[1259, 693, 666, 1700, 1115, 607, 982, 769, 1106, 1500, 101, 88, 1698, 1864, 1358, 1594, 192, 153, 1868, 1654, 604, 1948, 526, 778, 172, 1664, 1966, 99, 1334, 1030, 1349, 1209],
[1211, 579, 1369, 492, 1725, 203, 1125, 778, 701, 1982, 1420, 155, 736, 1145, 2018, 609, 658, 561, 1147, 923, 1794, 1753, 116, 1374, 612, 956, 1587, 392, 1062, 2047, 901, 1931],
[460, 1093, 1346, 1917, 1223, 470, 271, 390, 547, 112, 143, 1633, 1030, 643, 96, 1759, 920, 1959, 75, 1280, 1630, 999, 333, 853, 1110, 1291, 1911, 57, 171, 1658, 1704, 1508],
[908, 500, 393, 184, 1437, 482, 2008, 1834, 356, 1435, 1550, 1407, 1236, 109, 1167, 452, 1141, 934, 207, 957, 660, 670, 28, 1066, 1252, 1932, 669, 906, 1904, 1820, 2043, 881],
[1599, 1031, 1474, 336, 1540, 571, 437, 1440, 1616, 1365, 1412, 1246, 400, 405, 1776, 96, 296, 38, 1597, 466, 1630, 1256, 1940, 887, 1769, 294, 285, 842, 1756, 1619, 451, 1529],
[1615, 339, 1722, 525, 942, 105, 1365, 670, 785, 1316, 465, 1860, 438, 968, 547, 1938, 1816, 1429, 1065, 1942, 660, 1446, 1093, 1066, 931, 121, 688, 1033, 1178, 754, 1783, 94],
[912, 1354, 598, 254, 341, 1980, 1166, 585, 1302, 473, 554, 242, 174, 2030, 2011, 325, 978, 1690, 258, 396, 1831, 1768, 1291, 1699, 2001, 433, 1414, 2012, 1045, 511, 533, 1104],
[80, 1791, 1062, 1136, 391, 568, 1651, 101, 959, 2043, 1683, 760, 794, 181, 570, 540, 1599, 20, 1017, 973, 1654, 396, 586, 778, 2044, 1664, 1911, 929, 66, 897, 510, 643],
[1161, 1093, 161, 1296, 589, 54, 906, 981, 1927, 605, 516, 1731, 1461, 1204, 1902, 920, 1488, 177, 805, 1402, 610, 1446, 1154, 1067, 2025, 645, 762, 1715, 415, 1658, 1713, 1607],
[374, 1444, 1577, 792, 1450, 628, 604, 1729, 322, 514, 1725, 540, 1070, 575, 653, 800, 250, 187, 569, 349, 354, 1573, 176, 793, 897, 359, 536, 276, 1224, 23, 145, 1287],
[1184, 415, 1644, 1737, 1788, 385, 784, 1861, 1172, 1118, 367, 1156, 234, 1946, 1742, 981, 828, 1798, 1821, 361, 1148, 670, 518, 1288, 761, 1050, 1642, 1006, 1747, 840, 1599, 720],
[1141, 1731, 1670, 1542, 1347, 1907, 683, 753, 1347, 68, 2031, 153, 556, 719, 736, 1759, 1131, 1073, 1747, 1730, 1487, 1137, 1869, 1624, 699, 1900, 748, 49, 1312, 735, 726, 1268],
[1141, 1383, 405, 1033, 490, 488, 1102, 471, 713, 1630, 447, 703, 1495, 1001, 1855, 354, 456, 411, 786, 853, 168, 407, 116, 699, 605, 128, 532, 1076, 208, 447, 1448, 1071],
[345, 1013, 948, 1728, 1837, 337, 930, 1226, 1643, 1729, 983, 1688, 2009, 435, 1358, 721, 42, 1779, 1332, 1077, 1873, 128, 1327, 125, 1226, 1704, 705, 1459, 1449, 862, 155, 1870],
[336, 904, 684, 184, 1542, 714, 1752, 1180, 1373, 1816, 504, 1716, 1066, 1086, 1212, 530, 1413, 1278, 75, 1347, 82, 1623, 1307, 1717, 1861, 494, 888, 1589, 670, 1999, 905, 1430],
[578, 554, 14, 523, 1016, 300, 1589, 1017, 356, 1583, 1654, 414, 449, 376, 1413, 58, 706, 963, 388, 1626, 131, 352, 1024, 1054, 2025, 1561, 77, 1589, 1486, 431, 1249, 1508],
[184, 2043, 169, 1673, 580, 162, 1752, 397, 1119, 2009, 697, 150, 1475, 157, 1523, 1402, 575, 86, 1373, 1230, 1564, 1308, 626, 1093, 1603, 1446, 1390, 1543, 1778, 1142, 1357, 1831],
[1484, 1987, 932, 1728, 1504, 1618, 291, 1865, 1151, 460, 1792, 141, 234, 2043, 829, 513, 435, 791, 1037, 1541, 65, 424, 1589, 1711, 312, 1306, 212, 686, 673, 984, 1914, 1549],
[513, 1536, 1844, 1319, 572, 1069, 121, 735, 1949, 1211, 1362, 1027, 105, 1379, 315, 1782, 706, 1658, 1510, 1989, 1443, 1690, 822, 1614, 1194, 1460, 992, 2040, 1178, 1474, 1110, 1326],
[1858, 194, 1594, 1935, 1622, 1892, 1577, 137, 1907, 2015, 757, 414, 1823, 836, 496, 530, 1385, 1503, 1065, 1554, 664, 525, 1031, 433, 69, 466, 1016, 1846, 1609, 1658, 911, 94],
[1134, 1744, 323, 691, 1837, 347, 1871, 172, 811, 91, 1883, 436, 1912, 23, 1336, 1684, 519, 1612, 1219, 1402, 728, 1953, 1658, 641, 27, 1340, 436, 139, 2008, 1030, 159, 324],
[1270, 1536, 1639, 414, 1387, 1170, 1067, 1701, 1414, 505, 1122, 36, 1731, 350, 1552, 1214, 1444, 30, 107, 172, 480, 1858, 655, 168, 1107, 691, 1272, 797, 1656, 548, 1407, 1375],
[1270, 286, 1371, 1552, 1622, 1739, 1348, 2018, 345, 1537, 1941, 2024, 1423, 740, 284, 513, 91, 1228, 2015, 385, 992, 39, 813, 803, 2025, 497, 663, 462, 1609, 334, 927, 1470],
[1718, 994, 265, 1421, 1622, 1098, 845, 1868, 832, 459, 447, 619, 1970, 929, 513, 63, 1448, 1509, 1219, 1942, 285, 1373, 1259, 1004, 11, 1040, 1984, 57, 188, 1687, 1475, 805],
[1157, 832, 480, 1225, 1019, 347, 326, 999, 125, 1542, 118, 1383, 1343, 1077, 1821, 1602, 1978, 1642, 618, 808, 692, 1953, 1353, 963, 619, 1291, 1016, 1458, 1995, 1688, 1872, 1718],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
]])
# fmt: on
torch.testing.assert_close(output_tokens.cpu(), EXPECTED_OUTPUT_TOKENS)
@slow
@require_torch_gpu
def test_1b_model_integration_generate_multiple_audio(self):
"""
Test the generated tokens match the ones from the original model implementation.
Such tokens are to be retreived using https://gist.github.com/eustlb/0c94de002e1325abb61d32217f74c0f8, which is a script that infers the original model.
"""
processor = AutoProcessor.from_pretrained(self.model_checkpoint)
ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train")
conversation = []
# context
for text, audio, speaker_id in zip(ds[:4]["text"], ds[:4]["audio"], ds[:4]["speaker_id"]):
conversation.append(
{
"role": f"{speaker_id}",
"content": [{"type": "text", "text": text}, {"type": "audio", "path": audio["array"]}],
}
)
# text prompt
conversation.append({"role": f"{ds[4]['speaker_id']}", "content": [{"type": "text", "text": ds[4]["text"]}]})
inputs = processor.apply_chat_template(
conversation,
tokenize=True,
return_dict=True,
).to(torch_device)
model = CsmForConditionalGeneration.from_pretrained(self.model_checkpoint, device_map=torch_device)
output_tokens = model.generate(**inputs, do_sample=False, depth_decoder_do_sample=False)
# fmt: off
EXPECTED_OUTPUT_TOKENS = torch.tensor([[
[420, 1189, 1311, 318, 359, 694, 1550, 1044, 1614, 1437, 1978, 537, 554, 1681, 147, 1225, 422, 1357, 1681, 1619, 165, 641, 1132, 1975, 1568, 406, 756, 503, 1673, 1428, 762, 781],
[1848, 1412, 957, 1656, 871, 540, 1999, 175, 711, 1383, 1814, 104, 742, 1285, 733, 1251, 1165, 1915, 1392, 645, 1804, 913, 1772, 632, 376, 1507, 1132, 725, 716, 1121, 1769, 1509],
[429, 1138, 895, 1018, 1099, 257, 1395, 1015, 576, 1599, 497, 19, 1858, 1437, 282, 357, 1143, 828, 1481, 70, 985, 551, 935, 278, 1102, 1453, 1902, 755, 526, 498, 1441, 1733],
[546, 343, 1547, 879, 2039, 692, 1999, 1150, 1969, 1866, 1178, 199, 1913, 1738, 1530, 1728, 1193, 74, 695, 612, 1095, 1597, 1381, 683, 1385, 2045, 1069, 865, 438, 70, 1437, 318],
[1741, 1621, 733, 1580, 1006, 1790, 1031, 1563, 569, 1822, 1229, 854, 142, 1554, 792, 741, 147, 552, 731, 772, 908, 831, 1291, 1819, 296, 290, 1871, 100, 1904, 1420, 1903, 1653],
[1264, 1576, 963, 12, 1403, 453, 259, 1359, 1270, 466, 1744, 1579, 1081, 1691, 1495, 1293, 110, 1020, 2042, 189, 1358, 955, 784, 1317, 2, 1794, 388, 376, 327, 511, 866, 1308],
[1407, 1412, 1665, 1683, 284, 874, 1859, 326, 1491, 1343, 777, 695, 1424, 396, 274, 202, 178, 747, 470, 1805, 1414, 2000, 127, 1884, 531, 215, 1322, 1098, 1674, 1227, 1092, 204],
[584, 637, 1665, 1683, 1136, 1201, 212, 310, 1441, 1619, 190, 1611, 1629, 2011, 1754, 1587, 413, 1287, 1251, 1382, 1904, 444, 1665, 1047, 1982, 1169, 1200, 809, 117, 327, 958, 1877],
[471, 1469, 1679, 1184, 343, 974, 1442, 897, 1888, 1468, 1092, 1398, 1714, 963, 1577, 1797, 766, 565, 403, 920, 1806, 466, 1193, 446, 825, 775, 1886, 1095, 159, 1085, 858, 504],
[28, 1511, 1510, 1580, 447, 1934, 1031, 1439, 202, 1435, 474, 1731, 724, 1080, 1121, 421, 625, 1410, 95, 605, 815, 1825, 127, 785, 900, 1673, 178, 1242, 2033, 1230, 350, 139],
[20, 1215, 253, 955, 871, 1689, 1986, 24, 1648, 423, 562, 1937, 1146, 26, 1266, 346, 188, 318, 179, 1164, 1100, 1978, 478, 1192, 715, 392, 1837, 425, 1492, 766, 1651, 822],
[1879, 1401, 1444, 723, 1754, 732, 1307, 702, 1768, 2013, 1284, 577, 1287, 1532, 647, 189, 903, 587, 800, 152, 898, 182, 2016, 639, 1074, 1220, 1934, 264, 250, 745, 1652, 536],
[1874, 1526, 232, 1580, 1980, 988, 1623, 341, 1768, 956, 1430, 1667, 1687, 1289, 826, 1378, 173, 1466, 479, 835, 1786, 1671, 328, 131, 815, 871, 379, 1329, 440, 1117, 392, 272],
[1762, 426, 1350, 1590, 314, 190, 1514, 344, 1926, 822, 534, 523, 703, 36, 379, 494, 464, 1886, 1555, 1318, 1654, 1469, 1976, 304, 218, 655, 1826, 958, 502, 326, 1898, 861],
[1577, 386, 503, 1492, 698, 405, 1031, 349, 1804, 2012, 1450, 996, 1140, 26, 449, 33, 1917, 354, 702, 1255, 1942, 1184, 864, 2045, 514, 744, 466, 54, 37, 486, 362, 525],
[1109, 1920, 445, 1719, 1670, 1220, 745, 40, 171, 1921, 999, 104, 489, 1911, 883, 306, 649, 1751, 762, 1183, 1085, 1112, 1912, 2035, 1940, 1129, 1592, 1276, 1570, 1236, 738, 209],
[1837, 990, 1063, 318, 1398, 1838, 1678, 906, 754, 802, 562, 353, 1389, 207, 1319, 1188, 2013, 1079, 888, 1706, 1042, 657, 482, 953, 94, 2007, 871, 485, 1596, 275, 410, 1855],
[872, 974, 1344, 1798, 655, 805, 1604, 1913, 455, 615, 1827, 966, 1330, 1826, 1285, 359, 544, 221, 1538, 1658, 374, 1352, 1714, 1925, 235, 65, 350, 931, 1009, 1164, 218, 736],
[1547, 617, 1622, 740, 655, 265, 1324, 1265, 1449, 482, 1037, 105, 1128, 701, 1866, 1674, 1999, 1302, 985, 1942, 663, 449, 1881, 698, 805, 1446, 1742, 1192, 1623, 605, 948, 2],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
]])
# fmt: on
torch.testing.assert_close(output_tokens.cpu(), EXPECTED_OUTPUT_TOKENS)
@slow
@require_torch_gpu
def test_1b_model_integration_generate_batched(self):
"""
Test the generated tokens match the ones from the original model implementation.
Such tokens are to be retreived using https://gist.github.com/eustlb/bcc532b53161bc31da3d66cb07ae193f, which is a script that infers the original model.
"""
processor = AutoProcessor.from_pretrained(self.model_checkpoint)
ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train")
conversation = [
[
{
"role": f"{ds[0]['speaker_id']}",
"content": [
{"type": "text", "text": ds[0]["text"]},
{"type": "audio", "path": ds[0]["audio"]["array"]},
],
},
{
"role": f"{ds[1]['speaker_id']}",
"content": [
{"type": "text", "text": ds[1]["text"]},
],
},
],
[
{
"role": f"{ds[0]['speaker_id']}",
"content": [
{"type": "text", "text": ds[0]["text"]},
],
}
],
]
inputs = processor.apply_chat_template(
conversation,
tokenize=True,
return_dict=True,
).to(torch_device)
model = CsmForConditionalGeneration.from_pretrained(self.model_checkpoint, device_map=torch_device)
output_tokens = model.generate(**inputs, do_sample=False, depth_decoder_do_sample=False)
# fmt: off
EXPECTED_OUTPUT_TOKENS = torch.tensor([
[
[1140, 10, 37, 1180, 1100, 1319, 601, 1482, 1918, 1739, 372, 856, 674, 1, 854, 459, 1843, 1191, 347, 349, 1087, 846, 759, 1690, 947, 1280, 580, 1909, 1192, 487, 1302, 1601],
[1494, 1412, 1824, 1852, 150, 928, 91, 326, 623, 1632, 1163, 1221, 1949, 999, 1779, 248, 693, 1149, 1423, 1503, 1656, 80, 1947, 1666, 933, 1950, 1544, 1577, 1612, 1791, 1883, 765],
[778, 645, 830, 1051, 524, 1704, 1805, 1438, 211, 906, 691, 814, 1798, 1642, 1042, 284, 1906, 1513, 520, 137, 1052, 1548, 423, 1564, 330, 873, 1381, 188, 317, 1503, 1707, 1744],
[1416, 864, 242, 1653, 604, 1577, 202, 1808, 926, 1867, 204, 134, 1096, 1765, 496, 1680, 268, 1796, 2024, 1989, 583, 183, 952, 105, 765, 1534, 669, 895, 2008, 11, 1199, 195],
[1356, 796, 25, 1580, 15, 344, 1730, 99, 1330, 315, 955, 1964, 1731, 543, 1159, 1860, 671, 732, 63, 382, 143, 395, 1749, 1421, 1640, 1340, 650, 100, 171, 1346, 41, 806],
[1860, 1835, 823, 388, 254, 1734, 1135, 324, 1508, 983, 937, 1703, 1541, 875, 1319, 799, 1259, 1175, 1295, 807, 261, 760, 1916, 1606, 1616, 1894, 1605, 441, 387, 167, 2016, 222],
[1165, 919, 1318, 54, 1727, 1766, 777, 1128, 623, 353, 1840, 241, 977, 424, 1055, 898, 395, 655, 1695, 1084, 1346, 616, 1028, 1927, 603, 858, 758, 1539, 0, 1655, 1853, 1661],
[902, 1746, 1318, 298, 1982, 1184, 775, 328, 1676, 871, 133, 1374, 1927, 1984, 698, 1037, 100, 1884, 1596, 429, 1794, 2046, 105, 2037, 1767, 178, 176, 1293, 1893, 1780, 1832, 1382],
[1932, 714, 1084, 1167, 624, 509, 1213, 651, 1000, 1686, 1537, 555, 461, 623, 1433, 1089, 1212, 1628, 834, 1111, 943, 1816, 1947, 1063, 354, 1843, 1741, 2015, 404, 928, 1488, 168],
[1437, 314, 1356, 404, 1274, 2016, 998, 1350, 155, 553, 368, 1501, 1431, 1563, 1105, 1353, 535, 908, 1305, 1214, 1656, 65, 1469, 1517, 480, 252, 1289, 696, 302, 632, 246, 72],
[724, 848, 1140, 927, 1669, 296, 447, 1708, 1898, 685, 1041, 1685, 708, 1510, 1623, 876, 11, 99, 43, 586, 1705, 1753, 1477, 1191, 583, 1249, 1613, 992, 1319, 677, 418, 668],
[925, 54, 1810, 674, 1306, 848, 573, 1772, 105, 301, 1753, 989, 440, 1057, 823, 1313, 1663, 750, 1477, 102, 1437, 1114, 399, 1440, 319, 118, 1827, 295, 1429, 139, 1594, 55],
[629, 149, 784, 838, 984, 604, 685, 1229, 1432, 859, 1526, 1336, 1949, 281, 988, 1260, 52, 6, 1216, 1542, 1426, 1938, 253, 280, 1319, 794, 901, 843, 615, 437, 814, 20],
[1281, 502, 1237, 404, 625, 1444, 397, 1999, 2016, 1686, 533, 1785, 1152, 1245, 579, 1906, 1204, 549, 1334, 536, 1351, 1979, 208, 111, 2011, 751, 677, 1948, 1772, 1525, 2038, 419],
[9, 490, 869, 2026, 1928, 1489, 587, 549, 1241, 460, 1458, 1636, 924, 222, 1246, 480, 706, 398, 75, 1717, 604, 1446, 333, 237, 805, 1446, 421, 1343, 78, 1260, 1872, 1116],
[1944, 755, 375, 332, 1464, 828, 1273, 579, 1457, 353, 1510, 1910, 1609, 705, 400, 1666, 227, 1544, 1270, 136, 1857, 1975, 1762, 2006, 1102, 221, 1965, 151, 2041, 198, 1830, 287],
[221, 502, 440, 247, 181, 1912, 42, 357, 1883, 596, 919, 953, 1774, 772, 915, 188, 438, 1226, 544, 1313, 726, 1298, 85, 677, 566, 1581, 30, 341, 878, 1732, 591, 1446],
[1178, 1690, 320, 1746, 1798, 685, 1941, 666, 832, 623, 1907, 128, 337, 1779, 824, 923, 1041, 287, 1165, 437, 1803, 1222, 870, 646, 358, 220, 2009, 735, 468, 1908, 1349, 1603],
[1432, 1286, 540, 1687, 1741, 951, 299, 1233, 1061, 1128, 985, 953, 1917, 198, 2031, 1559, 1096, 1455, 780, 437, 163, 1268, 649, 1029, 1081, 1518, 304, 1638, 814, 364, 140, 1385],
[905, 463, 1739, 1063, 351, 936, 1652, 101, 1323, 1731, 298, 1193, 266, 1554, 1837, 1659, 409, 1739, 1012, 725, 851, 1909, 213, 1918, 1759, 1561, 1250, 970, 1571, 352, 911, 195],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
],
[
[1375, 203, 265, 164, 200, 1867, 976, 924, 1972, 1637, 1048, 271, 1912, 1430, 853, 1942, 260, 1642, 400, 57, 1376, 1626, 1821, 1163, 619, 777, 1076, 951, 389, 1820, 84, 1417],
[914, 527, 286, 968, 305, 1314, 805, 1703, 87, 559, 1980, 1124, 1726, 36, 1139, 618, 1628, 519, 1943, 781, 400, 1265, 438, 113, 87, 856, 465, 162, 1099, 352, 1141, 274],
[1408, 6, 126, 2009, 90, 996, 934, 134, 1857, 126, 602, 876, 1092, 1962, 1205, 828, 707, 1063, 393, 1533, 123, 1086, 1749, 1324, 1, 1763, 1707, 1191, 34, 1323, 1017, 1787],
[1000, 683, 1630, 703, 1574, 587, 25, 1049, 213, 1270, 1641, 1072, 1892, 1634, 1603, 90, 867, 2037, 1021, 715, 206, 507, 1138, 959, 1822, 1785, 280, 1100, 1660, 251, 1903, 988],
[1657, 1981, 246, 1048, 1952, 451, 305, 423, 2000, 416, 756, 1748, 7, 748, 1866, 1795, 1682, 1832, 338, 212, 1685, 518, 154, 1407, 416, 765, 776, 25, 55, 458, 612, 262],
[1034, 564, 667, 1474, 1212, 350, 712, 941, 1151, 1182, 1280, 640, 924, 1722, 1816, 458, 226, 359, 1518, 102, 1203, 459, 676, 1788, 1110, 393, 1974, 1721, 795, 1459, 798, 1723],
[742, 1616, 119, 653, 441, 679, 246, 1432, 486, 1615, 1191, 500, 650, 223, 687, 1765, 1875, 963, 1385, 863, 151, 1771, 458, 1170, 737, 1932, 785, 1954, 1067, 16, 1986, 2029],
[1437, 1078, 1767, 1452, 1392, 45, 2010, 1664, 245, 2015, 1416, 1055, 457, 985, 740, 1594, 1562, 1838, 258, 1431, 701, 604, 1813, 352, 792, 632, 21, 895, 70, 609, 850, 1599],
[983, 1961, 54, 135, 846, 711, 473, 1630, 1373, 1094, 251, 525, 632, 1014, 1594, 1594, 1752, 398, 1266, 1357, 942, 1680, 191, 874, 483, 1291, 381, 1873, 1964, 1278, 1477, 122],
[1663, 1969, 1887, 113, 145, 251, 1133, 156, 245, 1641, 209, 1322, 2037, 836, 539, 667, 940, 797, 1758, 1357, 191, 1137, 587, 1699, 27, 701, 395, 99, 1682, 876, 762, 839],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
]
])
# fmt: on
torch.testing.assert_close(output_tokens.cpu(), EXPECTED_OUTPUT_TOKENS)

View File

@ -0,0 +1,140 @@
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import shutil
import tempfile
import unittest
import jinja2
import numpy as np
from transformers import CsmProcessor
from transformers.testing_utils import require_torch
from transformers.utils import is_torch_available
from ...test_processing_common import ProcessorTesterMixin
if is_torch_available():
import torch
@require_torch
class CsmProcessorTest(ProcessorTesterMixin, unittest.TestCase):
processor_class = CsmProcessor
@classmethod
def setUpClass(cls):
# TODO: @eustlb, change for hf-internal-testing/csm-1b
cls.checkpoint = "eustlb/csm-1b"
processor = CsmProcessor.from_pretrained(cls.checkpoint)
cls.audio_token = processor.audio_token
cls.audio_token_id = processor.audio_token_id
cls.pad_token_id = processor.tokenizer.pad_token_id
cls.bos_token_id = processor.tokenizer.bos_token_id
cls.tmpdirname = tempfile.mkdtemp()
processor.save_pretrained(cls.tmpdirname)
@classmethod
def tearDownClass(cls):
shutil.rmtree(cls.tmpdirname, ignore_errors=True)
def prepare_processor_dict(self):
return {"chat_template": "\n{%- for message in messages %}\n {#-- Validate role is a stringified integer --#}\n {%- if not message['role'] is string or not message['role'].isdigit() %}\n {{- raise_exception(\"The role must be an integer or a stringified integer (e.g. '0') designating the speaker id\") }}\n {%- endif %}\n\n {#-- Validate content is a list --#}\n {%- set content = message['content'] %}\n {%- if content is not iterable or content is string %}\n {{- raise_exception(\"The content must be a list\") }}\n {%- endif %}\n\n {#-- Collect content types --#}\n {%- set content_types = content | map(attribute='type') | list %}\n {%- set is_last = loop.last %}\n\n {#-- Last message validation --#}\n {%- if is_last %}\n {%- if 'text' not in content_types %}\n {{- raise_exception(\"The last message must include one item of type 'text'\") }}\n {%- elif (content_types | select('equalto', 'text') | list | length > 1) or (content_types | select('equalto', 'audio') | list | length > 1) %}\n {{- raise_exception(\"At most two items are allowed in the last message: one 'text' and one 'audio'\") }}\n {%- endif %}\n\n {#-- All other messages validation --#}\n {%- else %}\n {%- if content_types | select('equalto', 'text') | list | length != 1\n or content_types | select('equalto', 'audio') | list | length != 1 %}\n {{- raise_exception(\"Each message (except the last) must contain exactly one 'text' and one 'audio' item\") }}\n {%- elif content_types | reject('in', ['text', 'audio']) | list | length > 0 %}\n {{- raise_exception(\"Only 'text' and 'audio' types are allowed in content\") }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n\n{%- for message in messages %}\n {{- bos_token }}\n {{- '[' + message['role'] + ']' }}\n {{- message['content'][0]['text'] }}\n {{- eos_token }}\n {%- if message['content']|length > 1 %}\n {{- '<|AUDIO|><|audio_eos|>' }}\n {%- endif %}\n{%- endfor %}\n"} # fmt: skip
def test_chat_template_is_saved(self):
processor_loaded = self.processor_class.from_pretrained(self.tmpdirname)
processor_dict_loaded = json.loads(processor_loaded.to_json_string())
# chat templates aren't serialized to json in processors
self.assertFalse("chat_template" in processor_dict_loaded.keys())
# they have to be saved as separate file and loaded back from that file
# so we check if the same template is loaded
processor_dict = self.prepare_processor_dict()
self.assertTrue(processor_loaded.chat_template == processor_dict.get("chat_template", None))
def test_apply_chat_template(self):
# Message contains content which a mix of lists with images and image urls and string
messages = [
{
"role": "0",
"content": [
{"type": "text", "text": "This is a test sentence 0."},
{"type": "audio"},
],
},
{
"role": "1",
"content": [
{"type": "text", "text": "This is a test sentence 1."},
{"type": "audio"},
],
},
{
"role": "0",
"content": [
{"type": "text", "text": "This is a prompt."},
],
},
]
processor = CsmProcessor.from_pretrained(self.tmpdirname)
rendered = processor.apply_chat_template(messages, tokenize=False)
expected_rendered = (
"<|begin_of_text|>[0]This is a test sentence 0.<|end_of_text|>"
"<|AUDIO|><|audio_eos|>"
"<|begin_of_text|>[1]This is a test sentence 1.<|end_of_text|>"
"<|AUDIO|><|audio_eos|>"
"<|begin_of_text|>[0]This is a prompt.<|end_of_text|>"
)
self.assertEqual(rendered, expected_rendered)
messages = [
{
"role": "0",
"content": [
{"type": "text", "text": "This is a test sentence."},
],
},
{
"role": "1",
"content": [
{"type": "text", "text": "This is a test sentence."},
],
},
]
# this should raise an error because the CSM processor requires audio content in the messages expect the last one
with self.assertRaises(jinja2.exceptions.TemplateError):
input_ids = processor.apply_chat_template(messages, tokenize=False)
# now let's very that it expands audio tokens correctly
messages = [
{
"role": "0",
"content": [
{"type": "text", "text": "This is a test sentence."},
{"type": "audio", "audio": np.zeros(4096)},
],
},
]
input_ids = processor.apply_chat_template(messages, tokenize=True)
# 4096 audio input values should give 3 audio tokens
expected_ids = torch.tensor(
[[128000, 58, 15, 60, 2028, 374, 264, 1296, 11914, 13, 128001, 128002, 128002, 128002, 128003]]
)
torch.testing.assert_close(input_ids, expected_ids)

View File

@ -4350,8 +4350,8 @@ class ModelTesterMixin:
out_shared_prefix_last_tokens = logits_shared_prefix[0, -3:, :] # last three tokens
# comparing softmax-normalized logits:
normalized_0 = F.softmax(out_last_tokens)
normalized_1 = F.softmax(out_shared_prefix_last_tokens)
normalized_0 = F.softmax(out_last_tokens, dim=-1)
normalized_1 = F.softmax(out_shared_prefix_last_tokens, dim=-1)
torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4)
@slow
@ -4403,7 +4403,7 @@ class ModelTesterMixin:
self.skipTest(reason="This model does not support `logits_to_keep` argument.")
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
batch_size, sequence_length = inputs["input_ids"].shape
batch_size, sequence_length = inputs["input_ids"].shape[:2]
vocab_size = config.get_text_config().vocab_size
model = model_class(config).to(device=torch_device).eval()
# some models have labels but `logits_to_keep` should not be used in train mode

View File

@ -159,6 +159,9 @@ IGNORE_NON_TESTED = (
"InternVLVisionModel", # Building part of bigger (tested) model
"JanusVisionModel", # Building part of bigger (tested) model
"TimesFmModel", # Building part of bigger (tested) model
"CsmDepthDecoderForCausalLM", # Building part of bigger (tested) model. Tested implicitly through CsmForConditionalGenerationIntegrationTest.
"CsmDepthDecoderModel", # Building part of bigger (tested) model. Tested implicitly through CsmForConditionalGenerationIntegrationTest.
"CsmBackboneModel", # Building part of bigger (tested) model. Tested implicitly through CsmForConditionalGenerationIntegrationTest.
]
)
@ -368,6 +371,10 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
"Qwen2_5OmniToken2WavModel", # Building part of a bigger model
"Qwen2_5OmniToken2WavBigVGANModel", # Building part of a bigger model
"Qwen2_5OmniToken2WavDiTModel", # Building part of a bigger model
"CsmBackboneModel", # Building part of a bigger model
"CsmDepthDecoderModel", # Building part of a bigger model
"CsmDepthDecoderForCausalLM", # Building part of a bigger model
"CsmForConditionalGeneration", # Building part of a bigger model
]
# DO NOT edit this list!