mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-20 17:13:56 +08:00
Add CSM model (#36719)
* draft structure * depth decoder with forward pre hook * full model forward draft * draft update * depth decoder update * ConversationalSpeechModelForCausalLM udpates * add generate * max length criteria small fix * udpate * updates * generation update * update in loss compute * conversion script * update for correct input embeddings * handle interleaved rope * update * update * update * support compile * update training * add doc * update doc * correct inits * ConversationalSpeechModel -> Csm * conf update * name update * tests CsmForCausalLMTest * convert use cached_file * conf + modeling updates * generate utils handle third dim shape * integration test * modeling + conf updates * common test handle more than 2 dims * add nested audio list utils * processing handle nested audio list * csm processing draft * mimi util * init updates * modular update * convert modular * processing update * csm tests update * generate tests handle third dim * generate utils handle third dim * propagate _get_initial_cache_position update * tied_weight_keys update + convert correctly * fix inputs_embeds * revert audio nested list * batch inference update + return audio * audio_utils update * processor update * some more integration tests * remove old test * porcessing output labels * improve * fix * update rope values with equivalent ones * conversion update * udpate tests * handle depth decoder generation config * remove default eos_token_id * make style * revert modeling_mimi * add default generation_config * remove sdpa since handled by default * make * fix conflict * fix conflicts * correct naming * correct imports * make * causal -> conditional naming * causal -> conditional naming * auto update * make * make * add doc * test update * fix weight init * audio tokens offsets as buffer * 4d mask in conditional class * make * doc update * fix causal mask * fix causal mask * doc update * doc update * add processor doc * update doc * fix 4d causal mask * update make_list_of_audio * do not default to mutable * remove duplicates * remove useless reset_parameters * use GradientCheckpointingLayer * use can_return_tuple * formatting * prepend placeholder in _sample * torch compile fix * some more fixies * convert modular * fix * default max_length in convert * handle depth decoder generation config correctly * clearer formulation * handle output_loading_info * handle softmax warning * add doc * propagate _get_initial_cache_position changes * generation in its own module * add processor tests * fix compile witu cuda graphs * fix compile with cuda graphs * add csm.md * include CSM loss * doc nit * doc nit * doc nit * Update docs/source/en/model_doc/csm.md Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * add save_audio to processor * Update src/transformers/models/csm/modular_csm.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * doc update * simplify audio_codes_mask computation * doc update * simplify loss computation * fix static cache test * fix * remove comment * simplify encoded length computation * use hf-internal-testing * doc update * cast to float before numpy * nit * mem efficient codebook head * nit * cat input values with cutoffs --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
@ -825,6 +825,8 @@
|
||||
title: Bark
|
||||
- local: model_doc/clap
|
||||
title: CLAP
|
||||
- local: model_doc/csm
|
||||
title: CSM
|
||||
- local: model_doc/dac
|
||||
title: dac
|
||||
- local: model_doc/encodec
|
||||
|
377
docs/source/en/model_doc/csm.md
Normal file
377
docs/source/en/model_doc/csm.md
Normal file
@ -0,0 +1,377 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# Csm
|
||||
|
||||
## Overview
|
||||
|
||||
The Conversational Speech Model (CSM) is the first open-source contextual text-to-speech model [released by Sesame](https://www.sesame.com/research/crossing_the_uncanny_valley_of_voice). It is designed to generate natural-sounding speech with or without conversational context. This context typically consists of multi-turn dialogue between speakers, represented as sequences of text and corresponding spoken audio.
|
||||
|
||||
**Model Architecture:**
|
||||
CSM is composed of two LLaMA-style auto-regressive transformer decoders: a backbone decoder that predicts the first codebook token and a depth decoder that generates the remaining tokens. It uses the pretrained codec model [Mimi](./mimi.md), introduced by Kyutai, to encode speech into discrete codebook tokens and decode them back into audio.
|
||||
|
||||
The original csm-1b checkpoint is available under the [Sesame](https://huggingface.co/sesame/csm-1b) organization on Hugging Face.
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/eustlb/documentation-images/resolve/main/csm_architecture.png"/>
|
||||
</div>
|
||||
|
||||
## Usage Tips
|
||||
|
||||
### Without Conversational Context
|
||||
|
||||
CSM can be used to simply generate speech from a text prompt:
|
||||
|
||||
```python
|
||||
import torch
|
||||
from transformers import CsmForConditionalGeneration, AutoProcessor
|
||||
|
||||
model_id = "eustlb/csm-1b"
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
# load the model and the processor
|
||||
processor = AutoProcessor.from_pretrained(model_id)
|
||||
model = CsmForConditionalGeneration.from_pretrained(model_id, device_map=device)
|
||||
|
||||
# prepare the inputs
|
||||
text = "[0]The past is just a story we tell ourselves." # `[0]` for speaker id 0
|
||||
inputs = processor(text, add_special_tokens=True).to(device)
|
||||
|
||||
# another equivalent way to prepare the inputs
|
||||
conversation = [
|
||||
{"role": "0", "content": [{"type": "text", "text": "The past is just a story we tell ourselves."}]},
|
||||
]
|
||||
inputs = processor.apply_chat_template(
|
||||
conversation,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
).to(device)
|
||||
|
||||
# infer the model
|
||||
audio = model.generate(**inputs, output_audio=True)
|
||||
processor.save_audio(audio, "example_without_context.wav")
|
||||
```
|
||||
|
||||
### With Conversational Context
|
||||
|
||||
CSM can be used to generate speech given a conversation, allowing consistency in the voices and content-aware generation:
|
||||
|
||||
```python
|
||||
import torch
|
||||
from transformers import CsmForConditionalGeneration, AutoProcessor
|
||||
from datasets import load_dataset, Audio
|
||||
|
||||
model_id = "eustlb/csm-1b"
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
# load the model and the processor
|
||||
processor = AutoProcessor.from_pretrained(model_id)
|
||||
model = CsmForConditionalGeneration.from_pretrained(model_id, device_map=device)
|
||||
|
||||
# prepare the inputs
|
||||
ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train")
|
||||
# ensure the audio is 24kHz
|
||||
ds = ds.cast_column("audio", Audio(sampling_rate=24000))
|
||||
conversation = []
|
||||
|
||||
# 1. context
|
||||
for text, audio, speaker_id in zip(ds[:4]["text"], ds[:4]["audio"], ds[:4]["speaker_id"]):
|
||||
conversation.append(
|
||||
{
|
||||
"role": f"{speaker_id}",
|
||||
"content": [{"type": "text", "text": text}, {"type": "audio", "path": audio["array"]}],
|
||||
}
|
||||
)
|
||||
|
||||
# 2. text prompt
|
||||
conversation.append({"role": f"{ds[4]['speaker_id']}", "content": [{"type": "text", "text": ds[4]["text"]}]})
|
||||
|
||||
inputs = processor.apply_chat_template(
|
||||
conversation,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
).to(device)
|
||||
|
||||
# infer the model
|
||||
audio = model.generate(**inputs, output_audio=True)
|
||||
processor.save_audio(audio, "example_with_context.wav")
|
||||
```
|
||||
|
||||
### Batched Inference
|
||||
|
||||
CSM supports batched inference!
|
||||
|
||||
```python
|
||||
import torch
|
||||
from transformers import CsmForConditionalGeneration, AutoProcessor
|
||||
from datasets import load_dataset, Audio
|
||||
|
||||
model_id = "eustlb/csm-1b"
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
# load the model and the processor
|
||||
processor = AutoProcessor.from_pretrained(model_id)
|
||||
model = CsmForConditionalGeneration.from_pretrained(model_id, device_map=device)
|
||||
|
||||
# prepare the inputs
|
||||
ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train")
|
||||
# ensure the audio is 24kHz
|
||||
ds = ds.cast_column("audio", Audio(sampling_rate=24000))
|
||||
# here a batch with two prompts
|
||||
conversation = [
|
||||
[
|
||||
{
|
||||
"role": f"{ds[0]['speaker_id']}",
|
||||
"content": [
|
||||
{"type": "text", "text": ds[0]["text"]},
|
||||
{"type": "audio", "path": ds[0]["audio"]["array"]},
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": f"{ds[1]['speaker_id']}",
|
||||
"content": [
|
||||
{"type": "text", "text": ds[1]["text"]},
|
||||
],
|
||||
},
|
||||
],
|
||||
[
|
||||
{
|
||||
"role": f"{ds[0]['speaker_id']}",
|
||||
"content": [
|
||||
{"type": "text", "text": ds[0]["text"]},
|
||||
],
|
||||
}
|
||||
],
|
||||
]
|
||||
inputs = processor.apply_chat_template(
|
||||
conversation,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
).to(device)
|
||||
|
||||
audio = model.generate(**inputs, output_audio=True)
|
||||
processor.save_audio(audio, [f"speech_batch_idx_{i}.wav" for i in range(len(audio))])
|
||||
```
|
||||
|
||||
### Making The Model Go Brrr
|
||||
|
||||
CSM supports full-graph compilation with CUDA graphs!
|
||||
|
||||
```python
|
||||
import torch
|
||||
import copy
|
||||
from transformers import CsmForConditionalGeneration, AutoProcessor
|
||||
from datasets import load_dataset
|
||||
|
||||
model_id = "eustlb/csm-1b"
|
||||
device = "cuda"
|
||||
|
||||
# set logs to ensure no recompilation and graph breaks
|
||||
torch._logging.set_logs(graph_breaks=True, recompiles=True, cudagraphs=True)
|
||||
|
||||
# load the model and the processor
|
||||
processor = AutoProcessor.from_pretrained(model_id)
|
||||
model = CsmForConditionalGeneration.from_pretrained(model_id, device_map=device)
|
||||
|
||||
# use static cache, enabling automatically torch compile with fullgraph and reduce-overhead
|
||||
model.generation_config.max_length = 250 # big enough to avoid recompilation
|
||||
model.generation_config.max_new_tokens = None # would take precedence over max_length
|
||||
model.generation_config.cache_implementation = "static"
|
||||
model.depth_decoder.generation_config.cache_implementation = "static"
|
||||
|
||||
# generation kwargs
|
||||
gen_kwargs = {
|
||||
"do_sample": False,
|
||||
"depth_decoder_do_sample": False,
|
||||
"temperature": 1.0,
|
||||
"depth_decoder_temperature": 1.0,
|
||||
}
|
||||
|
||||
# Define a timing decorator
|
||||
class TimerContext:
|
||||
def __init__(self, name="Execution"):
|
||||
self.name = name
|
||||
self.start_event = None
|
||||
self.end_event = None
|
||||
|
||||
def __enter__(self):
|
||||
# Use CUDA events for more accurate GPU timing
|
||||
self.start_event = torch.cuda.Event(enable_timing=True)
|
||||
self.end_event = torch.cuda.Event(enable_timing=True)
|
||||
self.start_event.record()
|
||||
return self
|
||||
|
||||
def __exit__(self, *args):
|
||||
self.end_event.record()
|
||||
torch.cuda.synchronize()
|
||||
elapsed_time = self.start_event.elapsed_time(self.end_event) / 1000.0
|
||||
print(f"{self.name} time: {elapsed_time:.4f} seconds")
|
||||
|
||||
# prepare the inputs
|
||||
ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train")
|
||||
|
||||
conversation = [
|
||||
{
|
||||
"role": f"{ds[0]['speaker_id']}",
|
||||
"content": [
|
||||
{"type": "text", "text": ds[0]["text"]},
|
||||
{"type": "audio", "path": ds[0]["audio"]["array"]},
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": f"{ds[1]['speaker_id']}",
|
||||
"content": [
|
||||
{"type": "text", "text": ds[1]["text"]},
|
||||
{"type": "audio", "path": ds[1]["audio"]["array"]},
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": f"{ds[2]['speaker_id']}",
|
||||
"content": [
|
||||
{"type": "text", "text": ds[2]["text"]},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
padded_inputs_1 = processor.apply_chat_template(
|
||||
conversation,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
).to(device)
|
||||
|
||||
print("\n" + "="*50)
|
||||
print("First generation - compiling and recording CUDA graphs...")
|
||||
with TimerContext("First generation"):
|
||||
_ = model.generate(**padded_inputs_1, **gen_kwargs)
|
||||
print("="*50)
|
||||
|
||||
print("\n" + "="*50)
|
||||
print("Second generation - fast !!!")
|
||||
with TimerContext("Second generation"):
|
||||
_ = model.generate(**padded_inputs_1, **gen_kwargs)
|
||||
print("="*50)
|
||||
|
||||
# now with different inputs
|
||||
conversation = [
|
||||
{
|
||||
"role": f"{ds[0]['speaker_id']}",
|
||||
"content": [
|
||||
{"type": "text", "text": ds[2]["text"]},
|
||||
{"type": "audio", "path": ds[2]["audio"]["array"]},
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": f"{ds[1]['speaker_id']}",
|
||||
"content": [
|
||||
{"type": "text", "text": ds[3]["text"]},
|
||||
{"type": "audio", "path": ds[3]["audio"]["array"]},
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": f"{ds[2]['speaker_id']}",
|
||||
"content": [
|
||||
{"type": "text", "text": ds[4]["text"]},
|
||||
],
|
||||
},
|
||||
]
|
||||
padded_inputs_2 = processor.apply_chat_template(
|
||||
conversation,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
).to(device)
|
||||
|
||||
print("\n" + "="*50)
|
||||
print("Generation with other inputs!")
|
||||
with TimerContext("Generation with different inputs"):
|
||||
_ = model.generate(**padded_inputs_2, **gen_kwargs)
|
||||
print("="*50)
|
||||
```
|
||||
|
||||
### Training
|
||||
|
||||
CSM Transformers integration supports training!
|
||||
|
||||
```python
|
||||
from transformers import CsmForConditionalGeneration, AutoProcessor
|
||||
from datasets import load_dataset, Audio
|
||||
|
||||
model_id = "eustlb/csm-1b"
|
||||
device = "cuda"
|
||||
|
||||
# load the model and the processor
|
||||
processor = AutoProcessor.from_pretrained(model_id)
|
||||
model = CsmForConditionalGeneration.from_pretrained(model_id, device_map=device)
|
||||
model.train()
|
||||
|
||||
ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train")
|
||||
# ensure the audio is 24kHz
|
||||
ds = ds.cast_column("audio", Audio(sampling_rate=24000))
|
||||
conversation = []
|
||||
|
||||
# context
|
||||
for text, audio, speaker_id in zip(ds[:4]["text"], ds[:4]["audio"], ds[:4]["speaker_id"]):
|
||||
conversation.append(
|
||||
{
|
||||
"role": f"{speaker_id}",
|
||||
"content": [{"type": "text", "text": text}, {"type": "audio", "path": audio["array"]}],
|
||||
}
|
||||
)
|
||||
|
||||
inputs = processor.apply_chat_template(
|
||||
conversation,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
output_labels=True,
|
||||
).to(device)
|
||||
|
||||
out = model(**inputs)
|
||||
out.loss.backward()
|
||||
```
|
||||
|
||||
This model was contributed by [Eustache Le Bihan](https://huggingface.co/eustlb).
|
||||
The original code can be found [here](https://github.com/SesameAILabs/csm).
|
||||
|
||||
|
||||
## CsmConfig
|
||||
|
||||
[[autodoc]] CsmConfig
|
||||
|
||||
## CsmDepthDecoderConfig
|
||||
|
||||
[[autodoc]] CsmDepthDecoderConfig
|
||||
|
||||
## CsmProcessor
|
||||
|
||||
[[autodoc]] CsmProcessor
|
||||
- __call__
|
||||
|
||||
## CsmForConditionalGeneration
|
||||
|
||||
[[autodoc]] CsmForConditionalGeneration
|
||||
- forward
|
||||
- generate
|
||||
|
||||
## CsmDepthDecoderForCausalLM
|
||||
|
||||
[[autodoc]] CsmDepthDecoderForCausalLM
|
||||
|
||||
## CsmDepthDecoderModel
|
||||
|
||||
[[autodoc]] CsmDepthDecoderModel
|
||||
|
||||
## CsmBackboneModel
|
||||
|
||||
[[autodoc]] CsmBackboneModel
|
@ -24,7 +24,12 @@ from typing import List, Optional, Tuple, Union
|
||||
import numpy as np
|
||||
import requests
|
||||
|
||||
from .utils import is_librosa_available, requires_backends
|
||||
from .utils import (
|
||||
is_librosa_available,
|
||||
is_numpy_array,
|
||||
is_torch_tensor,
|
||||
requires_backends,
|
||||
)
|
||||
|
||||
|
||||
if is_librosa_available():
|
||||
@ -69,6 +74,36 @@ AudioInput = Union[
|
||||
]
|
||||
|
||||
|
||||
def is_valid_audio(audio):
|
||||
return is_numpy_array(audio) or is_torch_tensor(audio)
|
||||
|
||||
|
||||
def is_valid_list_of_audio(audio):
|
||||
return audio and all(is_valid_audio(audio_i) for audio_i in audio)
|
||||
|
||||
|
||||
def make_list_of_audio(
|
||||
audio: Union[list[AudioInput], AudioInput],
|
||||
) -> AudioInput:
|
||||
"""
|
||||
Ensure that the output is a list of audio.
|
||||
Args:
|
||||
audio (`Union[List[AudioInput], AudioInput]`):
|
||||
The input audio.
|
||||
Returns:
|
||||
list: A list of audio.
|
||||
"""
|
||||
# If it's a list of audios, it's already in the right format
|
||||
if isinstance(audio, (list, tuple)) and is_valid_list_of_audio(audio):
|
||||
return audio
|
||||
|
||||
# If it's a single audio, convert it to a list of
|
||||
if is_valid_audio(audio):
|
||||
return [audio]
|
||||
|
||||
raise ValueError("Invalid input type. Must be a single audio or a list of audio")
|
||||
|
||||
|
||||
def hertz_to_mel(freq: Union[float, np.ndarray], mel_scale: str = "htk") -> Union[float, np.ndarray]:
|
||||
"""
|
||||
Convert frequency from hertz to mels.
|
||||
|
@ -73,7 +73,7 @@ class MaxLengthCriteria(StoppingCriteria):
|
||||
|
||||
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
|
||||
cur_len = input_ids.shape[-1]
|
||||
cur_len = input_ids.shape[1]
|
||||
is_done = cur_len >= self.max_length
|
||||
if self.max_position_embeddings is not None and not is_done and cur_len >= self.max_position_embeddings:
|
||||
logger.warning_once(
|
||||
|
@ -563,7 +563,7 @@ class GenerationMixin:
|
||||
if model_inputs["inputs_embeds"] is not None:
|
||||
batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
|
||||
else:
|
||||
batch_size, sequence_length = model_inputs[input_ids_key].shape
|
||||
batch_size, sequence_length = model_inputs[input_ids_key].shape[:2]
|
||||
|
||||
# Create the causal mask with fixed shape in advance, to reduce recompilations. If the function to create
|
||||
# the 4D causal mask exists, it should be present in the base model (XXXModel class) or in its decoder.
|
||||
@ -1708,7 +1708,7 @@ class GenerationMixin:
|
||||
|
||||
return generation_config, model_kwargs
|
||||
|
||||
def _get_initial_cache_position(self, input_ids, model_kwargs):
|
||||
def _get_initial_cache_position(self, seq_length, device, model_kwargs):
|
||||
"""Calculates `cache_position` for the pre-fill stage based on `input_ids` and optionally past length"""
|
||||
# `torch.compile`-friendly `torch.arange` from a shape -- the lines below are equivalent to `torch.arange`
|
||||
if "inputs_embeds" in model_kwargs and not self.config.is_encoder_decoder:
|
||||
@ -1718,7 +1718,7 @@ class GenerationMixin:
|
||||
torch.ones_like(model_kwargs["decoder_inputs_embeds"][0, :, 0], dtype=torch.int64).cumsum(0) - 1
|
||||
)
|
||||
else:
|
||||
cache_position = torch.ones_like(input_ids[0, :], dtype=torch.int64).cumsum(0) - 1
|
||||
cache_position = torch.ones(seq_length, dtype=torch.int64, device=device).cumsum(0) - 1
|
||||
|
||||
past_length = 0
|
||||
if model_kwargs.get("past_key_values") is not None:
|
||||
@ -2332,7 +2332,7 @@ class GenerationMixin:
|
||||
streamer.put(input_ids.cpu())
|
||||
|
||||
# 6. Prepare `max_length` depending on other stopping criteria.
|
||||
input_ids_length = input_ids.shape[-1]
|
||||
input_ids_length = input_ids.shape[1]
|
||||
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
|
||||
has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None
|
||||
generation_config = self._prepare_generated_length(
|
||||
@ -2805,9 +2805,9 @@ class GenerationMixin:
|
||||
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
|
||||
|
||||
# keep track of which sequences are already finished
|
||||
batch_size = input_ids.shape[0]
|
||||
batch_size, cur_length = input_ids.shape[:2]
|
||||
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
|
||||
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
|
||||
model_kwargs = self._get_initial_cache_position(cur_length, input_ids.device, model_kwargs)
|
||||
|
||||
this_peer_finished = False
|
||||
|
||||
@ -3016,9 +3016,9 @@ class GenerationMixin:
|
||||
)
|
||||
|
||||
# keep track of which sequences are already finished
|
||||
batch_size = input_ids.shape[0]
|
||||
batch_size, cur_len = input_ids.shape[:2]
|
||||
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
|
||||
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
|
||||
model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs)
|
||||
|
||||
# Create cosine_matrix_mask based on the attention_mask
|
||||
cosine_matrix_mask = torch.ones_like(input_ids, dtype=torch.long)
|
||||
@ -3428,10 +3428,10 @@ class GenerationMixin:
|
||||
)
|
||||
|
||||
# keep track of which sequences are already finished
|
||||
batch_size, cur_len = input_ids.shape
|
||||
batch_size, cur_len = input_ids.shape[:2]
|
||||
this_peer_finished = False
|
||||
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
|
||||
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
|
||||
model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs)
|
||||
|
||||
model_forward = self.__call__
|
||||
compile_forward = self._valid_auto_compile_criteria(model_kwargs, generation_config)
|
||||
@ -3834,7 +3834,7 @@ class GenerationMixin:
|
||||
num_beams = generation_config.num_beams
|
||||
num_return_sequences = generation_config.num_return_sequences
|
||||
|
||||
batch_size_unflattened, cur_len = input_ids.shape
|
||||
batch_size_unflattened, cur_len = input_ids.shape[:2]
|
||||
batch_size = batch_size_unflattened // num_beams
|
||||
# TODO (joao): standardize special cases
|
||||
if self.__class__.__name__ == "MoshiDepthDecoder":
|
||||
@ -3857,7 +3857,7 @@ class GenerationMixin:
|
||||
dim=0,
|
||||
).to(input_ids.device)
|
||||
|
||||
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
|
||||
model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs)
|
||||
|
||||
# (joao) feature lost in the refactor. Probably won't implement, hurts readability with minimal gains (there
|
||||
# are newer low-memory alternatives like the offloaded cache)
|
||||
@ -4156,7 +4156,7 @@ class GenerationMixin:
|
||||
device = input_ids.device
|
||||
|
||||
batch_beam_size, cur_len = input_ids.shape
|
||||
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
|
||||
model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs)
|
||||
|
||||
if return_dict_in_generate and output_scores:
|
||||
beam_indices = [tuple(() for _ in range(num_sub_beams * batch_size)) for _ in range(num_beam_groups)]
|
||||
@ -4190,7 +4190,7 @@ class GenerationMixin:
|
||||
|
||||
this_peer_finished = False
|
||||
|
||||
decoder_prompt_len = input_ids.shape[-1] # record the prompt length of decoder
|
||||
decoder_prompt_len = input_ids.shape[1] # record the prompt length of decoder
|
||||
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
|
||||
# predicted tokens in cur_len step
|
||||
current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device)
|
||||
@ -4444,8 +4444,8 @@ class GenerationMixin:
|
||||
batch_size = len(constrained_beam_scorer._beam_hyps)
|
||||
num_beams = constrained_beam_scorer.num_beams
|
||||
|
||||
batch_beam_size, cur_len = input_ids.shape
|
||||
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
|
||||
batch_beam_size, cur_len = input_ids.shape[:2]
|
||||
model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs)
|
||||
|
||||
if num_beams * batch_size != batch_beam_size:
|
||||
raise ValueError(
|
||||
@ -4477,7 +4477,7 @@ class GenerationMixin:
|
||||
|
||||
this_peer_finished = False
|
||||
|
||||
decoder_prompt_len = input_ids.shape[-1] # record the prompt length of decoder
|
||||
decoder_prompt_len = input_ids.shape[1] # record the prompt length of decoder
|
||||
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
|
||||
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
||||
|
||||
@ -4698,14 +4698,14 @@ class GenerationMixin:
|
||||
)
|
||||
|
||||
# keep track of which sequences are already finished
|
||||
batch_size = input_ids.shape[0]
|
||||
batch_size, cur_len = input_ids.shape[:2]
|
||||
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
|
||||
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
|
||||
model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs)
|
||||
|
||||
this_peer_finished = False
|
||||
is_first_iteration = True # to preserve the same API in the output as other generation methods
|
||||
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
|
||||
cur_len = input_ids.shape[-1]
|
||||
cur_len = input_ids.shape[1]
|
||||
|
||||
# 1. Fetch candidate sequences from a `CandidateGenerator` and move to the correct device
|
||||
candidate_input_ids, candidate_logits = candidate_generator.get_candidates(input_ids)
|
||||
@ -4795,7 +4795,7 @@ class GenerationMixin:
|
||||
input_ids = torch.cat((input_ids, valid_tokens), dim=-1)
|
||||
if streamer is not None:
|
||||
streamer.put(valid_tokens.cpu())
|
||||
new_cur_len = input_ids.shape[-1]
|
||||
new_cur_len = input_ids.shape[1]
|
||||
|
||||
# 4.2. Discard past key values relative to unused assistant tokens
|
||||
new_cache_size = new_cur_len - 1
|
||||
|
@ -158,4 +158,5 @@ LOSS_MAPPING = {
|
||||
"RTDetrForObjectDetection": RTDetrForObjectDetectionLoss,
|
||||
"RTDetrV2ForObjectDetection": RTDetrForObjectDetectionLoss,
|
||||
"DFineForObjectDetection": DFineForObjectDetectionLoss,
|
||||
"CsmForConditionalGeneration": ForCausalLMLoss,
|
||||
}
|
||||
|
@ -68,6 +68,7 @@ if TYPE_CHECKING:
|
||||
from .convnextv2 import *
|
||||
from .cpm import *
|
||||
from .cpmant import *
|
||||
from .csm import *
|
||||
from .ctrl import *
|
||||
from .cvt import *
|
||||
from .d_fine import *
|
||||
|
@ -80,6 +80,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
|
||||
("convnext", "ConvNextConfig"),
|
||||
("convnextv2", "ConvNextV2Config"),
|
||||
("cpmant", "CpmAntConfig"),
|
||||
("csm", "CsmConfig"),
|
||||
("ctrl", "CTRLConfig"),
|
||||
("cvt", "CvtConfig"),
|
||||
("d_fine", "DFineConfig"),
|
||||
@ -437,6 +438,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
|
||||
("convnextv2", "ConvNeXTV2"),
|
||||
("cpm", "CPM"),
|
||||
("cpmant", "CPM-Ant"),
|
||||
("csm", "CSM"),
|
||||
("ctrl", "CTRL"),
|
||||
("cvt", "CvT"),
|
||||
("d_fine", "D-FINE"),
|
||||
|
@ -78,6 +78,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
||||
("convnext", "ConvNextModel"),
|
||||
("convnextv2", "ConvNextV2Model"),
|
||||
("cpmant", "CpmAntModel"),
|
||||
("csm", "CsmForConditionalGeneration"),
|
||||
("ctrl", "CTRLModel"),
|
||||
("cvt", "CvtModel"),
|
||||
("d_fine", "DFineModel"),
|
||||
@ -1446,6 +1447,7 @@ MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for Text-To-Waveform mapping
|
||||
("bark", "BarkModel"),
|
||||
("csm", "CsmForConditionalGeneration"),
|
||||
("fastspeech2_conformer", "FastSpeech2ConformerWithHifiGan"),
|
||||
("musicgen", "MusicgenForConditionalGeneration"),
|
||||
("musicgen_melody", "MusicgenMelodyForConditionalGeneration"),
|
||||
|
28
src/transformers/models/csm/__init__.py
Normal file
28
src/transformers/models/csm/__init__.py
Normal file
@ -0,0 +1,28 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import _LazyModule
|
||||
from ...utils.import_utils import define_import_structure
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_csm import *
|
||||
from .modeling_csm import *
|
||||
from .processing_csm import *
|
||||
else:
|
||||
import sys
|
||||
|
||||
_file = globals()["__file__"]
|
||||
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
440
src/transformers/models/csm/configuration_csm.py
Normal file
440
src/transformers/models/csm/configuration_csm.py
Normal file
@ -0,0 +1,440 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 Sesame and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...modeling_rope_utils import rope_config_validation
|
||||
from ...utils import logging
|
||||
from ..auto.configuration_auto import AutoConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class CsmDepthDecoderConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`CsmDepthDecoderModel`]. It is used to instantiate an CSM depth decoder
|
||||
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield
|
||||
a similar configuration to that of the csm-1b.
|
||||
|
||||
e.g. [eustlb/csm-1b](https://huggingface.co/eustlb/csm-1b)
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
|
||||
Args:
|
||||
num_codebooks (`int`, *optional*, defaults to 32):
|
||||
Number of codebooks used in the underlying codec model responsible for tokenizing the audio.
|
||||
backbone_hidden_size (`int`, *optional*, defaults to 2048):
|
||||
Dimension of the hidden representations of the backbone model used with this depth decoder.
|
||||
vocab_size (`int`, *optional*, defaults to 2051):
|
||||
Vocabulary size of the CsmDepthDecoder model. Defines the number of different audio tokens that can be represented by each codebook.
|
||||
hidden_size (`int`, *optional*, defaults to 1024):
|
||||
Dimension of the hidden representations.
|
||||
intermediate_size (`int`, *optional*, defaults to 8192):
|
||||
Dimension of the MLP representations.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 4):
|
||||
Number of hidden layers in the Transformer decoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 8):
|
||||
Number of attention heads for each attention layer in the Transformer decoder.
|
||||
num_key_value_heads (`int`, *optional*, defaults to 2):
|
||||
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
||||
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
||||
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
||||
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
||||
by meanpooling all the original heads within that group. For more details checkout [this
|
||||
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
|
||||
`num_attention_heads`.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
||||
The non-linear activation function (function or string) in the decoder.
|
||||
max_position_embeddings (`int`, *optional*, defaults to 33):
|
||||
The maximum sequence length that this model might ever be used with.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
|
||||
The epsilon used by the rms normalization layers.
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||
relevant if `config.is_decoder=True`.
|
||||
pad_token_id (`int`, *optional*, defaults to 2050):
|
||||
Padding token id.
|
||||
bos_token_id (`int`, *optional*):
|
||||
Beginning of stream token id.
|
||||
eos_token_id (`int`, *optional*):
|
||||
End of stream token id.
|
||||
rope_theta (`float`, *optional*, defaults to 500000):
|
||||
The base period of the RoPE embeddings.
|
||||
rope_scaling (`Dict`, *optional*):
|
||||
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
|
||||
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
|
||||
accordingly.
|
||||
Expected contents:
|
||||
`rope_type` (`str`):
|
||||
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
|
||||
'llama3'], with 'default' being the original RoPE implementation.
|
||||
`factor` (`float`, *optional*):
|
||||
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
|
||||
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
|
||||
original maximum pre-trained length.
|
||||
`original_max_position_embeddings` (`int`, *optional*):
|
||||
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
|
||||
pretraining.
|
||||
`attention_factor` (`float`, *optional*):
|
||||
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
|
||||
computation. If unspecified, it defaults to value recommended by the implementation, using the
|
||||
`factor` field to infer the suggested value.
|
||||
`beta_fast` (`float`, *optional*):
|
||||
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 32.
|
||||
`beta_slow` (`float`, *optional*):
|
||||
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 1.
|
||||
`short_factor` (`List[float]`, *optional*):
|
||||
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`long_factor` (`List[float]`, *optional*):
|
||||
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`low_freq_factor` (`float`, *optional*):
|
||||
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
|
||||
`high_freq_factor` (`float`, *optional*):
|
||||
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
|
||||
attention_bias (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
mlp_bias (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers.
|
||||
head_dim (`int`, *optional*):
|
||||
The attention head dimension. If None, it will default to hidden_size // num_attention_heads
|
||||
|
||||
```python
|
||||
>>> from transformers import CsmDepthDecoder, CsmDepthDecoderConfig
|
||||
|
||||
>>> # Initializing a CsmDepthDecoder
|
||||
>>> configuration = CsmDepthDecoderConfig()
|
||||
>>> model = CsmDepthDecoderModel(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "csm_depth_decoder_model"
|
||||
base_config_key = "depth_decoder_config"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_codebooks=32,
|
||||
backbone_hidden_size=2048,
|
||||
vocab_size=2051,
|
||||
hidden_size=1024,
|
||||
intermediate_size=8192,
|
||||
num_hidden_layers=4,
|
||||
num_attention_heads=8,
|
||||
num_key_value_heads=2,
|
||||
hidden_act="silu",
|
||||
max_position_embeddings=33,
|
||||
initializer_range=0.02,
|
||||
rms_norm_eps=1e-5,
|
||||
use_cache=True,
|
||||
pad_token_id=None,
|
||||
bos_token_id=None,
|
||||
eos_token_id=None,
|
||||
rope_theta=500000,
|
||||
rope_scaling=None,
|
||||
attention_bias=False,
|
||||
attention_dropout=0.0,
|
||||
mlp_bias=False,
|
||||
head_dim=None,
|
||||
**kwargs,
|
||||
):
|
||||
if kwargs.pop("tie_word_embeddings", False):
|
||||
raise ValueError("`tie_word_embeddings=True` is not supported for CsmDepthDecoderConfig")
|
||||
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
tie_word_embeddings=False,
|
||||
**kwargs,
|
||||
)
|
||||
self.num_codebooks = num_codebooks
|
||||
self.vocab_size = vocab_size
|
||||
self.backbone_hidden_size = backbone_hidden_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
|
||||
# for backward compatibility
|
||||
if num_key_value_heads is None:
|
||||
num_key_value_heads = num_attention_heads
|
||||
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.initializer_range = initializer_range
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.rope_scaling = rope_scaling
|
||||
self.attention_bias = attention_bias
|
||||
self.attention_dropout = attention_dropout
|
||||
self.mlp_bias = mlp_bias
|
||||
self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads
|
||||
# Validate the correctness of rotary position embeddings parameters
|
||||
# BC: if there is a 'type' field, copy it it to 'rope_type'.
|
||||
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
||||
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
||||
rope_config_validation(self)
|
||||
|
||||
|
||||
class CsmConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`CsmForConditionalGeneration`]. It is used to instantiate an CSM
|
||||
model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
||||
with the defaults will yield a similar configuration to that of the csm-1b.
|
||||
|
||||
e.g. [eustlb/csm-1b](https://huggingface.co/eustlb/csm-1b)
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
num_codebooks (`int`, *optional*, defaults to 32):
|
||||
Number of codebooks used in the underlying codec model responsible for tokenizing the audio.
|
||||
vocab_size (`int`, *optional*, defaults to 2051):
|
||||
Vocabulary size of the Csm model. Defines the number of different audio tokens that can be represented by each codebook.
|
||||
text_vocab_size (`int`, *optional*, defaults to 128256):
|
||||
Vocabulary size of the text input for the Csm model. Defines the number of different text tokens that can be represented.
|
||||
hidden_size (`int`, *optional*, defaults to 2048):
|
||||
Dimension of the hidden representations of the backbone model.
|
||||
intermediate_size (`int`, *optional*, defaults to 8192):
|
||||
Dimension of the MLP representations of the backbone model.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 16):
|
||||
Number of hidden layers in the backbone model Transformer decoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 32):
|
||||
Number of attention heads for each attention layer in the backbone model Transformer decoder.
|
||||
num_key_value_heads (`int`, *optional*, defaults to 8):
|
||||
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
||||
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
||||
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
||||
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
||||
by meanpooling all the original heads within that group. For more details checkout [this
|
||||
paper](https://arxiv.org/pdf/2305.13245.pdf).
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
||||
The non-linear activation function (function or string) in the backbone model Transformer decoder.
|
||||
max_position_embeddings (`int`, *optional*, defaults to 2048):
|
||||
The maximum sequence length that this model might ever be used with.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
|
||||
The epsilon used by the rms normalization layers.
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||
relevant if `config.is_decoder=True`.
|
||||
pad_token_id (`int`, *optional*, defaults to 128002):
|
||||
Padding token id.
|
||||
codebook_pad_token_id (`int`, *optional*, defaults to 2050):
|
||||
Padding token id for codebook tokens.
|
||||
codebook_eos_token_id (`int`, *optional*, defaults to 0):
|
||||
End of stream token id for codebook tokens.
|
||||
bos_token_id (`int`, *optional*, defaults to 128000):
|
||||
Beginning of stream token id.
|
||||
eos_token_id (`int`, *optional*):
|
||||
End of stream token id.
|
||||
audio_token_id (`int`, *optional*, defaults to 128002):
|
||||
Audio token id in the text input.
|
||||
audio_eos_token_id (`int`, *optional*, defaults to 128003):
|
||||
End of stream token id for audio in the text input.
|
||||
rope_theta (`float`, *optional*, defaults to 500000):
|
||||
The base period of the RoPE embeddings.
|
||||
rope_scaling (`Dict`, *optional*, defaults to `{'factor': 32.0, 'high_freq_factor': 0.5, 'low_freq_factor': 0.125, 'original_max_position_embeddings': 1024, 'rope_type': 'llama3'}`):
|
||||
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
|
||||
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
|
||||
accordingly.
|
||||
Expected contents:
|
||||
`rope_type` (`str`):
|
||||
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
|
||||
'llama3'], with 'default' being the original RoPE implementation.
|
||||
`factor` (`float`, *optional*):
|
||||
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
|
||||
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
|
||||
original maximum pre-trained length.
|
||||
`original_max_position_embeddings` (`int`, *optional*):
|
||||
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
|
||||
pretraining.
|
||||
`attention_factor` (`float`, *optional*):
|
||||
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
|
||||
computation. If unspecified, it defaults to value recommended by the implementation, using the
|
||||
`factor` field to infer the suggested value.
|
||||
`beta_fast` (`float`, *optional*):
|
||||
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 32.
|
||||
`beta_slow` (`float`, *optional*):
|
||||
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 1.
|
||||
`short_factor` (`List[float]`, *optional*):
|
||||
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`long_factor` (`List[float]`, *optional*):
|
||||
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`low_freq_factor` (`float`, *optional*):
|
||||
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
|
||||
`high_freq_factor` (`float`, *optional*):
|
||||
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
|
||||
attention_bias (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
mlp_bias (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers.
|
||||
head_dim (`int`, *optional*):
|
||||
The attention head dimension. If None, it will default to hidden_size // num_attention_heads
|
||||
tie_codebooks_embeddings (`bool`, *optional*, defaults to `True`):
|
||||
Whether to tie the codebook tokens embeddings of the backbone model to the codebook tokens embeddings of the depth decoder.
|
||||
depth_decoder_config (`CsmDepthDecoderConfig`, *optional*):
|
||||
Configuration for the depth decoder.
|
||||
codec_config (`PretrainedConfig`, *optional*):
|
||||
Configuration for the codec.
|
||||
|
||||
```python
|
||||
>>> from transformers import CsmForConditionalGeneration, CsmConfig
|
||||
|
||||
>>> # Initializing a CsmConfig
|
||||
>>> configuration = CsmConfig()
|
||||
|
||||
>>> # Initializing a model
|
||||
>>> model = CsmForConditionalGeneration(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "csm"
|
||||
base_config_key = "csm_config"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
sub_configs = {
|
||||
"codec_config": AutoConfig,
|
||||
"depth_decoder_config": CsmDepthDecoderConfig,
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_codebooks=32,
|
||||
vocab_size=2051,
|
||||
text_vocab_size=128256,
|
||||
hidden_size=2048,
|
||||
intermediate_size=8192,
|
||||
num_hidden_layers=16,
|
||||
num_attention_heads=32,
|
||||
num_key_value_heads=8,
|
||||
hidden_act="silu",
|
||||
max_position_embeddings=2048,
|
||||
initializer_range=0.02,
|
||||
rms_norm_eps=1e-5,
|
||||
use_cache=True,
|
||||
pad_token_id=128002,
|
||||
codebook_pad_token_id=2050,
|
||||
codebook_eos_token_id=0,
|
||||
bos_token_id=128000,
|
||||
eos_token_id=None,
|
||||
audio_token_id=128002,
|
||||
audio_eos_token_id=128003,
|
||||
rope_theta=500000,
|
||||
rope_scaling=None,
|
||||
attention_bias=False,
|
||||
attention_dropout=0.0,
|
||||
mlp_bias=False,
|
||||
head_dim=None,
|
||||
tie_codebooks_embeddings=True,
|
||||
depth_decoder_config=None,
|
||||
codec_config=None,
|
||||
**kwargs,
|
||||
):
|
||||
if kwargs.pop("tie_word_embeddings", False):
|
||||
raise ValueError("`tie_word_embeddings=True` is not supported for CsmConfig")
|
||||
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
tie_word_embeddings=False,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if depth_decoder_config is None:
|
||||
self.depth_decoder_config = CsmDepthDecoderConfig()
|
||||
logger.info("depth_decoder_config is None, using default depth decoder config.")
|
||||
elif isinstance(depth_decoder_config, dict):
|
||||
self.depth_decoder_config = CsmDepthDecoderConfig(**depth_decoder_config)
|
||||
elif isinstance(depth_decoder_config, CsmDepthDecoderConfig):
|
||||
self.depth_decoder_config = depth_decoder_config
|
||||
|
||||
if codec_config is None:
|
||||
self.codec_config = AutoConfig.for_model("mimi")
|
||||
logger.info("codec_config is None, using default audio encoder config.")
|
||||
elif isinstance(codec_config, dict):
|
||||
self.codec_config = AutoConfig.for_model(**codec_config)
|
||||
elif isinstance(codec_config, PretrainedConfig):
|
||||
self.codec_config = codec_config
|
||||
|
||||
self.text_vocab_size = text_vocab_size
|
||||
self.num_codebooks = num_codebooks
|
||||
self.audio_token_id = audio_token_id
|
||||
self.audio_eos_token_id = audio_eos_token_id
|
||||
self.codebook_pad_token_id = codebook_pad_token_id
|
||||
self.codebook_eos_token_id = codebook_eos_token_id
|
||||
self.tie_codebooks_embeddings = tie_codebooks_embeddings
|
||||
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
|
||||
# for backward compatibility
|
||||
if num_key_value_heads is None:
|
||||
num_key_value_heads = num_attention_heads
|
||||
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.initializer_range = initializer_range
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.rope_scaling = rope_scaling
|
||||
self.attention_bias = attention_bias
|
||||
self.attention_dropout = attention_dropout
|
||||
self.mlp_bias = mlp_bias
|
||||
self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads
|
||||
# Validate the correctness of rotary position embeddings parameters
|
||||
# BC: if there is a 'type' field, copy it it to 'rope_type'.
|
||||
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
||||
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
||||
rope_config_validation(self)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"CsmDepthDecoderConfig",
|
||||
"CsmConfig",
|
||||
]
|
339
src/transformers/models/csm/convert_csm.py
Normal file
339
src/transformers/models/csm/convert_csm.py
Normal file
@ -0,0 +1,339 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import gc
|
||||
import os
|
||||
import re
|
||||
|
||||
import torch
|
||||
from tokenizers.processors import TemplateProcessing
|
||||
|
||||
from transformers import (
|
||||
AutoFeatureExtractor,
|
||||
AutoTokenizer,
|
||||
CsmConfig,
|
||||
CsmDepthDecoderConfig,
|
||||
CsmForConditionalGeneration,
|
||||
CsmProcessor,
|
||||
MimiModel,
|
||||
)
|
||||
from transformers.utils.hub import cached_file
|
||||
|
||||
|
||||
# fmt: off
|
||||
ORIGINAL_TO_CONVERTED_KEY_MAPPING = {
|
||||
r"backbone\.layers\.(\d+)": r"backbone_model.layers.\1",
|
||||
r"decoder\.layers\.(\d+)": r"depth_decoder.model.layers.\1",
|
||||
|
||||
r"attn": r"self_attn",
|
||||
r"output_proj": r"o_proj",
|
||||
r"w1": r"gate_proj",
|
||||
r"w2": r"down_proj",
|
||||
r"w3": r"up_proj",
|
||||
|
||||
r"text_embeddings": r"embed_text_tokens",
|
||||
r"audio_embeddings": r"backbone_model.embed_tokens.embed_audio_tokens",
|
||||
|
||||
r"codebook0_head": r"lm_head",
|
||||
r"audio_head": r"depth_decoder.codebooks_head.weight",
|
||||
r"projection": r"depth_decoder.model.inputs_embeds_projector",
|
||||
|
||||
r"sa_norm.scale": r"input_layernorm.weight",
|
||||
r"mlp_norm.scale": r"post_attention_layernorm.weight",
|
||||
r"decoder.norm.scale": r"depth_decoder.model.norm.weight",
|
||||
r"backbone.norm.scale": r"backbone_model.norm.weight",
|
||||
}
|
||||
# fmt: on
|
||||
|
||||
|
||||
def permute_for_rope(input_tensor, n_heads, dim1, dim2):
|
||||
"""
|
||||
When you go from the complex ROPE formulation to sin and cos one, you need
|
||||
to permute the query and key weights (to avoid doing it on the fly)
|
||||
"""
|
||||
input_tensor = input_tensor.reshape(dim1, dim2)
|
||||
input_tensor = input_tensor.view(n_heads, dim1 // n_heads // 2, 2, dim2)
|
||||
input_tensor = input_tensor.transpose(1, 2).reshape(dim1, dim2)
|
||||
return input_tensor
|
||||
|
||||
|
||||
def convert_key(key, mapping):
|
||||
for pattern, replacement in mapping.items():
|
||||
key = re.sub(pattern, replacement, key)
|
||||
return key
|
||||
|
||||
|
||||
def write_model(
|
||||
input_path_or_repo,
|
||||
model_name,
|
||||
codec_model_path_or_repo,
|
||||
output_dir,
|
||||
safe_serialization=True,
|
||||
):
|
||||
print("Converting the model.")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
codec_model = MimiModel.from_pretrained(codec_model_path_or_repo)
|
||||
codec_model.config._attn_implementation_autoset = False
|
||||
|
||||
# prepare rope scaling args: the model uses originally
|
||||
# 1 - for the depth decoder
|
||||
# rope_theta=500000,
|
||||
# rope_scaling={
|
||||
# "factor": 32.0,
|
||||
# "high_freq_factor": 4.0,
|
||||
# "low_freq_factor": 1.0,
|
||||
# "original_max_position_embeddings": 8192,
|
||||
# "rope_type": "llama3",
|
||||
# },
|
||||
# 2 - for the backbone
|
||||
# rope_theta=500000,
|
||||
# rope_scaling={
|
||||
# "factor": 32.0,
|
||||
# "high_freq_factor": 4.0,
|
||||
# "low_freq_factor": 1.0,
|
||||
# "original_max_position_embeddings": 8192,
|
||||
# "rope_type": "llama3",
|
||||
# },
|
||||
#
|
||||
# Yet we want to use max_position_embeddings=32, resp. 2048
|
||||
# This will throw warning as we would have original_max_position_embeddings >= max_position_embeddings
|
||||
# Therefore, we convert values to equivalent ones
|
||||
|
||||
depth_decoder_config = CsmDepthDecoderConfig(
|
||||
rope_scaling={
|
||||
"factor": 32.0,
|
||||
"high_freq_factor": 0.0078125,
|
||||
"low_freq_factor": 0.001953125,
|
||||
"original_max_position_embeddings": 16,
|
||||
"rope_type": "llama3",
|
||||
},
|
||||
)
|
||||
|
||||
config = CsmConfig(
|
||||
codec_config=codec_model.config,
|
||||
depth_decoder_config=depth_decoder_config,
|
||||
rope_scaling={
|
||||
"factor": 32.0,
|
||||
"high_freq_factor": 0.5,
|
||||
"low_freq_factor": 0.125,
|
||||
"original_max_position_embeddings": 1024,
|
||||
"rope_type": "llama3",
|
||||
},
|
||||
)
|
||||
|
||||
params = {
|
||||
"backbone": {
|
||||
"num_attention_heads": config.num_attention_heads,
|
||||
"num_key_value_heads": config.num_key_value_heads,
|
||||
"dim_per_head": config.head_dim,
|
||||
"key_value_dim": config.head_dim * config.num_key_value_heads,
|
||||
"dim": config.hidden_size,
|
||||
},
|
||||
"depth_decoder": {
|
||||
"num_attention_heads": config.depth_decoder_config.num_attention_heads,
|
||||
"num_key_value_heads": config.depth_decoder_config.num_key_value_heads,
|
||||
"dim_per_head": config.depth_decoder_config.head_dim,
|
||||
"key_value_dim": config.depth_decoder_config.head_dim * config.depth_decoder_config.num_key_value_heads,
|
||||
"dim": config.depth_decoder_config.hidden_size,
|
||||
},
|
||||
}
|
||||
|
||||
model_path = cached_file(
|
||||
input_path_or_repo,
|
||||
model_name,
|
||||
)
|
||||
print(f"Fetching all parameters from the checkpoint at {model_path}...")
|
||||
loaded = torch.load(model_path, map_location="cpu")
|
||||
|
||||
print("Converting model...")
|
||||
state_dict = {}
|
||||
|
||||
# -----------------------
|
||||
# convert parameter names
|
||||
# -----------------------
|
||||
|
||||
# Add codec_model. prefix to every key in the codec model state dict
|
||||
codec_state_dict = {f"codec_model.{k}": v for k, v in codec_model.state_dict().items()}
|
||||
state_dict.update(codec_state_dict)
|
||||
|
||||
for key, value in loaded.items():
|
||||
new_key = convert_key(key, ORIGINAL_TO_CONVERTED_KEY_MAPPING)
|
||||
current_parameter = value
|
||||
|
||||
# Post-process the current_parameter.
|
||||
if re.search("(k|q)_proj.weight", new_key):
|
||||
params_keys = "backbone" if "backbone" in new_key else "depth_decoder"
|
||||
if "q_proj" in new_key:
|
||||
num_heads = params[params_keys]["num_attention_heads"]
|
||||
dim_per_head = params[params_keys]["dim_per_head"]
|
||||
param_dim = params[params_keys]["dim"]
|
||||
dim = params[params_keys]["dim"]
|
||||
else:
|
||||
num_heads = params[params_keys]["num_key_value_heads"]
|
||||
dim_per_head = params[params_keys]["dim_per_head"]
|
||||
param_dim = params[params_keys]["key_value_dim"]
|
||||
dim = params[params_keys]["dim"]
|
||||
|
||||
current_parameter = permute_for_rope(value, num_heads, param_dim, dim)
|
||||
state_dict[new_key] = current_parameter.reshape(num_heads * dim_per_head, dim)
|
||||
|
||||
state_dict[new_key] = current_parameter
|
||||
|
||||
# add the depth decoder embed audio tokens weights, latter tied to the backbone embed audio tokens weights
|
||||
state_dict["depth_decoder.model.embed_tokens.weight"] = state_dict[
|
||||
"backbone_model.embed_tokens.embed_audio_tokens.weight"
|
||||
].clone()
|
||||
del loaded
|
||||
gc.collect()
|
||||
|
||||
# -------------------------
|
||||
# load the weights and save
|
||||
# -------------------------
|
||||
|
||||
print("Loading the checkpoint in a Csm model.")
|
||||
with torch.device("meta"):
|
||||
model = CsmForConditionalGeneration(config)
|
||||
model.load_state_dict(state_dict, strict=True, assign=True)
|
||||
print("Checkpoint loaded successfully.")
|
||||
del model.config._name_or_path
|
||||
|
||||
# default generation config
|
||||
model.generation_config._from_model_config = False
|
||||
model.generation_config.max_new_tokens = 125
|
||||
model.generation_config.do_sample = True
|
||||
model.generation_config.top_k = 50
|
||||
model.generation_config.temperature = 0.9
|
||||
model.generation_config.depth_decoder_do_sample = True
|
||||
model.generation_config.depth_decoder_top_k = 50
|
||||
model.generation_config.depth_decoder_temperature = 0.9
|
||||
|
||||
print("Saving the model.")
|
||||
model.save_pretrained(output_dir, safe_serialization=safe_serialization)
|
||||
del state_dict, model
|
||||
|
||||
# Safety check: reload the converted model
|
||||
gc.collect()
|
||||
print("Reloading the model to check if it's saved correctly.")
|
||||
CsmForConditionalGeneration.from_pretrained(output_dir, torch_dtype=torch.bfloat16, device_map="auto")
|
||||
print("Model reloaded successfully.")
|
||||
|
||||
|
||||
def write_tokenizer(output_dir):
|
||||
# from https://github.com/SesameAILabs/csm/blob/2d720827843b653c4d67bb4445b1c0a4f59e646f/generator.py#L22-L36
|
||||
def load_llama3_tokenizer():
|
||||
"""
|
||||
https://github.com/huggingface/transformers/issues/22794#issuecomment-2092623992
|
||||
"""
|
||||
tokenizer_name = "meta-llama/Llama-3.2-1B"
|
||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
||||
bos = tokenizer.bos_token
|
||||
eos = tokenizer.eos_token
|
||||
tokenizer._tokenizer.post_processor = TemplateProcessing(
|
||||
single=f"{bos}:0 $A:0 {eos}:0",
|
||||
pair=f"{bos}:0 $A:0 {eos}:0 {bos}:1 $B:1 {eos}:1",
|
||||
special_tokens=[(f"{bos}", tokenizer.bos_token_id), (f"{eos}", tokenizer.eos_token_id)],
|
||||
)
|
||||
|
||||
return tokenizer
|
||||
|
||||
tokenizer = load_llama3_tokenizer()
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
tokenizer.save_pretrained(output_dir)
|
||||
|
||||
# manually modify in tokenizer_config.json
|
||||
# "128002": {
|
||||
# "content": "<|AUDIO|>",
|
||||
# ...
|
||||
# }
|
||||
# "128003": {
|
||||
# "content": "<|audio_eos|>",
|
||||
# ...
|
||||
# }
|
||||
print(
|
||||
"Tokenizer saved successfully. Please manually modify in tokenizer_config.json AND tokenizer.json as follows: "
|
||||
)
|
||||
print("""
|
||||
# "128002": {
|
||||
# "content": "<|AUDIO|>",
|
||||
# ...
|
||||
# }
|
||||
# "128003": {
|
||||
# "content": "<|audio_eos|>",
|
||||
# ...
|
||||
# }
|
||||
""")
|
||||
|
||||
|
||||
def write_processor(output_dir, codec_model_path_or_repo):
|
||||
chat_template = "\n{%- for message in messages %}\n {#-- Validate role is a stringified integer --#}\n {%- if not message['role'] is string or not message['role'].isdigit() %}\n {{- raise_exception(\"The role must be an integer or a stringified integer (e.g. '0') designating the speaker id\") }}\n {%- endif %}\n\n {#-- Validate content is a list --#}\n {%- set content = message['content'] %}\n {%- if content is not iterable or content is string %}\n {{- raise_exception(\"The content must be a list\") }}\n {%- endif %}\n\n {#-- Collect content types --#}\n {%- set content_types = content | map(attribute='type') | list %}\n {%- set is_last = loop.last %}\n\n {#-- Last message validation --#}\n {%- if is_last %}\n {%- if 'text' not in content_types %}\n {{- raise_exception(\"The last message must include one item of type 'text'\") }}\n {%- elif (content_types | select('equalto', 'text') | list | length > 1) or (content_types | select('equalto', 'audio') | list | length > 1) %}\n {{- raise_exception(\"At most two items are allowed in the last message: one 'text' and one 'audio'\") }}\n {%- endif %}\n\n {#-- All other messages validation --#}\n {%- else %}\n {%- if content_types | select('equalto', 'text') | list | length != 1\n or content_types | select('equalto', 'audio') | list | length != 1 %}\n {{- raise_exception(\"Each message (except the last) must contain exactly one 'text' and one 'audio' item\") }}\n {%- elif content_types | reject('in', ['text', 'audio']) | list | length > 0 %}\n {{- raise_exception(\"Only 'text' and 'audio' types are allowed in content\") }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n\n{%- for message in messages %}\n {{- bos_token }}\n {{- '[' + message['role'] + ']' }}\n {{- message['content'][0]['text'] }}\n {{- eos_token }}\n {%- if message['content']|length > 1 %}\n {{- '<|AUDIO|><|audio_eos|>' }}\n {%- endif %}\n{%- endfor %}\n"
|
||||
tokenizer = AutoTokenizer.from_pretrained(output_dir)
|
||||
feature_extractor = AutoFeatureExtractor.from_pretrained(codec_model_path_or_repo)
|
||||
|
||||
processor = CsmProcessor(
|
||||
tokenizer=tokenizer,
|
||||
feature_extractor=feature_extractor,
|
||||
chat_template=chat_template,
|
||||
)
|
||||
|
||||
processor.save_pretrained(output_dir)
|
||||
print("Processor saved successfully.")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Convert Csm weights to HuggingFace format")
|
||||
parser.add_argument(
|
||||
"--input_path_or_repo",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path or repo containing Csm weights",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_name",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Name of the model in input_path_or_repo",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--codec_model_path_or_repo",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path or repo containing the codec model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
help="Location to write HF model and tokenizer",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--safe_serialization", action="store_true", default=True, help="Whether or not to save using `safetensors`."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
write_model(
|
||||
args.input_path_or_repo,
|
||||
args.model_name,
|
||||
args.codec_model_path_or_repo,
|
||||
output_dir=args.output_dir,
|
||||
safe_serialization=args.safe_serialization,
|
||||
)
|
||||
|
||||
write_tokenizer(args.output_dir)
|
||||
|
||||
write_processor(args.output_dir, args.codec_model_path_or_repo)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
491
src/transformers/models/csm/generation_csm.py
Normal file
491
src/transformers/models/csm/generation_csm.py
Normal file
@ -0,0 +1,491 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ...generation import (
|
||||
GenerateDecoderOnlyOutput,
|
||||
GenerationConfig,
|
||||
GenerationMixin,
|
||||
GenerationMode,
|
||||
)
|
||||
from ...generation.logits_process import LogitsProcessorList
|
||||
from ...generation.stopping_criteria import MaxLengthCriteria, StoppingCriteriaList
|
||||
from ...generation.utils import GenerateNonBeamOutput
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...generation.streamers import BaseStreamer
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CsmGenerateOutput(GenerateDecoderOnlyOutput):
|
||||
"""
|
||||
Outputs of CsmForConditionalGeneration.generate.
|
||||
|
||||
Args:
|
||||
sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
|
||||
if all batches finished early due to the `eos_token_id`.
|
||||
scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True`):
|
||||
Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
|
||||
at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
|
||||
each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
|
||||
logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True`):
|
||||
Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
|
||||
at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
|
||||
each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
|
||||
attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`):
|
||||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||
`torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
|
||||
hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`):
|
||||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||
`torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`.
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True`):
|
||||
Returns the model cache, used to speed up decoding. Different models have a different cache format, check
|
||||
audio (`list(torch.FloatTensor)` of length `batch_size`):
|
||||
The generated audio.
|
||||
"""
|
||||
|
||||
audio: Optional[List[torch.Tensor]] = None
|
||||
|
||||
|
||||
class CsmGenerationMixin(GenerationMixin):
|
||||
def _get_stopping_criteria(
|
||||
self,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> StoppingCriteriaList:
|
||||
criteria = super()._get_stopping_criteria(*args, **kwargs)
|
||||
|
||||
kept_criteria = StoppingCriteriaList()
|
||||
for criterion in criteria:
|
||||
if not isinstance(criterion, MaxLengthCriteria):
|
||||
logger.warning(
|
||||
f"Csm does not support {criterion.__class__.__name__} stopping criteria, it will be ignored."
|
||||
)
|
||||
else:
|
||||
kept_criteria.append(criterion)
|
||||
return kept_criteria
|
||||
|
||||
def _prepare_generation_config(
|
||||
self, generation_config: Optional[GenerationConfig], use_model_defaults: Optional[bool] = None, **kwargs: Dict
|
||||
) -> Tuple[GenerationConfig, Dict]:
|
||||
"""
|
||||
This method overrides [~generation.utils.GenerationMixin._prepare_generation_config].
|
||||
It ensures that the depth decoder generation config is initialized and that passed args as depth_decoder_* are properly handled.
|
||||
"""
|
||||
# extract depth decoder kwargs and remove them from the main kwargs
|
||||
depth_decoder_kwargs = {
|
||||
k[len("depth_decoder_") :]: v for k, v in kwargs.items() if k.startswith("depth_decoder_")
|
||||
}
|
||||
|
||||
# remove the depth decoder keys from the original kwargs
|
||||
kwargs = {k: v for k, v in kwargs.items() if not k.startswith("depth_decoder_")}
|
||||
|
||||
# initialize the generation config
|
||||
generation_config, model_kwargs = super()._prepare_generation_config(
|
||||
generation_config, use_model_defaults, **kwargs
|
||||
)
|
||||
self.depth_decoder.generation_config.update(**depth_decoder_kwargs)
|
||||
|
||||
# ensure the depth decoder generation config is valid
|
||||
depth_decoder_min_new_tokens = getattr(self.depth_decoder.generation_config, "min_new_tokens") or (
|
||||
self.config.num_codebooks - 1
|
||||
)
|
||||
depth_decoder_max_new_tokens = getattr(self.depth_decoder.generation_config, "max_new_tokens") or (
|
||||
self.config.num_codebooks - 1
|
||||
)
|
||||
|
||||
if {depth_decoder_min_new_tokens, depth_decoder_max_new_tokens} != {self.config.num_codebooks - 1}:
|
||||
raise ValueError(
|
||||
f"depth_decoder_generation_config's min_new_tokens ({depth_decoder_min_new_tokens}) and max_new_tokens ({depth_decoder_max_new_tokens}) must be equal to self.config.num_codebooks - 1 ({self.config.num_codebooks - 1})"
|
||||
)
|
||||
elif self.depth_decoder.generation_config.return_dict_in_generate:
|
||||
logger.warning(
|
||||
"depth_decoder_generation_config.return_dict_in_generate is set to True, but this will be ignored as the depth decoder model does not return a dictionary in generate"
|
||||
)
|
||||
self.depth_decoder.generation_config.return_dict_in_generate = False
|
||||
|
||||
self.depth_decoder.generation_config.min_new_tokens = depth_decoder_min_new_tokens
|
||||
self.depth_decoder.generation_config.max_new_tokens = depth_decoder_max_new_tokens
|
||||
|
||||
# Monkey patch the get_generation_mode method to support CSM model
|
||||
original_get_generation_mode = generation_config.get_generation_mode
|
||||
|
||||
def patched_get_generation_mode(assistant_model=None):
|
||||
generation_mode = original_get_generation_mode(assistant_model)
|
||||
if generation_mode not in [GenerationMode.GREEDY_SEARCH, GenerationMode.SAMPLE]:
|
||||
raise ValueError(
|
||||
f"Generation mode {generation_mode} is not supported for CSM model. Please set generation parameters to use greedy or sampling generation."
|
||||
)
|
||||
|
||||
return generation_mode
|
||||
|
||||
generation_config.get_generation_mode = patched_get_generation_mode
|
||||
|
||||
return generation_config, model_kwargs
|
||||
|
||||
def _sample(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
logits_processor: LogitsProcessorList,
|
||||
stopping_criteria: StoppingCriteriaList,
|
||||
generation_config: GenerationConfig,
|
||||
synced_gpus: bool,
|
||||
streamer: Optional["BaseStreamer"],
|
||||
**model_kwargs,
|
||||
) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
|
||||
"""
|
||||
This method overrides [~generation.utils.GenerationMixin._sample].
|
||||
To ease maintenance, modifications are marked with the comment "Csm specific".
|
||||
|
||||
Indeed, Csm model requires a custom generation sampling step:
|
||||
1. Infer the backbone model to sample the first codebook token
|
||||
2. Call generate on the depth decoder with the first codebook token as input_ids to sample the next codebook tokens
|
||||
3. Use these generated codebook tokens as input_ids to sample the next first codebook token using the backbone model
|
||||
4. Repeat until stopping criteria is met
|
||||
|
||||
Csm supports two stopping criterias:
|
||||
- stop when the generated sequence is at max_length
|
||||
- stop when all the generated codebook tokens are the codebook_eos_token_id
|
||||
"""
|
||||
# init values
|
||||
# *************** Csm specific ***************
|
||||
pad_token_id = self.config.codebook_pad_token_id
|
||||
has_eos_stopping_criteria = generation_config._eos_token_tensor is not None
|
||||
# ============================================
|
||||
output_attentions = generation_config.output_attentions
|
||||
output_hidden_states = generation_config.output_hidden_states
|
||||
output_scores = generation_config.output_scores
|
||||
output_logits = generation_config.output_logits
|
||||
return_dict_in_generate = generation_config.return_dict_in_generate
|
||||
do_sample = generation_config.do_sample
|
||||
|
||||
# init attention / hidden states / scores tuples
|
||||
scores = () if (return_dict_in_generate and output_scores) else None
|
||||
raw_logits = () if (return_dict_in_generate and output_logits) else None
|
||||
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
|
||||
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
|
||||
|
||||
# keep track of which sequences are already finished
|
||||
batch_size, cur_len = input_ids.shape[:2]
|
||||
this_peer_finished = False
|
||||
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
|
||||
model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs)
|
||||
|
||||
# *************** Csm specific ***************
|
||||
if input_ids.ndim == 2 and model_kwargs.get("inputs_embeds") is None:
|
||||
# in the case where the passed input_ids correspond to text tokens, i.e. don't have a third dimension for codebook ids,
|
||||
# we need to remove the input length to the MaxLengthCriteria stopping criteria has such input are not returned
|
||||
for criterion in stopping_criteria:
|
||||
if isinstance(criterion, MaxLengthCriteria):
|
||||
criterion.max_length -= cur_len
|
||||
# ============================================
|
||||
|
||||
model_forward = self.__call__
|
||||
compile_forward = self._valid_auto_compile_criteria(model_kwargs, generation_config)
|
||||
if compile_forward:
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "0"
|
||||
model_forward = self.get_compiled_call(generation_config.compile_config)
|
||||
|
||||
is_prefill = True
|
||||
while self._has_unfinished_sequences(
|
||||
this_peer_finished,
|
||||
synced_gpus,
|
||||
device=input_ids.device,
|
||||
):
|
||||
# prepare model inputs
|
||||
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
||||
|
||||
# prepare variable output controls (note: some models won't accept all output controls)
|
||||
model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
|
||||
# *************** Csm specific ***************
|
||||
model_inputs.update({"output_hidden_states": True})
|
||||
# ============================================
|
||||
|
||||
if is_prefill:
|
||||
outputs = self(**model_inputs, return_dict=True)
|
||||
is_prefill = False
|
||||
else:
|
||||
outputs = model_forward(**model_inputs, return_dict=True)
|
||||
|
||||
# synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
|
||||
model_kwargs = self._update_model_kwargs_for_generation(
|
||||
outputs,
|
||||
model_kwargs,
|
||||
)
|
||||
if synced_gpus and this_peer_finished:
|
||||
continue
|
||||
|
||||
# Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
|
||||
# (the clone itself is always small)
|
||||
next_token_logits = outputs.logits[:, -1, :].clone().float()
|
||||
next_token_logits = next_token_logits.to(input_ids.device)
|
||||
|
||||
# pre-process distribution
|
||||
next_token_scores = logits_processor(input_ids, next_token_logits)
|
||||
|
||||
# Store scores, attentions and hidden_states when required
|
||||
if return_dict_in_generate:
|
||||
if output_scores:
|
||||
scores += (next_token_scores,)
|
||||
if output_logits:
|
||||
raw_logits += (next_token_logits,)
|
||||
if output_attentions:
|
||||
decoder_attentions += (outputs.attentions,)
|
||||
|
||||
if output_hidden_states:
|
||||
decoder_hidden_states += (outputs.hidden_states,)
|
||||
|
||||
# token selection
|
||||
if do_sample:
|
||||
probs = nn.functional.softmax(next_token_scores, dim=-1)
|
||||
# TODO (joao): this OP throws "skipping cudagraphs due to ['incompatible ops']", find solution
|
||||
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
||||
else:
|
||||
next_tokens = torch.argmax(next_token_scores, dim=-1)
|
||||
|
||||
# *************** Csm specific ***************
|
||||
# infer the depth decoder
|
||||
first_codebook_ids = next_tokens[:, None]
|
||||
# adds place holder in position 0 that will be replaced by the backbone_last_hidden_state
|
||||
depth_decoder_input_ids = nn.functional.pad(first_codebook_ids, (1, 0), value=0)
|
||||
backbone_last_hidden_state = outputs.hidden_states[-1][:, -1, :]
|
||||
|
||||
depth_decoder_outputs = self.depth_decoder.generate(
|
||||
input_ids=depth_decoder_input_ids, backbone_last_hidden_state=backbone_last_hidden_state.clone()
|
||||
)
|
||||
codebook_ids = (
|
||||
depth_decoder_outputs
|
||||
if isinstance(depth_decoder_outputs, torch.Tensor)
|
||||
else depth_decoder_outputs.sequences
|
||||
)
|
||||
# remove the place holder in position 0
|
||||
codebook_ids = codebook_ids[:, 1:]
|
||||
next_tokens = codebook_ids
|
||||
|
||||
# finished sentences should have their next token be a padding token
|
||||
if has_eos_stopping_criteria:
|
||||
next_tokens = next_tokens * unfinished_sequences.unsqueeze(-1) + pad_token_id * (
|
||||
1 - unfinished_sequences.unsqueeze(-1)
|
||||
)
|
||||
|
||||
# update generated ids, model inputs, and length for next step
|
||||
if input_ids.ndim == 2:
|
||||
input_ids = next_tokens[:, None, :]
|
||||
else:
|
||||
input_ids = torch.cat([input_ids, next_tokens[:, None, :]], dim=1)
|
||||
# ============================================
|
||||
|
||||
if streamer is not None:
|
||||
streamer.put(next_tokens.cpu())
|
||||
|
||||
# *************** Csm specific ***************
|
||||
# for the eos stopping criteria, is it expected that the eos token is the same for each codebook !!!!
|
||||
unfinished_sequences = unfinished_sequences & ~(
|
||||
input_ids[:, -1, :-1] == self.config.codebook_eos_token_id
|
||||
).all(-1)
|
||||
# ============================================
|
||||
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
|
||||
this_peer_finished = unfinished_sequences.max() == 0
|
||||
cur_len += 1
|
||||
|
||||
# This is needed to properly delete outputs.logits which may be very large for first iteration
|
||||
# Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
|
||||
del outputs
|
||||
|
||||
# *************** Csm specific ***************
|
||||
del depth_decoder_outputs
|
||||
# ============================================
|
||||
|
||||
if streamer is not None:
|
||||
streamer.end()
|
||||
|
||||
if return_dict_in_generate:
|
||||
return GenerateDecoderOnlyOutput(
|
||||
sequences=input_ids,
|
||||
scores=scores,
|
||||
logits=raw_logits,
|
||||
attentions=decoder_attentions,
|
||||
hidden_states=decoder_hidden_states,
|
||||
past_key_values=model_kwargs.get("past_key_values"),
|
||||
)
|
||||
else:
|
||||
return input_ids
|
||||
|
||||
def generate(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
input_values: Optional[torch.Tensor] = None,
|
||||
input_values_cutoffs: Optional[torch.Tensor] = None,
|
||||
generation_config: Optional[GenerationConfig] = None,
|
||||
logits_processor: Optional[LogitsProcessorList] = None,
|
||||
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
||||
synced_gpus: Optional[bool] = None,
|
||||
streamer: Optional["BaseStreamer"] = None,
|
||||
output_audio: Optional[bool] = False,
|
||||
**kwargs,
|
||||
) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
|
||||
r"""
|
||||
This method overrides [`~generation.utils.GenerationMixin.generate`] to match the specifics of the Csm model.
|
||||
Indeed, Csm model requires a custom generation sampling step:
|
||||
1. Infer the backbone model to sample the first codebook token
|
||||
2. Call generate on the depth decoder with the first codebook token as `input_ids` to sample the next codebook tokens
|
||||
3. Use these generated codebook tokens as `input_ids` to sample the next first codebook token using the backbone model
|
||||
4. Repeat until stopping criteria is met
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the
|
||||
model's default generation configuration. You can override any `generation_config` by passing the corresponding
|
||||
parameters to generate(), e.g. `.generate(inputs, do_sample=True)`.
|
||||
</Tip>
|
||||
|
||||
Parameters:
|
||||
inputs_ids (`torch.Tensor` of shape (batch_size, seq_length), *optional*):
|
||||
The sequence used as a prompt for the backbone model.
|
||||
input_values (`torch.Tensor` of shape (batch_size, channels, max_concatenated_audio_length), *optional*):
|
||||
The batched audio input values, where each batch entry contains the concatenation of all audio segments for that entry.
|
||||
These values will be encoded into codebook tokens using the codec model and merged with the text input ids provided in `input_ids`.
|
||||
input_values_cutoffs (`torch.Tensor` of shape (batch_size, max_num_audio), *optional*):
|
||||
Specify the end positions of audio segments within each batch entry, relative to the concatenated audio input.
|
||||
If a batch entry has fewer segments than the maximum, it is padded with -1. For example, in a batch of 2 sequences
|
||||
where the first contains 2 audio segments of length l1, and the second contains 1 audio segment of length l2,
|
||||
the input_values_cutoffs would be: [[l1, 2 * l1], [l2, -1]].
|
||||
generation_config ([`~generation.GenerationConfig`], *optional*):
|
||||
The generation configuration to be used as base parametrization for the generation call. `**kwargs`
|
||||
passed to generate matching the attributes of `generation_config` will override them. If
|
||||
`generation_config` is not provided, the default will be used, which has the following loading
|
||||
priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
|
||||
configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
|
||||
default values, whose documentation should be checked to parameterize generation.
|
||||
logits_processor (`LogitsProcessorList`, *optional*):
|
||||
Custom logits processors that complement the default logits processors built from arguments and
|
||||
generation config. If a logit processor is passed that is already created with the arguments or a
|
||||
generation config an error is thrown. This feature is intended for advanced users.
|
||||
stopping_criteria (`StoppingCriteriaList`, *optional*):
|
||||
Custom stopping criteria that complements the default stopping criteria built from arguments and a
|
||||
generation config. If a stopping criteria is passed that is already created with the arguments or a
|
||||
generation config an error is thrown. If your stopping criteria depends on the `scores` input, make
|
||||
sure you pass `return_dict_in_generate=True, output_scores=True` to `generate`. This feature is
|
||||
intended for advanced users.
|
||||
synced_gpus (`bool`, *optional*):
|
||||
Whether to continue running the while loop until max_length. Unless overridden, this flag will be set
|
||||
to `True` if using `FullyShardedDataParallel` or DeepSpeed ZeRO Stage 3 with multiple GPUs to avoid
|
||||
deadlocking if one GPU finishes generating before other GPUs. Otherwise, defaults to `False`.
|
||||
streamer (`BaseStreamer`, *optional*):
|
||||
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
|
||||
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
|
||||
output_audio (`bool`, *optional*):
|
||||
Whether to return the generated audio.
|
||||
kwargs (`Dict[str, Any]`, *optional*):
|
||||
Ad hoc parametrization of `generation_config` and/or additional model-specific kwargs that will be
|
||||
forwarded to the `forward` function of the model. Depth decoder specific kwargs should be prefixed with *depth_decoder_*.
|
||||
|
||||
Return:
|
||||
[`CsmGenerateOutput`] or `torch.LongTensor` or `List[torch.FloatTensor]`: A [`CsmGenerateOutput`]
|
||||
(if `return_dict_in_generate=True` or when `config.return_dict_in_generate=True`) or a `torch.LongTensor` when `output_audio=False`
|
||||
or a `List[torch.FloatTensor]` otherwise.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import CsmProcessor, CsmForConditionalGeneration
|
||||
>>> from datasets import load_dataset, Audio
|
||||
|
||||
>>> model_id = "eustlb/csm-1b"
|
||||
>>> torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
>>> processor = AutoProcessor.from_pretrained(model_id)
|
||||
|
||||
>>> ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train")
|
||||
>>> # ensure the audio is 24kHz
|
||||
>>> ds = ds.cast_column("audio", Audio(sampling_rate=24000))
|
||||
|
||||
>>> conversation = []
|
||||
>>> # prepare a conversation with text and corresponding audio
|
||||
>>> for text, audio, speaker_id in zip(ds[:4]["text"], ds[:4]["audio"], ds[:4]["speaker_id"]):
|
||||
... conversation.append(
|
||||
... {
|
||||
... "role": f"{speaker_id}",
|
||||
... "content": [{"type": "text", "text": text}, {"type": "audio", "path": audio["array"]}],
|
||||
... }
|
||||
... )
|
||||
|
||||
>>> # text prompt
|
||||
>>> conversation.append({"role": f"{ds[4]['speaker_id']}", "content": [{"type": "text", "text": ds[4]["text"]}]})
|
||||
|
||||
>>> inputs = processor.apply_chat_template(
|
||||
... conversation,
|
||||
... tokenize=True,
|
||||
... return_dict=True,
|
||||
... ).to(torch_device)
|
||||
|
||||
>>> model = CsmForConditionalGeneration.from_pretrained(model_id, device_map=torch_device)
|
||||
>>> audio = model.generate(**inputs, output_audio=True)
|
||||
>>> processor.save_audio(audio, "output.wav")
|
||||
```
|
||||
"""
|
||||
generate_output = super().generate(
|
||||
input_ids=input_ids,
|
||||
input_values=input_values,
|
||||
input_values_cutoffs=input_values_cutoffs,
|
||||
generation_config=generation_config,
|
||||
logits_processor=logits_processor,
|
||||
stopping_criteria=stopping_criteria,
|
||||
synced_gpus=synced_gpus,
|
||||
streamer=streamer,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
generate_returned_dict = not isinstance(generate_output, torch.Tensor)
|
||||
audio = None
|
||||
if output_audio:
|
||||
generated_audio_codes = generate_output.sequences if generate_returned_dict else generate_output
|
||||
|
||||
# infer the codec model
|
||||
audio = []
|
||||
with torch.no_grad():
|
||||
# =======================================
|
||||
# TODO: @eustlb, this should be batched !!!
|
||||
# but requires making sure batched inference of the codec model works as intended
|
||||
for audio_codes_batch in generated_audio_codes:
|
||||
eos_idxs = (audio_codes_batch == self.config.codebook_eos_token_id).all(dim=-1).nonzero()
|
||||
if eos_idxs.numel() != 0:
|
||||
cutoff_idx = eos_idxs.min()
|
||||
else:
|
||||
cutoff_idx = audio_codes_batch.shape[1]
|
||||
|
||||
audio_codes_batch = audio_codes_batch[:cutoff_idx]
|
||||
codec_decode_output = self.codec_model.decode(audio_codes_batch.transpose(0, 1).unsqueeze(0))
|
||||
audio.append(codec_decode_output.audio_values[0, 0])
|
||||
# =======================================
|
||||
|
||||
if generate_returned_dict:
|
||||
return CsmGenerateOutput(audio=audio, **generate_output)
|
||||
elif output_audio:
|
||||
return audio
|
||||
else:
|
||||
return generate_output
|
1710
src/transformers/models/csm/modeling_csm.py
Normal file
1710
src/transformers/models/csm/modeling_csm.py
Normal file
File diff suppressed because it is too large
Load Diff
1042
src/transformers/models/csm/modular_csm.py
Normal file
1042
src/transformers/models/csm/modular_csm.py
Normal file
File diff suppressed because it is too large
Load Diff
364
src/transformers/models/csm/processing_csm.py
Normal file
364
src/transformers/models/csm/processing_csm.py
Normal file
@ -0,0 +1,364 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 Sesame and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ...utils import is_soundfile_available, is_torch_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_soundfile_available():
|
||||
import soundfile as sf
|
||||
|
||||
from ...audio_utils import AudioInput, make_list_of_audio
|
||||
from ...feature_extraction_utils import BatchFeature
|
||||
from ...processing_utils import AudioKwargs, ProcessingKwargs, ProcessorMixin, Unpack
|
||||
from ...tokenization_utils_base import (
|
||||
PreTokenizedInput,
|
||||
TextInput,
|
||||
)
|
||||
|
||||
|
||||
class CsmAudioKwargs(AudioKwargs, total=False):
|
||||
encoded_length_kwargs: Optional[Dict[str, Any]]
|
||||
|
||||
|
||||
class CsmProcessorKwargs(ProcessingKwargs, total=False):
|
||||
audio_kwargs: CsmAudioKwargs
|
||||
_defaults = {
|
||||
"text_kwargs": {
|
||||
"padding": True,
|
||||
"padding_side": "left",
|
||||
"add_special_tokens": False,
|
||||
},
|
||||
"audio_kwargs": {
|
||||
"encoded_length_kwargs": {
|
||||
"kernel_sizes": [7, 3, 1, 8, 3, 1, 10, 3, 1, 12, 3, 1, 16, 3, 4],
|
||||
"strides": [1, 1, 1, 4, 1, 1, 5, 1, 1, 6, 1, 1, 8, 1, 2],
|
||||
"dilations": [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
|
||||
"use_causal_conv": True,
|
||||
},
|
||||
"sampling_rate": 24000,
|
||||
},
|
||||
"common_kwargs": {"return_tensors": "pt"},
|
||||
}
|
||||
|
||||
|
||||
class CsmProcessor(ProcessorMixin):
|
||||
r"""
|
||||
Constructs a Csm processor which wraps [`EncodecFeatureExtractor`] and
|
||||
[`PretrainedTokenizerFast`] into a single processor that inherits both the audio feature extraction and
|
||||
tokenizer functionalities. See the [`~CsmProcessor.__call__`] for more
|
||||
information.
|
||||
The preferred way of passing kwargs is as a dictionary per modality, see usage example below.
|
||||
```python
|
||||
from transformers import CsmProcessor
|
||||
from datasets import load_dataset
|
||||
|
||||
ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train")
|
||||
audio = ds[0]["audio"]["array"]
|
||||
|
||||
processor = CsmProcessor.from_pretrained("eustlb/csm-1b")
|
||||
|
||||
processor(
|
||||
text=["<|begin_of_text|>[0]What are you working on?<|end_of_text|><|AUDIO|><|audio_eos|><|begin_of_text|>[1]I'm figuring out my budget.<|end_of_text|>"],
|
||||
audio=audio,
|
||||
text_kwargs = {"padding": False},
|
||||
audio_kwargs = {"sampling_rate": 16000},
|
||||
common_kwargs = {"return_tensors": "pt"},
|
||||
)
|
||||
# this should error out because EncodecFeatureExtractor expects a 24kHz audio :)
|
||||
```
|
||||
|
||||
Args:
|
||||
feature_extractor ([`EncodecFeatureExtractor`]):
|
||||
The feature extractor is a required input.
|
||||
tokenizer ([`PreTrainedTokenizer`, `PreTrainedTokenizerFast`]):
|
||||
The tokenizer is a required input.
|
||||
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
|
||||
in a chat into a tokenizable string.
|
||||
|
||||
"""
|
||||
|
||||
attributes = ["feature_extractor", "tokenizer"]
|
||||
valid_kwargs = ["chat_template"]
|
||||
feature_extractor_class = "EncodecFeatureExtractor"
|
||||
tokenizer_class = "PreTrainedTokenizerFast"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
feature_extractor,
|
||||
tokenizer,
|
||||
chat_template=None,
|
||||
):
|
||||
if not hasattr(tokenizer, "audio_token"):
|
||||
self.audio_token = "<|AUDIO|>"
|
||||
self.audio_token_id = tokenizer.convert_tokens_to_ids(self.audio_token)
|
||||
else:
|
||||
self.audio_token = tokenizer.audio_token
|
||||
self.audio_token_id = tokenizer.audio_token_id
|
||||
|
||||
if not hasattr(tokenizer, "audio_eos_token"):
|
||||
self.audio_eos_token = "<|audio_eos|>"
|
||||
self.audio_eos_token_id = tokenizer.convert_tokens_to_ids(self.audio_eos_token)
|
||||
else:
|
||||
self.audio_eos_token = tokenizer.audio_eos_token
|
||||
self.audio_eos_token_id = tokenizer.audio_eos_token_id
|
||||
|
||||
super().__init__(feature_extractor, tokenizer, chat_template=chat_template)
|
||||
|
||||
@staticmethod
|
||||
def _get_encoded_length(audio_length, kernel_sizes=None, strides=None, dilations=None, use_causal_conv=None):
|
||||
"""
|
||||
Compute the length of the encoded audio sequence.
|
||||
|
||||
Args:
|
||||
audio_length (int): The length of the audio sequence.
|
||||
kernel_sizes (List[int]): The kernel sizes for the convolutional layers.
|
||||
strides (List[int]): The strides for the convolutional layers.
|
||||
use_causal_conv (bool): Whether to use causal convolutions.
|
||||
"""
|
||||
cur_length = audio_length
|
||||
|
||||
if kernel_sizes is None or strides is None or dilations is None or use_causal_conv is None:
|
||||
return cur_length
|
||||
|
||||
for kernel_size, stride, dilation in zip(kernel_sizes, strides, dilations):
|
||||
effective_kernel_size = (kernel_size - 1) * dilation + 1
|
||||
padding_total = kernel_size - stride
|
||||
padding_right = padding_total // 2
|
||||
padding_left = padding_total - padding_right
|
||||
|
||||
n_frames = (cur_length - effective_kernel_size + padding_total) / stride + 1
|
||||
n_frames = math.ceil(n_frames) - 1
|
||||
ideal_length = n_frames * stride + kernel_size - padding_total
|
||||
extra_padding = ideal_length - cur_length
|
||||
|
||||
if use_causal_conv:
|
||||
padding_left = padding_total
|
||||
padding_right = extra_padding
|
||||
else:
|
||||
padding_left = padding_left
|
||||
padding_right = padding_right + extra_padding
|
||||
|
||||
cur_length = cur_length + padding_left + padding_right
|
||||
cur_length = (cur_length - dilation * (kernel_size - 1) - 1) // stride + 1
|
||||
|
||||
return cur_length
|
||||
|
||||
def save_audio(
|
||||
self,
|
||||
audio: AudioInput,
|
||||
saving_path: Union[str, Path, List[Union[str, Path]]],
|
||||
**kwargs: Unpack[CsmProcessorKwargs],
|
||||
):
|
||||
# TODO: @eustlb, this should be in AudioProcessor
|
||||
if not is_soundfile_available():
|
||||
raise ImportError("Please install `soundfile` to save audio files.")
|
||||
|
||||
# ensure correct audio input
|
||||
audio = make_list_of_audio(audio)
|
||||
|
||||
# ensure correct saving path
|
||||
if isinstance(saving_path, (str, Path)):
|
||||
saving_path = [saving_path]
|
||||
elif not (isinstance(saving_path, (list, tuple)) and all(isinstance(p, (str, Path)) for p in saving_path)):
|
||||
raise ValueError("Invalid input path. Please provide a string, or a list of strings")
|
||||
|
||||
if len(audio) != len(saving_path):
|
||||
raise ValueError("The number of audio and saving paths must be the same")
|
||||
|
||||
output_kwargs = self._merge_kwargs(
|
||||
CsmProcessorKwargs,
|
||||
**kwargs,
|
||||
)
|
||||
audio_kwargs = output_kwargs["audio_kwargs"]
|
||||
sampling_rate = audio_kwargs["sampling_rate"]
|
||||
|
||||
for audio_value, p in zip(audio, saving_path):
|
||||
if isinstance(audio_value, torch.Tensor):
|
||||
audio_value = audio_value.cpu().float().numpy()
|
||||
sf.write(p, audio_value, sampling_rate)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]],
|
||||
audio: Optional[AudioInput] = None,
|
||||
output_labels: Optional[bool] = False,
|
||||
depth_decoder_labels_ratio: Optional[float] = 1.0,
|
||||
**kwargs: Unpack[CsmProcessorKwargs],
|
||||
):
|
||||
r"""
|
||||
Main method to prepare text(s) and audio to be fed as input to the model. This method forwards the `text`
|
||||
arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizerFast.__call__`] to encode
|
||||
the text. To prepare the audio, this method forwards the `audio` arguments to
|
||||
EncodecFeatureExtractor's [`~EncodecFeatureExtractor.__call__`]. Please refer
|
||||
to the docstring of the above two methods for more information.
|
||||
|
||||
Args:
|
||||
audio (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
||||
The audio or batch of audio to be prepared. Each audio can be a NumPy array or PyTorch
|
||||
tensor.
|
||||
text (`str`, `List[str]`, `List[List[str]]`):
|
||||
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
|
||||
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
|
||||
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
||||
output_labels (bool, *optional*, default=False):
|
||||
Whether to return labels for training. Indices will be in `[config.audio_token_id, -100, -101]`.
|
||||
- `config.audio_token_id` indicates an audio frame (considering sequence length elements as frames)
|
||||
- `-100` will be ignored in the loss computation
|
||||
- `-101` indicates the audio frame will be used only for the backbone model (using the first codebook token as labels)
|
||||
depth_decoder_labels_ratio (float, *optional*, default=1.0):
|
||||
The ratio of audio frames to keep for the depth decoder labels.
|
||||
return_tensors (`str` or [`~utils.TensorType`], *optional*):
|
||||
If set, will return tensors of a particular framework. Acceptable values are:
|
||||
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
||||
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
||||
- `'np'`: Return NumPy `np.ndarray` objects.
|
||||
- `'jax'`: Return JAX `jnp.ndarray` objects.
|
||||
Returns:
|
||||
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
|
||||
|
||||
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
|
||||
- **input_values** -- List of audio values to be fed to a model. Returned when `audio` is not `None`.
|
||||
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
|
||||
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
|
||||
`None`).
|
||||
- **labels** -- List of labels for the audio frames. Returned when `output_labels=True`.
|
||||
"""
|
||||
|
||||
output_kwargs = self._merge_kwargs(
|
||||
CsmProcessorKwargs,
|
||||
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
text_kwargs = output_kwargs["text_kwargs"]
|
||||
audio_kwargs = output_kwargs["audio_kwargs"]
|
||||
common_kwargs = output_kwargs["common_kwargs"]
|
||||
|
||||
return_tensors = common_kwargs.pop("return_tensors", None)
|
||||
if return_tensors != "pt":
|
||||
raise ValueError(f"{self.__class__.__name__} only supports `return_tensors='pt'`.")
|
||||
|
||||
if isinstance(text, str):
|
||||
text = [text]
|
||||
elif not (isinstance(text, (list, tuple)) and all(isinstance(t, str) for t in text)):
|
||||
raise ValueError("Invalid input text. Please provide a string, or a list of strings")
|
||||
n_audio_in_text = [t.count(self.audio_token) for t in text]
|
||||
|
||||
n_audio = 0
|
||||
if audio is not None:
|
||||
audio = make_list_of_audio(audio)
|
||||
n_audio = len(audio)
|
||||
|
||||
if sum(n_audio_in_text) > 0 and n_audio != sum(n_audio_in_text):
|
||||
if audio is None:
|
||||
raise ValueError("No audio were provided, but there are audio tokens in the prompt")
|
||||
else:
|
||||
raise ValueError(
|
||||
f"The number of audio tokens in each text ({n_audio_in_text}) should be the same as the "
|
||||
f"number of provided audios ({n_audio})."
|
||||
)
|
||||
|
||||
if audio is not None:
|
||||
encoded_length_kwargs = audio_kwargs.pop("encoded_length_kwargs", {})
|
||||
num_audio_tokens_list = [
|
||||
self._get_encoded_length(audio_array.shape[-1], **encoded_length_kwargs) for audio_array in audio
|
||||
]
|
||||
num_audio_tokens_list_copy = num_audio_tokens_list.copy()
|
||||
|
||||
# expand the text to repeat the audio token for the corresponding number of frames
|
||||
expanded_text = []
|
||||
for sample in text:
|
||||
replace_str = []
|
||||
while self.audio_token in sample:
|
||||
num_audio_tokens = num_audio_tokens_list_copy.pop(0)
|
||||
expanded_audio_token = self.audio_token * num_audio_tokens
|
||||
|
||||
replace_str.append(expanded_audio_token)
|
||||
sample = sample.replace(self.audio_token, "<placeholder>", 1)
|
||||
|
||||
while "<placeholder>" in sample:
|
||||
sample = sample.replace("<placeholder>", replace_str.pop(0), 1)
|
||||
expanded_text.append(sample)
|
||||
|
||||
text = expanded_text
|
||||
|
||||
encoding = self.tokenizer(text, **text_kwargs)
|
||||
data = {}
|
||||
data.update(encoding)
|
||||
|
||||
if audio is not None:
|
||||
audio_kwargs.pop("return_attention_mask", None) # not supported by the feature extractor
|
||||
|
||||
concatenated_audio, input_values_cutoffs = [], []
|
||||
offset = 0
|
||||
for n_audio in n_audio_in_text:
|
||||
if n_audio == 0:
|
||||
concatenated_audio.append(np.zeros(0))
|
||||
input_values_cutoffs.append(torch.tensor([-1]))
|
||||
else:
|
||||
concatenated_audio.append(
|
||||
np.concatenate(
|
||||
[
|
||||
el.cpu().numpy() if isinstance(el, torch.Tensor) else el
|
||||
for el in audio[offset : offset + n_audio]
|
||||
],
|
||||
axis=-1,
|
||||
)
|
||||
)
|
||||
input_values_cutoffs.append(
|
||||
torch.tensor([el.shape[-1] for el in audio[offset : offset + n_audio]]).cumsum(dim=-1)
|
||||
)
|
||||
offset += n_audio
|
||||
|
||||
audio_inputs = self.feature_extractor(concatenated_audio, **audio_kwargs)
|
||||
audio_inputs.pop("padding_mask", None) # not applicable here
|
||||
data.update(audio_inputs)
|
||||
|
||||
# pad and stack the audio cut idxs
|
||||
max_len = max(cut_idxs.shape[-1] for cut_idxs in input_values_cutoffs)
|
||||
input_values_cutoffs = [
|
||||
torch.nn.functional.pad(cut_idxs, (0, max_len - cut_idxs.shape[-1]), value=-1)
|
||||
for cut_idxs in input_values_cutoffs
|
||||
]
|
||||
data["input_values_cutoffs"] = torch.stack(input_values_cutoffs, dim=0)
|
||||
|
||||
if output_labels:
|
||||
audio_frame_idxs = (data["input_ids"] == self.audio_token_id).nonzero()
|
||||
n_audio_frames = audio_frame_idxs.shape[0]
|
||||
|
||||
if depth_decoder_labels_ratio <= 1.0:
|
||||
rand_idxs = torch.randperm(n_audio_frames)[: int(n_audio_frames * (1 - depth_decoder_labels_ratio))]
|
||||
skip_frames_idxs = audio_frame_idxs[rand_idxs]
|
||||
else:
|
||||
skip_frames_idxs = audio_frame_idxs
|
||||
|
||||
labels = torch.where(data["input_ids"] == self.audio_token_id, data["input_ids"], -100)
|
||||
labels[skip_frames_idxs[:, 0], skip_frames_idxs[:, 1]] = -101
|
||||
|
||||
data["labels"] = labels
|
||||
|
||||
return BatchFeature(data=data, tensor_type=return_tensors)
|
||||
|
||||
|
||||
__all__ = ["CsmProcessor"]
|
@ -1111,7 +1111,7 @@ class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel, GenerationMixin):
|
||||
)
|
||||
return model_inputs
|
||||
|
||||
def _get_initial_cache_position(self, input_ids, model_kwargs):
|
||||
def _get_initial_cache_position(self, seq_length, device, model_kwargs):
|
||||
"""
|
||||
Calculates `cache_position` for the pre-fill stage based on `input_ids` and optionally past length.
|
||||
Since gpt bigcode is special, the method is overridden here, other models use it from `generation.utils.py`.
|
||||
@ -1125,8 +1125,8 @@ class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel, GenerationMixin):
|
||||
if "inputs_embeds" in model_kwargs:
|
||||
cur_len = model_kwargs["inputs_embeds"].shape[1]
|
||||
else:
|
||||
cur_len = input_ids.shape[-1]
|
||||
model_kwargs["cache_position"] = torch.arange(past_length, cur_len, device=input_ids.device)
|
||||
cur_len = seq_length
|
||||
model_kwargs["cache_position"] = torch.arange(past_length, cur_len, device=device)
|
||||
return model_kwargs
|
||||
|
||||
@add_start_docstrings_to_model_forward(GPT_BIGCODE_INPUTS_DOCSTRING)
|
||||
|
@ -1563,7 +1563,7 @@ class JanusForConditionalGeneration(JanusPreTrainedModel, GenerationMixin):
|
||||
|
||||
inputs_embeds = self.get_input_embeddings()(input_tokens)
|
||||
|
||||
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
|
||||
model_kwargs = self._get_initial_cache_position(seq_len, device, model_kwargs)
|
||||
|
||||
if model_kwargs.get("past_key_values", None) is None:
|
||||
# Prepare cache if not provided.
|
||||
|
@ -1378,7 +1378,7 @@ class JanusForConditionalGeneration(JanusPreTrainedModel, GenerationMixin):
|
||||
|
||||
inputs_embeds = self.get_input_embeddings()(input_tokens)
|
||||
|
||||
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
|
||||
model_kwargs = self._get_initial_cache_position(seq_len, device, model_kwargs)
|
||||
|
||||
if model_kwargs.get("past_key_values", None) is None:
|
||||
# Prepare cache if not provided.
|
||||
|
@ -216,6 +216,32 @@ class MimiConv1d(nn.Module):
|
||||
end = padded.shape[-1] - extra_pad
|
||||
return padded[..., :end]
|
||||
|
||||
def _get_output_length(self, input_length: torch.LongTensor) -> torch.LongTensor:
|
||||
"""
|
||||
Return the length of the output of the MimiConv1d.
|
||||
"""
|
||||
# padding size
|
||||
n_frames = (input_length - self.kernel_size + self.padding_total) / self.stride + 1
|
||||
n_frames = torch.ceil(n_frames).to(torch.int64) - 1
|
||||
ideal_length = n_frames * self.stride + self.kernel_size - self.padding_total
|
||||
extra_padding = ideal_length - input_length
|
||||
|
||||
if self.causal:
|
||||
padding_left = self.padding_total
|
||||
padding_right = extra_padding
|
||||
else:
|
||||
padding_left = self.padding_left
|
||||
padding_right = self.padding_right + extra_padding
|
||||
|
||||
# padding
|
||||
input_length = input_length + padding_left + padding_right
|
||||
|
||||
# conv
|
||||
output_lenght = (
|
||||
input_length + 2 * self.conv.padding[0] - self.conv.dilation[0] * (self.conv.kernel_size[0] - 1) - 1
|
||||
) // self.conv.stride[0] + 1
|
||||
return output_lenght
|
||||
|
||||
def forward(self, hidden_states):
|
||||
extra_padding = self._get_extra_padding_for_conv1d(hidden_states)
|
||||
|
||||
@ -331,21 +357,28 @@ class MimiEncoder(nn.Module):
|
||||
model = [MimiConv1d(config, config.audio_channels, config.num_filters, config.kernel_size)]
|
||||
scaling = 1
|
||||
|
||||
# keep track of MimiConv1d submodule layer names for easy encoded length computation
|
||||
mimiconv1d_layer_names = ["layers.0"]
|
||||
|
||||
# Downsample to raw audio scale
|
||||
for ratio in reversed(config.upsampling_ratios):
|
||||
current_scale = scaling * config.num_filters
|
||||
# Add residual layers
|
||||
for j in range(config.num_residual_layers):
|
||||
mimiconv1d_layer_names.extend([f"layers.{len(model)}.block.1", f"layers.{len(model)}.block.3"])
|
||||
model += [MimiResnetBlock(config, current_scale, [config.dilation_growth_rate**j, 1])]
|
||||
# Add downsampling layers
|
||||
model += [nn.ELU()]
|
||||
mimiconv1d_layer_names.append(f"layers.{len(model)}")
|
||||
model += [MimiConv1d(config, current_scale, current_scale * 2, kernel_size=ratio * 2, stride=ratio)]
|
||||
scaling *= 2
|
||||
|
||||
model += [nn.ELU()]
|
||||
mimiconv1d_layer_names.append(f"layers.{len(model)}")
|
||||
model += [MimiConv1d(config, scaling * config.num_filters, config.hidden_size, config.last_kernel_size)]
|
||||
|
||||
self.layers = nn.ModuleList(model)
|
||||
self._mimiconv1d_layer_names = mimiconv1d_layer_names
|
||||
|
||||
# Copied from transformers.models.encodec.modeling_encodec.EncodecEncoder.forward
|
||||
def forward(self, hidden_states):
|
||||
@ -1567,6 +1600,38 @@ class MimiModel(MimiPreTrainedModel):
|
||||
codes = codes.transpose(0, 1)
|
||||
return codes, past_key_values
|
||||
|
||||
def get_encoded_length(self, input_length: torch.LongTensor) -> torch.LongTensor:
|
||||
"""
|
||||
Return the number of frames of the encoded audio waveform.
|
||||
"""
|
||||
output_length = input_length
|
||||
|
||||
# encoder
|
||||
for layer_name in self.encoder._mimiconv1d_layer_names:
|
||||
output_length = self.encoder.get_submodule(layer_name)._get_output_length(output_length)
|
||||
|
||||
# downsample
|
||||
output_length = self.downsample._get_output_length(output_length)
|
||||
|
||||
return output_length
|
||||
|
||||
def get_audio_codes_mask(self, padding_mask: torch.Tensor, padding_side: str = "right"):
|
||||
"""
|
||||
Get the mask for the audio codes from the original padding mask.
|
||||
"""
|
||||
encoded_lengths = self.get_encoded_length(padding_mask.sum(dim=-1))
|
||||
|
||||
audio_codes_mask = torch.arange(encoded_lengths.max(), device=encoded_lengths.device).expand(
|
||||
len(encoded_lengths), -1
|
||||
)
|
||||
audio_codes_mask = audio_codes_mask < encoded_lengths.unsqueeze(1)
|
||||
audio_codes_mask = audio_codes_mask.to(padding_mask.device)
|
||||
|
||||
if padding_side == "right":
|
||||
return audio_codes_mask
|
||||
else:
|
||||
return audio_codes_mask.flip(dims=[-1])
|
||||
|
||||
def encode(
|
||||
self,
|
||||
input_values: torch.Tensor,
|
||||
|
@ -3084,10 +3084,10 @@ class Qwen2_5OmniTalkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForCon
|
||||
thinker_reply_part=thinker_reply_part,
|
||||
)
|
||||
|
||||
def _get_initial_cache_position(self, input_ids, model_kwargs):
|
||||
def _get_initial_cache_position(self, seq_length, device, model_kwargs):
|
||||
# Talker needs to calculate cache_position with input_ids, so pop inputs_embeds temporarily
|
||||
inputs_embeds = model_kwargs.pop("inputs_embeds")
|
||||
model_kwargs = super()._get_initial_cache_position(input_ids, model_kwargs)
|
||||
model_kwargs = super()._get_initial_cache_position(seq_length, device, model_kwargs)
|
||||
model_kwargs["inputs_embeds"] = inputs_embeds
|
||||
return model_kwargs
|
||||
|
||||
|
@ -2771,10 +2771,10 @@ class Qwen2_5OmniTalkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForCon
|
||||
thinker_reply_part=thinker_reply_part,
|
||||
)
|
||||
|
||||
def _get_initial_cache_position(self, input_ids, model_kwargs):
|
||||
def _get_initial_cache_position(self, seq_length, device, model_kwargs):
|
||||
# Talker needs to calculate cache_position with input_ids, so pop inputs_embeds temporarily
|
||||
inputs_embeds = model_kwargs.pop("inputs_embeds")
|
||||
model_kwargs = super()._get_initial_cache_position(input_ids, model_kwargs)
|
||||
model_kwargs = super()._get_initial_cache_position(seq_length, device, model_kwargs)
|
||||
model_kwargs["inputs_embeds"] = inputs_embeds
|
||||
return model_kwargs
|
||||
|
||||
|
@ -1058,7 +1058,7 @@ class ProcessorMixin(PushToHubMixin):
|
||||
# update defaults with arguments from tokenizer init
|
||||
for modality_key in ModelProcessorKwargs.__annotations__[modality].__annotations__.keys():
|
||||
# init with tokenizer init kwargs if necessary
|
||||
if modality_key in tokenizer_init_kwargs:
|
||||
if tokenizer_init_kwargs is not None and modality_key in tokenizer_init_kwargs:
|
||||
value = (
|
||||
getattr(self.tokenizer, modality_key)
|
||||
if hasattr(self.tokenizer, modality_key)
|
||||
|
@ -501,9 +501,9 @@ class GenerationTesterMixin:
|
||||
output_generate = self._greedy_generate(model=model, inputs_dict=inputs_dict)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
|
||||
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + 1)
|
||||
else:
|
||||
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1])
|
||||
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1])
|
||||
|
||||
@pytest.mark.generate
|
||||
def test_greedy_generate_dict_outputs(self):
|
||||
@ -525,13 +525,13 @@ class GenerationTesterMixin:
|
||||
)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
|
||||
self.assertTrue(output_generate.sequences.shape[1] == self.max_new_tokens + 1)
|
||||
self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
|
||||
# Retrocompatibility check
|
||||
self.assertIsInstance(output_generate, GreedySearchEncoderDecoderOutput)
|
||||
else:
|
||||
self.assertTrue(
|
||||
output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]
|
||||
output_generate.sequences.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1]
|
||||
)
|
||||
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
|
||||
# Retrocompatibility check
|
||||
@ -565,10 +565,10 @@ class GenerationTesterMixin:
|
||||
)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
|
||||
self.assertTrue(output_generate.sequences.shape[1] == self.max_new_tokens + 1)
|
||||
else:
|
||||
self.assertTrue(
|
||||
output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]
|
||||
output_generate.sequences.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1]
|
||||
)
|
||||
|
||||
self._check_generate_outputs(output_generate, model.config, use_cache=True)
|
||||
@ -582,9 +582,9 @@ class GenerationTesterMixin:
|
||||
output_generate = self._sample_generate(model=model, inputs_dict=inputs_dict, num_return_sequences=1)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
|
||||
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + 1)
|
||||
else:
|
||||
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1])
|
||||
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1])
|
||||
|
||||
@pytest.mark.generate
|
||||
def test_sample_generate_dict_output(self):
|
||||
@ -607,13 +607,13 @@ class GenerationTesterMixin:
|
||||
)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
|
||||
self.assertTrue(output_generate.sequences.shape[1] == self.max_new_tokens + 1)
|
||||
self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
|
||||
# Retrocompatibility check
|
||||
self.assertIsInstance(output_generate, SampleEncoderDecoderOutput)
|
||||
else:
|
||||
self.assertTrue(
|
||||
output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]
|
||||
output_generate.sequences.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1]
|
||||
)
|
||||
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
|
||||
# Retrocompatibility check
|
||||
@ -632,9 +632,9 @@ class GenerationTesterMixin:
|
||||
output_generate = self._beam_search_generate(model=model, inputs_dict=inputs_dict, beam_kwargs=beam_kwargs)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
|
||||
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + 1)
|
||||
else:
|
||||
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1])
|
||||
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1])
|
||||
|
||||
@pytest.mark.generate
|
||||
def test_beam_search_generate_dict_output(self):
|
||||
@ -657,13 +657,13 @@ class GenerationTesterMixin:
|
||||
use_cache=False,
|
||||
)
|
||||
if model.config.is_encoder_decoder:
|
||||
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
|
||||
self.assertTrue(output_generate.sequences.shape[1] == self.max_new_tokens + 1)
|
||||
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
|
||||
# Retrocompatibility check
|
||||
self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput)
|
||||
else:
|
||||
self.assertTrue(
|
||||
output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]
|
||||
output_generate.sequences.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1]
|
||||
)
|
||||
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
|
||||
# Retrocompatibility check
|
||||
@ -706,10 +706,10 @@ class GenerationTesterMixin:
|
||||
)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
|
||||
self.assertTrue(output_generate.sequences.shape[1] == self.max_new_tokens + 1)
|
||||
else:
|
||||
self.assertTrue(
|
||||
output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]
|
||||
output_generate.sequences.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1]
|
||||
)
|
||||
|
||||
self._check_generate_outputs(
|
||||
@ -759,9 +759,9 @@ class GenerationTesterMixin:
|
||||
)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
|
||||
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + 1)
|
||||
else:
|
||||
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1])
|
||||
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1])
|
||||
|
||||
@pytest.mark.generate
|
||||
def test_beam_sample_generate_dict_output(self):
|
||||
@ -786,13 +786,13 @@ class GenerationTesterMixin:
|
||||
)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
|
||||
self.assertTrue(output_generate.sequences.shape[1] == self.max_new_tokens + 1)
|
||||
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
|
||||
# Retrocompatibility check
|
||||
self.assertIsInstance(output_generate, BeamSampleEncoderDecoderOutput)
|
||||
else:
|
||||
self.assertTrue(
|
||||
output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]
|
||||
output_generate.sequences.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1]
|
||||
)
|
||||
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
|
||||
# Retrocompatibility check
|
||||
@ -840,9 +840,9 @@ class GenerationTesterMixin:
|
||||
beam_kwargs=beam_kwargs,
|
||||
)
|
||||
if model.config.is_encoder_decoder:
|
||||
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
|
||||
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + 1)
|
||||
else:
|
||||
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1])
|
||||
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1])
|
||||
|
||||
# check `group_beam_search` for higher than 1 `num_return_sequences`
|
||||
num_return_sequences = 2
|
||||
@ -853,9 +853,9 @@ class GenerationTesterMixin:
|
||||
beam_kwargs=beam_kwargs,
|
||||
)
|
||||
if model.config.is_encoder_decoder:
|
||||
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
|
||||
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + 1)
|
||||
else:
|
||||
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1])
|
||||
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1])
|
||||
|
||||
@pytest.mark.generate
|
||||
def test_group_beam_search_generate_dict_output(self):
|
||||
@ -878,13 +878,13 @@ class GenerationTesterMixin:
|
||||
use_cache=False,
|
||||
)
|
||||
if model.config.is_encoder_decoder:
|
||||
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
|
||||
self.assertTrue(output_generate.sequences.shape[1] == self.max_new_tokens + 1)
|
||||
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
|
||||
# Retrocompatibility check
|
||||
self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput)
|
||||
else:
|
||||
self.assertTrue(
|
||||
output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]
|
||||
output_generate.sequences.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1]
|
||||
)
|
||||
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
|
||||
# Retrocompatibility check
|
||||
@ -923,9 +923,9 @@ class GenerationTesterMixin:
|
||||
)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
|
||||
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + 1)
|
||||
else:
|
||||
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1])
|
||||
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1])
|
||||
|
||||
for generation_output in output_generate:
|
||||
self._check_sequence_inside_sequence(force_tokens, generation_output)
|
||||
@ -947,9 +947,9 @@ class GenerationTesterMixin:
|
||||
)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
|
||||
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + 1)
|
||||
else:
|
||||
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1])
|
||||
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1])
|
||||
|
||||
for generation_output in output_generate:
|
||||
self._check_sequence_inside_sequence(force_tokens, generation_output)
|
||||
@ -987,13 +987,13 @@ class GenerationTesterMixin:
|
||||
)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
|
||||
self.assertTrue(output_generate.sequences.shape[1] == self.max_new_tokens + 1)
|
||||
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
|
||||
# Retrocompatibility check
|
||||
self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput)
|
||||
else:
|
||||
self.assertTrue(
|
||||
output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]
|
||||
output_generate.sequences.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1]
|
||||
)
|
||||
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
|
||||
# Retrocompatibility check
|
||||
@ -1031,9 +1031,9 @@ class GenerationTesterMixin:
|
||||
use_cache=True, # Enable cache
|
||||
)
|
||||
if model.config.is_encoder_decoder:
|
||||
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
|
||||
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + 1)
|
||||
else:
|
||||
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1])
|
||||
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1])
|
||||
|
||||
@pytest.mark.generate
|
||||
def test_contrastive_generate_dict_outputs_use_cache(self):
|
||||
@ -1067,10 +1067,10 @@ class GenerationTesterMixin:
|
||||
)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
|
||||
self.assertTrue(output_generate.sequences.shape[1] == self.max_new_tokens + 1)
|
||||
else:
|
||||
self.assertTrue(
|
||||
output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]
|
||||
output_generate.sequences.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1]
|
||||
)
|
||||
|
||||
self._check_generate_outputs(output_generate, model.config, use_cache=True)
|
||||
@ -1499,7 +1499,7 @@ class GenerationTesterMixin:
|
||||
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||
model_kwargs["position_ids"] = position_ids
|
||||
if "cache_position" in signature:
|
||||
cache_position = torch.arange(input_ids.shape[-1], device=torch_device)
|
||||
cache_position = torch.arange(input_ids.shape[1], device=torch_device)
|
||||
model_kwargs["cache_position"] = cache_position
|
||||
return model_kwargs
|
||||
|
||||
@ -1525,10 +1525,12 @@ class GenerationTesterMixin:
|
||||
pad_token_id = (
|
||||
config.get_text_config().pad_token_id if config.get_text_config().pad_token_id is not None else 0
|
||||
)
|
||||
pad_size = (input_ids.shape[0], 32)
|
||||
pad_size = (input_ids.shape[0], 32, *input_ids.shape[2:])
|
||||
padding = torch.ones(pad_size, dtype=input_ids.dtype, device=torch_device) * pad_token_id
|
||||
padded_input_ids = torch.cat((padding, input_ids), dim=1)
|
||||
padded_attention_mask = torch.cat((torch.zeros_like(padding), attention_mask), dim=1)
|
||||
padded_attention_mask = torch.cat(
|
||||
(torch.zeros(pad_size[:2], dtype=input_ids.dtype, device=torch_device), attention_mask), dim=1
|
||||
)
|
||||
model_kwargs = _prepare_model_kwargs(padded_input_ids, padded_attention_mask, signature)
|
||||
next_logits_with_padding = model(**model_kwargs).logits[:, -1, :]
|
||||
|
||||
@ -1587,7 +1589,7 @@ class GenerationTesterMixin:
|
||||
else text_config.num_attention_heads
|
||||
)
|
||||
encoder_per_head_embed_dim = embed_dim // encoder_num_attention_heads
|
||||
batch_size, seq_length = inputs["decoder_input_ids"].shape
|
||||
batch_size, seq_length = inputs["decoder_input_ids"].shape[:2]
|
||||
# The sequence length for the encoder K V depends on the model. Since it is not manipulated in
|
||||
# autoregressive generation, we're keeping the test general and not checking the 3rd dim
|
||||
default_cross_attention_shape = (
|
||||
@ -1606,7 +1608,7 @@ class GenerationTesterMixin:
|
||||
for _ in range(num_decoder_layers)
|
||||
]
|
||||
else:
|
||||
batch_size, seq_length = inputs["input_ids"].shape
|
||||
batch_size, seq_length = inputs["input_ids"].shape[:2]
|
||||
default_self_attention_shape = (batch_size, num_key_value_heads, seq_length, per_head_embed_dim)
|
||||
all_cache_shapes = [
|
||||
[default_self_attention_shape, default_self_attention_shape] for _ in range(num_decoder_layers)
|
||||
@ -1727,7 +1729,7 @@ class GenerationTesterMixin:
|
||||
"min_new_tokens": 5, # generate exactly 5 tokens
|
||||
}
|
||||
outputs_from_ids = model.generate(input_ids, **generation_kwargs, **inputs_dict)
|
||||
self.assertEqual(outputs_from_ids.sequences.shape, (input_ids.shape[0], input_ids.shape[1] + 5))
|
||||
self.assertEqual(outputs_from_ids.sequences.shape[:2], (input_ids.shape[0], input_ids.shape[1] + 5))
|
||||
|
||||
# Same thing, but from input embeddings (`input_ids` is passed so the prompt is present in the output).
|
||||
# The output of the two calls should be the same.
|
||||
@ -2262,11 +2264,11 @@ class GenerationTesterMixin:
|
||||
self.assertTrue(hasattr(model, "_compiled_call")) # our auto compile should have been called
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
|
||||
self.assertTrue(output_generate.sequences.shape[1] == self.max_new_tokens + 1)
|
||||
self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
|
||||
else:
|
||||
self.assertTrue(
|
||||
output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]
|
||||
output_generate.sequences.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1]
|
||||
)
|
||||
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
|
||||
|
||||
@ -2408,7 +2410,7 @@ class GenerationTesterMixin:
|
||||
config = config.text_config if hasattr(config, "text_config") else config
|
||||
|
||||
generated_length = (
|
||||
output.sequences.shape[-1] - 1 if config.is_encoder_decoder else output.sequences.shape[-1] - prompt_length
|
||||
output.sequences.shape[1] - 1 if config.is_encoder_decoder else output.sequences.shape[1] - prompt_length
|
||||
)
|
||||
decoder_past_key_values = getattr(output, "past_key_values", None)
|
||||
if config.is_encoder_decoder and isinstance(decoder_past_key_values, EncoderDecoderCache):
|
||||
@ -2441,7 +2443,7 @@ class GenerationTesterMixin:
|
||||
batch_size=internal_batch_size,
|
||||
attentions=output.decoder_attentions,
|
||||
prompt_length=1, # the BOS token
|
||||
output_length=output.sequences.shape[-1],
|
||||
output_length=output.sequences.shape[1],
|
||||
config=config,
|
||||
decoder_past_key_values=decoder_past_key_values,
|
||||
)
|
||||
@ -2450,7 +2452,7 @@ class GenerationTesterMixin:
|
||||
batch_size=internal_batch_size,
|
||||
attentions=output.attentions,
|
||||
prompt_length=prompt_length,
|
||||
output_length=output.sequences.shape[-1],
|
||||
output_length=output.sequences.shape[1],
|
||||
config=config,
|
||||
decoder_past_key_values=decoder_past_key_values,
|
||||
)
|
||||
@ -2469,7 +2471,7 @@ class GenerationTesterMixin:
|
||||
batch_size=internal_batch_size,
|
||||
hidden_states=output.decoder_hidden_states,
|
||||
prompt_length=1, # the BOS token
|
||||
output_length=output.sequences.shape[-1],
|
||||
output_length=output.sequences.shape[1],
|
||||
config=config,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
@ -2478,7 +2480,7 @@ class GenerationTesterMixin:
|
||||
batch_size=internal_batch_size,
|
||||
hidden_states=output.hidden_states,
|
||||
prompt_length=prompt_length,
|
||||
output_length=output.sequences.shape[-1],
|
||||
output_length=output.sequences.shape[1],
|
||||
config=config,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
@ -2506,7 +2508,7 @@ class GenerationTesterMixin:
|
||||
)
|
||||
if has_standard_cache:
|
||||
if use_cache:
|
||||
cache_length = output.sequences.shape[-1] - 1
|
||||
cache_length = output.sequences.shape[1] - 1
|
||||
self._check_past_key_values_for_generate(
|
||||
batch_size=internal_batch_size,
|
||||
decoder_past_key_values=decoder_past_key_values,
|
||||
|
0
tests/models/csm/__init__.py
Normal file
0
tests/models/csm/__init__.py
Normal file
693
tests/models/csm/test_modeling_csm.py
Normal file
693
tests/models/csm/test_modeling_csm.py
Normal file
@ -0,0 +1,693 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024, The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Testing suite for the PyTorch ConversationalSpeechModel model."""
|
||||
|
||||
import collections
|
||||
import copy
|
||||
import re
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
from parameterized import parameterized
|
||||
|
||||
from transformers import (
|
||||
AutoProcessor,
|
||||
CsmConfig,
|
||||
CsmForConditionalGeneration,
|
||||
is_torch_available,
|
||||
)
|
||||
from transformers.testing_utils import (
|
||||
cleanup,
|
||||
require_torch_gpu,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from transformers.utils.import_utils import is_datasets_available
|
||||
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import (
|
||||
ModelTesterMixin,
|
||||
_config_zero_init,
|
||||
ids_tensor,
|
||||
)
|
||||
|
||||
|
||||
if is_datasets_available():
|
||||
from datasets import load_dataset
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers.pytorch_utils import id_tensor_storage
|
||||
|
||||
|
||||
class CsmModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
ignore_index=-100,
|
||||
batch_size=3,
|
||||
seq_length=7,
|
||||
is_training=True,
|
||||
depth_decoder_config={
|
||||
"num_codebooks": 10,
|
||||
"backbone_hidden_size": 64,
|
||||
"vocab_size": 6,
|
||||
"hidden_size": 64,
|
||||
"intermediate_size": 128,
|
||||
"num_hidden_layers": 2,
|
||||
"num_attention_heads": 4,
|
||||
"num_key_value_heads": 2,
|
||||
"hidden_act": "silu",
|
||||
"max_position_embeddings": 10,
|
||||
},
|
||||
codec_config={
|
||||
"model_type": "mimi",
|
||||
"audio_channels": 1,
|
||||
"chunk_in_sec": None,
|
||||
"hidden_size": 32,
|
||||
"num_filters": 8,
|
||||
"num_residual_layers": 1,
|
||||
"upsampling_ratios": [8, 4],
|
||||
"codebook_size": 64,
|
||||
"vector_quantization_hidden_dimension": 64,
|
||||
"upsample_groups": 32,
|
||||
"num_hidden_layers": 2,
|
||||
"num_attention_heads": 2,
|
||||
"num_key_value_heads": 2,
|
||||
"sliding_window": 4,
|
||||
"codebook_dim": 64,
|
||||
"use_cache": False,
|
||||
},
|
||||
config={
|
||||
"num_codebooks": 10,
|
||||
"vocab_size": 6,
|
||||
"text_vocab_size": 99,
|
||||
"hidden_size": 64,
|
||||
"intermediate_size": 64,
|
||||
"num_hidden_layers": 2,
|
||||
"num_attention_heads": 4,
|
||||
"num_key_value_heads": 2,
|
||||
"hidden_act": "silu",
|
||||
"max_position_embeddings": 10,
|
||||
"bos_token_id": 1,
|
||||
"pad_token_id": 2,
|
||||
"eos_token_id": 3,
|
||||
"codebook_pad_token_id": 2,
|
||||
"codebook_eos_token_id": 3,
|
||||
},
|
||||
):
|
||||
self.parent = parent
|
||||
self.is_training = is_training
|
||||
self.ignore_index = ignore_index
|
||||
self.depth_decoder_config = depth_decoder_config
|
||||
self.codec_config = codec_config
|
||||
self.config = config
|
||||
self.seq_length = seq_length
|
||||
self.batch_size = batch_size
|
||||
|
||||
self.num_hidden_layers = config["num_hidden_layers"]
|
||||
self.vocab_size = config["vocab_size"]
|
||||
self.hidden_size = config["hidden_size"]
|
||||
self.num_attention_heads = config["num_attention_heads"]
|
||||
self.pad_token_id = config["pad_token_id"]
|
||||
|
||||
def get_config(self):
|
||||
return CsmConfig(
|
||||
depth_decoder_config=self.depth_decoder_config,
|
||||
codec_config=self.codec_config,
|
||||
**self.config,
|
||||
)
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
config = self.get_config()
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length, config.num_codebooks], config.vocab_size - 1) + 1
|
||||
attention_mask = input_ids[..., -1].ne(1).to(torch_device)
|
||||
return config, input_ids, attention_mask
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config, input_ids, attention_mask = self.prepare_config_and_inputs()
|
||||
inputs_dict = {"input_ids": input_ids, "attention_mask": attention_mask}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
class CsmForConditionalGenerationTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (CsmForConditionalGeneration,) if is_torch_available() else ()
|
||||
test_pruning = False
|
||||
test_headmasking = False
|
||||
test_resize_embeddings = False
|
||||
test_resize_embeddings_untied = False
|
||||
test_torch_exportable = True
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = CsmModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=CsmConfig)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||
"""
|
||||
Overrides [ModelTesterMixin._prepare_for_class] to handle third input_ids dimension.
|
||||
"""
|
||||
inputs_dict = copy.deepcopy(inputs_dict)
|
||||
|
||||
if return_labels:
|
||||
inputs_dict["labels"] = torch.zeros(
|
||||
(
|
||||
self.model_tester.batch_size,
|
||||
self.model_tester.seq_length,
|
||||
self.model_tester.config["num_codebooks"],
|
||||
),
|
||||
dtype=torch.long,
|
||||
device=torch_device,
|
||||
)
|
||||
|
||||
return inputs_dict
|
||||
|
||||
def _get_logits_processor_kwargs(self, do_sample=False, config=None):
|
||||
"""
|
||||
Overrides [GenerationTesterMixin._get_logits_processor_kwargs] to restrict to top_k, top_p, and temperature sampling.
|
||||
"""
|
||||
logits_processor_kwargs = {}
|
||||
if do_sample:
|
||||
logits_processor_kwargs.update(
|
||||
{
|
||||
"top_k": 10,
|
||||
"top_p": 0.7,
|
||||
"temperature": 0.7,
|
||||
}
|
||||
)
|
||||
|
||||
return logits_processor_kwargs
|
||||
|
||||
def test_initialization(self):
|
||||
"""
|
||||
Overrides [ModelTesterMixin.test_initialization] because of specificities of Mimi codec model.
|
||||
See https://github.com/huggingface/transformers/blob/1077603410cd73ba71d64a522033574d66d64b55/tests/models/mimi/test_modeling_mimi.py#L384-L397
|
||||
"""
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
configs_no_init = _config_zero_init(config)
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config=configs_no_init)
|
||||
for name, param in model.named_parameters():
|
||||
uniform_init_parms = ["conv", "input_proj", "output_proj"]
|
||||
if param.requires_grad:
|
||||
if any(x in name for x in uniform_init_parms):
|
||||
self.assertTrue(
|
||||
-1.0 <= ((param.data.mean() * 1e9).round() / 1e9).item() <= 1.0,
|
||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||
)
|
||||
|
||||
def _check_similar_generate_outputs(self, output_1, output_2, atol=1e-5, rtol=1e-5):
|
||||
"""
|
||||
Overrides [GenerationTesterMixin._check_similar_generate_outputs] to handle third input_ids dimension.
|
||||
Here we only look a the first codebook (index 0 on last dimension of the generated sequences) since returned scores
|
||||
are for this token.
|
||||
"""
|
||||
# scores doesn't include data regarding decoder input tokens
|
||||
decoder_input_length = output_1.sequences.shape[1] - len(output_1.scores)
|
||||
output_matches = output_1.sequences[..., 0] == output_2.sequences[..., 0]
|
||||
has_matching_outputs = output_matches.all()
|
||||
has_matching_scores = None
|
||||
if not has_matching_outputs:
|
||||
for batch_idx in range(output_1.sequences.shape[0]):
|
||||
batch_matches = output_matches[batch_idx]
|
||||
if batch_matches.all():
|
||||
continue
|
||||
first_mismatch_idx = batch_matches.int().argmin() # gets the index of the first False
|
||||
first_mismatch_idx -= decoder_input_length
|
||||
output_1_first_mismatch_scores = output_1.scores[first_mismatch_idx][batch_idx]
|
||||
output_2_first_mismatch_scores = output_2.scores[first_mismatch_idx][batch_idx]
|
||||
has_matching_scores = torch.allclose(
|
||||
output_1_first_mismatch_scores, output_2_first_mismatch_scores, rtol=atol, atol=rtol
|
||||
)
|
||||
if not has_matching_scores:
|
||||
break
|
||||
self.assertTrue(has_matching_outputs or has_matching_scores)
|
||||
|
||||
@parameterized.expand([("random",), ("same",)])
|
||||
@pytest.mark.generate
|
||||
@unittest.skip(reason="CSM does not support assisted decoding.")
|
||||
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
|
||||
pass
|
||||
|
||||
@pytest.mark.generate
|
||||
@unittest.skip(reason="CSM does not support assisted decoding.")
|
||||
def test_assisted_decoding_sample(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.generate
|
||||
@unittest.skip(reason="CSM does not support Dola decoding.")
|
||||
def test_dola_decoding_sample(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.generate
|
||||
@unittest.skip(reason="CSM does not support beam search.")
|
||||
def test_beam_sample_generate(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.generate
|
||||
@unittest.skip(reason="CSM does not support beam search.")
|
||||
def test_beam_search_generate(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.generate
|
||||
@unittest.skip(reason="CSM does not support beam search.")
|
||||
def test_beam_search_generate_dict_output(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.generate
|
||||
@unittest.skip(reason="CSM does not support beam search.")
|
||||
def test_beam_search_generate_dict_outputs_use_cache(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.generate
|
||||
@unittest.skip(reason="CSM does not support beam search.")
|
||||
def test_beam_sample_generate_dict_output(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.generate
|
||||
@unittest.skip(reason="CSM does not support group beam search.")
|
||||
def test_group_beam_search_generate(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.generate
|
||||
@unittest.skip(reason="CSM does not support group beam search.")
|
||||
def test_group_beam_search_generate_dict_output(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.generate
|
||||
@unittest.skip(reason="CSM does not support constrained beam search.")
|
||||
def test_constrained_beam_search_generate(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.generate
|
||||
@unittest.skip(reason="CSM does not support constrained beam search.")
|
||||
def test_constrained_beam_search_generate_dict_output(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.generate
|
||||
@unittest.skip(reason="CSM does not support contrastive search.")
|
||||
def test_contrastive_generate(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.generate
|
||||
@unittest.skip(reason="CSM does not support contrastive search.")
|
||||
def test_contrastive_generate_dict_outputs_use_cache(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.generate
|
||||
@unittest.skip(reason="CSM does not support contrastive search.")
|
||||
def test_contrastive_generate_low_memory(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.generate
|
||||
@unittest.skip(reason="CSM does not support prompt lookup decoding.")
|
||||
def test_prompt_lookup_decoding_matches_greedy_search(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.generate
|
||||
@unittest.skip(reason="CSM does not support prompt lookup decoding.")
|
||||
def test_prompt_lookup_decoding_stops_at_eos(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip(reason="CSM has custom embedding approach (text and audio embeddings).")
|
||||
def test_model_get_set_embeddings(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip(reason="CSM has custom embedding approach (text and audio embeddings).")
|
||||
def test_tie_model_weights(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.generate
|
||||
@unittest.skip(reason="CSM does not support beam search.")
|
||||
def test_generate_from_inputs_embeds_1_beam_search(self, _, num_beams):
|
||||
pass
|
||||
|
||||
@pytest.mark.generate
|
||||
@unittest.skip(reason="CSM does not support beam search.")
|
||||
def test_model_parallel_beam_search(self):
|
||||
pass
|
||||
|
||||
def test_tied_weights_keys(self):
|
||||
"""
|
||||
Overrides [ModelTesterMixin.test_tied_weights_keys] to not test for text config (not applicable to CSM).
|
||||
"""
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
for model_class in self.all_model_classes:
|
||||
model_tied = model_class(config)
|
||||
|
||||
ptrs = collections.defaultdict(list)
|
||||
for name, tensor in model_tied.state_dict().items():
|
||||
ptrs[id_tensor_storage(tensor)].append(name)
|
||||
|
||||
# These are all the pointers of shared tensors.
|
||||
tied_params = [names for _, names in ptrs.items() if len(names) > 1]
|
||||
|
||||
tied_weight_keys = model_tied._tied_weights_keys if model_tied._tied_weights_keys is not None else []
|
||||
# Detect we get a hit for each key
|
||||
for key in tied_weight_keys:
|
||||
is_tied_key = any(re.search(key, p) for group in tied_params for p in group)
|
||||
self.assertTrue(is_tied_key, f"{key} is not a tied weight key for {model_class}.")
|
||||
|
||||
# Removed tied weights found from tied params -> there should only be one left after
|
||||
for key in tied_weight_keys:
|
||||
for i in range(len(tied_params)):
|
||||
tied_params[i] = [p for p in tied_params[i] if re.search(key, p) is None]
|
||||
|
||||
tied_params = [group for group in tied_params if len(group) > 1]
|
||||
self.assertListEqual(
|
||||
tied_params,
|
||||
[],
|
||||
f"Missing `_tied_weights_keys` for {model_class}: add all of {tied_params} except one.",
|
||||
)
|
||||
|
||||
def _get_custom_4d_mask_test_data(self):
|
||||
"""
|
||||
Overrides [ModelTesterMixin._get_custom_4d_mask_test_data] to handle third input_ids dimension.
|
||||
"""
|
||||
# Sequence in which all but the last token is the same
|
||||
input_ids = torch.tensor([[0, 1, 2, 3], [0, 1, 2, 4], [0, 1, 2, 5]], device=torch_device, dtype=torch.int64)
|
||||
input_ids = input_ids.unsqueeze(-1).expand(-1, -1, self.model_tester.config["num_codebooks"])
|
||||
position_ids = torch.tensor([[0, 1, 2, 3]] * 3, device=torch_device, dtype=torch.int64)
|
||||
|
||||
# Combining common prefix with the unique ending tokens:
|
||||
input_ids_shared_prefix = torch.cat([input_ids[0][:-1], input_ids[:, -1]]).unsqueeze(0)
|
||||
|
||||
# Creating a 4D mask where each of the last 3 tokens do not attend to each other.
|
||||
mask_shared_prefix = torch.tensor(
|
||||
[
|
||||
[
|
||||
[
|
||||
[1, 0, 0, 0, 0, 0],
|
||||
[1, 1, 0, 0, 0, 0],
|
||||
[1, 1, 1, 0, 0, 0],
|
||||
[1, 1, 1, 1, 0, 0],
|
||||
[1, 1, 1, 0, 1, 0],
|
||||
[1, 1, 1, 0, 0, 1],
|
||||
]
|
||||
]
|
||||
],
|
||||
)
|
||||
# inverting the attention mask
|
||||
mask_dtype = torch.float32
|
||||
min_dtype = torch.finfo(mask_dtype).min
|
||||
mask_shared_prefix = (mask_shared_prefix.eq(0.0)).to(dtype=mask_dtype, device=torch_device) * min_dtype
|
||||
|
||||
# Creating a position_ids tensor. note the repeating figures in the end.
|
||||
position_ids_shared_prefix = torch.tensor([[0, 1, 2, 3, 3, 3]], device=torch_device, dtype=torch.int64)
|
||||
|
||||
return input_ids, position_ids, input_ids_shared_prefix, mask_shared_prefix, position_ids_shared_prefix
|
||||
|
||||
|
||||
class CsmForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# TODO: @eustlb, update with correct sesame's repo
|
||||
self.model_checkpoint = "eustlb/csm-1b"
|
||||
|
||||
def tearDown(self):
|
||||
cleanup(torch_device, gc_collect=True)
|
||||
|
||||
def _load_conversation(self):
|
||||
ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train")
|
||||
ds = ds.filter(lambda x: x["conversation_id"] == 0)
|
||||
ds = ds.sort("turn_id")
|
||||
return ds[0]
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
def test_1b_model_integration_generate(self):
|
||||
"""
|
||||
Tests the generated tokens match the ones from the original model implementation.
|
||||
Such tokens are to be retreived using https://gist.github.com/eustlb/d25577a357ddcf8f4a8cd0d00baca551, which is a script that infers the original model.
|
||||
"""
|
||||
processor = AutoProcessor.from_pretrained(self.model_checkpoint)
|
||||
prompt = "<|begin_of_text|>[0]What are you working on?<|end_of_text|><|AUDIO|><|audio_eos|><|begin_of_text|>[1]I'm figuring out my budget.<|end_of_text|>"
|
||||
|
||||
ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train")
|
||||
audio = ds[0]["audio"]["array"]
|
||||
inputs = processor(text=prompt, audio=audio, return_tensors="pt").to(torch_device)
|
||||
|
||||
model = CsmForConditionalGeneration.from_pretrained(self.model_checkpoint, device_map=torch_device)
|
||||
output_tokens = model.generate(**inputs, do_sample=False, depth_decoder_do_sample=False)
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_OUTPUT_TOKENS = torch.tensor([[
|
||||
[1140, 10, 37, 1180, 1100, 1319, 601, 1482, 1918, 1739, 372, 856, 674, 1, 854, 459, 1843, 1191, 347, 349, 1087, 846, 759, 1690, 947, 1280, 580, 1909, 1192, 487, 1302, 1601],
|
||||
[1494, 1412, 1824, 1852, 150, 928, 91, 326, 623, 1632, 1163, 1221, 1949, 999, 1779, 248, 693, 1149, 1423, 1503, 598, 80, 223, 1798, 251, 385, 1391, 1692, 1228, 1631, 1101, 866],
|
||||
[778, 645, 830, 1812, 524, 1704, 1805, 1289, 74, 1069, 243, 1622, 1755, 1281, 1397, 620, 1962, 1995, 253, 1124, 1007, 518, 89, 559, 1304, 1482, 523, 1747, 1979, 1003, 1707, 1578],
|
||||
[1356, 481, 642, 989, 287, 1819, 171, 1115, 824, 1253, 1488, 1074, 1019, 342, 279, 513, 1275, 1364, 893, 2007, 553, 407, 882, 1170, 1586, 485, 762, 559, 100, 542, 911, 1460],
|
||||
[1860, 593, 1944, 404, 575, 545, 862, 830, 1002, 125, 2010, 268, 1779, 804, 811, 809, 255, 373, 387, 1756, 259, 822, 1191, 700, 1686, 390, 1676, 844, 2006, 286, 1376, 719],
|
||||
[1165, 1047, 848, 212, 1018, 1470, 93, 1709, 1487, 1691, 1190, 275, 1278, 2018, 121, 1023, 485, 463, 39, 1825, 1936, 1817, 569, 209, 1553, 1599, 1137, 769, 968, 558, 1957, 265],
|
||||
[902, 1608, 719, 850, 371, 1920, 75, 1917, 2005, 1238, 562, 1743, 713, 95, 1107, 1463, 696, 840, 8, 487, 1950, 1171, 1004, 1516, 1130, 303, 1866, 1728, 2046, 238, 265, 153],
|
||||
[1932, 839, 334, 1167, 134, 2025, 40, 505, 1244, 1238, 1840, 800, 697, 72, 216, 486, 940, 1312, 510, 361, 549, 583, 1364, 844, 397, 1181, 1779, 962, 457, 1782, 1316, 465],
|
||||
[31, 1558, 1048, 404, 354, 7, 827, 414, 1082, 807, 243, 1517, 801, 1364, 99, 1276, 1655, 1488, 1313, 464, 828, 1612, 774, 1558, 745, 1496, 960, 1874, 995, 1943, 255, 213],
|
||||
[355, 1270, 413, 1519, 1659, 1904, 690, 552, 1279, 1821, 2022, 458, 1779, 2003, 604, 832, 661, 1295, 305, 1701, 173, 869, 230, 539, 1188, 669, 117, 692, 250, 388, 1995, 294],
|
||||
[629, 199, 1899, 1123, 1070, 344, 578, 1795, 1451, 1257, 168, 1410, 1120, 1270, 316, 983, 1245, 1870, 165, 471, 966, 1337, 308, 1118, 746, 67, 1767, 1480, 1517, 1585, 871, 1110],
|
||||
[1281, 1173, 784, 404, 368, 403, 580, 526, 853, 1692, 792, 895, 1286, 573, 1368, 896, 931, 1958, 1912, 644, 583, 1706, 1176, 1262, 1637, 315, 524, 1629, 795, 1211, 915, 533],
|
||||
[9, 1783, 621, 1954, 1212, 993, 197, 977, 1662, 1340, 618, 1997, 1689, 1001, 74, 1765, 1865, 797, 1219, 1609, 671, 1491, 950, 1849, 1301, 2031, 875, 323, 203, 1063, 1490, 1538],
|
||||
[1944, 1578, 1256, 1169, 790, 1444, 1382, 1616, 1100, 1264, 214, 1646, 488, 573, 1333, 285, 1954, 74, 1333, 674, 1303, 266, 622, 1290, 402, 109, 1331, 1666, 1347, 780, 106, 605],
|
||||
[221, 161, 1322, 1, 565, 1507, 1403, 1091, 1557, 932, 1664, 1165, 1828, 1647, 2008, 1616, 648, 1113, 1870, 22, 734, 1458, 1940, 1756, 1689, 925, 1318, 1095, 985, 473, 604, 1974],
|
||||
[1178, 597, 1804, 747, 1383, 360, 1497, 406, 1053, 1023, 1901, 56, 1221, 628, 75, 1729, 575, 1681, 840, 410, 650, 794, 1171, 1889, 187, 54, 1364, 1390, 505, 1285, 1814, 90],
|
||||
[1432, 1221, 1800, 1873, 1255, 627, 41, 9, 630, 896, 1469, 1195, 1098, 145, 442, 1460, 13, 57, 2039, 1015, 149, 461, 1084, 1288, 1099, 910, 63, 157, 906, 111, 1394, 460],
|
||||
[1352, 593, 307, 780, 1614, 1675, 1491, 1253, 723, 1793, 1032, 1486, 1805, 1904, 777, 398, 1791, 951, 770, 499, 1858, 244, 1372, 1514, 1858, 1200, 69, 181, 673, 1144, 1938, 1191],
|
||||
[905, 403, 1626, 1529, 581, 1443, 976, 754, 1561, 1370, 1048, 253, 194, 1271, 853, 959, 1532, 30, 286, 1594, 1255, 1135, 1410, 1699, 1423, 2002, 260, 69, 941, 1640, 895, 722],
|
||||
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
|
||||
]])
|
||||
# fmt: on
|
||||
|
||||
torch.testing.assert_close(output_tokens.cpu(), EXPECTED_OUTPUT_TOKENS)
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
def test_1b_model_integration_generate_no_audio(self):
|
||||
"""
|
||||
Tests the generated tokens match the ones from the original model implementation.
|
||||
Such tokens are to be retreived using https://gist.github.com/eustlb/aed822f765e928b9612e01b0d8836d69, which is a script that infers the original model.
|
||||
"""
|
||||
|
||||
processor = AutoProcessor.from_pretrained(self.model_checkpoint)
|
||||
|
||||
conversation = [
|
||||
{"role": "0", "content": [{"type": "text", "text": "The past is just a story we tell ourselves."}]},
|
||||
]
|
||||
|
||||
inputs = processor.apply_chat_template(conversation, tokenize=True, return_dict=True).to(torch_device)
|
||||
|
||||
model = CsmForConditionalGeneration.from_pretrained(self.model_checkpoint, device_map=torch_device)
|
||||
output_tokens = model.generate(**inputs, do_sample=False, depth_decoder_do_sample=False)
|
||||
|
||||
print(output_tokens)
|
||||
# fmt: off
|
||||
EXPECTED_OUTPUT_TOKENS = torch.tensor([[
|
||||
[1656, 629, 723, 1785, 206, 1873, 1059, 1190, 1833, 240, 618, 350, 156, 109, 2010, 452, 435, 1764, 77, 654, 1133, 908, 1095, 74, 804, 494, 1760, 1343, 1312, 1464, 1657, 324],
|
||||
[366, 1532, 1945, 21, 145, 1428, 1417, 1987, 1793, 1444, 356, 1491, 849, 333, 788, 426, 1423, 1004, 414, 1823, 1169, 257, 1892, 696, 1572, 998, 1098, 523, 390, 1977, 546, 1692],
|
||||
[1343, 1382, 1288, 1744, 1685, 1154, 1837, 1156, 1680, 1641, 1479, 1548, 632, 824, 694, 2010, 671, 1251, 1822, 343, 638, 1372, 696, 1272, 144, 125, 1332, 579, 936, 77, 159, 357],
|
||||
[456, 1534, 349, 274, 1956, 1502, 1268, 1038, 1911, 523, 1360, 1159, 761, 293, 718, 1143, 63, 705, 168, 550, 413, 1372, 1771, 787, 631, 693, 784, 1789, 2039, 1131, 1601, 918],
|
||||
[456, 829, 2026, 1108, 1649, 207, 1308, 1440, 1192, 1394, 426, 546, 590, 36, 1682, 1827, 1387, 1425, 1909, 1500, 1438, 1297, 5, 888, 948, 1745, 1304, 1364, 1692, 131, 300, 1908],
|
||||
[2027, 1431, 1037, 1789, 1296, 1264, 1331, 1787, 1235, 1902, 1161, 1591, 590, 561, 1633, 1218, 510, 148, 1962, 118, 212, 608, 565, 1869, 583, 598, 532, 658, 1416, 9, 1172, 493],
|
||||
[1215, 460, 1722, 317, 1423, 716, 1589, 1177, 1927, 1860, 1756, 1552, 1674, 643, 74, 1256, 587, 1742, 771, 2028, 469, 1070, 1683, 1614, 699, 494, 2020, 139, 1365, 1171, 171, 904],
|
||||
[1615, 339, 323, 317, 469, 714, 104, 2015, 1407, 278, 468, 77, 2007, 650, 1630, 269, 168, 934, 1544, 58, 1487, 1373, 705, 874, 1252, 2031, 1995, 254, 1334, 1171, 1911, 1607],
|
||||
[1259, 693, 666, 1700, 1115, 607, 982, 769, 1106, 1500, 101, 88, 1698, 1864, 1358, 1594, 192, 153, 1868, 1654, 604, 1948, 526, 778, 172, 1664, 1966, 99, 1334, 1030, 1349, 1209],
|
||||
[1211, 579, 1369, 492, 1725, 203, 1125, 778, 701, 1982, 1420, 155, 736, 1145, 2018, 609, 658, 561, 1147, 923, 1794, 1753, 116, 1374, 612, 956, 1587, 392, 1062, 2047, 901, 1931],
|
||||
[460, 1093, 1346, 1917, 1223, 470, 271, 390, 547, 112, 143, 1633, 1030, 643, 96, 1759, 920, 1959, 75, 1280, 1630, 999, 333, 853, 1110, 1291, 1911, 57, 171, 1658, 1704, 1508],
|
||||
[908, 500, 393, 184, 1437, 482, 2008, 1834, 356, 1435, 1550, 1407, 1236, 109, 1167, 452, 1141, 934, 207, 957, 660, 670, 28, 1066, 1252, 1932, 669, 906, 1904, 1820, 2043, 881],
|
||||
[1599, 1031, 1474, 336, 1540, 571, 437, 1440, 1616, 1365, 1412, 1246, 400, 405, 1776, 96, 296, 38, 1597, 466, 1630, 1256, 1940, 887, 1769, 294, 285, 842, 1756, 1619, 451, 1529],
|
||||
[1615, 339, 1722, 525, 942, 105, 1365, 670, 785, 1316, 465, 1860, 438, 968, 547, 1938, 1816, 1429, 1065, 1942, 660, 1446, 1093, 1066, 931, 121, 688, 1033, 1178, 754, 1783, 94],
|
||||
[912, 1354, 598, 254, 341, 1980, 1166, 585, 1302, 473, 554, 242, 174, 2030, 2011, 325, 978, 1690, 258, 396, 1831, 1768, 1291, 1699, 2001, 433, 1414, 2012, 1045, 511, 533, 1104],
|
||||
[80, 1791, 1062, 1136, 391, 568, 1651, 101, 959, 2043, 1683, 760, 794, 181, 570, 540, 1599, 20, 1017, 973, 1654, 396, 586, 778, 2044, 1664, 1911, 929, 66, 897, 510, 643],
|
||||
[1161, 1093, 161, 1296, 589, 54, 906, 981, 1927, 605, 516, 1731, 1461, 1204, 1902, 920, 1488, 177, 805, 1402, 610, 1446, 1154, 1067, 2025, 645, 762, 1715, 415, 1658, 1713, 1607],
|
||||
[374, 1444, 1577, 792, 1450, 628, 604, 1729, 322, 514, 1725, 540, 1070, 575, 653, 800, 250, 187, 569, 349, 354, 1573, 176, 793, 897, 359, 536, 276, 1224, 23, 145, 1287],
|
||||
[1184, 415, 1644, 1737, 1788, 385, 784, 1861, 1172, 1118, 367, 1156, 234, 1946, 1742, 981, 828, 1798, 1821, 361, 1148, 670, 518, 1288, 761, 1050, 1642, 1006, 1747, 840, 1599, 720],
|
||||
[1141, 1731, 1670, 1542, 1347, 1907, 683, 753, 1347, 68, 2031, 153, 556, 719, 736, 1759, 1131, 1073, 1747, 1730, 1487, 1137, 1869, 1624, 699, 1900, 748, 49, 1312, 735, 726, 1268],
|
||||
[1141, 1383, 405, 1033, 490, 488, 1102, 471, 713, 1630, 447, 703, 1495, 1001, 1855, 354, 456, 411, 786, 853, 168, 407, 116, 699, 605, 128, 532, 1076, 208, 447, 1448, 1071],
|
||||
[345, 1013, 948, 1728, 1837, 337, 930, 1226, 1643, 1729, 983, 1688, 2009, 435, 1358, 721, 42, 1779, 1332, 1077, 1873, 128, 1327, 125, 1226, 1704, 705, 1459, 1449, 862, 155, 1870],
|
||||
[336, 904, 684, 184, 1542, 714, 1752, 1180, 1373, 1816, 504, 1716, 1066, 1086, 1212, 530, 1413, 1278, 75, 1347, 82, 1623, 1307, 1717, 1861, 494, 888, 1589, 670, 1999, 905, 1430],
|
||||
[578, 554, 14, 523, 1016, 300, 1589, 1017, 356, 1583, 1654, 414, 449, 376, 1413, 58, 706, 963, 388, 1626, 131, 352, 1024, 1054, 2025, 1561, 77, 1589, 1486, 431, 1249, 1508],
|
||||
[184, 2043, 169, 1673, 580, 162, 1752, 397, 1119, 2009, 697, 150, 1475, 157, 1523, 1402, 575, 86, 1373, 1230, 1564, 1308, 626, 1093, 1603, 1446, 1390, 1543, 1778, 1142, 1357, 1831],
|
||||
[1484, 1987, 932, 1728, 1504, 1618, 291, 1865, 1151, 460, 1792, 141, 234, 2043, 829, 513, 435, 791, 1037, 1541, 65, 424, 1589, 1711, 312, 1306, 212, 686, 673, 984, 1914, 1549],
|
||||
[513, 1536, 1844, 1319, 572, 1069, 121, 735, 1949, 1211, 1362, 1027, 105, 1379, 315, 1782, 706, 1658, 1510, 1989, 1443, 1690, 822, 1614, 1194, 1460, 992, 2040, 1178, 1474, 1110, 1326],
|
||||
[1858, 194, 1594, 1935, 1622, 1892, 1577, 137, 1907, 2015, 757, 414, 1823, 836, 496, 530, 1385, 1503, 1065, 1554, 664, 525, 1031, 433, 69, 466, 1016, 1846, 1609, 1658, 911, 94],
|
||||
[1134, 1744, 323, 691, 1837, 347, 1871, 172, 811, 91, 1883, 436, 1912, 23, 1336, 1684, 519, 1612, 1219, 1402, 728, 1953, 1658, 641, 27, 1340, 436, 139, 2008, 1030, 159, 324],
|
||||
[1270, 1536, 1639, 414, 1387, 1170, 1067, 1701, 1414, 505, 1122, 36, 1731, 350, 1552, 1214, 1444, 30, 107, 172, 480, 1858, 655, 168, 1107, 691, 1272, 797, 1656, 548, 1407, 1375],
|
||||
[1270, 286, 1371, 1552, 1622, 1739, 1348, 2018, 345, 1537, 1941, 2024, 1423, 740, 284, 513, 91, 1228, 2015, 385, 992, 39, 813, 803, 2025, 497, 663, 462, 1609, 334, 927, 1470],
|
||||
[1718, 994, 265, 1421, 1622, 1098, 845, 1868, 832, 459, 447, 619, 1970, 929, 513, 63, 1448, 1509, 1219, 1942, 285, 1373, 1259, 1004, 11, 1040, 1984, 57, 188, 1687, 1475, 805],
|
||||
[1157, 832, 480, 1225, 1019, 347, 326, 999, 125, 1542, 118, 1383, 1343, 1077, 1821, 1602, 1978, 1642, 618, 808, 692, 1953, 1353, 963, 619, 1291, 1016, 1458, 1995, 1688, 1872, 1718],
|
||||
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
|
||||
]])
|
||||
# fmt: on
|
||||
|
||||
torch.testing.assert_close(output_tokens.cpu(), EXPECTED_OUTPUT_TOKENS)
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
def test_1b_model_integration_generate_multiple_audio(self):
|
||||
"""
|
||||
Test the generated tokens match the ones from the original model implementation.
|
||||
Such tokens are to be retreived using https://gist.github.com/eustlb/0c94de002e1325abb61d32217f74c0f8, which is a script that infers the original model.
|
||||
"""
|
||||
processor = AutoProcessor.from_pretrained(self.model_checkpoint)
|
||||
|
||||
ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train")
|
||||
conversation = []
|
||||
|
||||
# context
|
||||
for text, audio, speaker_id in zip(ds[:4]["text"], ds[:4]["audio"], ds[:4]["speaker_id"]):
|
||||
conversation.append(
|
||||
{
|
||||
"role": f"{speaker_id}",
|
||||
"content": [{"type": "text", "text": text}, {"type": "audio", "path": audio["array"]}],
|
||||
}
|
||||
)
|
||||
|
||||
# text prompt
|
||||
conversation.append({"role": f"{ds[4]['speaker_id']}", "content": [{"type": "text", "text": ds[4]["text"]}]})
|
||||
|
||||
inputs = processor.apply_chat_template(
|
||||
conversation,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
).to(torch_device)
|
||||
|
||||
model = CsmForConditionalGeneration.from_pretrained(self.model_checkpoint, device_map=torch_device)
|
||||
output_tokens = model.generate(**inputs, do_sample=False, depth_decoder_do_sample=False)
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_OUTPUT_TOKENS = torch.tensor([[
|
||||
[420, 1189, 1311, 318, 359, 694, 1550, 1044, 1614, 1437, 1978, 537, 554, 1681, 147, 1225, 422, 1357, 1681, 1619, 165, 641, 1132, 1975, 1568, 406, 756, 503, 1673, 1428, 762, 781],
|
||||
[1848, 1412, 957, 1656, 871, 540, 1999, 175, 711, 1383, 1814, 104, 742, 1285, 733, 1251, 1165, 1915, 1392, 645, 1804, 913, 1772, 632, 376, 1507, 1132, 725, 716, 1121, 1769, 1509],
|
||||
[429, 1138, 895, 1018, 1099, 257, 1395, 1015, 576, 1599, 497, 19, 1858, 1437, 282, 357, 1143, 828, 1481, 70, 985, 551, 935, 278, 1102, 1453, 1902, 755, 526, 498, 1441, 1733],
|
||||
[546, 343, 1547, 879, 2039, 692, 1999, 1150, 1969, 1866, 1178, 199, 1913, 1738, 1530, 1728, 1193, 74, 695, 612, 1095, 1597, 1381, 683, 1385, 2045, 1069, 865, 438, 70, 1437, 318],
|
||||
[1741, 1621, 733, 1580, 1006, 1790, 1031, 1563, 569, 1822, 1229, 854, 142, 1554, 792, 741, 147, 552, 731, 772, 908, 831, 1291, 1819, 296, 290, 1871, 100, 1904, 1420, 1903, 1653],
|
||||
[1264, 1576, 963, 12, 1403, 453, 259, 1359, 1270, 466, 1744, 1579, 1081, 1691, 1495, 1293, 110, 1020, 2042, 189, 1358, 955, 784, 1317, 2, 1794, 388, 376, 327, 511, 866, 1308],
|
||||
[1407, 1412, 1665, 1683, 284, 874, 1859, 326, 1491, 1343, 777, 695, 1424, 396, 274, 202, 178, 747, 470, 1805, 1414, 2000, 127, 1884, 531, 215, 1322, 1098, 1674, 1227, 1092, 204],
|
||||
[584, 637, 1665, 1683, 1136, 1201, 212, 310, 1441, 1619, 190, 1611, 1629, 2011, 1754, 1587, 413, 1287, 1251, 1382, 1904, 444, 1665, 1047, 1982, 1169, 1200, 809, 117, 327, 958, 1877],
|
||||
[471, 1469, 1679, 1184, 343, 974, 1442, 897, 1888, 1468, 1092, 1398, 1714, 963, 1577, 1797, 766, 565, 403, 920, 1806, 466, 1193, 446, 825, 775, 1886, 1095, 159, 1085, 858, 504],
|
||||
[28, 1511, 1510, 1580, 447, 1934, 1031, 1439, 202, 1435, 474, 1731, 724, 1080, 1121, 421, 625, 1410, 95, 605, 815, 1825, 127, 785, 900, 1673, 178, 1242, 2033, 1230, 350, 139],
|
||||
[20, 1215, 253, 955, 871, 1689, 1986, 24, 1648, 423, 562, 1937, 1146, 26, 1266, 346, 188, 318, 179, 1164, 1100, 1978, 478, 1192, 715, 392, 1837, 425, 1492, 766, 1651, 822],
|
||||
[1879, 1401, 1444, 723, 1754, 732, 1307, 702, 1768, 2013, 1284, 577, 1287, 1532, 647, 189, 903, 587, 800, 152, 898, 182, 2016, 639, 1074, 1220, 1934, 264, 250, 745, 1652, 536],
|
||||
[1874, 1526, 232, 1580, 1980, 988, 1623, 341, 1768, 956, 1430, 1667, 1687, 1289, 826, 1378, 173, 1466, 479, 835, 1786, 1671, 328, 131, 815, 871, 379, 1329, 440, 1117, 392, 272],
|
||||
[1762, 426, 1350, 1590, 314, 190, 1514, 344, 1926, 822, 534, 523, 703, 36, 379, 494, 464, 1886, 1555, 1318, 1654, 1469, 1976, 304, 218, 655, 1826, 958, 502, 326, 1898, 861],
|
||||
[1577, 386, 503, 1492, 698, 405, 1031, 349, 1804, 2012, 1450, 996, 1140, 26, 449, 33, 1917, 354, 702, 1255, 1942, 1184, 864, 2045, 514, 744, 466, 54, 37, 486, 362, 525],
|
||||
[1109, 1920, 445, 1719, 1670, 1220, 745, 40, 171, 1921, 999, 104, 489, 1911, 883, 306, 649, 1751, 762, 1183, 1085, 1112, 1912, 2035, 1940, 1129, 1592, 1276, 1570, 1236, 738, 209],
|
||||
[1837, 990, 1063, 318, 1398, 1838, 1678, 906, 754, 802, 562, 353, 1389, 207, 1319, 1188, 2013, 1079, 888, 1706, 1042, 657, 482, 953, 94, 2007, 871, 485, 1596, 275, 410, 1855],
|
||||
[872, 974, 1344, 1798, 655, 805, 1604, 1913, 455, 615, 1827, 966, 1330, 1826, 1285, 359, 544, 221, 1538, 1658, 374, 1352, 1714, 1925, 235, 65, 350, 931, 1009, 1164, 218, 736],
|
||||
[1547, 617, 1622, 740, 655, 265, 1324, 1265, 1449, 482, 1037, 105, 1128, 701, 1866, 1674, 1999, 1302, 985, 1942, 663, 449, 1881, 698, 805, 1446, 1742, 1192, 1623, 605, 948, 2],
|
||||
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
|
||||
]])
|
||||
# fmt: on
|
||||
|
||||
torch.testing.assert_close(output_tokens.cpu(), EXPECTED_OUTPUT_TOKENS)
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
def test_1b_model_integration_generate_batched(self):
|
||||
"""
|
||||
Test the generated tokens match the ones from the original model implementation.
|
||||
Such tokens are to be retreived using https://gist.github.com/eustlb/bcc532b53161bc31da3d66cb07ae193f, which is a script that infers the original model.
|
||||
"""
|
||||
processor = AutoProcessor.from_pretrained(self.model_checkpoint)
|
||||
|
||||
ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train")
|
||||
conversation = [
|
||||
[
|
||||
{
|
||||
"role": f"{ds[0]['speaker_id']}",
|
||||
"content": [
|
||||
{"type": "text", "text": ds[0]["text"]},
|
||||
{"type": "audio", "path": ds[0]["audio"]["array"]},
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": f"{ds[1]['speaker_id']}",
|
||||
"content": [
|
||||
{"type": "text", "text": ds[1]["text"]},
|
||||
],
|
||||
},
|
||||
],
|
||||
[
|
||||
{
|
||||
"role": f"{ds[0]['speaker_id']}",
|
||||
"content": [
|
||||
{"type": "text", "text": ds[0]["text"]},
|
||||
],
|
||||
}
|
||||
],
|
||||
]
|
||||
|
||||
inputs = processor.apply_chat_template(
|
||||
conversation,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
).to(torch_device)
|
||||
|
||||
model = CsmForConditionalGeneration.from_pretrained(self.model_checkpoint, device_map=torch_device)
|
||||
output_tokens = model.generate(**inputs, do_sample=False, depth_decoder_do_sample=False)
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_OUTPUT_TOKENS = torch.tensor([
|
||||
[
|
||||
[1140, 10, 37, 1180, 1100, 1319, 601, 1482, 1918, 1739, 372, 856, 674, 1, 854, 459, 1843, 1191, 347, 349, 1087, 846, 759, 1690, 947, 1280, 580, 1909, 1192, 487, 1302, 1601],
|
||||
[1494, 1412, 1824, 1852, 150, 928, 91, 326, 623, 1632, 1163, 1221, 1949, 999, 1779, 248, 693, 1149, 1423, 1503, 1656, 80, 1947, 1666, 933, 1950, 1544, 1577, 1612, 1791, 1883, 765],
|
||||
[778, 645, 830, 1051, 524, 1704, 1805, 1438, 211, 906, 691, 814, 1798, 1642, 1042, 284, 1906, 1513, 520, 137, 1052, 1548, 423, 1564, 330, 873, 1381, 188, 317, 1503, 1707, 1744],
|
||||
[1416, 864, 242, 1653, 604, 1577, 202, 1808, 926, 1867, 204, 134, 1096, 1765, 496, 1680, 268, 1796, 2024, 1989, 583, 183, 952, 105, 765, 1534, 669, 895, 2008, 11, 1199, 195],
|
||||
[1356, 796, 25, 1580, 15, 344, 1730, 99, 1330, 315, 955, 1964, 1731, 543, 1159, 1860, 671, 732, 63, 382, 143, 395, 1749, 1421, 1640, 1340, 650, 100, 171, 1346, 41, 806],
|
||||
[1860, 1835, 823, 388, 254, 1734, 1135, 324, 1508, 983, 937, 1703, 1541, 875, 1319, 799, 1259, 1175, 1295, 807, 261, 760, 1916, 1606, 1616, 1894, 1605, 441, 387, 167, 2016, 222],
|
||||
[1165, 919, 1318, 54, 1727, 1766, 777, 1128, 623, 353, 1840, 241, 977, 424, 1055, 898, 395, 655, 1695, 1084, 1346, 616, 1028, 1927, 603, 858, 758, 1539, 0, 1655, 1853, 1661],
|
||||
[902, 1746, 1318, 298, 1982, 1184, 775, 328, 1676, 871, 133, 1374, 1927, 1984, 698, 1037, 100, 1884, 1596, 429, 1794, 2046, 105, 2037, 1767, 178, 176, 1293, 1893, 1780, 1832, 1382],
|
||||
[1932, 714, 1084, 1167, 624, 509, 1213, 651, 1000, 1686, 1537, 555, 461, 623, 1433, 1089, 1212, 1628, 834, 1111, 943, 1816, 1947, 1063, 354, 1843, 1741, 2015, 404, 928, 1488, 168],
|
||||
[1437, 314, 1356, 404, 1274, 2016, 998, 1350, 155, 553, 368, 1501, 1431, 1563, 1105, 1353, 535, 908, 1305, 1214, 1656, 65, 1469, 1517, 480, 252, 1289, 696, 302, 632, 246, 72],
|
||||
[724, 848, 1140, 927, 1669, 296, 447, 1708, 1898, 685, 1041, 1685, 708, 1510, 1623, 876, 11, 99, 43, 586, 1705, 1753, 1477, 1191, 583, 1249, 1613, 992, 1319, 677, 418, 668],
|
||||
[925, 54, 1810, 674, 1306, 848, 573, 1772, 105, 301, 1753, 989, 440, 1057, 823, 1313, 1663, 750, 1477, 102, 1437, 1114, 399, 1440, 319, 118, 1827, 295, 1429, 139, 1594, 55],
|
||||
[629, 149, 784, 838, 984, 604, 685, 1229, 1432, 859, 1526, 1336, 1949, 281, 988, 1260, 52, 6, 1216, 1542, 1426, 1938, 253, 280, 1319, 794, 901, 843, 615, 437, 814, 20],
|
||||
[1281, 502, 1237, 404, 625, 1444, 397, 1999, 2016, 1686, 533, 1785, 1152, 1245, 579, 1906, 1204, 549, 1334, 536, 1351, 1979, 208, 111, 2011, 751, 677, 1948, 1772, 1525, 2038, 419],
|
||||
[9, 490, 869, 2026, 1928, 1489, 587, 549, 1241, 460, 1458, 1636, 924, 222, 1246, 480, 706, 398, 75, 1717, 604, 1446, 333, 237, 805, 1446, 421, 1343, 78, 1260, 1872, 1116],
|
||||
[1944, 755, 375, 332, 1464, 828, 1273, 579, 1457, 353, 1510, 1910, 1609, 705, 400, 1666, 227, 1544, 1270, 136, 1857, 1975, 1762, 2006, 1102, 221, 1965, 151, 2041, 198, 1830, 287],
|
||||
[221, 502, 440, 247, 181, 1912, 42, 357, 1883, 596, 919, 953, 1774, 772, 915, 188, 438, 1226, 544, 1313, 726, 1298, 85, 677, 566, 1581, 30, 341, 878, 1732, 591, 1446],
|
||||
[1178, 1690, 320, 1746, 1798, 685, 1941, 666, 832, 623, 1907, 128, 337, 1779, 824, 923, 1041, 287, 1165, 437, 1803, 1222, 870, 646, 358, 220, 2009, 735, 468, 1908, 1349, 1603],
|
||||
[1432, 1286, 540, 1687, 1741, 951, 299, 1233, 1061, 1128, 985, 953, 1917, 198, 2031, 1559, 1096, 1455, 780, 437, 163, 1268, 649, 1029, 1081, 1518, 304, 1638, 814, 364, 140, 1385],
|
||||
[905, 463, 1739, 1063, 351, 936, 1652, 101, 1323, 1731, 298, 1193, 266, 1554, 1837, 1659, 409, 1739, 1012, 725, 851, 1909, 213, 1918, 1759, 1561, 1250, 970, 1571, 352, 911, 195],
|
||||
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
|
||||
],
|
||||
[
|
||||
[1375, 203, 265, 164, 200, 1867, 976, 924, 1972, 1637, 1048, 271, 1912, 1430, 853, 1942, 260, 1642, 400, 57, 1376, 1626, 1821, 1163, 619, 777, 1076, 951, 389, 1820, 84, 1417],
|
||||
[914, 527, 286, 968, 305, 1314, 805, 1703, 87, 559, 1980, 1124, 1726, 36, 1139, 618, 1628, 519, 1943, 781, 400, 1265, 438, 113, 87, 856, 465, 162, 1099, 352, 1141, 274],
|
||||
[1408, 6, 126, 2009, 90, 996, 934, 134, 1857, 126, 602, 876, 1092, 1962, 1205, 828, 707, 1063, 393, 1533, 123, 1086, 1749, 1324, 1, 1763, 1707, 1191, 34, 1323, 1017, 1787],
|
||||
[1000, 683, 1630, 703, 1574, 587, 25, 1049, 213, 1270, 1641, 1072, 1892, 1634, 1603, 90, 867, 2037, 1021, 715, 206, 507, 1138, 959, 1822, 1785, 280, 1100, 1660, 251, 1903, 988],
|
||||
[1657, 1981, 246, 1048, 1952, 451, 305, 423, 2000, 416, 756, 1748, 7, 748, 1866, 1795, 1682, 1832, 338, 212, 1685, 518, 154, 1407, 416, 765, 776, 25, 55, 458, 612, 262],
|
||||
[1034, 564, 667, 1474, 1212, 350, 712, 941, 1151, 1182, 1280, 640, 924, 1722, 1816, 458, 226, 359, 1518, 102, 1203, 459, 676, 1788, 1110, 393, 1974, 1721, 795, 1459, 798, 1723],
|
||||
[742, 1616, 119, 653, 441, 679, 246, 1432, 486, 1615, 1191, 500, 650, 223, 687, 1765, 1875, 963, 1385, 863, 151, 1771, 458, 1170, 737, 1932, 785, 1954, 1067, 16, 1986, 2029],
|
||||
[1437, 1078, 1767, 1452, 1392, 45, 2010, 1664, 245, 2015, 1416, 1055, 457, 985, 740, 1594, 1562, 1838, 258, 1431, 701, 604, 1813, 352, 792, 632, 21, 895, 70, 609, 850, 1599],
|
||||
[983, 1961, 54, 135, 846, 711, 473, 1630, 1373, 1094, 251, 525, 632, 1014, 1594, 1594, 1752, 398, 1266, 1357, 942, 1680, 191, 874, 483, 1291, 381, 1873, 1964, 1278, 1477, 122],
|
||||
[1663, 1969, 1887, 113, 145, 251, 1133, 156, 245, 1641, 209, 1322, 2037, 836, 539, 667, 940, 797, 1758, 1357, 191, 1137, 587, 1699, 27, 701, 395, 99, 1682, 876, 762, 839],
|
||||
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
]
|
||||
])
|
||||
# fmt: on
|
||||
|
||||
torch.testing.assert_close(output_tokens.cpu(), EXPECTED_OUTPUT_TOKENS)
|
140
tests/models/csm/test_processor_csm.py
Normal file
140
tests/models/csm/test_processor_csm.py
Normal file
@ -0,0 +1,140 @@
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import jinja2
|
||||
import numpy as np
|
||||
|
||||
from transformers import CsmProcessor
|
||||
from transformers.testing_utils import require_torch
|
||||
from transformers.utils import is_torch_available
|
||||
|
||||
from ...test_processing_common import ProcessorTesterMixin
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
@require_torch
|
||||
class CsmProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
processor_class = CsmProcessor
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
# TODO: @eustlb, change for hf-internal-testing/csm-1b
|
||||
cls.checkpoint = "eustlb/csm-1b"
|
||||
processor = CsmProcessor.from_pretrained(cls.checkpoint)
|
||||
cls.audio_token = processor.audio_token
|
||||
cls.audio_token_id = processor.audio_token_id
|
||||
cls.pad_token_id = processor.tokenizer.pad_token_id
|
||||
cls.bos_token_id = processor.tokenizer.bos_token_id
|
||||
cls.tmpdirname = tempfile.mkdtemp()
|
||||
processor.save_pretrained(cls.tmpdirname)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
shutil.rmtree(cls.tmpdirname, ignore_errors=True)
|
||||
|
||||
def prepare_processor_dict(self):
|
||||
return {"chat_template": "\n{%- for message in messages %}\n {#-- Validate role is a stringified integer --#}\n {%- if not message['role'] is string or not message['role'].isdigit() %}\n {{- raise_exception(\"The role must be an integer or a stringified integer (e.g. '0') designating the speaker id\") }}\n {%- endif %}\n\n {#-- Validate content is a list --#}\n {%- set content = message['content'] %}\n {%- if content is not iterable or content is string %}\n {{- raise_exception(\"The content must be a list\") }}\n {%- endif %}\n\n {#-- Collect content types --#}\n {%- set content_types = content | map(attribute='type') | list %}\n {%- set is_last = loop.last %}\n\n {#-- Last message validation --#}\n {%- if is_last %}\n {%- if 'text' not in content_types %}\n {{- raise_exception(\"The last message must include one item of type 'text'\") }}\n {%- elif (content_types | select('equalto', 'text') | list | length > 1) or (content_types | select('equalto', 'audio') | list | length > 1) %}\n {{- raise_exception(\"At most two items are allowed in the last message: one 'text' and one 'audio'\") }}\n {%- endif %}\n\n {#-- All other messages validation --#}\n {%- else %}\n {%- if content_types | select('equalto', 'text') | list | length != 1\n or content_types | select('equalto', 'audio') | list | length != 1 %}\n {{- raise_exception(\"Each message (except the last) must contain exactly one 'text' and one 'audio' item\") }}\n {%- elif content_types | reject('in', ['text', 'audio']) | list | length > 0 %}\n {{- raise_exception(\"Only 'text' and 'audio' types are allowed in content\") }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n\n{%- for message in messages %}\n {{- bos_token }}\n {{- '[' + message['role'] + ']' }}\n {{- message['content'][0]['text'] }}\n {{- eos_token }}\n {%- if message['content']|length > 1 %}\n {{- '<|AUDIO|><|audio_eos|>' }}\n {%- endif %}\n{%- endfor %}\n"} # fmt: skip
|
||||
|
||||
def test_chat_template_is_saved(self):
|
||||
processor_loaded = self.processor_class.from_pretrained(self.tmpdirname)
|
||||
processor_dict_loaded = json.loads(processor_loaded.to_json_string())
|
||||
# chat templates aren't serialized to json in processors
|
||||
self.assertFalse("chat_template" in processor_dict_loaded.keys())
|
||||
|
||||
# they have to be saved as separate file and loaded back from that file
|
||||
# so we check if the same template is loaded
|
||||
processor_dict = self.prepare_processor_dict()
|
||||
self.assertTrue(processor_loaded.chat_template == processor_dict.get("chat_template", None))
|
||||
|
||||
def test_apply_chat_template(self):
|
||||
# Message contains content which a mix of lists with images and image urls and string
|
||||
messages = [
|
||||
{
|
||||
"role": "0",
|
||||
"content": [
|
||||
{"type": "text", "text": "This is a test sentence 0."},
|
||||
{"type": "audio"},
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "1",
|
||||
"content": [
|
||||
{"type": "text", "text": "This is a test sentence 1."},
|
||||
{"type": "audio"},
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "0",
|
||||
"content": [
|
||||
{"type": "text", "text": "This is a prompt."},
|
||||
],
|
||||
},
|
||||
]
|
||||
processor = CsmProcessor.from_pretrained(self.tmpdirname)
|
||||
rendered = processor.apply_chat_template(messages, tokenize=False)
|
||||
|
||||
expected_rendered = (
|
||||
"<|begin_of_text|>[0]This is a test sentence 0.<|end_of_text|>"
|
||||
"<|AUDIO|><|audio_eos|>"
|
||||
"<|begin_of_text|>[1]This is a test sentence 1.<|end_of_text|>"
|
||||
"<|AUDIO|><|audio_eos|>"
|
||||
"<|begin_of_text|>[0]This is a prompt.<|end_of_text|>"
|
||||
)
|
||||
self.assertEqual(rendered, expected_rendered)
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "0",
|
||||
"content": [
|
||||
{"type": "text", "text": "This is a test sentence."},
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "1",
|
||||
"content": [
|
||||
{"type": "text", "text": "This is a test sentence."},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
# this should raise an error because the CSM processor requires audio content in the messages expect the last one
|
||||
with self.assertRaises(jinja2.exceptions.TemplateError):
|
||||
input_ids = processor.apply_chat_template(messages, tokenize=False)
|
||||
|
||||
# now let's very that it expands audio tokens correctly
|
||||
messages = [
|
||||
{
|
||||
"role": "0",
|
||||
"content": [
|
||||
{"type": "text", "text": "This is a test sentence."},
|
||||
{"type": "audio", "audio": np.zeros(4096)},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
input_ids = processor.apply_chat_template(messages, tokenize=True)
|
||||
|
||||
# 4096 audio input values should give 3 audio tokens
|
||||
expected_ids = torch.tensor(
|
||||
[[128000, 58, 15, 60, 2028, 374, 264, 1296, 11914, 13, 128001, 128002, 128002, 128002, 128003]]
|
||||
)
|
||||
torch.testing.assert_close(input_ids, expected_ids)
|
@ -4350,8 +4350,8 @@ class ModelTesterMixin:
|
||||
out_shared_prefix_last_tokens = logits_shared_prefix[0, -3:, :] # last three tokens
|
||||
|
||||
# comparing softmax-normalized logits:
|
||||
normalized_0 = F.softmax(out_last_tokens)
|
||||
normalized_1 = F.softmax(out_shared_prefix_last_tokens)
|
||||
normalized_0 = F.softmax(out_last_tokens, dim=-1)
|
||||
normalized_1 = F.softmax(out_shared_prefix_last_tokens, dim=-1)
|
||||
torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4)
|
||||
|
||||
@slow
|
||||
@ -4403,7 +4403,7 @@ class ModelTesterMixin:
|
||||
self.skipTest(reason="This model does not support `logits_to_keep` argument.")
|
||||
|
||||
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
batch_size, sequence_length = inputs["input_ids"].shape
|
||||
batch_size, sequence_length = inputs["input_ids"].shape[:2]
|
||||
vocab_size = config.get_text_config().vocab_size
|
||||
model = model_class(config).to(device=torch_device).eval()
|
||||
# some models have labels but `logits_to_keep` should not be used in train mode
|
||||
|
@ -159,6 +159,9 @@ IGNORE_NON_TESTED = (
|
||||
"InternVLVisionModel", # Building part of bigger (tested) model
|
||||
"JanusVisionModel", # Building part of bigger (tested) model
|
||||
"TimesFmModel", # Building part of bigger (tested) model
|
||||
"CsmDepthDecoderForCausalLM", # Building part of bigger (tested) model. Tested implicitly through CsmForConditionalGenerationIntegrationTest.
|
||||
"CsmDepthDecoderModel", # Building part of bigger (tested) model. Tested implicitly through CsmForConditionalGenerationIntegrationTest.
|
||||
"CsmBackboneModel", # Building part of bigger (tested) model. Tested implicitly through CsmForConditionalGenerationIntegrationTest.
|
||||
]
|
||||
)
|
||||
|
||||
@ -368,6 +371,10 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
|
||||
"Qwen2_5OmniToken2WavModel", # Building part of a bigger model
|
||||
"Qwen2_5OmniToken2WavBigVGANModel", # Building part of a bigger model
|
||||
"Qwen2_5OmniToken2WavDiTModel", # Building part of a bigger model
|
||||
"CsmBackboneModel", # Building part of a bigger model
|
||||
"CsmDepthDecoderModel", # Building part of a bigger model
|
||||
"CsmDepthDecoderForCausalLM", # Building part of a bigger model
|
||||
"CsmForConditionalGeneration", # Building part of a bigger model
|
||||
]
|
||||
|
||||
# DO NOT edit this list!
|
||||
|
Reference in New Issue
Block a user