mirror of
https://github.com/huggingface/transformers.git
synced 2025-11-14 22:55:41 +08:00
Compare commits
26 Commits
fix_export
...
cb-prefix-
| Author | SHA1 | Date | |
|---|---|---|---|
| 6f4d984732 | |||
| ce80e15615 | |||
| ab8bd76292 | |||
| dcff594a54 | |||
| 4730839b68 | |||
| ae98ce3044 | |||
| 44fdf20213 | |||
| a84f9e8c69 | |||
| da80c4f945 | |||
| 5a88f41243 | |||
| 7ade82d550 | |||
| ea68a1c07b | |||
| 877df4c36f | |||
| d7f21267c0 | |||
| 5a0e4d4c7c | |||
| 3ed7936e58 | |||
| 3e1b4f3c10 | |||
| 69ef1e56cb | |||
| fc18b3b3aa | |||
| 56d0030ce9 | |||
| 6dc28d9726 | |||
| 371bcb3c8b | |||
| 1a8f01885f | |||
| d44d737af1 | |||
| 3a5e3d74fa | |||
| 399d943af8 |
@ -46,8 +46,8 @@ jobs:
|
||||
- run: uv pip install -U -e .
|
||||
- run: echo 'export "GIT_COMMIT_MESSAGE=$(git show -s --format=%s)"' >> "$BASH_ENV" && source "$BASH_ENV"
|
||||
- run: mkdir -p test_preparation
|
||||
- run: python utils/tests_fetcher.py | tee tests_fetched_summary.txt || true
|
||||
- run: python utils/tests_fetcher.py --filter_tests || true
|
||||
- run: python utils/tests_fetcher.py | tee tests_fetched_summary.txt
|
||||
- run: python utils/tests_fetcher.py --filter_tests
|
||||
- run: export "GIT_COMMIT_MESSAGE=$(git show -s --format=%s)" && echo $GIT_COMMIT_MESSAGE && python .circleci/create_circleci_config.py --fetcher_folder test_preparation
|
||||
- run: |
|
||||
if [ ! -s test_preparation/generated_config.yml ]; then
|
||||
@ -98,8 +98,8 @@ jobs:
|
||||
- run: uv pip install -U -e .
|
||||
- run: echo 'export "GIT_COMMIT_MESSAGE=$(git show -s --format=%s)"' >> "$BASH_ENV" && source "$BASH_ENV"
|
||||
- run: mkdir -p test_preparation
|
||||
- run: python utils/tests_fetcher.py --fetch_all | tee tests_fetched_summary.txt || true
|
||||
- run: python utils/tests_fetcher.py --filter_tests || true
|
||||
- run: python utils/tests_fetcher.py --fetch_all | tee tests_fetched_summary.txt
|
||||
- run: python utils/tests_fetcher.py --filter_tests
|
||||
- run: export "GIT_COMMIT_MESSAGE=$(git show -s --format=%s)" && echo $GIT_COMMIT_MESSAGE && python .circleci/create_circleci_config.py --fetcher_folder test_preparation
|
||||
- run: |
|
||||
if [ ! -s test_preparation/generated_config.yml ]; then
|
||||
|
||||
@ -125,9 +125,8 @@ If you're contributing a **vision-language model** (or any multimodal model that
|
||||
All new models should use the modular architecture pattern. Create a `modular_<model_name>.py` file using the modular model converter:
|
||||
|
||||
- Use the CLI, [`transformers add-new-model-like`](https://github.com/huggingface/transformers/blob/main/src/transformers/cli/add_new_model_like.py) to generate a modular skeleton and get started
|
||||
- All code should be in the modular file if possible. Modeling must be in it, it's better if configuration is in it as well. [Modular guide](./modular_transformers#implementing-a-modular-file) shows a quick way to set up a modular file.
|
||||
- All code should be in the modular file if possible. Modeling must be in it, it's better if configuration is in it as well.
|
||||
- Reuse existing patterns from similar models as much as possible
|
||||
- You can make the model compatible with inference engines such as vLLM or SGLang, and enable zero-effort integration. See specific requirements for model implementation in ["Transformers modeling backend"](./transformers_as_backend#multimodal-models)
|
||||
|
||||
To verify your modular file is correct, run:
|
||||
|
||||
|
||||
1
Makefile
1
Makefile
@ -45,7 +45,6 @@ repo-consistency:
|
||||
python utils/check_modular_conversion.py
|
||||
python utils/check_dummies.py
|
||||
python utils/check_repo.py
|
||||
python utils/check_init_weights_data.py
|
||||
python utils/check_inits.py
|
||||
python utils/check_pipeline_typing.py
|
||||
python utils/check_config_docstrings.py
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
FROM rocm/pytorch:rocm7.1_ubuntu22.04_py3.10_pytorch_release_2.8.0
|
||||
FROM rocm/pytorch:rocm7.0.2_ubuntu24.04_py3.12_pytorch_release_2.7.1
|
||||
LABEL maintainer="Hugging Face"
|
||||
|
||||
ARG DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
@ -508,16 +508,16 @@ BERT `_init_weights` Methode:
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.zero_()
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.zero_()
|
||||
module.weight.fill_(1.0)
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
```
|
||||
|
||||
Sie können weitere benutzerdefinierte Schemata verwenden, wenn Sie eine spezielle Initialisierung für einige Module benötigen. Zum Beispiel in
|
||||
@ -533,9 +533,9 @@ def _init_weights(self, module):
|
||||
module.project_hid._is_hf_initialized = True
|
||||
module.project_q._is_hf_initialized = True
|
||||
elif isinstance(module, nn.Linear):
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.zero_()
|
||||
module.bias.data.zero_()
|
||||
```
|
||||
|
||||
Das Flag `_is_hf_initialized` wird intern verwendet, um sicherzustellen, dass wir ein Submodul nur einmal initialisieren. Wenn Sie es auf
|
||||
|
||||
@ -118,7 +118,7 @@
|
||||
- local: tools
|
||||
title: Tools
|
||||
- local: transformers_as_backend
|
||||
title: Transformers as modeling backend
|
||||
title: Inference server backends
|
||||
- local: continuous_batching
|
||||
title: Continuous Batching
|
||||
title: Inference
|
||||
|
||||
@ -314,16 +314,16 @@ Random initialization occurs in the `_init_weights` method of `BrandNewLlamaPreT
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.zero_()
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.zero_()
|
||||
module.weight.fill_(1.0)
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
```
|
||||
|
||||
The initialization scheme can look different if you need to adapt it to your model. For example, [`Wav2Vec2ForPreTraining`] initializes [nn.Linear](https://pytorch.org/docs/stable/generated/torch.nn.Linear.html) in its last two linear layers.
|
||||
@ -339,9 +339,9 @@ def _init_weights(self, module):
|
||||
module.project_hid._is_hf_initialized = True
|
||||
module.project_q._is_hf_initialized = True
|
||||
elif isinstance(module, nn.Linear):
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.zero_()
|
||||
module.bias.data.zero_()
|
||||
```
|
||||
|
||||
### Convert checkpoints to Transformers
|
||||
|
||||
@ -136,7 +136,7 @@ inputs = processor.apply_chat_template(
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
fps=1,
|
||||
video_fps=1,
|
||||
|
||||
# kwargs to be passed to `Qwen2-5-OmniProcessor`
|
||||
padding=True,
|
||||
@ -245,7 +245,7 @@ inputs = processor.apply_chat_template(
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
fps=1,
|
||||
video_fps=1,
|
||||
|
||||
# kwargs to be passed to `Qwen2-5-OmniProcessor`
|
||||
padding=True,
|
||||
|
||||
@ -54,7 +54,7 @@ processor = AutoProcessor.from_pretrained("Qwen/Qwen2-Audio-7B", trust_remote_co
|
||||
prompt = "<|audio_bos|><|AUDIO|><|audio_eos|>Generate the caption in English:"
|
||||
url = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-Audio/glass-breaking-151256.mp3"
|
||||
audio, sr = librosa.load(BytesIO(urlopen(url).read()), sr=processor.feature_extractor.sampling_rate)
|
||||
inputs = processor(text=prompt, audio=audio, return_tensors="pt").to(model.device)
|
||||
inputs = processor(text=prompt, audios=audio, return_tensors="pt").to(model.device)
|
||||
|
||||
generate_ids = model.generate(**inputs, max_length=256)
|
||||
generate_ids = generate_ids[:, inputs.input_ids.size(1):]
|
||||
@ -63,7 +63,7 @@ response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_
|
||||
|
||||
# We can also omit the audio_bos and audio_eos tokens
|
||||
prompt = "<|AUDIO|>Generate the caption in English:"
|
||||
inputs = processor(text=prompt, audio=audio, return_tensors="pt").to(model.device)
|
||||
inputs = processor(text=prompt, audios=audio, return_tensors="pt").to(model.device)
|
||||
|
||||
generate_ids = model.generate(**inputs, max_length=256)
|
||||
generate_ids = generate_ids[:, inputs.input_ids.size(1):]
|
||||
@ -106,7 +106,7 @@ for message in conversation:
|
||||
sr=processor.feature_extractor.sampling_rate)[0]
|
||||
)
|
||||
|
||||
inputs = processor(text=text, audio=audios, return_tensors="pt", padding=True)
|
||||
inputs = processor(text=text, audios=audios, return_tensors="pt", padding=True)
|
||||
inputs.input_ids = inputs.input_ids.to(model.device)
|
||||
|
||||
generate_ids = model.generate(**inputs, max_length=256)
|
||||
@ -156,7 +156,7 @@ for message in conversation:
|
||||
sr=processor.feature_extractor.sampling_rate)[0]
|
||||
)
|
||||
|
||||
inputs = processor(text=text, audio=audios, return_tensors="pt", padding=True)
|
||||
inputs = processor(text=text, audios=audios, return_tensors="pt", padding=True)
|
||||
inputs.input_ids = inputs.input_ids.to(model.device)
|
||||
|
||||
generate_ids = model.generate(**inputs, max_length=256)
|
||||
@ -213,7 +213,7 @@ for conversation in conversations:
|
||||
sr=processor.feature_extractor.sampling_rate)[0]
|
||||
)
|
||||
|
||||
inputs = processor(text=text, audio=audios, return_tensors="pt", padding=True)
|
||||
inputs = processor(text=text, audios=audios, return_tensors="pt", padding=True)
|
||||
inputs['input_ids'] = inputs['input_ids'].to(model.device)
|
||||
inputs.input_ids = inputs.input_ids.to(model.device)
|
||||
|
||||
|
||||
@ -80,7 +80,7 @@ inputs = processor.apply_chat_template(
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
fps=1,
|
||||
video_fps=1,
|
||||
|
||||
# kwargs to be passed to `Qwen3OmniMoeProcessor`
|
||||
padding=True,
|
||||
@ -136,7 +136,7 @@ inputs = processor.apply_chat_template(
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
fps=1,
|
||||
video_fps=1,
|
||||
|
||||
# kwargs to be passed to `Qwen3OmniMoeProcessor`
|
||||
padding=True,
|
||||
@ -245,7 +245,7 @@ inputs = processor.apply_chat_template(
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
fps=1,
|
||||
video_fps=1,
|
||||
|
||||
# kwargs to be passed to `Qwen3OmniMoeProcessor`
|
||||
padding=True,
|
||||
|
||||
@ -61,7 +61,7 @@ Here is how to use the processor to process text and audio:
|
||||
>>> audio_sample = next(iter(dataset))["audio"]
|
||||
|
||||
>>> # now, process it
|
||||
>>> audio_inputs = processor(audio=audio_sample["array"], return_tensors="pt")
|
||||
>>> audio_inputs = processor(audios=audio_sample["array"], return_tensors="pt")
|
||||
|
||||
>>> # now, process some English test as well
|
||||
>>> text_inputs = processor(text = "Hello, my dog is cute", src_lang="eng", return_tensors="pt")
|
||||
|
||||
@ -61,7 +61,7 @@ Here is how to use the processor to process text and audio:
|
||||
>>> audio_sample = next(iter(dataset))["audio"]
|
||||
|
||||
>>> # now, process it
|
||||
>>> audio_inputs = processor(audio=audio_sample["array"], return_tensors="pt")
|
||||
>>> audio_inputs = processor(audios=audio_sample["array"], return_tensors="pt")
|
||||
|
||||
>>> # now, process some English text as well
|
||||
>>> text_inputs = processor(text = "Hello, my dog is cute", src_lang="eng", return_tensors="pt")
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# Contributing a new model to Transformers
|
||||
|
||||
Modular Transformers lowers the bar for contributing models and significantly reduces the code required to add a model by allowing imports and inheritance. We recommend to go through [general contribution guidelines for new models](./contributing#do-you-want-to-implement-a-new-model) before diving into the details here.
|
||||
Modular Transformers lowers the bar for contributing models and significantly reduces the code required to add a model by allowing imports and inheritance.
|
||||
|
||||
One of Transformers' core design feature is the [single model, single file](https://huggingface.co/blog/transformers-design-philosophy) policy. Model components - such as attention layers - are repeated across many files and any independent implementations tend to diverge as fixes and changes are applied to specific parts of the code.
|
||||
|
||||
|
||||
@ -149,7 +149,7 @@ The example below packs `up_proj` and `gate_proj` into a single `gate_up_proj` m
|
||||
```python
|
||||
class Llama4TextExperts(nn.Module):
|
||||
...
|
||||
self.gate_up_proj = nn.Parameter(torch.zeros(self.num_experts, self.hidden_size, 2 * self.expert_dim))
|
||||
self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, 2 * self.expert_dim))
|
||||
```
|
||||
|
||||
Batch matrix multiplication can be used in the `forward` pass to compute the output of the `gate_up_proj` module.
|
||||
|
||||
@ -14,9 +14,9 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# Transformers as modeling backend
|
||||
# Inference server backends
|
||||
|
||||
Transformers' models are compatible with different inference servers like vLLM and SGLang. Instead of implementing a new model architecture from scratch for each inference server, you only need a model definition in `transformers`, which can be plugged into any inference server. It simplifies maintenance and makes it easy for users to use different inference servers for different use cases.
|
||||
Transformers' models are compatible with different inference servers like vLLM and SGLang. Instead of implementing a model for each inference server, you only need one model, which can be plugged into any inference server. It simplifies maintenance and makes it easy for users to use different inference servers for different use cases.
|
||||
|
||||
With Transformers as a backend, you can also serve any model - including custom and Hub-hosted models - without waiting for native support.
|
||||
|
||||
@ -157,13 +157,57 @@ class MyConfig(PreTrainedConfig):
|
||||
|
||||
### Multimodal models
|
||||
|
||||
For multimodal models, you need to include a few more changes on top of the general recommendations outlined in ["contribuiting a model"](./contributing#vision-language-model-contribution-checklist). These rules ensure that your model integrates properly and enables processing multimodal data.
|
||||
For multimodal models, you need to include a few more changes on top of the general recommendations. These rules ensure that your model integrates properly with multimodal data.
|
||||
|
||||
1. A multimodal model's processing class must have the `self.image_token` and `self.image_token_ids` attributes. These are placeholder tokens used to indicate image positions in the input. This placeholder token is the same token used in the input prompt to denote images and used in model code to scatter image features.
|
||||
1. A multimodal model requires a base `MyMultiModalModel` class to handle multimodal fusion without a language modeling head and a separate generative class that adds a head.
|
||||
|
||||
2. The processing class needs `self._get_num_multimodal_tokens` method to compute the number of placeholder tokens needed for multimodal inputs with given sizes and to return a [`MultiModalData`] object. The placeholders between `<image>` tokens such as row or column tokens don't count as image placeholders. Only tokens that are actually replaced by image features later in modeling should be counted!
|
||||
The base model needs to implement the `get_image_features()` method to accept image pixel values and return encoded outputs. These are later merged with the language embeddings and don't require any postprocessing. The shape of the returned features must match the number of input images. If a vision encoder returns variable-length outputs (patch-based), return a list of 2D tensors of size `(image_seq_len, image_dim)` for each image.
|
||||
|
||||
3. The processor needs to check the value of `return_mm_token_type_ids` and return `mm_token_type_ids` to indicate whether each position is a text token (`0`), image placeholder token (`1`) or video placeholder token (`2`). Each multimodal token type ID sequence must be contiguous without breaks between consecutive tokens, therefore special tokens for begin/end/row/column must be treated as placeholders.
|
||||
Expand the code below for an example.
|
||||
|
||||
<details>
|
||||
<summary>modeling_my_multimodal_model.py</summary>
|
||||
|
||||
```python
|
||||
from transformers.generation import GenerationMixin
|
||||
|
||||
class MyMultimodalModel(MyMultimodalPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.language_model = AutoModel.from_config(config.text_config)
|
||||
self.vision_tower = AutoModel.from_config(config.vision_config)
|
||||
self.multimodal_projection = nn.Linear(vision_dim, text_dim)
|
||||
|
||||
def get_image_features(self, pixel_values):
|
||||
return self.vision_tower(pixel_values).last_hidden_states
|
||||
|
||||
def forward(self, input_ids, pixel_values, **kwargs):
|
||||
# process your inputs
|
||||
return MyModelOutputWithPast(
|
||||
last_hidden_state=last_hidden_state,
|
||||
image_hidden_states=image_features,
|
||||
[...]
|
||||
)
|
||||
|
||||
class MyMultimodalModelForConditionalGeneration(MyMultimodalPreTrainedModel, GenerationMixin):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.model = MyMultimodalModel(config)
|
||||
self.lm_head = nn.Linear(hidden_dim, vocab_size)
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
2. A multimodal model config must be nested with the following fields.
|
||||
* text_config: decoder language model config
|
||||
* vision_config: vision encoder config
|
||||
* image_token_id: ID of the image placeholder token used in the input to indicate image position
|
||||
|
||||
3. A multimodal model's processing class must have the `self.image_token` and `self.image_token_ids` attributes. These are placeholder tokens used to indicate image positions in the input. The placeholder token is the same token used in the input prompt and to mask scatter image features.
|
||||
|
||||
The processing class also needs `self._get_num_multimodal_tokens` method to compute the number of placeholder tokens needed for multimodal inputs with given sizes and to return a [`MultiModalData`] object. The placeholder for row and column tokens don't count as image placeholders. Only the tokens that are actually replaced by image features are computed.
|
||||
|
||||
Finally, when `return_mm_token_type_ids=True`, the class has to return `mm_token_type_ids` to indicate whether each position is a text token (`0`) or image placeholder token (`1`). Each image's token type IDs must be contiguous with no breaks between consecutive ones.
|
||||
|
||||
Expand the code below for an example.
|
||||
|
||||
@ -202,5 +246,5 @@ class MyMultimodalProcessor(ProcessorMixin):
|
||||
|
||||
## Resources
|
||||
|
||||
* Read the [Transformers modeling backend integration in vLLM](https://blog.vllm.ai/2025/04/11/transformers-backend.html) blog post for more details about the Transformers modeling backend in vLLM.
|
||||
* Read the [Transformers modeling backend integration in SGLang](https://huggingface.co/blog/transformers-backend-sglang) blog post for more details about the Transformers modeling backend in SGLang.
|
||||
* Read the [Transformers backend integration in vLLM](https://blog.vllm.ai/2025/04/11/transformers-backend.html) blog post for more details about the Transformers backend in vLLM.
|
||||
* Read the [Transformers backend integration in SGLang](https://huggingface.co/blog/transformers-backend-sglang) blog post for more details about the Transformers backend in SGLang.
|
||||
|
||||
@ -170,7 +170,7 @@ Per quanto riguarda la classe `TrainingArguments`:
|
||||
- L'argomento `evaluate_during_training` di `TrainingArguments` è deprecato a favore di `eval_strategy`.
|
||||
|
||||
Per quanto riguarda il modello Transfo-XL:
|
||||
- L'attributo di configurazione `tie_weight` di Transfo-XL diventa `tie_word_embeddings`.
|
||||
- L'attributo di configurazione `tie_weight` di Transfo-XL diventa `tie_words_embeddings`.
|
||||
- Il metodo di modellazione `reset_length` di Transfo-XL diventa `reset_memory_length`.
|
||||
|
||||
Per quanto riguarda le pipeline:
|
||||
|
||||
@ -406,16 +406,16 @@ model = BrandNewBertModel(BrandNewBertConfig())
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.zero_()
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.zero_()
|
||||
module.weight.fill_(1.0)
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
```
|
||||
|
||||
特定のモジュールに特別な初期化が必要な場合、カスタムスキームをさらに持つことができます。たとえば、
|
||||
@ -431,9 +431,9 @@ def _init_weights(self, module):
|
||||
module.project_hid._is_hf_initialized = True
|
||||
module.project_q._is_hf_initialized = True
|
||||
elif isinstance(module, nn.Linear):
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.zero_()
|
||||
module.bias.data.zero_()
|
||||
```
|
||||
|
||||
`_is_hf_initialized`フラグは、サブモジュールを一度だけ初期化することを確実にするために内部で使用されます。
|
||||
|
||||
@ -348,16 +348,16 @@ model = BrandNewBertModel(BrandNewBertConfig())
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.zero_()
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.zero_()
|
||||
module.weight.fill_(1.0)
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
```
|
||||
|
||||
몇 가지 모듈에 대해 특별한 초기화가 필요한 경우 사용자 정의 방식을 사용할 수도 있습니다. 예를 들어, `Wav2Vec2ForPreTraining`에서 마지막 두 개의 선형 레이어는 일반적인 PyTorch `nn.Linear`의 초기화를 가져야 하지만, 다른 모든 레이어는 위와 같은 초기화를 사용해야 합니다. 이는 다음과 같이 코드화됩니다:
|
||||
@ -371,9 +371,9 @@ def _init_weights(self, module):
|
||||
module.project_hid._is_hf_initialized = True
|
||||
module.project_q._is_hf_initialized = True
|
||||
elif isinstance(module, nn.Linear):
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.zero_()
|
||||
module.bias.data.zero_()
|
||||
```
|
||||
|
||||
`_is_hf_initialized` 플래그는 서브모듈을 한 번만 초기화하도록 내부적으로 사용됩니다. `module.project_q` 및 `module.project_hid`에 대해 `True`로 설정함으로써, 우리가 수행한 사용자 정의 초기화가 이후에 덮어쓰이지 않도록 합니다. 즉, `_init_weights` 함수가 이들에게 적용되지 않습니다.
|
||||
|
||||
@ -152,7 +152,7 @@ class ParallelInterface(MutableMapping):
|
||||
```python
|
||||
class Llama4TextExperts(nn.Module):
|
||||
...
|
||||
self.gate_up_proj = nn.Parameter(torch.zeros(self.num_experts, self.hidden_size, 2 * self.expert_dim))
|
||||
self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, 2 * self.expert_dim))
|
||||
```
|
||||
|
||||
배치 행렬 곱셈을 `forward` 패스에서 사용하여 `gate_up_proj` 모듈의 출력을 계산할 수 있습니다.
|
||||
|
||||
@ -502,10 +502,16 @@ class DummyBertLMPredictionHead(nn.Module):
|
||||
|
||||
# The output weights are the same as the input embeddings, but there is
|
||||
# an output-only bias for each token.
|
||||
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=True)
|
||||
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
||||
|
||||
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
||||
self.decoder.bias = self.bias
|
||||
|
||||
def _tie_weights(self):
|
||||
self.decoder.bias = self.bias
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.transform(hidden_states)
|
||||
hidden_states = self.decoder(hidden_states)
|
||||
@ -530,18 +536,18 @@ class DummyBertPreTrainedModel(PreTrainedModel):
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.zero_()
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.zero_()
|
||||
module.weight.fill_(1.0)
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
elif isinstance(module, DummyBertLMPredictionHead):
|
||||
module.bias.zero_()
|
||||
module.bias.data.zero_()
|
||||
|
||||
|
||||
@auto_docstring(
|
||||
|
||||
@ -265,7 +265,7 @@ class MyNewModel2PreTrainedModel(PreTrainedModel):
|
||||
|
||||
# We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight)
|
||||
if "RMSNorm" in module.__class__.__name__:
|
||||
module.weight.zero_()
|
||||
module.weight.data.zero_()
|
||||
|
||||
|
||||
class MyNewModel2ForSequenceClassification(GenericForSequenceClassification, MyNewModel2PreTrainedModel):
|
||||
|
||||
@ -104,9 +104,9 @@ class NewTaskModelPreTrainedModel(PreTrainedModel):
|
||||
std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
|
||||
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.normal_(mean=0.0, std=std)
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.zero_()
|
||||
module.bias.data.zero_()
|
||||
|
||||
|
||||
def token_type_ids_mask_function(
|
||||
@ -428,7 +428,7 @@ class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin):
|
||||
"^multi_modal_projector": "model.multi_modal_projector",
|
||||
"^language_model.lm_head": "lm_head",
|
||||
}
|
||||
_tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"}
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
main_input_name: ClassVar[str] = "doc_input_ids" # transformers-related
|
||||
|
||||
def __init__(self, config):
|
||||
@ -440,15 +440,7 @@ class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin):
|
||||
self.custom_text_proj = nn.Linear(self.config.text_config.hidden_size, self.embedding_dim)
|
||||
|
||||
if self.language_model._tied_weights_keys is not None:
|
||||
prefix = "model.language_model."
|
||||
prefixed_mapping = {
|
||||
f"{prefix}{target}": f"{prefix}{source}"
|
||||
for target, source in self.language_model._tied_weights_keys.items()
|
||||
}
|
||||
if isinstance(self._tied_weights_keys, dict):
|
||||
self._tied_weights_keys.update(prefixed_mapping)
|
||||
else:
|
||||
self._tied_weights_keys = prefixed_mapping
|
||||
self._tied_weights_keys = [f"model.language_model.{k}" for k in self.language_model._tied_weights_keys]
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
|
||||
@ -505,10 +505,16 @@ class RobertaLMPredictionHead(nn.Module):
|
||||
|
||||
# The output weights are the same as the input embeddings, but there is
|
||||
# an output-only bias for each token.
|
||||
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=True)
|
||||
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
||||
|
||||
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
||||
self.decoder.bias = self.bias
|
||||
|
||||
def _tie_weights(self):
|
||||
self.decoder.bias = self.bias
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.transform(hidden_states)
|
||||
hidden_states = self.decoder(hidden_states)
|
||||
@ -533,18 +539,18 @@ class RobertaPreTrainedModel(PreTrainedModel):
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.zero_()
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.zero_()
|
||||
module.weight.fill_(1.0)
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
elif isinstance(module, RobertaLMPredictionHead):
|
||||
module.bias.zero_()
|
||||
module.bias.data.zero_()
|
||||
|
||||
|
||||
@auto_docstring(
|
||||
|
||||
@ -846,11 +846,11 @@ class TestDetrPreTrainedModel(PreTrainedModel):
|
||||
nn.init.xavier_uniform_(module.output_proj.weight.data)
|
||||
nn.init.constant_(module.output_proj.bias.data, 0.0)
|
||||
elif isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)):
|
||||
module.weight.normal_(mean=0.0, std=std)
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.zero_()
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.normal_(mean=0.0, std=std)
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
if hasattr(module, "reference_points") and not self.config.two_stage:
|
||||
|
||||
@ -19,15 +19,7 @@ class NewTaskModelForNewTask(PaliGemmaForConditionalGeneration):
|
||||
self.custom_text_proj = nn.Linear(self.config.text_config.hidden_size, self.embedding_dim)
|
||||
|
||||
if self.language_model._tied_weights_keys is not None:
|
||||
prefix = "model.language_model."
|
||||
prefixed_mapping = {
|
||||
f"{prefix}{target}": f"{prefix}{source}"
|
||||
for target, source in self.language_model._tied_weights_keys.items()
|
||||
}
|
||||
if isinstance(self._tied_weights_keys, dict):
|
||||
self._tied_weights_keys.update(prefixed_mapping)
|
||||
else:
|
||||
self._tied_weights_keys = prefixed_mapping
|
||||
self._tied_weights_keys = [f"model.language_model.{k}" for k in self.language_model._tied_weights_keys]
|
||||
|
||||
self.post_init()
|
||||
|
||||
|
||||
@ -27,6 +27,7 @@
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import warnings
|
||||
from dataclasses import dataclass, field
|
||||
from random import randint
|
||||
from typing import Optional
|
||||
@ -179,11 +180,29 @@ class ModelArguments:
|
||||
)
|
||||
},
|
||||
)
|
||||
freeze_feature_extractor: Optional[bool] = field(
|
||||
default=None, metadata={"help": "Whether to freeze the feature extractor layers of the model."}
|
||||
)
|
||||
ignore_mismatched_sizes: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Will enable to load a pretrained model whose head dimensions are different."},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if not self.freeze_feature_extractor and self.freeze_feature_encoder:
|
||||
warnings.warn(
|
||||
"The argument `--freeze_feature_extractor` is deprecated and "
|
||||
"will be removed in a future version. Use `--freeze_feature_encoder` "
|
||||
"instead. Setting `freeze_feature_encoder==True`.",
|
||||
FutureWarning,
|
||||
)
|
||||
if self.freeze_feature_extractor and not self.freeze_feature_encoder:
|
||||
raise ValueError(
|
||||
"The argument `--freeze_feature_extractor` is deprecated and "
|
||||
"should not be used in combination with `--freeze_feature_encoder`. "
|
||||
"Only make use of `--freeze_feature_encoder`."
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
# See all possible arguments in src/transformers/training_args.py
|
||||
|
||||
@ -17,6 +17,7 @@ import contextlib
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from itertools import cycle
|
||||
from typing import Optional
|
||||
|
||||
import datasets
|
||||
@ -29,42 +30,32 @@ from transformers.generation import GenerationConfig
|
||||
from transformers.generation.continuous_batching.requests import logger
|
||||
|
||||
|
||||
# MODEL_ID = "Qwen/Qwen3-4B-Instruct-2507"
|
||||
SLIDING_WINDOW = 0
|
||||
MODEL_ID = "google/gemma-2-2b-it" if SLIDING_WINDOW > 0 else "meta-llama/Meta-Llama-3-8B"
|
||||
FORCE_MAX_LENGTH = False # should be False unless you are debugging sliding window features
|
||||
SKIP_SPECIAL_TOKENS = False
|
||||
|
||||
|
||||
def generate_simple(
|
||||
attn_impl: str, simple_batch_inputs: list[int], generation_config: GenerationConfig
|
||||
def generate_without_cb(
|
||||
model_id: str, sliding_window: int, attn_impl: str, batched_inputs: list[int], generation_config: GenerationConfig
|
||||
) -> dict[str, str]:
|
||||
attn_impl = {
|
||||
"sdpa": "sdpa",
|
||||
"eager": "eager",
|
||||
"paged_attention": "eager", # TODO: this does not work on AMD docker
|
||||
"flash_paged": "flash_attention_2", # TODO: this does not work on AMD docker
|
||||
"kernels-community/flash-attn": "eager",
|
||||
}[attn_impl]
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, dtype=torch.bfloat16, attn_implementation=attn_impl)
|
||||
# Setup model and tokenizer
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, dtype=torch.bfloat16, attn_implementation=attn_impl)
|
||||
model = model.cuda().eval()
|
||||
if getattr(model.config, "sliding_window", None) is not None:
|
||||
model.config.sliding_window = SLIDING_WINDOW
|
||||
|
||||
if sliding_window > 0 and getattr(model.config, "sliding_window", None) is not None:
|
||||
model.config.sliding_window = sliding_window
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
# Generate one by one
|
||||
decoded_outputs = {}
|
||||
for input_ids in tqdm(simple_batch_inputs, desc="Generating outputs without CB"):
|
||||
for input_ids in tqdm(batched_inputs, desc="Generating outputs without CB"):
|
||||
key = " ".join(map(str, input_ids)) # This will be used to identify the output after batched generation
|
||||
input_ids = torch.tensor([input_ids]).to("cuda")
|
||||
# attention_mask = torch.ones_like(input_ids)
|
||||
outputs = model.generate(input_ids, generation_config=generation_config, use_model_defaults=False)
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
outputs = model.generate(
|
||||
input_ids, attention_mask=attention_mask, generation_config=generation_config, use_model_defaults=False
|
||||
)
|
||||
generated_tokens = outputs[0][input_ids.shape[1] :]
|
||||
decoded_output = tokenizer.decode(generated_tokens, skip_special_tokens=SKIP_SPECIAL_TOKENS)
|
||||
decoded_outputs[key] = decoded_output
|
||||
decoded_outputs[key] = tokenizer.decode(generated_tokens, skip_special_tokens=False)
|
||||
return decoded_outputs
|
||||
|
||||
|
||||
def setup_metrics():
|
||||
def maybe_setup_metrics(use_metrics: bool) -> None:
|
||||
if not use_metrics:
|
||||
return
|
||||
try:
|
||||
from opentelemetry import metrics, trace
|
||||
from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter
|
||||
@ -119,16 +110,14 @@ def batch_generate(
|
||||
token_count = 0
|
||||
data = []
|
||||
for i, request in enumerate(batch_outputs):
|
||||
input_text = tokenizer.decode(batch_outputs[request].prompt_ids, skip_special_tokens=SKIP_SPECIAL_TOKENS)
|
||||
input_text = tokenizer.decode(batch_outputs[request].prompt_ids, skip_special_tokens=False)
|
||||
# The key is used to tie back to the output of unbatched generation
|
||||
key = " ".join(map(str, batch_outputs[request].prompt_ids))
|
||||
data.append({"input": input_text, "key": key})
|
||||
|
||||
# Try to decode the output
|
||||
try:
|
||||
output_text = tokenizer.decode(
|
||||
batch_outputs[request].generated_tokens, skip_special_tokens=SKIP_SPECIAL_TOKENS
|
||||
)
|
||||
output_text = tokenizer.decode(batch_outputs[request].generated_tokens, skip_special_tokens=False)
|
||||
token_count += len(batch_outputs[request].generated_tokens[1:])
|
||||
data[-1]["cb_outputs"] = output_text
|
||||
except Exception as e:
|
||||
@ -138,14 +127,7 @@ def batch_generate(
|
||||
|
||||
# Display sample if asked
|
||||
if i < displayed_samples:
|
||||
if len(output_text) > 0:
|
||||
print("-" * 20)
|
||||
print(f"{request} Input: {input_text}")
|
||||
print(f"{request} Output: {output_text}")
|
||||
else:
|
||||
print(f"{request} Input: {input_text}")
|
||||
print("[WARN]")
|
||||
print(f"{request} Output was empty!")
|
||||
print("-" * 20, f"{request} Input: {input_text}", f"{request} Output: {output_text}", sep="\n")
|
||||
|
||||
# Compare with classic generate if asked
|
||||
if expected_outputs is not None:
|
||||
@ -182,75 +164,102 @@ def batch_generate(
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Parse args
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
# Continuous batching parameters
|
||||
parser.add_argument("--num-blocks", "-n", type=int, default=None)
|
||||
parser.add_argument("--max-batch-tokens", "-b", type=int, default=None)
|
||||
|
||||
# Model parameters
|
||||
parser.add_argument("--sliding-window", type=int, default=0)
|
||||
parser.add_argument("--attn", type=str, default="kernels-community/flash-attn", help="Attention implementation")
|
||||
|
||||
# Performance parameters
|
||||
parser.add_argument("--matmul-precision", "-mp", type=str, default="high") # set to "none" to disable
|
||||
parser.add_argument("--cuda-graph", "-cg", help="Use cuda graphs", type=str, default=None)
|
||||
parser.add_argument("--compile", action="store_true", help="Compile the model using torch.compile")
|
||||
parser.add_argument("--do-sample", action="store_true", help="Activate sampling")
|
||||
|
||||
# Benchmark parameters
|
||||
parser.add_argument("--samples", type=int, default=500, help="Number of samples to generate")
|
||||
parser.add_argument("--add-prefix", action="store_true", help="Add a prefix to the samples")
|
||||
parser.add_argument("--compare", action="store_true", help="Compare CB generation with classic generate")
|
||||
parser.add_argument("--profile", type=str, default=None)
|
||||
parser.add_argument("--metrics", action="store_true")
|
||||
parser.add_argument("--force-max-length", action="store_true", help="Force generation to stop at max length")
|
||||
|
||||
# Display parameters
|
||||
parser.add_argument("--displayed", type=int, default=0, help="Number of samples to display")
|
||||
parser.add_argument("--log-level", type=str, default="INFO")
|
||||
parser.add_argument("--output-file", type=str, default=None)
|
||||
parser.add_argument("--compare", action="store_true")
|
||||
parser.add_argument("--metrics", action="store_true")
|
||||
parser.add_argument("--profile", type=str, default=None)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Set log level
|
||||
# Create model
|
||||
model_id = "google/gemma-2-2b-it" if args.sliding_window > 0 else "meta-llama/Llama-3.1-8B-Instruct"
|
||||
has_system_role = args.sliding_window == 0
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation=args.attn, dtype=torch.bfloat16)
|
||||
model = model.cuda().eval()
|
||||
|
||||
if args.sliding_window > 0 and getattr(model.config, "sliding_window", None) is not None:
|
||||
print(f"Setting sliding window from {model.config.sliding_window} to {args.sliding_window}")
|
||||
model.config.sliding_window = args.sliding_window
|
||||
|
||||
# Set up diagnostics
|
||||
logger.setLevel(args.log_level.upper())
|
||||
maybe_setup_metrics(args.metrics)
|
||||
|
||||
# If turned on, we setup metrics
|
||||
if args.metrics:
|
||||
setup_metrics()
|
||||
|
||||
# Set matmul precision if not none
|
||||
# Set up performance
|
||||
if args.matmul_precision != "none":
|
||||
torch.set_float32_matmul_precision(args.matmul_precision)
|
||||
# Parse cuda graph argument
|
||||
if args.cuda_graph is not None:
|
||||
use_cuda_graph = {
|
||||
"none": None,
|
||||
"yes": True, "y": True, "true": True, "t": True, "1": True,
|
||||
"no": False, "n": False, "false": False, "f": False, "0": False,
|
||||
}[args.cuda_graph.lower()] # fmt: skip
|
||||
else:
|
||||
use_cuda_graph = None
|
||||
|
||||
# Prepare model
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
MODEL_ID,
|
||||
attn_implementation=args.attn,
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
model = model.cuda().eval()
|
||||
if getattr(model.config, "sliding_window", None) is not None:
|
||||
print(f"Setting sliding window from {model.config.sliding_window} to {SLIDING_WINDOW}")
|
||||
model.config.sliding_window = SLIDING_WINDOW
|
||||
cuda_graph_arg = args.cuda_graph.lower() if args.cuda_graph is not None else None
|
||||
use_cuda_graph = {
|
||||
"none": None, None: None,
|
||||
"yes": True, "y": True, "true": True, "t": True, "1": True,
|
||||
"no": False, "n": False, "false": False, "f": False, "0": False,
|
||||
}[cuda_graph_arg] # fmt: skip
|
||||
|
||||
# If turned on, we compile the model
|
||||
if args.compile:
|
||||
model.forward = torch.compile(model.forward, mode="max-autotune-no-cudagraphs")
|
||||
|
||||
# Prepare tokenizer and dataset
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, padding_side="left")
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="left")
|
||||
|
||||
dataset = datasets.load_dataset("openai/gsm8k", "socratic", split="test")
|
||||
dataset = dataset.select(range(args.samples))
|
||||
|
||||
simple_batch_inputs = [tokenizer(item["question"])["input_ids"] for item in dataset]
|
||||
if args.add_prefix:
|
||||
possible_prefixes = [
|
||||
None,
|
||||
"You are a bot that solves math problems.",
|
||||
"You are a bot who solves math problems. Try to make your answer clear and understandable, and include your stages of reasoning.",
|
||||
"You are a bot with the aim to solves math problems. Try to make your answer clear and understandable, and include your stages of reasoning. No loud words or emojis, all responses must be readable by a child. Here is now the problem:",
|
||||
] # fmt: skip
|
||||
else:
|
||||
possible_prefixes = [None]
|
||||
|
||||
batched_inputs = []
|
||||
for item, prefix in zip(dataset, cycle(possible_prefixes)):
|
||||
messages = []
|
||||
question = item["question"]
|
||||
if prefix is not None:
|
||||
if has_system_role:
|
||||
messages.append({"role": "system", "content": prefix})
|
||||
else:
|
||||
question = prefix + "\n\n" + question
|
||||
messages.append({"role": "user", "content": question})
|
||||
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
|
||||
batched_inputs.append(inputs["input_ids"])
|
||||
|
||||
# Prepare generation config
|
||||
generation_config = GenerationConfig(
|
||||
generation_cfg = GenerationConfig(
|
||||
max_new_tokens=512,
|
||||
use_cuda_graph=use_cuda_graph,
|
||||
eos_token_id=tokenizer.pad_token_id if FORCE_MAX_LENGTH else tokenizer.eos_token_id,
|
||||
eos_token_id=tokenizer.pad_token_id if args.force_max_length else tokenizer.eos_token_id,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
do_sample=not args.compare,
|
||||
do_sample=args.do_sample,
|
||||
temperature=0.8,
|
||||
top_p=0.9,
|
||||
num_blocks=args.num_blocks,
|
||||
@ -258,7 +267,12 @@ if __name__ == "__main__":
|
||||
)
|
||||
|
||||
# If we need to compare, we need to generate the reference outputs
|
||||
expected_outputs = generate_simple(args.attn, simple_batch_inputs, generation_config) if args.compare else None
|
||||
if args.compare:
|
||||
expected_outputs = generate_without_cb(
|
||||
model_id, args.sliding_window, args.attn, batched_inputs, generation_cfg
|
||||
)
|
||||
else:
|
||||
expected_outputs = None
|
||||
|
||||
# If no output file is provided, we pick a name based on the args
|
||||
if args.output_file is None:
|
||||
@ -271,8 +285,8 @@ if __name__ == "__main__":
|
||||
# Run warmup batch generation # TODO: understand why warmup incurs a large overhead during cache creation
|
||||
batch_generate(
|
||||
model,
|
||||
simple_batch_inputs[: min(5, args.samples)],
|
||||
generation_config,
|
||||
batched_inputs[: min(5, args.samples)],
|
||||
generation_cfg,
|
||||
tokenizer,
|
||||
displayed_samples=-1,
|
||||
)
|
||||
@ -285,8 +299,8 @@ if __name__ == "__main__":
|
||||
# Run batch generation
|
||||
gen_time, tok_per_sec = batch_generate(
|
||||
model,
|
||||
simple_batch_inputs,
|
||||
generation_config,
|
||||
batched_inputs,
|
||||
generation_cfg,
|
||||
tokenizer,
|
||||
displayed_samples=args.displayed,
|
||||
output_file=args.output_file,
|
||||
@ -297,5 +311,5 @@ if __name__ == "__main__":
|
||||
prof.export_chrome_trace(filename)
|
||||
|
||||
# Example usage:
|
||||
# python examples/pytorch/continuous_batching.py --attn sdpa_paged -mp none --samples 3 --compare
|
||||
# python examples/pytorch/continuous_batching.py --num-blocks 369 --max-batch-tokens 23 --attn sdpa_paged -mp none --samples 1 --displayed 0 --output-file sliced.json
|
||||
# python examples/pytorch/continuous_batching.py --attn sdpa --add-prefix --samples 10 --compare
|
||||
# python examples/pytorch/continuous_batching.py --attn flash_attention_2 -mp none --add-prefix --samples 500
|
||||
|
||||
@ -876,7 +876,7 @@ class PreTrainedConfig(PushToHubMixin):
|
||||
if hasattr(self, "quantization_config"):
|
||||
serializable_config_dict["quantization_config"] = (
|
||||
self.quantization_config.to_dict()
|
||||
if not isinstance(self.quantization_config, dict) and self.quantization_config is not None
|
||||
if not isinstance(self.quantization_config, dict)
|
||||
else self.quantization_config
|
||||
)
|
||||
self.dict_dtype_to_str(serializable_config_dict)
|
||||
@ -910,7 +910,7 @@ class PreTrainedConfig(PushToHubMixin):
|
||||
if hasattr(self, "quantization_config"):
|
||||
output["quantization_config"] = (
|
||||
self.quantization_config.to_dict()
|
||||
if not isinstance(self.quantization_config, dict) and self.quantization_config is not None
|
||||
if not isinstance(self.quantization_config, dict)
|
||||
else self.quantization_config
|
||||
)
|
||||
self.dict_dtype_to_str(output)
|
||||
|
||||
@ -1,136 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright (C) 2025 the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from copy import deepcopy
|
||||
|
||||
from .core_model_loading import Concatenate, MergeModulelist, WeightConverter
|
||||
from .utils import is_torch_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
def _build_checkpoint_conversion_mapping():
|
||||
mapping = {
|
||||
"mixtral": [
|
||||
WeightConverter(
|
||||
source_keys=[
|
||||
"block_sparse_moe.experts.*.w1.weight",
|
||||
"block_sparse_moe.experts.*.w3.weight",
|
||||
], # you give me a list of 2 keys, I collect a list of a list of tensors
|
||||
target_keys="mlp.experts.gate_up_proj", # target key gets the list of two tensors
|
||||
operations=[
|
||||
MergeModulelist(
|
||||
dim=0
|
||||
), # each process has two lists of tensors, we cat each list. -> we end up with 2 tensors
|
||||
Concatenate(dim=1), # each process has 2 tensors, gate and up, we concat them into gate_up
|
||||
], # we want the loading to add this shard operation here. Though we can't shard after concats and merge, needs to be first
|
||||
),
|
||||
WeightConverter(
|
||||
source_keys=[
|
||||
"block_sparse_moe.experts.*.w2.weight",
|
||||
],
|
||||
target_keys="mlp.experts.down_proj", # target key gets the list of two tensors
|
||||
operations=[
|
||||
MergeModulelist(
|
||||
dim=0
|
||||
), # each process has two lists of tensors, we cat each list. -> we end up with 2 tensors
|
||||
], # we want the loading to add this shard operation here. Though we can't shard after concats and merge, needs to be first
|
||||
),
|
||||
# WeightConverter(
|
||||
# ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"],
|
||||
# "self_attn.qkv_proj",
|
||||
# operations=[Concatenate(dim=0)], # more like stack?
|
||||
# ),
|
||||
WeightConverter("*.block_sparse_moe.", "*.mlp."),
|
||||
],
|
||||
"qwen2_moe": [
|
||||
WeightConverter(
|
||||
source_keys=[
|
||||
"mlp.experts.*.gate_proj.weight",
|
||||
"mlp.experts.*.up_proj.weight",
|
||||
],
|
||||
target_keys="mlp.experts.gate_up_proj",
|
||||
operations=[MergeModulelist(dim=0), Concatenate(dim=1)],
|
||||
),
|
||||
WeightConverter(
|
||||
source_keys=["mlp.experts.*.down_proj.weight"],
|
||||
target_keys="mlp.experts.down_proj",
|
||||
operations=[MergeModulelist(dim=0)],
|
||||
),
|
||||
],
|
||||
"legacy": [
|
||||
WeightConverter(
|
||||
source_keys="LayerNorm.gamma",
|
||||
target_keys="LayerNorm.weight",
|
||||
),
|
||||
WeightConverter(
|
||||
source_keys="LayerNorm.beta",
|
||||
target_keys="LayerNorm.bias",
|
||||
),
|
||||
],
|
||||
}
|
||||
if hasattr(torch.nn.utils.parametrizations, "weight_norm"):
|
||||
mapping["legacy"] += [
|
||||
WeightConverter(
|
||||
source_keys="weight_g",
|
||||
target_keys="parametrizations.weight.original0",
|
||||
),
|
||||
WeightConverter(
|
||||
source_keys="weight_v",
|
||||
target_keys="parametrizations.weight.original1",
|
||||
),
|
||||
]
|
||||
else:
|
||||
mapping["legacy"] += [
|
||||
WeightConverter(
|
||||
source_keys="parametrizations.weight.original0",
|
||||
target_keys="weight_g",
|
||||
),
|
||||
WeightConverter(
|
||||
source_keys="parametrizations.weight.original1",
|
||||
target_keys="weight_v",
|
||||
),
|
||||
]
|
||||
|
||||
mapping["phimoe"] = mapping["mixtral"].copy()
|
||||
mapping["deepseek_v2"] = mapping["qwen2_moe"].copy()
|
||||
mapping["deepseek_v3"] = mapping["qwen2_moe"].copy()
|
||||
mapping["dot1"] = mapping["qwen2_moe"].copy()
|
||||
mapping["ernie_4_5_moe"] = mapping["qwen2_moe"].copy()
|
||||
mapping["glm4_moe"] = mapping["qwen2_moe"].copy()
|
||||
mapping["glm4v_moe"] = mapping["qwen2_moe"].copy()
|
||||
mapping["jamba"] = mapping["qwen2_moe"].copy()
|
||||
mapping["lfm2_moe"] = mapping["mixtral"].copy()
|
||||
mapping["long_cat_flash"] = mapping["qwen2_moe"].copy()
|
||||
mapping["qwen3_moe"] = mapping["qwen2_moe"].copy()
|
||||
mapping["qwen3_omni_moe"] = mapping["qwen2_moe"].copy()
|
||||
mapping["qwen3_next"] = mapping["qwen2_moe"].copy()
|
||||
mapping["qwen3_vl_moe"] = mapping["qwen2_moe"].copy()
|
||||
mapping["hunyuan_v1_moe"] = mapping["qwen2_moe"].copy()
|
||||
mapping["minimax"] = mapping["mixtral"].copy()
|
||||
|
||||
return mapping
|
||||
|
||||
|
||||
_checkpoint_conversion_mapping_cache = None
|
||||
|
||||
|
||||
def get_checkpoint_conversion_mapping(model_type):
|
||||
global _checkpoint_conversion_mapping_cache
|
||||
_checkpoint_conversion_mapping_cache = _build_checkpoint_conversion_mapping()
|
||||
globals()["_checkpoint_conversion_mapping"] = _checkpoint_conversion_mapping_cache
|
||||
return deepcopy(_checkpoint_conversion_mapping_cache.get(model_type, None))
|
||||
@ -1,732 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Core helpers for loading model checkpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import itertools
|
||||
import os
|
||||
import re
|
||||
from abc import abstractmethod
|
||||
from collections import defaultdict
|
||||
from collections.abc import MutableMapping, MutableSet, Sequence
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from functools import partial
|
||||
from types import MethodType
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from .integrations.tensor_parallel import ALL_PARALLEL_STYLES, DTensor, Replicate, TensorParallelLayer
|
||||
from .utils import is_torch_greater_or_equal, logging
|
||||
|
||||
|
||||
_torch_distributed_available = torch.distributed.is_available()
|
||||
_is_dtensor_available = _torch_distributed_available and is_torch_greater_or_equal("2.5")
|
||||
if _is_dtensor_available:
|
||||
from torch.distributed.tensor import DTensor
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .modeling_utils import PreTrainedModel
|
||||
from .quantizers import HfQuantizer
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
str_to_torch_dtype = {
|
||||
"BOOL": torch.bool,
|
||||
"U8": torch.uint8,
|
||||
"I8": torch.int8,
|
||||
"I16": torch.int16,
|
||||
"F16": torch.float16,
|
||||
"BF16": torch.bfloat16,
|
||||
"I32": torch.int32,
|
||||
"F32": torch.float32,
|
||||
"F64": torch.float64,
|
||||
"I64": torch.int64,
|
||||
"F8_E4M3": torch.float8_e4m3fn,
|
||||
"F8_E5M2": torch.float8_e5m2,
|
||||
}
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def _glob_to_regex_src(glob: str, *, digits_only: bool = True) -> str:
|
||||
"""
|
||||
Convert a glob with '*' into a regex *source* string. We don't use `glob.translate`
|
||||
'*' matches (\\d+) if digits_only else (.+). Inner groups are non-capturing.
|
||||
"""
|
||||
star = r"(\d+)" if digits_only else r"(.+)"
|
||||
return glob.replace(r"\*", star)
|
||||
|
||||
|
||||
def build_glob_alt(
|
||||
globs: list[str],
|
||||
) -> tuple[re.Pattern, dict[str, str]]:
|
||||
r"""
|
||||
Build one compiled regex alternation with a named group per glob. This allows to run a single
|
||||
re.match and get the correct group name to finally get which pattern matched.
|
||||
Returns (compiled_regex, name->glob map).
|
||||
|
||||
Example:
|
||||
|
||||
```py
|
||||
>>> reg, map_ = build_glob_alt(["mlp.*.w1", "mlp.*.w2"])
|
||||
>>> print(reg)
|
||||
(re.compile(r'(?P<g0>.*mlp\.(\d+)\.w1)|(?P<g1>.*mlp\.(\d+)\.w2)', re.UNICODE),
|
||||
>>> print(map_)
|
||||
{'g0': 'mlp.*.w1', 'g1': 'mlp.*.w2'})
|
||||
>>> match_ = reg.match("model.layers.0.mlp.0.w1.weight")
|
||||
>>> print(match_.lastgroup)
|
||||
'g0'
|
||||
>>> print(map_[match_.lastgroup])
|
||||
mlp.*.w1
|
||||
```
|
||||
"""
|
||||
name_map: dict[str, str] = {}
|
||||
parts: list[str] = []
|
||||
|
||||
for i, g in enumerate(globs):
|
||||
name = f"g{i}"
|
||||
name_map[name] = g
|
||||
pat_src = _glob_to_regex_src(g)
|
||||
prefix_src = ""
|
||||
if pat_src.startswith("*"):
|
||||
prefix_src = "."
|
||||
elif not pat_src.startswith(r"\^") and not pat_src.startswith(r".*"):
|
||||
prefix_src = ".*"
|
||||
|
||||
parts.append(f"(?P<{name}>{prefix_src}{pat_src}.*)")
|
||||
|
||||
alt_src = "|".join(parts).replace("\\^", "^").replace("\\.", r"\.")
|
||||
try:
|
||||
reg = re.compile(alt_src)
|
||||
except re.error as e:
|
||||
logger.error(f"Error compiling regex for alternation: {alt_src}")
|
||||
raise e
|
||||
|
||||
return reg, name_map
|
||||
|
||||
|
||||
def match_glob(key: str, alt: re.Pattern, name_map: dict[str, str]) -> Optional[str]:
|
||||
"""
|
||||
Match the key against the alternation; return the original glob string that matched.
|
||||
"""
|
||||
m = alt.match(key)
|
||||
if not m:
|
||||
return None
|
||||
return name_map.get(m.lastgroup)
|
||||
|
||||
|
||||
class ConversionOps:
|
||||
"""Base class for weight conversion operations."""
|
||||
|
||||
# The inverse operation class, will be used when saving the checkpoint
|
||||
reverse_op: type[ConversionOps]
|
||||
|
||||
@abstractmethod
|
||||
def convert(
|
||||
self, value: Union[dict[str, torch.Tensor], Sequence[torch.Tensor], torch.Tensor], *args, **kwargs
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class Chunk(ConversionOps):
|
||||
"""Split a tensor along ``dim`` into equally sized chunks or using explicit ``sizes``."""
|
||||
|
||||
reverse_op: type[ConversionOps]
|
||||
|
||||
def __init__(self, dim: int = 0, chunks: Optional[int] = None, sizes: Optional[Sequence[int]] = None):
|
||||
if chunks is None and sizes is None:
|
||||
raise ValueError("`chunks` or `sizes` must be provided for Chunk operations.")
|
||||
if chunks is not None and chunks <= 0:
|
||||
raise ValueError("`chunks` must be a strictly positive integer.")
|
||||
self.dim = dim
|
||||
self.chunks = chunks
|
||||
self.sizes = list(sizes) if sizes is not None else None
|
||||
self.reverse_op = Concatenate
|
||||
|
||||
def convert(self, value: torch.Tensor, *args, **kwargs) -> list[torch.Tensor]:
|
||||
# chunk requires a single tensor input
|
||||
if len(value) != 1 or len(value[0]) != 1:
|
||||
raise ValueError("Chunk operation requires a single tensor input.")
|
||||
return list(torch.chunk(value[0][0], self.chunks, dim=self.dim))
|
||||
|
||||
|
||||
class Concatenate(ConversionOps):
|
||||
"""Concatenate tensors along `dim` using a reusable buffer."""
|
||||
|
||||
reverse_op: type[ConversionOps]
|
||||
|
||||
def __init__(self, dim: int = 0):
|
||||
self.dim = dim
|
||||
self.reverse_op = Chunk
|
||||
|
||||
@torch.no_grad
|
||||
def convert(self, value: Sequence[torch.Tensor], *args, **kwargs) -> torch.Tensor:
|
||||
if isinstance(value[0], list):
|
||||
value = [v[0] for v in value]
|
||||
tensors = value
|
||||
if not tensors:
|
||||
raise ValueError("Fuse requires at least one tensor to concatenate.")
|
||||
|
||||
return torch.cat(tuple(tensors), dim=self.dim)
|
||||
|
||||
|
||||
class MergeModulelist(Concatenate):
|
||||
"""
|
||||
Merge a list of tensors into a single tensor along the first dimension.
|
||||
We explicitly define this because for EP or TP you want to make sure you know what you are doing!
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, dim: int = 0):
|
||||
super().__init__(dim=dim)
|
||||
self.reverse_op = SplitModulelist
|
||||
|
||||
@torch.no_grad
|
||||
def convert(self, value: Sequence[torch.Tensor], *args, **kwargs) -> list[torch.Tensor]:
|
||||
merged = []
|
||||
for group in value:
|
||||
if not isinstance(group, Sequence) or len(group) == 0:
|
||||
raise ValueError("MergeModulelist requires non-empty sub-sequences.")
|
||||
group = [k for k in group if k.ndim]
|
||||
merged.append(torch.stack(group, dim=self.dim))
|
||||
return merged
|
||||
|
||||
|
||||
class SplitModulelist(ConversionOps):
|
||||
"""Inverse of :class:`MergeModulelist` using explicit split sizes per group."""
|
||||
|
||||
def __init__(self, sizes: Sequence[Sequence[int]], dim: int = 0):
|
||||
if not isinstance(sizes, Sequence) or not all(isinstance(sub, Sequence) and sub for sub in sizes):
|
||||
raise ValueError("`sizes` must be a sequence of non-empty sequences of integers.")
|
||||
self.sizes = [list(sub) for sub in sizes]
|
||||
self.dim = dim
|
||||
self.reverse_op = MergeModulelist
|
||||
|
||||
@torch.no_grad
|
||||
def convert(self, value: Sequence[torch.Tensor], *, context: dict[str, Any]) -> list[list[torch.Tensor]]:
|
||||
if not isinstance(value, Sequence):
|
||||
raise TypeError("SplitModulelist expects a sequence of tensors.")
|
||||
if len(value) != len(self.sizes):
|
||||
raise ValueError("Number of tensors does not match the provided split specifications.")
|
||||
|
||||
result: list[list[torch.Tensor]] = []
|
||||
for tensor, split_sizes in zip(value, self.sizes):
|
||||
if not isinstance(tensor, torch.Tensor):
|
||||
raise TypeError("SplitModulelist can only split torch.Tensor instances.")
|
||||
splits = torch.split(tensor, split_sizes, dim=self.dim)
|
||||
result.append(list(splits))
|
||||
return result
|
||||
|
||||
|
||||
class PermuteForRope(ConversionOps):
|
||||
"""
|
||||
Applies the permutation required to convert complex RoPE weights to the split sin/cos format.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def _apply(self, tensor: torch.Tensor) -> torch.Tensor:
|
||||
dim1, dim2 = tensor.shape
|
||||
n_heads = self.config.getattr("num_attention_heads", 1)
|
||||
|
||||
tensor = tensor.view(n_heads, dim1 // n_heads // 2, 2, dim2)
|
||||
tensor = tensor.transpose(1, 2).reshape(dim1, dim2)
|
||||
return tensor
|
||||
|
||||
@torch.no_grad
|
||||
def convert(
|
||||
self, value: Union[dict[str, torch.Tensor], Sequence[torch.Tensor], torch.Tensor], config
|
||||
) -> Union[dict[str, torch.Tensor], list[torch.Tensor], torch.Tensor]:
|
||||
self.config = config
|
||||
out = [[self._apply(x) for x in inner] if isinstance(inner, list) else self._apply(inner) for inner in value]
|
||||
return out
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class WeightConverter:
|
||||
r"""
|
||||
A weight convert that acts on a pattern of source keys.
|
||||
The keys need to be collected based on the target keys.
|
||||
|
||||
With wild card, glob patterns are matched, so you have to be detailed with what to match. If you match:
|
||||
`model.layers.*.experts.*` -> it will act on all of them
|
||||
{"model.layers.*.experts.*": []}
|
||||
but
|
||||
`experts.*.mlp` will be layer specific.
|
||||
{"model.layers.1.experts.*": [], }
|
||||
- source_keys: str | list[str] (wildcards '*' match digits)
|
||||
- target_keys: str | list[str] | None
|
||||
- distributed_operation / operations / quantization_operations are ALWAYS lists.
|
||||
|
||||
TODO: for BNB we need to collect model.weight.quant_state_keys
|
||||
"""
|
||||
|
||||
source_keys: Union[str, list[str]]
|
||||
target_keys: Optional[Union[str, list[str]]] = None
|
||||
operations: list[ConversionOps] = field(default_factory=list, repr=False)
|
||||
|
||||
distributed_operation: Optional[TensorParallelLayer] = None
|
||||
quantization_operation: Optional[ConversionOps] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if not isinstance(self.source_keys, list):
|
||||
self.source_keys = [self.source_keys]
|
||||
targets_were_none = False
|
||||
if not isinstance(self.target_keys, list):
|
||||
if self.target_keys is None:
|
||||
self.target_keys = list(self.source_keys)
|
||||
targets_were_none = True
|
||||
else:
|
||||
self.target_keys = [self.target_keys]
|
||||
|
||||
if not targets_were_none and bool(len(self.source_keys) - 1) + bool(len(self.target_keys) - 1) >= 2:
|
||||
raise ValueError(
|
||||
f"source keys={self.source_keys}, target_keys={self.target_keys} but you can only have one to many, one to one or many to one."
|
||||
)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ConversionEntry:
|
||||
weight_converter: WeightConverter
|
||||
collected_tensors: dict = field(default_factory=lambda: defaultdict(dict))
|
||||
|
||||
|
||||
GLOBAL_WORKERS = min(16, (os.cpu_count() or 8) * 2) # NVMe: 8-16; HDD/NFS: 2-4
|
||||
|
||||
|
||||
# Factory function to create LoadedParameter subclasses dynamically
|
||||
def get_loaded_parameter_class(base_cls):
|
||||
"""
|
||||
base_cls: an nn.Parameter subclass (or nn.Parameter) or a Tensor
|
||||
Returns a new class that combines the base_cls with LoadedParameterMixin
|
||||
|
||||
"""
|
||||
|
||||
class LoadedParam(base_cls):
|
||||
_inplace_methods = [
|
||||
"add_",
|
||||
"mul_",
|
||||
"clamp_",
|
||||
"zero_",
|
||||
"fill_",
|
||||
"normal_",
|
||||
"uniform_",
|
||||
"copy_",
|
||||
"erfinv_",
|
||||
"log_",
|
||||
"__getitem__",
|
||||
"neg_",
|
||||
"exp_",
|
||||
"sub_",
|
||||
]
|
||||
|
||||
def __new__(cls, from_existing, **kwargs):
|
||||
if isinstance(from_existing, torch.nn.Parameter):
|
||||
inst = super().__new__(cls, from_existing.data, from_existing.requires_grad, **from_existing.__dict__)
|
||||
else:
|
||||
inst = super().__new__(cls, from_existing)
|
||||
# we store the original object to get it back later on
|
||||
inst._original = from_existing
|
||||
# Explicitly override all in-place methods per instance
|
||||
for method_name in inst._inplace_methods:
|
||||
setattr(inst, method_name, MethodType(inst._skip, inst))
|
||||
|
||||
return inst
|
||||
|
||||
def _skip(self, *args, **kwargs):
|
||||
"""Helper to skip in-place operations."""
|
||||
return self
|
||||
|
||||
def __repr__(self):
|
||||
return f"LoadedParameter(data={self.data})"
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
return super().data
|
||||
|
||||
@data.setter
|
||||
def data(self, new):
|
||||
pass
|
||||
|
||||
def __lt__(self, other):
|
||||
return torch.Tensor.__lt__(self, other)
|
||||
|
||||
def __le__(self, other):
|
||||
return torch.Tensor.__le__(self, other)
|
||||
|
||||
def __gt__(self, other):
|
||||
return torch.Tensor.__gt__(self, other)
|
||||
|
||||
def __ge__(self, other):
|
||||
return torch.Tensor.__ge__(self, other)
|
||||
|
||||
def __eq__(self, other):
|
||||
return torch.Tensor.__eq__(self, other)
|
||||
|
||||
def __ne__(self, other):
|
||||
return torch.Tensor.__ne__(self, other)
|
||||
|
||||
def __iadd__(self, *args, **kwargs):
|
||||
return self
|
||||
|
||||
def __isub__(self, *args, **kwargs):
|
||||
return self
|
||||
|
||||
def __imul__(self, *args, **kwargs):
|
||||
return self
|
||||
|
||||
def __imatmul__(self, *args, **kwargs):
|
||||
return self
|
||||
|
||||
def __itruediv__(self, *args, **kwargs):
|
||||
return self
|
||||
|
||||
def __ifloordiv__(self, *args, **kwargs):
|
||||
return self
|
||||
|
||||
def __imod__(self, *args, **kwargs):
|
||||
return self
|
||||
|
||||
def __ipow__(self, *args, **kwargs):
|
||||
return self
|
||||
|
||||
def __iand__(self, *args, **kwargs):
|
||||
return self
|
||||
|
||||
def __ior__(self, *args, **kwargs):
|
||||
return self
|
||||
|
||||
def __ixor__(self, *args, **kwargs):
|
||||
return self
|
||||
|
||||
def __ilshift__(self, *args, **kwargs):
|
||||
return self
|
||||
|
||||
def __irshift__(self, *args, **kwargs):
|
||||
return self
|
||||
|
||||
return LoadedParam
|
||||
|
||||
|
||||
def _materialize_copy(tensor, dtype=None):
|
||||
tensor = tensor[...]
|
||||
if dtype is not None:
|
||||
tensor = tensor.to(dtype)
|
||||
return tensor
|
||||
|
||||
|
||||
def spawn_materialize(thread_pool, tensor, dtype=None) -> Future:
|
||||
def _job():
|
||||
return _materialize_copy(tensor, dtype)
|
||||
|
||||
return thread_pool.submit(_job)
|
||||
|
||||
|
||||
def spawn_tp_materialize(thread_pool, tensor, sharding_method, tensor_idx, dtype=None) -> Future:
|
||||
def _job():
|
||||
return sharding_method.shard_tensor(tensor, param_casting_dtype=dtype, tensor_idx=tensor_idx)[0]
|
||||
|
||||
return thread_pool.submit(_job)
|
||||
|
||||
|
||||
def dot_natural_key(s: str):
|
||||
parts = s.split(".")
|
||||
for i, p in enumerate(parts):
|
||||
# whole-segment digits -> int; otherwise leave as str
|
||||
if p.isdigit():
|
||||
parts[i] = int(p)
|
||||
return parts
|
||||
|
||||
|
||||
@contextmanager
|
||||
def log_to_misc(
|
||||
layer_name: str,
|
||||
misc: MutableMapping[str, str],
|
||||
extras: Any = None,
|
||||
op: Union[list[ConversionOps], ConversionOps, None] = None,
|
||||
):
|
||||
# A simple helper to handle errors with contextual messages.
|
||||
try:
|
||||
yield
|
||||
except Exception as e:
|
||||
|
||||
def _format_op_name(curr_op: Union[list[ConversionOps], ConversionOps, None]) -> Optional[str]:
|
||||
if curr_op is None:
|
||||
return None
|
||||
if isinstance(curr_op, (list, tuple, set)):
|
||||
names = [o.__class__.__name__ for o in curr_op if o is not None]
|
||||
if not names:
|
||||
return None
|
||||
return ", ".join(names)
|
||||
return curr_op.__class__.__name__
|
||||
|
||||
op_name = _format_op_name(op)
|
||||
if isinstance(extras, tuple) and len(extras) == 2:
|
||||
values, target_keys = extras
|
||||
descriptor = f"{op_name} " if op_name else ""
|
||||
misc[layer_name] = (
|
||||
f"{e}\nError: {descriptor}on tensors destined for {target_keys}. Ckpt contains: {len(values[0])}"
|
||||
)
|
||||
elif isinstance(extras, str):
|
||||
suffix = f" via {op_name}" if op_name else ""
|
||||
misc[layer_name] = f"{e}\nError{suffix} when processing parameter {extras}"
|
||||
elif extras is None and op_name:
|
||||
misc[layer_name] = f"{op_name}: {e}"
|
||||
else:
|
||||
misc[layer_name] = f"{extras} |Error: {e}"
|
||||
raise SkipLayer()
|
||||
|
||||
|
||||
def set_param_for_module(
|
||||
model: PreTrainedModel,
|
||||
layer_name: str,
|
||||
param_value: torch.Tensor,
|
||||
mismatch_keys: MutableSet[tuple[str, torch.Size, torch.Size]],
|
||||
missing_keys: MutableSet[str],
|
||||
misc: MutableMapping[str, Any],
|
||||
distributed_operation: Optional[TensorParallelLayer],
|
||||
):
|
||||
with log_to_misc(layer_name, misc, layer_name):
|
||||
module_path, _, param_name = layer_name.rpartition(".")
|
||||
module_obj = model.get_submodule(module_path) if module_path else model
|
||||
param_value = param_value[0] if isinstance(param_value, list) else param_value[...]
|
||||
ref = getattr(module_obj, param_name)
|
||||
|
||||
use_dtensor = hasattr(distributed_operation, "use_dtensor") and distributed_operation.use_dtensor
|
||||
if not isinstance(param_value, torch.nn.Parameter):
|
||||
if distributed_operation is not None:
|
||||
param_value = DTensor.from_local(
|
||||
param_value,
|
||||
distributed_operation.device_mesh,
|
||||
getattr(distributed_operation, "shard", Replicate()),
|
||||
run_check=False,
|
||||
shape=ref.size(),
|
||||
stride=ref.stride(),
|
||||
)
|
||||
if not use_dtensor:
|
||||
# we convert to local
|
||||
param_value = param_value.to_local()
|
||||
if param_name not in module_obj._buffers:
|
||||
param_value = torch.nn.Parameter(param_value, requires_grad=param_value.is_floating_point())
|
||||
param_value = get_loaded_parameter_class(param_value.__class__)(from_existing=param_value)
|
||||
|
||||
# Remove from missing keys (it's either mismatched, or all good)
|
||||
missing_keys.discard(layer_name)
|
||||
if ref is not None and ref.shape != param_value.shape:
|
||||
mismatch_keys.add((layer_name, param_value.shape, ref.shape))
|
||||
module_obj.param_name._is_hf_initialized = False # Needs to be initialized
|
||||
else:
|
||||
param_value._is_hf_initialized = True # super important otherwise _init_weight re-initi if bias is missing
|
||||
setattr(module_obj, param_name, param_value)
|
||||
|
||||
|
||||
class SkipLayer(Exception):
|
||||
"""Control-flow sentinel: abort processing of the current layer only."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def convert_and_load_state_dict_in_model(
|
||||
model: PreTrainedModel,
|
||||
state_dict: dict[str, Any],
|
||||
weight_mapping: dict[str, WeightConverter] | None,
|
||||
tp_plan: dict[str, str] | None,
|
||||
quantizer: HfQuantizer | None,
|
||||
dtype: torch.dtype | None = None,
|
||||
device_map: dict | None = None,
|
||||
dtype_plan: dict | None = None,
|
||||
device_mesh: torch.distributed.device_mesh.DeviceMesh | None = None,
|
||||
):
|
||||
"""
|
||||
Convert a state dict according to a weight mapping (one WeightConverter per glob pattern),
|
||||
collecting tensors per *layer instance* (the concrete indices captured from '*').
|
||||
"""
|
||||
|
||||
prefix = model.base_model_prefix
|
||||
tp_plan = tp_plan or {} # {glob_pattern: plan_obj_or_key}
|
||||
device_map = device_map or {} # {exact_target_key: device}
|
||||
dtype_plan = dtype_plan or {} # {glob_pattern: dtype}
|
||||
weight_mapping = weight_mapping or {} # {glob_pattern: WeightConverter}
|
||||
meta_model_state_dict = model.state_dict()
|
||||
missing_keys = set(meta_model_state_dict.keys())
|
||||
|
||||
misc = {}
|
||||
mismatch_keys = set()
|
||||
unexpected_keys = set()
|
||||
# Global thread_pool
|
||||
thread_pool = ThreadPoolExecutor(max_workers=GLOBAL_WORKERS)
|
||||
|
||||
_patterns = list(itertools.chain.from_iterable([k.source_keys for k in weight_mapping]))
|
||||
source_to_target = {sk: k for k in weight_mapping for sk in k.source_keys}
|
||||
weight_pattern_alt, weight_pattern_by_group_name = build_glob_alt(_patterns)
|
||||
tp_plan_alt, tp_plan_by_group_name = build_glob_alt(list(tp_plan.keys()))
|
||||
dtype_policy_alt, dtype_policy_by_group_name = build_glob_alt(list(dtype_plan.keys()))
|
||||
|
||||
state_dict = sorted(state_dict.items(), key=lambda kv: dot_natural_key(kv[0]))
|
||||
# 1. Create the conversion entries
|
||||
by_conversion_pattern: dict[str, ConversionEntry] = {}
|
||||
for original_key, tensor in state_dict:
|
||||
matched_pattern = match_glob(original_key, weight_pattern_alt, weight_pattern_by_group_name)
|
||||
if matched_pattern is not None:
|
||||
converter = source_to_target[matched_pattern] # TODO make sure its the ref
|
||||
sub_with_extractor = partial(re.sub, matched_pattern.replace("*", r"(\d+)"), string=original_key)
|
||||
entry_key = "|".join(converter.target_keys)
|
||||
target_key = "|".join(map(sub_with_extractor, [k.replace("*", "\\1") for k in converter.target_keys]))
|
||||
entry: ConversionEntry = by_conversion_pattern.setdefault(entry_key, ConversionEntry(converter))
|
||||
converter_key = sub_with_extractor(matched_pattern)
|
||||
else:
|
||||
converter = WeightConverter(original_key)
|
||||
converter_key = entry_key = target_key = original_key
|
||||
entry = by_conversion_pattern.setdefault(converter_key, ConversionEntry(converter))
|
||||
|
||||
_dtype = dtype
|
||||
new_target_key = [] # test_load_with_mismatched_shapes for AutoModel.from_pretrained(AutoForCausal, vocab=10)
|
||||
for t in target_key.split("|"):
|
||||
if t.startswith(prefix) and meta_model_state_dict.get(re.sub(f"^{prefix}.", "", t, count=1)) is not None:
|
||||
t = re.sub(f"^{prefix}.", "", t, count=1)
|
||||
elif meta_model_state_dict.get(f"{prefix}.{t}") is not None:
|
||||
t = f"{prefix}.{t}"
|
||||
new_target_key.append(t)
|
||||
empty_param = meta_model_state_dict.get(t)
|
||||
# If it does not exist, it's unexpected
|
||||
if empty_param is None:
|
||||
unexpected_keys.add(t)
|
||||
continue
|
||||
|
||||
if quantizer is not None and quantizer.param_needs_quantization(model, t):
|
||||
if quantizer.__class__.__name__ == "FineGrainedFP8HfQuantizer":
|
||||
from .integrations.finegrained_fp8 import Fp8Quantize
|
||||
|
||||
converter.quantization_operation = Fp8Quantize() # TODO support other methods
|
||||
else:
|
||||
raise ValueError("This quantization method is gonna be supported SOOOON")
|
||||
else:
|
||||
_dtype = dtype
|
||||
matched_dtype_pattern = match_glob(t, dtype_policy_alt, dtype_policy_by_group_name)
|
||||
if matched_dtype_pattern is not None:
|
||||
_dtype = dtype_plan[matched_dtype_pattern]
|
||||
elif empty_param.dtype != _dtype:
|
||||
_dtype = empty_param.dtype
|
||||
|
||||
first_target_key = new_target_key[0]
|
||||
target_key = "|".join(new_target_key)
|
||||
|
||||
future = None
|
||||
if device_mesh:
|
||||
if matched_tp_pattern := match_glob(first_target_key, tp_plan_alt, tp_plan_by_group_name):
|
||||
empty_param = meta_model_state_dict.get(first_target_key)
|
||||
if getattr(converter, "distributed_operation", {}) is None:
|
||||
tp_layer = ALL_PARALLEL_STYLES[model.tp_plan[matched_tp_pattern]].__class__
|
||||
converter.distributed_operation = tp_layer(
|
||||
device_mesh=device_mesh, rank=device_map[""].index, empty_param=empty_param.clone()
|
||||
)
|
||||
# VERY IMPORTANT: this tells us wether we collected stuffs or not.
|
||||
shard_index = len(entry.collected_tensors[target_key].get(converter_key, []))
|
||||
future = spawn_tp_materialize(
|
||||
thread_pool,
|
||||
tensor,
|
||||
_dtype,
|
||||
converter.distributed_operation,
|
||||
shard_index,
|
||||
)
|
||||
|
||||
if future is None: # If not TP, async materialize the tensors. TODO handle disk offload?
|
||||
future = spawn_materialize(thread_pool, tensor, _dtype)
|
||||
entry.collected_tensors[target_key].setdefault(converter_key, []).append(future)
|
||||
|
||||
# 2. Actually convert the ckpt
|
||||
inverse_converters = {}
|
||||
keys = list(by_conversion_pattern.keys())
|
||||
|
||||
with logging.tqdm(total=len(keys), desc="Loading weights") as pbar:
|
||||
for key in keys[::-1]: # revert to process simple keys first
|
||||
group = by_conversion_pattern.pop(key)
|
||||
converter = group.weight_converter
|
||||
operations = converter.operations if isinstance(converter.operations, list) else [converter.operations]
|
||||
for layer_name, tensors_for_this_layer in group.collected_tensors.items():
|
||||
pbar.update(1)
|
||||
pbar.set_postfix({"Materializing param": layer_name})
|
||||
pbar.refresh()
|
||||
concrete_target_keys = layer_name.split("|")
|
||||
try:
|
||||
if bool(set(concrete_target_keys) - unexpected_keys):
|
||||
with log_to_misc(layer_name, misc):
|
||||
values = [[k.result() for k in inner] for inner in tensors_for_this_layer.values()]
|
||||
|
||||
for op in operations:
|
||||
with log_to_misc(layer_name, misc, (values, concrete_target_keys), operations):
|
||||
values = op.convert(values, model.config)
|
||||
|
||||
values = [values] if not isinstance(values, list) else values
|
||||
with log_to_misc(layer_name, misc, (values, concrete_target_keys), operations):
|
||||
realized_value = {
|
||||
k: t for k, t in zip(concrete_target_keys, values) if k not in unexpected_keys
|
||||
}
|
||||
|
||||
for k in list(realized_value.keys()).copy():
|
||||
if op := converter.quantization_operation:
|
||||
with log_to_misc(layer_name, misc, op=op):
|
||||
realized_value.update(
|
||||
op.convert(
|
||||
{k: realized_value.pop(k)}, quant_config=quantizer.quantization_config
|
||||
)
|
||||
)
|
||||
|
||||
for k, output_value in realized_value.items():
|
||||
for src in converter.source_keys: # what should happen to k when we meet k at saving
|
||||
inverse_converters[k] = {src: converter}
|
||||
set_param_for_module(
|
||||
model,
|
||||
k,
|
||||
output_value,
|
||||
mismatch_keys,
|
||||
missing_keys,
|
||||
misc,
|
||||
converter.distributed_operation,
|
||||
)
|
||||
|
||||
except SkipLayer:
|
||||
continue
|
||||
del group
|
||||
|
||||
model.inverse_converters = inverse_converters
|
||||
thread_pool.shutdown(wait=False)
|
||||
return missing_keys, unexpected_keys, mismatch_keys, misc
|
||||
|
||||
|
||||
# TODO this is not done yet!
|
||||
def revert_weight_conversion(model, state_dict):
|
||||
mapping = getattr(model, "_checkpoint_conversion_mapping", {}) # IDK why but setting this will fail all llava.
|
||||
reverse_key_mapping = [(v, k) for k, v in mapping.items()]
|
||||
original_state_dict = {}
|
||||
for key, value in state_dict.items():
|
||||
for pattern, inverse_converter in reverse_key_mapping:
|
||||
# TODO FIXME you name it
|
||||
replacement = inverse_converter.lstrip("^") # strip off un-needed chars and patterns
|
||||
replacement = re.sub(r"\(.*\)", "", replacement)
|
||||
key, n_replace = re.subn(pattern, replacement, key)
|
||||
# Early exit of the loop
|
||||
if n_replace > 0:
|
||||
break
|
||||
original_state_dict[key] = value
|
||||
state_dict = original_state_dict
|
||||
return state_dict
|
||||
@ -12,7 +12,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from collections import deque
|
||||
from math import floor, gcd, sqrt
|
||||
from typing import Optional
|
||||
|
||||
@ -21,8 +20,8 @@ import torch
|
||||
from ...configuration_utils import PreTrainedConfig
|
||||
from ...generation.configuration_utils import GenerationConfig
|
||||
from ...utils.metrics import attach_tracer, traced
|
||||
from .cache_manager import CacheAllocator, FullAttentionCacheAllocator, SlidingAttentionCacheAllocator
|
||||
from .requests import get_device_and_memory_breakdown, logger
|
||||
from .cache_manager import BlockManager, CacheAllocator, FullAttentionCacheAllocator, SlidingAttentionCacheAllocator
|
||||
from .requests import RequestState, get_device_and_memory_breakdown, logger
|
||||
|
||||
|
||||
def group_layers_by_attn_type(config: PreTrainedConfig) -> tuple[list[list[int]], list[str]]:
|
||||
@ -32,7 +31,7 @@ def group_layers_by_attn_type(config: PreTrainedConfig) -> tuple[list[list[int]]
|
||||
- All groups have the same number of layers
|
||||
|
||||
For a model with the following layer types: ["sliding", "full", "full", "sliding", "full", "full", "full", "full"]
|
||||
We would get two groups: [0, 3] and [1, 2], [4,5], [6,7].
|
||||
We would get four groups: [0, 3], [1, 2], [4,5] and [6,7].
|
||||
"""
|
||||
# If the config has no layer_type attribute, it means all layers are the same attention type
|
||||
layer_types = getattr(config, "layer_types", None)
|
||||
@ -116,7 +115,6 @@ class PagedAttentionCache:
|
||||
for the sliding-attention group, although it is not needed.
|
||||
"""
|
||||
|
||||
# TODO: this init is quite long, maybe a refactor is in order
|
||||
def __init__(
|
||||
self,
|
||||
config: PreTrainedConfig,
|
||||
@ -124,8 +122,10 @@ class PagedAttentionCache:
|
||||
device: torch.device,
|
||||
dtype: torch.dtype = torch.float16,
|
||||
tp_size: Optional[int] = None,
|
||||
allow_prefix_sharing: bool = True,
|
||||
) -> None:
|
||||
"""Initialize a paged attention cache for efficient memory usage.
|
||||
"""Initialize a paged attention cache for efficient memory usage. Also turns in prefix sharing if the model has
|
||||
only full attention layers.
|
||||
|
||||
Args:
|
||||
config: Model configuration
|
||||
@ -133,6 +133,7 @@ class PagedAttentionCache:
|
||||
device: Device for the cache tensors
|
||||
dtype: Data type of the cache
|
||||
tp_size: Tensor parallelism size
|
||||
allow_prefix_sharing: A flag to allow prefix sharing if the model has only full attention layers.
|
||||
"""
|
||||
self.config = config
|
||||
self.dtype = dtype
|
||||
@ -173,10 +174,12 @@ class PagedAttentionCache:
|
||||
page_size = self.head_dim * self.num_key_value_heads
|
||||
|
||||
if "flash" in self.config._attn_implementation:
|
||||
num_attention_masks = 1 # only used to compute the default meme args
|
||||
else:
|
||||
num_attention_masks = 0 # only used to compute the default memory footprint args
|
||||
elif "sliding_attention" in group_types:
|
||||
# TODO: when we generalize to allow for block-attn, we can use `num_attention_masks=sum(set(group_types))`
|
||||
num_attention_masks = 2 if "sliding_attention" in group_types else 1
|
||||
num_attention_masks = 2
|
||||
else:
|
||||
num_attention_masks = 1
|
||||
|
||||
memory_handler = PagedAttentionMemoryHandler(
|
||||
block_size=self.block_size,
|
||||
@ -218,7 +221,6 @@ class PagedAttentionCache:
|
||||
logger.info(f"{self.cache_shape = } {self.key_cache[0].shape = } {self.key_cache[0].numel() = }")
|
||||
|
||||
# Block management data structures
|
||||
self._free_blocks = deque(range(num_blocks))
|
||||
self.group_cache_managers: list[CacheAllocator] = []
|
||||
for i, group_type in enumerate(group_types):
|
||||
if group_type == "full_attention":
|
||||
@ -229,13 +231,19 @@ class PagedAttentionCache:
|
||||
raise ValueError(f"Invalid group type: {group_type}")
|
||||
self.group_cache_managers.append(cm)
|
||||
|
||||
# We only use prefix sharing if the whole model has only full attention layers
|
||||
self.use_prefix_sharing = allow_prefix_sharing and group_types == ["full_attention"]
|
||||
self._block_manager = BlockManager(num_blocks, self.block_size, self.use_prefix_sharing)
|
||||
self.blocks_to_complete: dict[str, int] = {}
|
||||
self._total_prefix_length: int = 0 # a counter to measure the impact of prefix sharing, also used in tests
|
||||
|
||||
@traced
|
||||
def allocate_blocks(self, n_blocks: int, request_id: str) -> int:
|
||||
def allocate_blocks(self, n_blocks: int, state: RequestState) -> int:
|
||||
"""Allocate cache blocks across all layer groups for a given request. Actual allocation is done by the cache
|
||||
managers, and this method only returns the maximum number of blocks actually allocated across all managers."""
|
||||
max_allocated = 0
|
||||
for cm in self.group_cache_managers:
|
||||
allocated = cm.allocate_blocks(n_blocks, request_id, self._free_blocks)
|
||||
allocated = cm.allocate_blocks(n_blocks, state.request_id, self._block_manager)
|
||||
if allocated is None:
|
||||
return None
|
||||
max_allocated = max(max_allocated, allocated)
|
||||
@ -246,11 +254,11 @@ class PagedAttentionCache:
|
||||
"""Free all allocated cache blocks for a given request across all layer groups. Actual deallocation is done
|
||||
by the cache managers."""
|
||||
for cm in self.group_cache_managers:
|
||||
cm.free_blocks(request_id, self._free_blocks)
|
||||
cm.free_blocks(request_id, self._block_manager)
|
||||
|
||||
def get_num_free_blocks(self) -> int:
|
||||
"""Get the current number of unallocated blocks available for new requests."""
|
||||
return len(self._free_blocks)
|
||||
return self._block_manager.num_free_blocks
|
||||
|
||||
@traced
|
||||
def extend_read_indices(
|
||||
@ -337,6 +345,44 @@ class PagedAttentionCache:
|
||||
# Return the new KV values
|
||||
return key_states_with_cache, value_states_with_cache
|
||||
|
||||
def search_prefix_match(self, request_id: str, prompt_ids: list[int]) -> int:
|
||||
"""Searches for a prefix match in the cache for the given (prompts_ids). If one is found, we reference the
|
||||
matching blocks in the (request_id), increase the reference count of the blocks and return the number of blocks
|
||||
that match. If no prefix match is found, we return 0."""
|
||||
current_hash = None
|
||||
allocated_blocks = []
|
||||
for b in range(len(prompt_ids) // self.block_size):
|
||||
tokens = prompt_ids[b * self.block_size : (b + 1) * self.block_size]
|
||||
current_hash = self._block_manager.compute_hash(current_hash, tokens)
|
||||
block_id = self._block_manager._hash_to_id.get(current_hash)
|
||||
if block_id is not None:
|
||||
allocated_blocks.append(block_id)
|
||||
self._block_manager.increase_ref_count(block_id)
|
||||
else:
|
||||
break
|
||||
# If we found a matching prefix, we reference the blocks in the request
|
||||
if allocated_blocks:
|
||||
logger.debug(f"Found prefix match for request {request_id} with {len(allocated_blocks)} blocks")
|
||||
cm = self.group_cache_managers[0]
|
||||
cm.block_table[request_id] = allocated_blocks
|
||||
|
||||
prefix_length = len(allocated_blocks) * self.block_size
|
||||
self._total_prefix_length += prefix_length
|
||||
return prefix_length
|
||||
|
||||
def mark_blocks_as_complete(self, state: RequestState) -> None:
|
||||
"""Marks the blocks that have been computed in the forward pass as complete. If prefix sharing is off, this is
|
||||
a no-op."""
|
||||
num_complete_blocks = 0 if not self.use_prefix_sharing else self.blocks_to_complete.pop(state.request_id)
|
||||
if num_complete_blocks == 0:
|
||||
return None
|
||||
cm = self.group_cache_managers[0] # if prefix sharing is on, there is only one group
|
||||
self._block_manager.mark_blocks_as_complete(
|
||||
num_complete_blocks=num_complete_blocks,
|
||||
allocated_blocks=cm.block_table[state.request_id],
|
||||
prompt_ids=(state.full_prompt_ids + state.static_outputs),
|
||||
)
|
||||
|
||||
|
||||
# TODO: rework computation with the groups and their sizes
|
||||
class PagedAttentionMemoryHandler:
|
||||
@ -471,6 +517,8 @@ class PagedAttentionMemoryHandler:
|
||||
2N * (layer_group_size * page_size * cache_dtype + 2 * num_group),
|
||||
m * N * (peak_activation_per_token * activation_dtype + 28 + 4 * num_group),
|
||||
])
|
||||
|
||||
If num_attention_masks is 0, the equation simplifies to a 1st degree polynomial.
|
||||
"""
|
||||
cache_memory = self.get_available_memory(max_memory_percent)
|
||||
logger.info(f"Cache memory: {cache_memory}")
|
||||
@ -482,11 +530,16 @@ class PagedAttentionMemoryHandler:
|
||||
c = -cache_memory
|
||||
logger.debug(f"Coefficients of 2nd degree polynomial: {a = }, {b = }, {c = }")
|
||||
|
||||
# Compute discriminant and greatest solution
|
||||
discriminant = b**2 - 4 * a * c
|
||||
if discriminant < 0:
|
||||
raise ValueError(f"Discriminant is negative: {discriminant = }")
|
||||
greatest_solution = (-b + sqrt(discriminant)) / (2 * a)
|
||||
# If num_attention_masks is 0, the equation simplifies to a 1st degree polynomial
|
||||
if self.num_attention_masks == 0:
|
||||
greatest_solution = -c / b
|
||||
# Otherwise, we solve the quadratic equation
|
||||
else:
|
||||
discriminant = b**2 - 4 * a * c
|
||||
if discriminant < 0:
|
||||
raise ValueError(f"Discriminant is negative: {discriminant = }")
|
||||
greatest_solution = (-b + sqrt(discriminant)) / (2 * a)
|
||||
|
||||
if greatest_solution < 0:
|
||||
raise ValueError(f"Greatest solution is negative: {greatest_solution = }")
|
||||
|
||||
|
||||
@ -14,29 +14,211 @@
|
||||
# limitations under the License.
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import deque
|
||||
from collections.abc import Iterator
|
||||
from math import ceil
|
||||
from typing import Optional
|
||||
from typing import Optional, TypeVar
|
||||
|
||||
from .requests import logger
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def reverse_enumerate(xs: list[T]) -> Iterator[tuple[int, T]]:
|
||||
index = len(xs) - 1
|
||||
for x in xs[::-1]:
|
||||
yield index, x
|
||||
index -= 1
|
||||
|
||||
|
||||
class Block:
|
||||
"""A class to represent a block managed by the block manager. We say that a block is complete when the physical KV
|
||||
cache it points to is fully computed. A block can have a parent, which is the block that came before in the
|
||||
sequence. Once a block is complete, it is given a hash, which takes into account the tokens ids of the block and
|
||||
its parent's hash (if there is a parent)."""
|
||||
|
||||
def __init__(self, id_: int, parent_id: int | None) -> None:
|
||||
self.id: int = id_
|
||||
self.parent_id: int | None = parent_id
|
||||
self.hash: int | None = None
|
||||
self.ref_count: int = 1
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"Block(id={self.id}, parent_id={self.parent_id}, hash={self.hash}, ref_count={self.ref_count})"
|
||||
|
||||
@property
|
||||
def is_complete(self) -> bool:
|
||||
return self.hash is not None
|
||||
|
||||
|
||||
class BlockManager:
|
||||
"""A class to manage the number of free blocks and block re-use. If prefix sharing is off, the block manager is a
|
||||
simple FIFO structure where blocks are either free or in use. If prefix sharing is on, blocks can have 3 states:
|
||||
- in use: one or more requests references this block, thus it cannot be written over. The number of requests
|
||||
referencing this block is stored as ref_count in the Block object.
|
||||
- un-initialized: the block points to a space in the KV cache tensor that contains no data yet. Those blocks can
|
||||
be given as free blocks to new requests without any overhead.
|
||||
- initialized: the block is complete and was used by one or more request that are finished. It contains KV cache
|
||||
data and its hash is stored in the hash table. If a new request needs a block with the same hash, we increase
|
||||
the ref_count of the block and remove it from the list of initialized blocks, because it is now in use.
|
||||
Still, the block can be freed if no un-initialized blocks are left. In that case, we remove its hash from the
|
||||
hash table.
|
||||
There is no structure to keep track of the blocks in use: if a block is neither un-initialized nor initialized,
|
||||
it is in use.
|
||||
"""
|
||||
|
||||
def __init__(self, num_blocks: int, block_size: int, use_prefix_sharing: bool) -> None:
|
||||
"""Initializes the block manager with a given number of blocks (num_blocks) of size (block_size). Prefix sharing
|
||||
can be turned on with the (use_prefix_sharing) flag, which only happens if the model has only full attention
|
||||
layers."""
|
||||
self.num_blocks = num_blocks
|
||||
self.block_size = block_size
|
||||
self._uninit_block_ids = deque(range(num_blocks))
|
||||
self._init_block_ids: dict[int, None] = {} # effectively act as an ordered set
|
||||
self._use_prefix_sharing = use_prefix_sharing
|
||||
self._hash_to_id: dict[int, int] = {}
|
||||
self._id_to_block: dict[int, Block] = {}
|
||||
|
||||
@property
|
||||
def num_free_blocks(self) -> int:
|
||||
"""Returns the number of free blocks left. Both initialized and uninitialized blocks are considered free."""
|
||||
return len(self._uninit_block_ids) + len(self._init_block_ids)
|
||||
|
||||
def is_enough_free_blocks(self, n_blocks: int) -> bool:
|
||||
"""Checks if there are enough free blocks to allocate the requested number of blocks (n_blocks). If there are
|
||||
not enough uninitialized blocks, we uninitialize the required number of initialized blocks."""
|
||||
# Exit early if there are enough uninitialized blocks
|
||||
if len(self._uninit_block_ids) >= n_blocks:
|
||||
return True
|
||||
# Exit early if even after uninitializing all initialized blocks, there are not enough free blocks
|
||||
block_to_unintialize = n_blocks - len(self._uninit_block_ids)
|
||||
if len(self._init_block_ids) < block_to_unintialize:
|
||||
return False
|
||||
# Uninitialize the required amount of blocks
|
||||
for _ in range(block_to_unintialize):
|
||||
id_to_unintialize = self._init_block_ids.popitem()[0]
|
||||
block = self._id_to_block[id_to_unintialize]
|
||||
self._hash_to_id.pop(block.hash)
|
||||
self._uninit_block_ids.append(id_to_unintialize)
|
||||
return True
|
||||
|
||||
def get_free_blocks(self, n_blocks: int, last_block_id: int | None) -> list[int] | None:
|
||||
"""Returns a list of (n_blocks) free block and mark them as no longuer free in the internal data structures. One
|
||||
can also pass a (last_block_id) to indicate the last block id in the sequence, which is used to keep track of
|
||||
the parent block. If the manager cannot find enough free blocks, it returns None."""
|
||||
if not self.is_enough_free_blocks(n_blocks):
|
||||
return None
|
||||
allocated_block_ids = [self._uninit_block_ids.popleft() for _ in range(n_blocks)]
|
||||
# If we use prefix caching, we keep track of the allocated blocks as partial blocks
|
||||
if self._use_prefix_sharing:
|
||||
for block_id in allocated_block_ids:
|
||||
block = Block(block_id, last_block_id)
|
||||
self._id_to_block[block_id] = block
|
||||
last_block_id = block_id
|
||||
# In both cases, we return the allocated block ids
|
||||
return allocated_block_ids
|
||||
|
||||
def increase_ref_count(self, block_id: int) -> None:
|
||||
"""Increases the reference count of a given (block_id)."""
|
||||
block = self._id_to_block[block_id]
|
||||
block.ref_count += 1
|
||||
if block.ref_count == 1:
|
||||
self._init_block_ids.pop(block_id)
|
||||
|
||||
def decrease_ref_count(self, block_id: int) -> None:
|
||||
"""Decreases the reference count of a given (block_id). If the reference count reaches 0, the block is no longer
|
||||
in use, and becomes initialized (if it was complete) or uninitialized (if it was incomplete)."""
|
||||
block = self._id_to_block[block_id]
|
||||
block.ref_count -= 1
|
||||
if block.ref_count == 0:
|
||||
if block.is_complete:
|
||||
self._init_block_ids[block_id] = None
|
||||
else:
|
||||
self._id_to_block.pop(block_id)
|
||||
self._uninit_block_ids.append(block_id)
|
||||
|
||||
def free_blocks(self, blocks: list[int]) -> None:
|
||||
"""Marks a list of (blocks) as free. If there is no prefix sharing, we simply add them to the uninitialized
|
||||
blocks queue. Otherwise, their new state depends on whether they are complete."""
|
||||
if self._use_prefix_sharing:
|
||||
for block_id in blocks:
|
||||
self.decrease_ref_count(block_id)
|
||||
else:
|
||||
self._uninit_block_ids.extend(blocks)
|
||||
|
||||
def mark_blocks_as_complete(
|
||||
self, num_complete_blocks: int, allocated_blocks: list[int], prompt_ids: list[int]
|
||||
) -> None:
|
||||
"""Among the list of (allocated_blocks), mark (num_complete_blocks) incomplete blocks as now complete. The list
|
||||
of (prompt_ids) is used to compute the hash of the new block."""
|
||||
# Look for the first complete block, starting from the last block in the sequence
|
||||
parent_hash = None
|
||||
incomplete_blocks: list[Block] = []
|
||||
for i, block_id in reverse_enumerate(allocated_blocks):
|
||||
block = self._id_to_block[block_id]
|
||||
if block.is_complete:
|
||||
parent_hash = block.hash
|
||||
break
|
||||
incomplete_blocks.append((i, block))
|
||||
|
||||
# Now go through the incomplete blocks and updated them
|
||||
new_parent_id = None
|
||||
while incomplete_blocks:
|
||||
i, block = incomplete_blocks.pop()
|
||||
|
||||
# If the parent id has been updated, we apply the change
|
||||
if new_parent_id is not None:
|
||||
block.parent_id = new_parent_id
|
||||
new_parent_id = None
|
||||
|
||||
# If we have set the hash for all complete blocks, we can stop
|
||||
if num_complete_blocks == 0:
|
||||
break
|
||||
|
||||
# Otherwise, we compute the hash
|
||||
num_complete_blocks -= 1
|
||||
tokens = prompt_ids[i * self.block_size : (i + 1) * self.block_size]
|
||||
block.hash = self.compute_hash(parent_hash, tokens)
|
||||
|
||||
existing_block_id = self._hash_to_id.get(block.hash)
|
||||
# If the block hash is already in the hash to id mapping, we reference the existing block instead
|
||||
if existing_block_id is not None:
|
||||
logger.debug(f"Found existing block {existing_block_id} for block {block.id}")
|
||||
allocated_blocks[i] = existing_block_id
|
||||
self._id_to_block[existing_block_id].ref_count += 1
|
||||
new_parent_id = existing_block_id
|
||||
self.free_blocks([block.id])
|
||||
|
||||
# Otherwise, we add the completed block to the hash table
|
||||
else:
|
||||
self._hash_to_id[block.hash] = block.id
|
||||
|
||||
# Update loop variables
|
||||
parent_hash = block.hash
|
||||
|
||||
def compute_hash(self, parent_hash: int | None, tokens: list[int]) -> int:
|
||||
"""Computes the hash of a block containing the given (tokens) with a given (parent_hash). If the block has no
|
||||
parent, the parent hash is None."""
|
||||
return hash((parent_hash, tuple(tokens)))
|
||||
|
||||
|
||||
class CacheAllocator(ABC):
|
||||
"""Abstract base class for cache managers. Cache managers keep track of per-request cache allocations, determine
|
||||
when a new physical block needs to be allocated and compute physical indices for reading or writing to the cache."""
|
||||
|
||||
_index: int
|
||||
_block_table: dict[str, list[int]] # request_id -> list of block_ids allocated to the request
|
||||
block_table: dict[str, list[int]] # request_id -> list of block_ids allocated to the request
|
||||
|
||||
@abstractmethod
|
||||
def allocate_blocks(self, n_blocks: int, request_id: str, free_blocks: deque[int]) -> Optional[int]:
|
||||
"""Allocates n_blocks for a given request_id. Returns the num of blocks allocated if successful and None
|
||||
otherwise."""
|
||||
def allocate_blocks(self, n_blocks: int, request_id: str, block_manager: BlockManager) -> Optional[int]:
|
||||
"""Allocates (n_blocks) for a given (request_id) using the (block_manager). Returns the num of blocks allocated
|
||||
if successful and None otherwise."""
|
||||
|
||||
def free_blocks(self, request_id: str, free_blocks: deque[int]) -> None:
|
||||
"""Frees all blocks associated with a request_id."""
|
||||
if request_id in self._block_table:
|
||||
blocks_to_free = self._block_table.pop(request_id)
|
||||
free_blocks.extend(blocks_to_free)
|
||||
def free_blocks(self, request_id: str, block_manager: BlockManager) -> None:
|
||||
"""Frees all blocks associated with a (request_id) using the (block_manager)."""
|
||||
if request_id in self.block_table:
|
||||
blocks_to_free = self.block_table.pop(request_id)
|
||||
block_manager.free_blocks(blocks_to_free)
|
||||
else:
|
||||
logger.warning(
|
||||
f"CacheAllocator {self._index} attempted to free blocks for non-existent request_id: {request_id}"
|
||||
@ -66,23 +248,30 @@ class FullAttentionCacheAllocator(CacheAllocator):
|
||||
"""
|
||||
self._index = index
|
||||
self.block_size = block_size
|
||||
self._block_table = {}
|
||||
self.block_table = {}
|
||||
|
||||
def allocate_blocks(self, n_blocks: int, request_id: str, free_blocks: deque[int]) -> Optional[int]:
|
||||
"""Allocate blocks for a given request_id. Returns the number of blocks allocated if successful and None
|
||||
otherwise. For group of full attention layers, we always allocate the number of requested blocks."""
|
||||
if len(free_blocks) < n_blocks:
|
||||
def allocate_blocks(self, n_blocks: int, request_id: str, block_manager: BlockManager) -> Optional[int]:
|
||||
"""Allocate (n_blocks) for a given (request_id) using the (block_manager). Returns the number of blocks
|
||||
allocated if successful and None otherwise. For group of full attention layers, we always allocate the number of
|
||||
requested blocks."""
|
||||
# Make sure the request_id is in the block table and get the first block id
|
||||
if request_id not in self.block_table:
|
||||
self.block_table[request_id] = [] # TODO: check the impact of making this a deque
|
||||
last_block_id = None
|
||||
else:
|
||||
last_block_id = self.block_table[request_id][-1]
|
||||
# Actual allocation, return early if failed
|
||||
allocated_blocks = block_manager.get_free_blocks(n_blocks, last_block_id)
|
||||
if allocated_blocks is None:
|
||||
return None
|
||||
if request_id not in self._block_table:
|
||||
self._block_table[request_id] = []
|
||||
self._block_table[request_id].extend(free_blocks.popleft() for _ in range(n_blocks))
|
||||
self.block_table[request_id].extend(allocated_blocks)
|
||||
return n_blocks
|
||||
|
||||
def get_read_indices(self, request_id: str, past_length: int, query_length: int) -> list[int]:
|
||||
"""Returns the physical indices of where to read request_id's cache. For a group of full attention layers, we
|
||||
first write the new cache to the cache tensor and then read the entire cache from the beginning to the end."""
|
||||
# Retrieve the block table for the request and raise an error if it doesn't exist
|
||||
block_table = self._block_table.get(request_id)
|
||||
block_table = self.block_table.get(request_id)
|
||||
if block_table is None:
|
||||
raise ValueError(f"No block table found for request {request_id}")
|
||||
# Compute the physical indices
|
||||
@ -97,7 +286,7 @@ class FullAttentionCacheAllocator(CacheAllocator):
|
||||
def get_write_indices(self, request_id: str, past_length: int, query_length: int) -> list[int]:
|
||||
"""Returns the physical indices for writing to the cache. For a group of full attention layers, we write the new
|
||||
cache as a continuation of the existing cache for the same request."""
|
||||
block_table = self._block_table.get(request_id)
|
||||
block_table = self.block_table.get(request_id)
|
||||
if block_table is None:
|
||||
raise ValueError(f"No block table found for request {request_id}")
|
||||
# Compute the physical indices
|
||||
@ -129,25 +318,26 @@ class SlidingAttentionCacheAllocator(CacheAllocator):
|
||||
self.block_size = block_size
|
||||
self.sliding_window = sliding_window
|
||||
self._max_blocks_per_request = ceil(self.sliding_window / self.block_size)
|
||||
self._block_table = {}
|
||||
self.block_table = {}
|
||||
|
||||
def allocate_blocks(self, n_blocks: int, request_id: str, free_blocks: deque[int]) -> Optional[int]:
|
||||
"""Allocate blocks for a given request_id. Returns the number of blocks allocated if successful and None
|
||||
otherwise. For group of sliding window attention layers, we only allocate up to the point where we can fit an
|
||||
entire sliding window in the cache tensor."""
|
||||
if request_id not in self._block_table:
|
||||
self._block_table[request_id] = []
|
||||
def allocate_blocks(self, n_blocks: int, request_id: str, block_manager: BlockManager) -> Optional[int]:
|
||||
"""Allocate (n_blocks) for a given (request_id) using the (block_manager). Returns the number of blocks
|
||||
allocated otherwise. For group of sliding window attention layers, we only allocate up to the point where we can
|
||||
fit an entire sliding window in the cache tensor."""
|
||||
if request_id not in self.block_table:
|
||||
self.block_table[request_id] = []
|
||||
# Early return if we are already at the max number of blocks per request
|
||||
already_allocated = len(self._block_table[request_id])
|
||||
already_allocated = len(self.block_table[request_id])
|
||||
if already_allocated == self._max_blocks_per_request:
|
||||
return 0
|
||||
# Compute actual number of blocks to allocate
|
||||
after_allocation = min(already_allocated + n_blocks, self._max_blocks_per_request)
|
||||
actual_n_blocks = after_allocation - already_allocated
|
||||
# Classic allocation
|
||||
if len(free_blocks) < actual_n_blocks:
|
||||
allocated_blocks = block_manager.get_free_blocks(actual_n_blocks, None) # no prefix caching w/ sliding window
|
||||
if allocated_blocks is None:
|
||||
return None
|
||||
self._block_table[request_id].extend(free_blocks.popleft() for _ in range(actual_n_blocks))
|
||||
self.block_table[request_id].extend(allocated_blocks)
|
||||
return actual_n_blocks
|
||||
|
||||
def get_read_indices(self, request_id: str, past_length: int, query_length: int) -> list[int]:
|
||||
@ -157,7 +347,7 @@ class SlidingAttentionCacheAllocator(CacheAllocator):
|
||||
sliding_window - 1 cache page and then manually add the new key / values states after. Hence the -1 indices
|
||||
which indicate where to store the new key or values indices."""
|
||||
# Retrieve the block table for the request and raise an error if it doesn't exist
|
||||
block_table = self._block_table.get(request_id)
|
||||
block_table = self.block_table.get(request_id)
|
||||
if block_table is None:
|
||||
raise ValueError(f"No block table found for request {request_id}")
|
||||
# Apply sliding window
|
||||
@ -178,7 +368,7 @@ class SlidingAttentionCacheAllocator(CacheAllocator):
|
||||
sliding window attention layers, we write the new cache in rolling-buffer kind of way: if we reach the end of
|
||||
the allocated physical cache, we start writing from the beginning of the physical cache again."""
|
||||
# Retrieve the block table for the request and raise an error if it doesn't exist
|
||||
block_table = self._block_table.get(request_id)
|
||||
block_table = self.block_table.get(request_id)
|
||||
if block_table is None:
|
||||
raise ValueError(f"No block table found for request {request_id}")
|
||||
# Apply sliding window
|
||||
@ -201,22 +391,3 @@ class SlidingAttentionCacheAllocator(CacheAllocator):
|
||||
"""Returns the attention type of the cache allocator and the key sequence length for the given request_id."""
|
||||
seqlens_k = query_length + min(past_length, self.sliding_window - 1)
|
||||
return "sliding_attention", seqlens_k
|
||||
|
||||
|
||||
# TODO: test the impact of this
|
||||
# def get_read_indices(self, request_id: str, past_length: int) -> list[int]:
|
||||
# # Retrieve the block table for the request and raise an error if it doesn't exist
|
||||
# block_table = self._block_table.get(request_id)
|
||||
# if block_table is None:
|
||||
# raise ValueError(f"No block table found for request {request_id}")
|
||||
# # Compute the physical indices
|
||||
# physical_indices = []
|
||||
# n_left = past_length
|
||||
# for block_idx in block_table:
|
||||
# block_physical_index = block_idx * self.block_size
|
||||
# pages_used = min(self.block_size, n_left)
|
||||
# physical_indices.extend(block_physical_index + i for i in range(pages_used))
|
||||
# n_left -= pages_used
|
||||
# if n_left == 0:
|
||||
# return physical_indices
|
||||
# raise ValueError(f"Request {request_id} required too many indices: {past_length = } and {len(block_table) = }")
|
||||
|
||||
@ -16,12 +16,13 @@
|
||||
import queue
|
||||
import threading
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from itertools import count
|
||||
from math import ceil
|
||||
from time import perf_counter
|
||||
from typing import Optional, Union
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -446,10 +447,7 @@ class ContinuousBatchProcessor:
|
||||
cumulative_seqlens_q = [0]
|
||||
logits_indices = []
|
||||
|
||||
if isinstance(self.cumulative_seqlens_k, dict):
|
||||
cumulative_seqlens_k = {layer_type: [0] for layer_type in self.cumulative_seqlens_k}
|
||||
else:
|
||||
cumulative_seqlens_k = [0]
|
||||
cumulative_seqlens_k = {layer_type: [0] for layer_type in self.cumulative_seqlens_k}
|
||||
|
||||
read_index = [[] for _ in range(self.cache.num_groups)]
|
||||
write_index = [[] for _ in range(self.cache.num_groups)]
|
||||
@ -498,10 +496,7 @@ class ContinuousBatchProcessor:
|
||||
self.metrics.record_kv_cache_memory_metrics(self.cache)
|
||||
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
if isinstance(self.cumulative_seqlens_k, dict):
|
||||
ck = max(cumulative_seqlens_k[layer_type][-1] for layer_type in self.cumulative_seqlens_k)
|
||||
else:
|
||||
ck = cumulative_seqlens_k[-1]
|
||||
ck = max(cumulative_seqlens_k[layer_type][-1] for layer_type in self.cumulative_seqlens_k)
|
||||
logger.debug(
|
||||
f"Scheduled: {len(self.requests_in_batch)}, Waiting: {len(self.scheduler.waiting_requests)}, "
|
||||
f"Active: {len(self.scheduler.active_requests)}. cum Q: {cumulative_seqlens_q[-1]}. "
|
||||
@ -517,7 +512,7 @@ class ContinuousBatchProcessor:
|
||||
read_index: list[list[int]],
|
||||
write_index: list[list[int]],
|
||||
cumulative_seqlens_q: list[int],
|
||||
cumulative_seqlens_k: Union[list[int], dict[str, list[int]]],
|
||||
cumulative_seqlens_k: dict[str, list[int]],
|
||||
logits_indices: list[int],
|
||||
) -> None:
|
||||
"""Builds the actual tensors for the current batch, by modifying the already allocated tensors in place."""
|
||||
@ -561,9 +556,7 @@ class ContinuousBatchProcessor:
|
||||
@traced
|
||||
def _maybe_send_output(self, state: RequestState) -> None:
|
||||
"""Send output to the queue based on streaming mode and request state."""
|
||||
if state.streaming:
|
||||
self.output_queue.put(state.to_generation_output())
|
||||
elif state.status == RequestStatus.FINISHED:
|
||||
if state.streaming or state.status == RequestStatus.FINISHED:
|
||||
self.output_queue.put(state.to_generation_output())
|
||||
|
||||
@traced
|
||||
@ -571,17 +564,27 @@ class ContinuousBatchProcessor:
|
||||
"""Update request states based on generated tokens."""
|
||||
out_tokens = self._sync()
|
||||
for i, state in enumerate(self.requests_in_batch):
|
||||
# If the request has no remaining prompt ids, it means prefill has already ended or just finished
|
||||
if len(state.remaining_prompt_ids) == 0:
|
||||
self.metrics.record_ttft_metric(state.created_time, state.request_id)
|
||||
state.status = RequestStatus.DECODING
|
||||
token = out_tokens[self.logits_indices[i]]
|
||||
state.prompt_ids = [token]
|
||||
if state.update_with_token(token):
|
||||
# Update the request and stop if it is complete
|
||||
is_finished = state.update_and_check_completion(token)
|
||||
# We mark the completed blocks as such
|
||||
self.cache.mark_blocks_as_complete(state)
|
||||
if is_finished:
|
||||
self.metrics.record_request_completion(state.created_time, state.request_id)
|
||||
self.scheduler.finish_request(state.request_id, evict_from_cache=(not self.manual_eviction))
|
||||
self._maybe_send_output(state)
|
||||
# Otherwise, the request is still prefilling, but the prefill has been split
|
||||
elif state.status == RequestStatus.PREFILLING_SPLIT:
|
||||
self.cache.mark_blocks_as_complete(state)
|
||||
state.status = RequestStatus.SPLIT_PENDING_REMAINDER
|
||||
else:
|
||||
raise ValueError(f"Request {state.request_id} is in an unexpected state: {state.status}")
|
||||
|
||||
if self.cache.get_num_free_blocks() == 0:
|
||||
raise ValueError("No more free blocks")
|
||||
|
||||
@ -726,6 +729,7 @@ class ContinuousBatchingManager:
|
||||
max_queue_size: int = 0,
|
||||
num_q_cuda_graphs: int = 0,
|
||||
num_kv_cuda_graphs: int = 0,
|
||||
allow_prefix_sharing: bool = True,
|
||||
) -> None:
|
||||
"""Initialize the continuous batching manager.
|
||||
|
||||
@ -735,6 +739,7 @@ class ContinuousBatchingManager:
|
||||
max_queue_size: Maximum size of the request queue (0 = unlimited)
|
||||
num_q_cuda_graphs: (optional) Number of CUDA graphs to use for the query dimension
|
||||
num_kv_cuda_graphs: (optional) Number of CUDA graphs to use for the keys/values dimension
|
||||
allow_prefix_sharing: (optional) Whether to allow prefix sharing if the model has only full attention layers
|
||||
"""
|
||||
if "paged|" not in model.config._attn_implementation:
|
||||
attn_implementation = f"paged|{model.config._attn_implementation}"
|
||||
@ -767,6 +772,8 @@ class ContinuousBatchingManager:
|
||||
self.manual_eviction = manual_eviction
|
||||
self.batch_processor: Optional[ContinuousBatchProcessor] = None
|
||||
|
||||
self._allow_prefix_sharing = allow_prefix_sharing
|
||||
|
||||
# If a number of cuda graphs was specified for either Q or KV, we activate cuda graphs
|
||||
if num_q_cuda_graphs > 0 or num_kv_cuda_graphs > 0:
|
||||
self.use_cuda_graph = True
|
||||
@ -799,7 +806,6 @@ class ContinuousBatchingManager:
|
||||
logger.warning("Manager thread is already running.")
|
||||
return
|
||||
|
||||
self._result_queue = queue.Queue()
|
||||
self._generation_thread = threading.Thread(target=self._run_generation_loop)
|
||||
self._generation_thread.start()
|
||||
|
||||
@ -814,6 +820,16 @@ class ContinuousBatchingManager:
|
||||
block: Whether to wait for the thread to stop
|
||||
timeout: Maximum time to wait for the thread to stop
|
||||
"""
|
||||
if self.batch_processor is None:
|
||||
logger.warning("\nBatch processor was not initialized.")
|
||||
else:
|
||||
if self.batch_processor.cache.use_prefix_sharing:
|
||||
logger.warning(
|
||||
f"\nPrefix sharing was on. Total prefix length: {self.batch_processor.cache._total_prefix_length}"
|
||||
)
|
||||
else:
|
||||
logger.warning("\nPrefix sharing was off.")
|
||||
|
||||
if self._generation_thread is None:
|
||||
logger.warning("Manager not started.")
|
||||
return
|
||||
@ -939,20 +955,6 @@ class ContinuousBatchingManager:
|
||||
request_cancelled = self.batch_processor.scheduler.request_is_cancelled(request_id)
|
||||
|
||||
@traced
|
||||
def warmup(self, batch_processor: ContinuousBatchProcessor) -> None:
|
||||
stream = torch.cuda.Stream(device=self.model.device)
|
||||
stream.wait_stream(torch.cuda.current_stream())
|
||||
with torch.cuda.stream(stream):
|
||||
# Warmup the model with a dummy forward pass
|
||||
self._generation_step(batch_processor)
|
||||
torch.cuda.current_stream().wait_stream(stream)
|
||||
|
||||
self.graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(self.graph, stream=stream):
|
||||
self._generation_step(batch_processor)
|
||||
|
||||
@traced
|
||||
# @torch.compile
|
||||
def _generation_step(self) -> None:
|
||||
"""Perform a single generation step. This is cuda graphed"""
|
||||
self.batch_processor._generation_step(self.model, self.logit_processor, self.do_sample)
|
||||
@ -968,6 +970,7 @@ class ContinuousBatchingManager:
|
||||
self.model.device,
|
||||
self.model.dtype,
|
||||
tp_size=getattr(self.model, "_tp_size", None), # Use model's actual TP setting
|
||||
allow_prefix_sharing=self._allow_prefix_sharing,
|
||||
)
|
||||
logger.debug(f"PagedAttentionCache created in {perf_counter() - t0} seconds")
|
||||
|
||||
@ -1059,6 +1062,15 @@ class ContinuousBatchingManager:
|
||||
class ContinuousMixin:
|
||||
"""Mixin class for models to add continuous batching capabilities."""
|
||||
|
||||
@contextmanager
|
||||
def continuous_batching_context_manager(self, **kwargs) -> Generator[ContinuousBatchingManager]:
|
||||
manager = self.init_continuous_batching(**kwargs)
|
||||
manager.start()
|
||||
try:
|
||||
yield manager
|
||||
finally:
|
||||
manager.stop(block=True)
|
||||
|
||||
def init_continuous_batching(
|
||||
self,
|
||||
generation_config: Optional[GenerationConfig] = None,
|
||||
@ -1066,6 +1078,7 @@ class ContinuousMixin:
|
||||
max_queue_size: int = 0,
|
||||
num_q_cuda_graphs: int = 0,
|
||||
num_kv_cuda_graphs: int = 0,
|
||||
allow_prefix_sharing: bool = True,
|
||||
) -> ContinuousBatchingManager:
|
||||
"""Initialize a manager for continuous batching inference.
|
||||
|
||||
@ -1098,6 +1111,7 @@ class ContinuousMixin:
|
||||
max_queue_size=max_queue_size,
|
||||
num_q_cuda_graphs=num_q_cuda_graphs,
|
||||
num_kv_cuda_graphs=num_kv_cuda_graphs,
|
||||
allow_prefix_sharing=allow_prefix_sharing,
|
||||
)
|
||||
|
||||
# TODO: support streaming
|
||||
@ -1169,5 +1183,6 @@ class ContinuousMixin:
|
||||
except Exception as e:
|
||||
logger.error(f"Error during batch generation: {e}", exc_info=True)
|
||||
finally:
|
||||
logger.debug("Generate batch is finished.") # a dummy log needed for the logs of stop to show. Won't show.
|
||||
manager.stop(block=True, timeout=5.0)
|
||||
return results
|
||||
|
||||
@ -116,10 +116,10 @@ class RequestState:
|
||||
error (Optional[str]): Any error message associated with the request. When None, has had no error yet.
|
||||
"""
|
||||
|
||||
# Required fields
|
||||
# Required fields # TODO: come up with better names / not sure prompt_ids and such are not redundant
|
||||
request_id: str
|
||||
full_prompt_ids: Optional[list[int]] = None # Full initial prompt
|
||||
prompt_ids: Optional[list[int]] = None # Tokens IDs currently being processed (initial + generated)
|
||||
prompt_ids: Optional[list[int]] = None # Tokens IDs currently being processed
|
||||
remaining_prompt_ids: list[int] = field(default_factory=list) # For split requests, prefill left to process
|
||||
static_outputs: list[int] = field(default_factory=list) # Generated tokens
|
||||
allocated_blocks: int = 0 # Number of blocks allocated to the request
|
||||
@ -164,7 +164,7 @@ class RequestState:
|
||||
|
||||
# TODO: this logic seems one token off, check it out
|
||||
@traced
|
||||
def update_with_token(self, token_id: int) -> bool:
|
||||
def update_and_check_completion(self, token_id: int) -> bool:
|
||||
"""Update the request with a newly generated token and check for completion.
|
||||
|
||||
Args:
|
||||
|
||||
@ -104,7 +104,7 @@ class Scheduler(ABC):
|
||||
)
|
||||
|
||||
@traced
|
||||
def _allocate_blocks_if_needed(self, state: RequestState, len_next_tokens: int) -> bool:
|
||||
def _allocate_blocks_if_needed(self, state: RequestState) -> bool:
|
||||
"""Allocate additional cache blocks for a request if the currently allocated blocks are insufficient to
|
||||
accommodate the next tokens. It calculates how many blocks are needed based on the request's current
|
||||
cache occupancy and the number of tokens to be processed. The allocation itself is done by the CacheAllocator
|
||||
@ -113,10 +113,11 @@ class Scheduler(ABC):
|
||||
# 1. we check that the occupancy is less than the requested length
|
||||
# 2. we allocate enough blocks to cover the requested length
|
||||
current_len = state.current_len()
|
||||
len_next_tokens = len(state.prompt_ids)
|
||||
occupancy = state.allocated_blocks * self.cache.block_size - current_len
|
||||
if occupancy < len_next_tokens or state.allocated_blocks == 0:
|
||||
blocks_needed = ((len_next_tokens - occupancy + 1) // self.cache.block_size) + 1
|
||||
allocated = self.cache.allocate_blocks(blocks_needed, state.request_id)
|
||||
allocated = self.cache.allocate_blocks(blocks_needed, state)
|
||||
if allocated is None:
|
||||
return False
|
||||
state.allocated_blocks += allocated
|
||||
@ -125,11 +126,29 @@ class Scheduler(ABC):
|
||||
@traced(span_name="prepare_request")
|
||||
def _prepare_request_for_processing(
|
||||
self, state: RequestState, token_budget: int, request_ids_to_remove_from_waiting: set[str]
|
||||
):
|
||||
"""Prepares a request for processing in the current batch."""
|
||||
request_tokens = (
|
||||
state.remaining_prompt_ids if state.status == RequestStatus.SPLIT_PENDING_REMAINDER else state.prompt_ids
|
||||
)
|
||||
) -> None:
|
||||
"""Prepares a request for processing in the current batch. If prefix sharing is enabled, and the request was
|
||||
pending, this is where we look for a prefix match and split the request if found."""
|
||||
# If prefix sharing is enabled, we look for a prefix match and split the request if found
|
||||
if self.cache.use_prefix_sharing and state.status == RequestStatus.PENDING:
|
||||
prefill_length = self.cache.search_prefix_match(state.request_id, state.prompt_ids)
|
||||
if prefill_length > 0:
|
||||
self.active_requests[state.request_id] = state
|
||||
request_ids_to_remove_from_waiting.add(state.request_id)
|
||||
state.status = RequestStatus.SPLIT_PENDING_REMAINDER
|
||||
# Even if we match the whole request, we keep at least 1 token to start decoding
|
||||
prefill_length = min(prefill_length, len(state.prompt_ids) - 1)
|
||||
state.remaining_prompt_ids = state.prompt_ids[prefill_length:]
|
||||
state.prompt_ids = state.prompt_ids[prefill_length:]
|
||||
state.position_offset += prefill_length
|
||||
|
||||
# If the request has a split prefill, the tokens to process are the remaining prompt ids
|
||||
if state.status == RequestStatus.SPLIT_PENDING_REMAINDER:
|
||||
request_tokens = state.remaining_prompt_ids
|
||||
# Otherwise, the tokens to process are the prompt ids, which are the full prompt or the last predicted tokens
|
||||
else:
|
||||
request_tokens = state.prompt_ids
|
||||
|
||||
if len(request_tokens) < token_budget:
|
||||
# Can process the entire prompt/remainder
|
||||
if state.status == RequestStatus.PENDING:
|
||||
@ -152,6 +171,7 @@ class Scheduler(ABC):
|
||||
state.prompt_ids = request_tokens[:token_budget]
|
||||
|
||||
|
||||
# TODO: further common-ize the two classes
|
||||
@attach_tracer()
|
||||
class FIFOScheduler(Scheduler):
|
||||
"""This scheduler processes requests in the order they arrive, meaning decoding requests has priority over
|
||||
@ -195,30 +215,31 @@ class FIFOScheduler(Scheduler):
|
||||
|
||||
self._prepare_request_for_processing(state, token_budget, request_ids_to_remove_from_waiting)
|
||||
request_len = len(state.prompt_ids)
|
||||
if not self._allocate_blocks_if_needed(
|
||||
state, len(state.prompt_ids)
|
||||
): # don't schedule if we can't allocate blocks
|
||||
if len(self.cache._free_blocks) == 0:
|
||||
# If we can't allocate blocks, do not schedule the request and break if the cache is full
|
||||
if not self._allocate_blocks_if_needed(state):
|
||||
if self.cache.get_num_free_blocks() == 0:
|
||||
break
|
||||
continue
|
||||
|
||||
@traced
|
||||
def _add_to_scheduled_requests(state: RequestState):
|
||||
scheduled_requests.append(state)
|
||||
|
||||
_add_to_scheduled_requests(state)
|
||||
# Add the request to the scheduled requests
|
||||
scheduled_requests.append(state)
|
||||
|
||||
# Update the token budget
|
||||
token_budget -= request_len
|
||||
# If using prefix sharing, we make note of the blocks that will be computed in the forward pass
|
||||
if self.cache.use_prefix_sharing:
|
||||
tokens_in_current_block = state.current_len() % self.cache.block_size
|
||||
tokens_after_forward = tokens_in_current_block + request_len
|
||||
complete_blocks = tokens_after_forward // self.cache.block_size
|
||||
self.cache.blocks_to_complete[state.request_id] = complete_blocks
|
||||
|
||||
@traced
|
||||
def _remove_from_waiting_requests(state: RequestState):
|
||||
req_id = state.request_id
|
||||
if req_id in self.waiting_requests:
|
||||
del self.waiting_requests[req_id]
|
||||
request_ids_to_remove_from_waiting.add(req_id)
|
||||
|
||||
_remove_from_waiting_requests(state)
|
||||
# Remove the request from the waiting queue and mark it as removed
|
||||
req_id = state.request_id
|
||||
was_waiting = self.waiting_requests.pop(req_id, None) is not None
|
||||
if was_waiting:
|
||||
request_ids_to_remove_from_waiting.add(req_id)
|
||||
|
||||
# Early exit of the loop if we have no token budget left
|
||||
if token_budget == 0:
|
||||
break
|
||||
|
||||
@ -249,6 +270,7 @@ class PrefillFirstScheduler(Scheduler):
|
||||
elif state.status == RequestStatus.DECODING:
|
||||
second_priority_states.append(state)
|
||||
|
||||
# Add waiting requests to second priority
|
||||
for req_id in self.waiting_requests_order:
|
||||
second_priority_states.append(self.waiting_requests[req_id])
|
||||
|
||||
@ -259,30 +281,31 @@ class PrefillFirstScheduler(Scheduler):
|
||||
for state in candidates:
|
||||
self._prepare_request_for_processing(state, token_budget, request_ids_to_remove_from_waiting)
|
||||
request_len = len(state.prompt_ids)
|
||||
if not self._allocate_blocks_if_needed(
|
||||
state, len(state.prompt_ids)
|
||||
): # don't schedule if we can't allocate blocks
|
||||
if len(self.cache._free_blocks) == 0:
|
||||
# If we can't allocate blocks, do not schedule the request and break if the cache is full
|
||||
if not self._allocate_blocks_if_needed(state):
|
||||
if self.cache.get_num_free_blocks() == 0:
|
||||
break
|
||||
continue
|
||||
|
||||
@traced
|
||||
def _add_to_scheduled_requests(state: RequestState):
|
||||
scheduled_requests.append(state)
|
||||
|
||||
_add_to_scheduled_requests(state)
|
||||
# Add the request to the scheduled requests
|
||||
scheduled_requests.append(state)
|
||||
|
||||
# Update the token budget
|
||||
token_budget -= request_len
|
||||
# If using prefix sharing, we make note of the blocks that will be computed in the forward pass
|
||||
if self.cache.use_prefix_sharing:
|
||||
tokens_in_current_block = state.current_len() % self.cache.block_size
|
||||
tokens_after_forward = tokens_in_current_block + request_len
|
||||
complete_blocks = tokens_after_forward // self.cache.block_size
|
||||
self.cache.blocks_to_complete[state.request_id] = complete_blocks
|
||||
|
||||
@traced
|
||||
def _remove_from_waiting_requests(state: RequestState):
|
||||
req_id = state.request_id
|
||||
if req_id in self.waiting_requests:
|
||||
del self.waiting_requests[req_id]
|
||||
request_ids_to_remove_from_waiting.add(req_id)
|
||||
|
||||
_remove_from_waiting_requests(state)
|
||||
# Remove the request from the waiting queue and mark it as removed
|
||||
req_id = state.request_id
|
||||
if req_id in self.waiting_requests:
|
||||
del self.waiting_requests[req_id]
|
||||
request_ids_to_remove_from_waiting.add(req_id)
|
||||
|
||||
# Early exit of the loop if we have no token budget left
|
||||
if token_budget == 0:
|
||||
break
|
||||
|
||||
|
||||
@ -411,7 +411,7 @@ class GenerationMixin(ContinuousMixin):
|
||||
"Generation config file not found, using a generation config created from the model config."
|
||||
)
|
||||
# Load custom generate function if `pretrained_model_name_or_path` defines it (and override `generate`)
|
||||
if hasattr(self, "load_custom_generate") and trust_remote_code:
|
||||
if hasattr(self, "load_custom_generate"):
|
||||
try:
|
||||
custom_generate = self.load_custom_generate(
|
||||
pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **repo_loading_kwargs
|
||||
@ -1635,12 +1635,7 @@ class GenerationMixin(ContinuousMixin):
|
||||
|
||||
# TransformersKwargs are model-agnostic attention and generation arguments such as 'output_attentions'
|
||||
for key, value in model_kwargs.items():
|
||||
if (
|
||||
value is not None
|
||||
and key not in model_args
|
||||
and key not in TransformersKwargs.__optional_keys__
|
||||
and key != "debug_io"
|
||||
):
|
||||
if value is not None and key not in model_args and key not in TransformersKwargs.__optional_keys__:
|
||||
unused_model_args.append(key)
|
||||
|
||||
if unused_model_args:
|
||||
|
||||
@ -383,11 +383,10 @@ class BayesianDetectorModel(PreTrainedModel):
|
||||
)
|
||||
self.prior = torch.nn.Parameter(torch.tensor([self.base_rate]))
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights."""
|
||||
if isinstance(module, nn.Parameter):
|
||||
module.weight.normal_(mean=0.0, std=0.02)
|
||||
module.weight.data.normal_(mean=0.0, std=0.02)
|
||||
|
||||
def _compute_posterior(
|
||||
self,
|
||||
|
||||
@ -512,8 +512,10 @@ def accelerate_disk_offload(
|
||||
checkpoint_files,
|
||||
device_map,
|
||||
checkpoint_keys,
|
||||
key_renaming_mapping,
|
||||
sharded_metadata,
|
||||
dtype,
|
||||
reverse_key_renaming_mapping,
|
||||
):
|
||||
disk_only_shard_files = []
|
||||
if disk_offload_folder is not None:
|
||||
@ -532,13 +534,19 @@ def accelerate_disk_offload(
|
||||
weight_map = dict.fromkeys(checkpoint_keys, checkpoint_files[0])
|
||||
else:
|
||||
folder = os.path.sep.join(checkpoint_files[0].split(os.path.sep)[:-1])
|
||||
# Fix the weight map keys according to the key mapping
|
||||
weight_map = {
|
||||
key_renaming_mapping[k]: v
|
||||
for k, v in sharded_metadata["weight_map"].items()
|
||||
if k in key_renaming_mapping
|
||||
}
|
||||
weight_map = {k: os.path.join(folder, v) for k, v in weight_map.items()}
|
||||
# Find potential checkpoints containing only offloaded weights
|
||||
disk_only_shard_files = get_disk_only_shard_files(device_map, weight_map)
|
||||
disk_offload_index = {
|
||||
name: {
|
||||
"safetensors_file": file,
|
||||
"weight_name": name,
|
||||
"weight_name": reverse_key_renaming_mapping[name],
|
||||
"dtype": str_dtype,
|
||||
}
|
||||
for name, file in weight_map.items()
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import inspect
|
||||
from copy import deepcopy
|
||||
from inspect import signature
|
||||
|
||||
from ..utils import (
|
||||
@ -23,6 +24,7 @@ if is_accelerate_available():
|
||||
import accelerate
|
||||
from accelerate import init_empty_weights
|
||||
from accelerate.hooks import add_hook_to_module, remove_hook_from_module
|
||||
from accelerate.utils import find_tied_parameters
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
@ -149,6 +151,52 @@ def replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name
|
||||
return model
|
||||
|
||||
|
||||
def get_keys_to_not_convert(model):
|
||||
r"""
|
||||
An utility function to get the key of the module to keep in full precision if any For example for CausalLM modules
|
||||
we may want to keep the lm_head in full precision for numerical stability reasons. For other architectures, we want
|
||||
to keep the tied weights of the model. The function will return a list of the keys of the modules to not convert in
|
||||
int8.
|
||||
|
||||
Parameters:
|
||||
model (`torch.nn.Module`):
|
||||
Input model
|
||||
"""
|
||||
# Create a copy of the model and tie the weights, then
|
||||
# check if it contains tied weights
|
||||
tied_model = deepcopy(model) # this has 0 cost since it is done inside `init_empty_weights` context manager`
|
||||
tied_model.tie_weights()
|
||||
|
||||
tied_params = find_tied_parameters(tied_model)
|
||||
tied_keys = sum(tied_params, [])
|
||||
has_tied_params = len(tied_keys) > 0
|
||||
|
||||
# If there is not tied weights, we want to keep the lm_head(output_embedding) in full precision
|
||||
if not has_tied_params:
|
||||
output_emb = model.get_output_embeddings()
|
||||
if output_emb is not None:
|
||||
list_last_module = [name for name, module in model.named_modules() if id(module) == id(output_emb)]
|
||||
return list_last_module
|
||||
|
||||
# otherwise, no tied weights, no output embedding defined, simply keep the last module in full precision
|
||||
list_modules = list(model.named_parameters())
|
||||
list_last_module = [list_modules[-1][0]]
|
||||
# add last module together with tied weights
|
||||
intersection = set(list_last_module) - set(tied_keys)
|
||||
list_untouched = list(set(tied_keys)) + list(intersection)
|
||||
|
||||
# remove ".weight" from the keys
|
||||
names_to_remove = [".weight", ".bias"]
|
||||
filtered_module_names = []
|
||||
for name in list_untouched:
|
||||
for name_to_remove in names_to_remove:
|
||||
if name_to_remove in name:
|
||||
name = name.replace(name_to_remove, "")
|
||||
filtered_module_names.append(name)
|
||||
|
||||
return filtered_module_names
|
||||
|
||||
|
||||
# Copied from PEFT: https://github.com/huggingface/peft/blob/47b3712898539569c02ec5b3ed4a6c36811331a1/src/peft/utils/integrations.py#L41
|
||||
def dequantize_bnb_weight(weight: "torch.nn.Parameter", dtype: "torch.dtype", state=None):
|
||||
"""
|
||||
|
||||
@ -13,11 +13,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import re
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Optional
|
||||
|
||||
from ..core_model_loading import ConversionOps
|
||||
from ..utils import is_accelerate_available, is_torch_accelerator_available, is_torch_available, logging
|
||||
|
||||
|
||||
@ -33,18 +30,6 @@ if is_accelerate_available():
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
try:
|
||||
_FP8_DTYPE = torch.float8_e4m3fn
|
||||
_FP8_MIN = torch.finfo(_FP8_DTYPE).min
|
||||
_FP8_MAX = torch.finfo(_FP8_DTYPE).max
|
||||
_FP8_IS_INT = False
|
||||
except AttributeError:
|
||||
_FP8_DTYPE = torch.int8
|
||||
_FP8_MIN, _FP8_MAX = -127, 127
|
||||
_FP8_IS_INT = True
|
||||
logger.warning_once(
|
||||
"torch.float8_e4m3fn not available; falling back to int8 emulation for Fp8Quantize operations."
|
||||
)
|
||||
|
||||
|
||||
# Copied from https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/inference/kernel.py
|
||||
@ -347,12 +332,6 @@ class FP8Linear(nn.Linear):
|
||||
if self.weight.element_size() > 1:
|
||||
return F.linear(input, self.weight, self.bias)
|
||||
else:
|
||||
if isinstance(self.weight, torch.distributed.tensor.DTensor):
|
||||
weight = self.weight._local_tensor.contiguous()
|
||||
scale_inv = self.weight_scale_inv._local_tensor.contiguous()
|
||||
else:
|
||||
weight = self.weight.contiguous()
|
||||
scale_inv = self.weight_scale_inv.contiguous()
|
||||
# Context manager used to switch among the available accelerators
|
||||
device_type = torch.accelerator.current_accelerator().type if is_torch_accelerator_available() else "cuda"
|
||||
torch_accelerator_module = getattr(torch, device_type, torch.cuda)
|
||||
@ -360,9 +339,9 @@ class FP8Linear(nn.Linear):
|
||||
qinput, scale = act_quant(input, self.block_size[1])
|
||||
output = w8a8_block_fp8_matmul_triton(
|
||||
qinput,
|
||||
weight,
|
||||
self.weight,
|
||||
scale,
|
||||
scale_inv,
|
||||
self.weight_scale_inv,
|
||||
self.block_size,
|
||||
output_dtype=input.dtype,
|
||||
)
|
||||
@ -371,124 +350,9 @@ class FP8Linear(nn.Linear):
|
||||
torch_accelerator_module.synchronize()
|
||||
if self.bias is not None:
|
||||
output = output + self.bias
|
||||
output = torch.nan_to_num(output, nan=0.0)
|
||||
return output.to(dtype=input.dtype)
|
||||
|
||||
|
||||
def _ceil_div(a, b):
|
||||
return (a + b - 1) // b
|
||||
|
||||
|
||||
class FP8Expert(nn.Module):
|
||||
dtype = torch.float8_e4m3fn
|
||||
|
||||
def __init__(self, config, block_size, device):
|
||||
super().__init__()
|
||||
|
||||
from ..activations import ACT2FN
|
||||
|
||||
self.block_size = block_size
|
||||
self.num_experts = config.num_local_experts
|
||||
self.hidden_dim = config.hidden_size
|
||||
self.intermediate_dim = config.intermediate_size
|
||||
|
||||
Wg_out, Wg_in = 2 * self.intermediate_dim, self.hidden_dim
|
||||
Wd_out, Wd_in = self.hidden_dim, self.intermediate_dim
|
||||
|
||||
self.gate_up_proj = nn.Parameter(
|
||||
torch.zeros(self.num_experts, Wg_out, Wg_in, dtype=FP8Expert.dtype, device=device)
|
||||
)
|
||||
self.down_proj = nn.Parameter(
|
||||
torch.zeros(self.num_experts, Wd_out, Wd_in, dtype=FP8Expert.dtype, device=device)
|
||||
)
|
||||
|
||||
# Create inverse scale tiles only when using 1-byte types (fp8)
|
||||
if self.gate_up_proj.element_size() == 1:
|
||||
bo, bi = self.block_size
|
||||
|
||||
# gate_up tiles: ceil(Wg_out/bo) x ceil(Wg_in/bi)
|
||||
gu_scale_o = _ceil_div(Wg_out, bo)
|
||||
gu_scale_i = _ceil_div(Wg_in, bi)
|
||||
self.gate_up_proj_scales_inv = nn.Parameter(
|
||||
torch.zeros(self.num_experts, gu_scale_o, gu_scale_i, dtype=torch.float32, device=device)
|
||||
)
|
||||
|
||||
# down tiles: ceil(Wd_out/bo) x ceil(Wd_in/bi)
|
||||
dp_scale_o = _ceil_div(Wd_out, bo)
|
||||
dp_scale_i = _ceil_div(Wd_in, bi)
|
||||
self.down_proj_scales_inv = nn.Parameter(
|
||||
torch.zeros(self.num_experts, dp_scale_o, dp_scale_i, dtype=torch.float32, device=device)
|
||||
)
|
||||
else:
|
||||
# Match FP8Linear behavior when not using 1-byte weights
|
||||
self.register_parameter("gate_up_proj_scale_inv", None)
|
||||
self.register_parameter("down_proj_scale_inv", None)
|
||||
|
||||
# (Optional) bias per projection — many MoEs omit bias; keep None to match your FP8Linear default
|
||||
self.register_parameter("gate_up_bias", None)
|
||||
self.register_parameter("down_bias", None)
|
||||
|
||||
# Activation used in the MLP (same as your config / ACT2FN)
|
||||
# Keep a handle here; actual usage happens in forward of your MoE block
|
||||
self.act_fn = ACT2FN[config.hidden_act]
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
top_k_index: torch.Tensor,
|
||||
top_k_weights: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
final_hidden_states = torch.zeros_like(hidden_states)
|
||||
num_experts = top_k_weights.shape[1]
|
||||
with torch.no_grad():
|
||||
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1)
|
||||
expert_mask = expert_mask.permute(2, 1, 0)
|
||||
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
|
||||
|
||||
for expert_idx in expert_hit:
|
||||
expert_idx = expert_idx[0]
|
||||
if expert_idx == num_experts:
|
||||
continue
|
||||
_, token_idx = torch.where(expert_mask[expert_idx])
|
||||
current_state = hidden_states.index_select(0, token_idx)
|
||||
gate, up = self.linear(
|
||||
current_state, self.gate_up_proj[expert_idx], self.gate_up_proj_scales_inv[expert_idx]
|
||||
).chunk(2, dim=-1)
|
||||
current_hidden_states = self.act_fn(gate) * up
|
||||
current_hidden_states = self.linear(
|
||||
current_hidden_states, self.down_proj[expert_idx], self.down_proj_scales_inv[expert_idx]
|
||||
)
|
||||
|
||||
routing_weights = top_k_weights[token_idx, expert_idx].unsqueeze(-1)
|
||||
current_hidden_states = current_hidden_states * routing_weights.to(current_hidden_states.dtype)
|
||||
final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))
|
||||
|
||||
return final_hidden_states
|
||||
|
||||
def linear(self, input: torch.Tensor, weight: torch.Tensor, weight_scale_inv: torch.Tensor) -> torch.Tensor:
|
||||
if weight.element_size() > 1:
|
||||
return F.linear(input, weight, None)
|
||||
else:
|
||||
# Context manager used to switch among the available accelerators
|
||||
device_type = torch.accelerator.current_accelerator().type if is_torch_accelerator_available() else "cuda"
|
||||
torch_accelerator_module = getattr(torch, device_type, torch.cuda)
|
||||
with torch_accelerator_module.device(input.device):
|
||||
qinput, scale = act_quant(input, self.block_size[1])
|
||||
output = w8a8_block_fp8_matmul_triton(
|
||||
qinput,
|
||||
weight,
|
||||
scale,
|
||||
weight_scale_inv,
|
||||
self.block_size,
|
||||
output_dtype=input.dtype,
|
||||
)
|
||||
# Blocks the CPU until all accelerator operations on the specified device are complete. It is used to ensure that the results of the
|
||||
# preceding operations are ready before proceeding
|
||||
torch_accelerator_module.synchronize()
|
||||
return output.to(dtype=input.dtype)
|
||||
|
||||
|
||||
# TODO: we do need this.... but not recursive...
|
||||
def _replace_with_fp8_linear(
|
||||
model,
|
||||
tp_plan=None,
|
||||
@ -497,48 +361,40 @@ def _replace_with_fp8_linear(
|
||||
quantization_config=None,
|
||||
has_been_replaced=False,
|
||||
):
|
||||
iterator = list(model.named_parameters()).copy()
|
||||
for name, empty_tensor in iterator:
|
||||
current_key_name = name
|
||||
name = name.rsplit(".", 1)[0] if "." in name else name
|
||||
module = model.get_submodule(name)
|
||||
"""Replace Linear layers with FP8Linear."""
|
||||
if current_key_name is None:
|
||||
current_key_name = []
|
||||
|
||||
current_key_name_str = re.sub(r"\d+", "*", current_key_name)
|
||||
if not any(key in current_key_name_str for key in (modules_to_not_convert or [])):
|
||||
with init_empty_weights():
|
||||
if (
|
||||
"gate_up_proj" in current_key_name
|
||||
or "down_proj" in current_key_name
|
||||
and "experts" in current_key_name
|
||||
): # Experts!
|
||||
in_features = empty_tensor.size(-2)
|
||||
out_features = empty_tensor.size(-1)
|
||||
model.set_submodule(
|
||||
name,
|
||||
FP8Expert(
|
||||
config=model.config,
|
||||
block_size=quantization_config.weight_block_size,
|
||||
device=empty_tensor.device,
|
||||
),
|
||||
)
|
||||
for name, module in model.named_children():
|
||||
current_key_name.append(name)
|
||||
|
||||
elif isinstance(module, nn.Linear):
|
||||
in_features = module.in_features
|
||||
out_features = module.out_features
|
||||
model.set_submodule(
|
||||
name,
|
||||
FP8Linear(
|
||||
in_features=in_features,
|
||||
out_features=out_features,
|
||||
bias=module.bias is not None,
|
||||
device=module.weight.device,
|
||||
dtype=module.weight.dtype,
|
||||
activation_scheme=quantization_config.activation_scheme,
|
||||
block_size=quantization_config.weight_block_size,
|
||||
),
|
||||
if isinstance(module, nn.Linear) and name not in (modules_to_not_convert or []):
|
||||
current_key_name_str = ".".join(current_key_name)
|
||||
if not any(key in current_key_name_str for key in (modules_to_not_convert or [])):
|
||||
with init_empty_weights():
|
||||
model._modules[name] = FP8Linear(
|
||||
in_features=module.in_features,
|
||||
out_features=module.out_features,
|
||||
bias=module.bias is not None,
|
||||
device=module.weight.device,
|
||||
dtype=module.weight.dtype,
|
||||
activation_scheme=quantization_config.activation_scheme,
|
||||
block_size=quantization_config.weight_block_size,
|
||||
)
|
||||
has_been_replaced = True
|
||||
# when changing a layer the TP PLAN for that layer should be updated. TODO
|
||||
has_been_replaced = True
|
||||
# when changing a layer the TP PLAN for that layer should be updated. TODO
|
||||
|
||||
if len(list(module.children())) > 0:
|
||||
_, has_been_replaced = _replace_with_fp8_linear(
|
||||
module,
|
||||
tp_plan,
|
||||
modules_to_not_convert,
|
||||
current_key_name,
|
||||
quantization_config,
|
||||
has_been_replaced=has_been_replaced,
|
||||
)
|
||||
|
||||
current_key_name.pop(-1)
|
||||
|
||||
return model, has_been_replaced
|
||||
|
||||
@ -549,7 +405,7 @@ def replace_with_fp8_linear(
|
||||
quantization_config=None,
|
||||
):
|
||||
"""Helper function to replace model layers with FP8 versions."""
|
||||
modules_to_not_convert += ["lm_head"]
|
||||
modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert
|
||||
|
||||
if quantization_config.modules_to_not_convert is not None:
|
||||
modules_to_not_convert.extend(quantization_config.modules_to_not_convert)
|
||||
@ -568,133 +424,3 @@ def replace_with_fp8_linear(
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
class QuantizationOp(ConversionOps):
|
||||
"""Base class for quantization operations."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class Fp8Quantize(QuantizationOp):
|
||||
"""
|
||||
A quantization operation that creates two tensors, weight and scale out of a weight.
|
||||
"""
|
||||
|
||||
reverse_op: type[ConversionOps]
|
||||
|
||||
def __init__(self, block_size: Optional[tuple[int, int]] = None):
|
||||
self.block_size = block_size
|
||||
self.reverse_op = Fp8Dequantize
|
||||
|
||||
def convert(self, input_dict: torch.Tensor, *, quant_config: dict[str, Any]) -> dict[str, torch.Tensor]:
|
||||
# Unpack single key/value (value may be wrapped in a list)
|
||||
target_keys, value = tuple(input_dict.items())[0]
|
||||
value = value[0] if isinstance(value, list) else value
|
||||
|
||||
# Resolve block size (support dict-like or attr-like quant_config)
|
||||
block_size = None
|
||||
if quant_config is not None:
|
||||
if isinstance(quant_config, dict):
|
||||
block_size = quant_config.get("weight_block_size")
|
||||
else:
|
||||
block_size = getattr(quant_config, "weight_block_size", None)
|
||||
if block_size is None:
|
||||
block_size = (value.shape[-2], value.shape[-1])
|
||||
|
||||
block_m, block_n = block_size
|
||||
rows, cols = value.shape[-2], value.shape[-1]
|
||||
|
||||
# Enforce exact tiling like your original
|
||||
if rows % block_m != 0 or cols % block_n != 0:
|
||||
raise ValueError(
|
||||
f"Matrix dimensions ({rows}, {cols}) must be divisible by block sizes ({block_m}, {block_n}). for {target_keys}"
|
||||
)
|
||||
|
||||
# Leading dims can be empty (2D) or include num_experts/... (3D+)
|
||||
leading_shape = value.shape[:-2]
|
||||
rows_tiles = rows // block_m
|
||||
cols_tiles = cols // block_n
|
||||
|
||||
original_shape = value.shape
|
||||
value_fp32 = value.to(torch.float32)
|
||||
|
||||
# Reshape to (..., rows_tiles, block_m, cols_tiles, block_n)
|
||||
reshaped = value_fp32.reshape(*leading_shape, rows_tiles, block_m, cols_tiles, block_n)
|
||||
|
||||
# Per-tile max-abs over the block dims
|
||||
# dims: block_m is at -3, block_n is at -1 after the reshape
|
||||
max_abs = reshaped.abs().amax(dim=(-3, -1))
|
||||
safe_max_abs = torch.where(max_abs > 0, max_abs, torch.ones_like(max_abs))
|
||||
|
||||
# Tile scale (we store inverse scale like your Linear: weight_scale_inv)
|
||||
scales = _FP8_MAX / safe_max_abs
|
||||
scales = torch.where(max_abs > 0, scales, torch.ones_like(scales)) # keep zeros stable
|
||||
|
||||
# Broadcast scales back over the block dims and quantize
|
||||
# max_abs/scales shape: (..., rows_tiles, cols_tiles)
|
||||
scales_broadcast = scales.unsqueeze(-1).unsqueeze(-3) # -> (..., rows_tiles, 1, cols_tiles, 1)
|
||||
scaled = reshaped * scales_broadcast
|
||||
|
||||
if _FP8_IS_INT:
|
||||
quantized = torch.clamp(scaled.round(), min=_FP8_MIN, max=_FP8_MAX).to(_FP8_DTYPE)
|
||||
else:
|
||||
quantized = torch.clamp(scaled, min=_FP8_MIN, max=_FP8_MAX).to(_FP8_DTYPE)
|
||||
|
||||
quantized = quantized.reshape(original_shape)
|
||||
|
||||
inv_scales = (1.0 / scales).to(torch.float32) # shape: (*leading, rows_tiles, cols_tiles)
|
||||
if target_keys.endswith("weight"):
|
||||
scale_key = target_keys.rsplit(".", 1)[0] + ".weight_scale_inv"
|
||||
else:
|
||||
scale_key = target_keys + "_scales_inv"
|
||||
|
||||
# Return both quantized weights and per-tile inverse scales (keeps leading dims, e.g., num_experts)
|
||||
return {
|
||||
target_keys: quantized,
|
||||
scale_key: inv_scales,
|
||||
}
|
||||
|
||||
|
||||
class Fp8Dequantize(QuantizationOp):
|
||||
"""Inverse operation of :class:`Fp8Quantize`. Takes a pair (weight, scale) and reconstructs the fp32 tensor."""
|
||||
|
||||
def __init__(self, block_size: Optional[tuple[int, int]] = None):
|
||||
self.block_size = block_size
|
||||
self.reverse_op = Fp8Quantize
|
||||
|
||||
def convert(
|
||||
self,
|
||||
value: Union[Sequence[torch.Tensor], dict[str, torch.Tensor]],
|
||||
*,
|
||||
context: dict[str, Any],
|
||||
) -> torch.Tensor:
|
||||
if isinstance(value, dict):
|
||||
tensors = list(value.values())
|
||||
else:
|
||||
tensors = list(value) if isinstance(value, Sequence) else [value]
|
||||
if len(tensors) != 2:
|
||||
raise ValueError("Fp8Dequantize expects exactly two tensors: quantized weights and scales.")
|
||||
quantized, scales = tensors
|
||||
if not isinstance(quantized, torch.Tensor) or not isinstance(scales, torch.Tensor):
|
||||
raise TypeError("Fp8Dequantize expects tensors as inputs.")
|
||||
|
||||
quantized_fp32 = quantized.to(torch.float32)
|
||||
rows, cols = quantized_fp32.shape[-2:]
|
||||
block_size = self.block_size
|
||||
if block_size is None:
|
||||
quant_config = context.get("quantization_config")
|
||||
block_size = getattr(quant_config, "weight_block_size", None)
|
||||
if block_size is None:
|
||||
block_size = (rows, cols)
|
||||
block_m, block_n = block_size
|
||||
if rows % block_m != 0 or cols % block_n != 0:
|
||||
raise ValueError(
|
||||
f"Matrix dimensions ({rows}, {cols}) must be divisible by block sizes ({block_m}, {block_n})."
|
||||
)
|
||||
|
||||
reshaped = quantized_fp32.reshape(-1, rows // block_m, block_m, cols // block_n, block_n)
|
||||
expanded_scales = scales.to(torch.float32).reshape(-1, rows // block_m, cols // block_n)
|
||||
expanded_scales = expanded_scales.unsqueeze(-1).unsqueeze(2)
|
||||
dequantized = reshaped * expanded_scales
|
||||
return dequantized.reshape(quantized_fp32.shape)
|
||||
|
||||
@ -236,7 +236,7 @@ class PeftAdapterMixin:
|
||||
**adapter_kwargs,
|
||||
)
|
||||
peft_config.inference_mode = not is_trainable
|
||||
# TODO: WE NEED TOO APPLY OUR DYNAMIC WEIGHT CONVERSION AT SOME POINT HERE!
|
||||
|
||||
# Create and add fresh new adapters into the model.
|
||||
inject_adapter_in_model(peft_config, self, adapter_name, **peft_load_kwargs)
|
||||
|
||||
|
||||
@ -18,7 +18,6 @@ import operator
|
||||
import os
|
||||
import re
|
||||
from functools import partial, reduce
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@ -141,16 +140,6 @@ def _blocks_to_block_sizes(total_size: int, blocks: int | list[int]) -> list[int
|
||||
return [single_size] * blocks
|
||||
|
||||
|
||||
def replace_layer_number_by_wildcard(name: str) -> str:
|
||||
"""
|
||||
Replace the numbers in the `name` by wildcards, only if they are in-between dots (`.`) or if they are between
|
||||
a dot (`.`) and the end of the string.
|
||||
This matches how modules are named/numbered when using a nn.ModuleList or nn.Sequential, but will NOT match
|
||||
numbers in a parameter name itself, e.g. if the param is named `"w1"` or `"w2"`.
|
||||
"""
|
||||
return re.sub(r"\.\d+(\.|$)", lambda m: ".*" + m.group(1), name)
|
||||
|
||||
|
||||
def _get_parameter_tp_plan(parameter_name: str, tp_plan: dict[str, str], is_weight=True) -> str | None:
|
||||
"""
|
||||
Get the TP style for a parameter from the TP plan.
|
||||
@ -161,11 +150,11 @@ def _get_parameter_tp_plan(parameter_name: str, tp_plan: dict[str, str], is_weig
|
||||
The `is_weight` is important because for weights, we want to support `.weights` and `.bias` cases seamlessly! but
|
||||
not parent classes for `post_init` calls
|
||||
"""
|
||||
generic_param_name = replace_layer_number_by_wildcard(parameter_name)
|
||||
generic_param_name = re.sub(r"\d+", "*", parameter_name)
|
||||
if generic_param_name in tp_plan:
|
||||
return tp_plan[generic_param_name]
|
||||
elif is_weight and "." in generic_param_name and (module_name := generic_param_name.rsplit(".", 1)[0]) in tp_plan:
|
||||
return tp_plan[module_name]
|
||||
elif "." in generic_param_name and generic_param_name.rsplit(".", 1)[0] in tp_plan and is_weight:
|
||||
return tp_plan[generic_param_name.rsplit(".", 1)[0]]
|
||||
return None
|
||||
|
||||
|
||||
@ -317,7 +306,7 @@ def repack_weights(
|
||||
return final_ordered_tensor
|
||||
|
||||
|
||||
def get_tensor_shard(param, empty_param, device_mesh, rank, dim, tensor_idx: Optional[int] = None):
|
||||
def get_tensor_shard(param, empty_param, device_mesh, rank, dim):
|
||||
"""
|
||||
Generalized tensor sharding across a multi-dimensional device mesh.
|
||||
Extract only the fraction of the parameter owned by the given `rank` when the parameter would have gone sharding at provided `dim`.
|
||||
@ -369,57 +358,32 @@ def get_tensor_shard(param, empty_param, device_mesh, rank, dim, tensor_idx: Opt
|
||||
rank (int): Global rank of the current process/device.
|
||||
dim (int): Dimension along which to shard the tensor.
|
||||
"""
|
||||
param_dim = empty_param.ndim
|
||||
param_dim = empty_param.dim()
|
||||
|
||||
if dim < 0:
|
||||
dim = param_dim + dim
|
||||
if dim >= param_dim:
|
||||
raise ValueError(f"dim {dim} is out of bounds for tensor of dimension {param_dim}")
|
||||
|
||||
# Flatten the mesh to get the total number of devices
|
||||
mesh_shape = device_mesh.shape
|
||||
world_size = reduce(operator.mul, mesh_shape)
|
||||
if dim < 0:
|
||||
dim = param_dim + dim
|
||||
if empty_param.dim() == 3 and dim == 1 and len(param.get_shape()) == 2:
|
||||
dim = 0
|
||||
elif empty_param.dim() == 3 and dim == 2 and len(param.get_shape()) == 2:
|
||||
dim = 0
|
||||
|
||||
shard_size = math.ceil(empty_param.size(dim) / world_size)
|
||||
start = rank * shard_size
|
||||
end = min(start + shard_size, empty_param.size(dim))
|
||||
|
||||
if dim >= param_dim:
|
||||
raise ValueError(f"dim {dim} is out of bounds for tensor of dimension {param_dim}")
|
||||
|
||||
if rank >= world_size:
|
||||
raise ValueError(f"Rank {rank} is out of bounds for mesh size {world_size}")
|
||||
|
||||
# we have the full tensor not 1 part of it.
|
||||
# in that case, we just assume that the weight was properly saved
|
||||
# and thus because we TP if the layer is colwise it should not use this. Layer should be packed_colwise
|
||||
# to inform that it needs to read form a packed tensor. It will also take care of the module list thingy.
|
||||
# here we take care of potential chunking / layer split / layer chunking.
|
||||
# The only "hard" case is? if we collect q,k,v -> merge it into qkv. In that case
|
||||
# actually we still shard dim=0 does not change
|
||||
# so only case is if the dim of the empty param is 3 and the shard dim is 0 -> we put the
|
||||
# tensor on a certain device (with the input tensor_index)
|
||||
dimensions = param.get_shape()
|
||||
shard_size = math.ceil(empty_param.shape[dim] / world_size)
|
||||
start = rank * shard_size
|
||||
|
||||
if empty_param.dim() == 3 and dim == 0 and len(param.get_shape()) == 2:
|
||||
# special case we don't "shard" just send this entire tensor to the correct rank.
|
||||
if start <= tensor_idx < end:
|
||||
# this tensor does need to be materialized on this device:
|
||||
return param[:]
|
||||
else:
|
||||
return torch.empty([], dtype=torch.int64, device=rank)
|
||||
|
||||
slice_indices = [slice(None)] * len(param.get_shape())
|
||||
|
||||
if start < param.get_shape()[dim]:
|
||||
# Construct slicing index dynamically
|
||||
end = min(start + shard_size, empty_param.shape[dim])
|
||||
slice_indices = [slice(None)] * param_dim
|
||||
if start < empty_param.shape[dim]:
|
||||
slice_indices[dim] = slice(start, end)
|
||||
param = param[tuple(slice_indices)]
|
||||
if isinstance(param, list): # TODO handle the modulelist case!
|
||||
param = [p[:] for p in param]
|
||||
return param
|
||||
|
||||
return param[tuple(slice_indices)]
|
||||
dimensions = list(param.shape)
|
||||
dimensions[dim] = 0
|
||||
return torch.empty(tuple(dimensions), dtype=torch.int64) # empty allocates memory....
|
||||
return torch.empty(tuple(dimensions), dtype=torch.int64)
|
||||
|
||||
|
||||
def distribute_module(
|
||||
@ -446,19 +410,6 @@ class TensorParallelLayer:
|
||||
"""
|
||||
|
||||
use_dtensor = True
|
||||
device_mesh = None
|
||||
rank = None
|
||||
|
||||
# Used to compare the shape of the original tensor
|
||||
empty_param = None
|
||||
|
||||
# Used to init the corresponding DTensor
|
||||
shard = None
|
||||
|
||||
def __init__(self, device_mesh=None, rank=None, empty_param=None):
|
||||
self.rank = rank
|
||||
self.device_mesh = device_mesh
|
||||
self.empty_param = empty_param
|
||||
|
||||
@staticmethod
|
||||
def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh): ...
|
||||
@ -488,12 +439,12 @@ class GatherParallel(TensorParallelLayer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
input_layouts: Placement | None = None,
|
||||
output_layouts: Placement | None = None,
|
||||
use_local_output: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
super().__init__()
|
||||
self.input_layouts = (input_layouts or Replicate(),)
|
||||
self.output_layouts = output_layouts
|
||||
self.desired_input_layouts = (Replicate(),)
|
||||
@ -514,21 +465,6 @@ class GatherParallel(TensorParallelLayer):
|
||||
dist.all_reduce(outputs[0], op=dist.ReduceOp.SUM, async_op=False)
|
||||
return outputs
|
||||
|
||||
def shard_tensor(
|
||||
self,
|
||||
param,
|
||||
param_type=None,
|
||||
param_casting_dtype=None,
|
||||
to_contiguous=None,
|
||||
rank=None,
|
||||
device_mesh=None,
|
||||
tensor_idx=None,
|
||||
):
|
||||
shard = [Replicate()]
|
||||
parameter = param[...].to(param_casting_dtype)
|
||||
self.shard = shard
|
||||
return parameter, shard
|
||||
|
||||
def prepare_module_tp(self, module: nn.Module, device_mesh) -> nn.Module:
|
||||
distribute_module(
|
||||
module,
|
||||
@ -557,23 +493,6 @@ class IsolatedParallel(TensorParallelLayer):
|
||||
# TODO: figure out dynamo support for instance method and switch this to instance method
|
||||
return outputs
|
||||
|
||||
def shard_tensor(
|
||||
self,
|
||||
param,
|
||||
param_type=None,
|
||||
param_casting_dtype=None,
|
||||
to_contiguous=None,
|
||||
rank=None,
|
||||
device_mesh=None,
|
||||
tensor_idx=None,
|
||||
):
|
||||
mesh = device_mesh or self.device_mesh
|
||||
parameter = param[...].to(param_casting_dtype)
|
||||
if mesh is not None:
|
||||
parameter = parameter / mesh.size()
|
||||
self.shard = None
|
||||
return parameter, None
|
||||
|
||||
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
|
||||
param = param[...].to(param_casting_dtype)
|
||||
if to_contiguous:
|
||||
@ -596,8 +515,8 @@ class ReplicateParallel(TensorParallelLayer):
|
||||
This class is used to replicate computation in a TP layer (used in SP regions when we don't use sequence parallelism for example)
|
||||
"""
|
||||
|
||||
def __init__(self, use_dtensor=True, use_local_output=True, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
def __init__(self, *, use_dtensor=True, use_local_output=True):
|
||||
super().__init__()
|
||||
self.input_layouts = (Replicate(),)
|
||||
self.output_layouts = (Replicate(),)
|
||||
self.desired_input_layouts = (Replicate(),)
|
||||
@ -618,33 +537,12 @@ class ReplicateParallel(TensorParallelLayer):
|
||||
def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
|
||||
return outputs.to_local() if use_local_output and isinstance(outputs, DTensor) else outputs
|
||||
|
||||
def shard_tensor(
|
||||
self,
|
||||
param,
|
||||
param_type=None,
|
||||
param_casting_dtype=None,
|
||||
to_contiguous=None,
|
||||
rank=None,
|
||||
device_mesh=None,
|
||||
tensor_idx=None,
|
||||
):
|
||||
parameter = param[...].to(param_casting_dtype)
|
||||
shard = [Replicate()]
|
||||
self.shard = shard
|
||||
return parameter, shard
|
||||
|
||||
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
|
||||
parameter, shard = self.shard_tensor(
|
||||
param,
|
||||
param_type=param_type,
|
||||
param_casting_dtype=param_casting_dtype,
|
||||
to_contiguous=to_contiguous,
|
||||
rank=rank,
|
||||
device_mesh=device_mesh,
|
||||
)
|
||||
if self.use_dtensor:
|
||||
parameter = DTensor.from_local(parameter, device_mesh, shard, run_check=False)
|
||||
return parameter
|
||||
param = param[...].to(param_casting_dtype)
|
||||
if to_contiguous:
|
||||
param = param.contiguous()
|
||||
param = DTensor.from_local(param, device_mesh, [Replicate()], run_check=False)
|
||||
return param
|
||||
|
||||
|
||||
class ColwiseParallel(TensorParallelLayer):
|
||||
@ -654,13 +552,13 @@ class ColwiseParallel(TensorParallelLayer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
input_layouts: Placement | None = None,
|
||||
output_layouts: Placement | None = None,
|
||||
use_local_output: bool = True,
|
||||
use_dtensor=True,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
super().__init__()
|
||||
self.input_layouts = (input_layouts or Replicate(),)
|
||||
self.output_layouts = (output_layouts or Shard(-1),)
|
||||
self.desired_input_layouts = (Replicate(),)
|
||||
@ -680,34 +578,18 @@ class ColwiseParallel(TensorParallelLayer):
|
||||
input_tensor = input_tensor.redistribute(placements=desired_input_layouts, async_op=False)
|
||||
return input_tensor
|
||||
|
||||
def shard_tensor(
|
||||
self,
|
||||
param,
|
||||
param_type=None,
|
||||
param_casting_dtype=None,
|
||||
to_contiguous=None,
|
||||
rank=None,
|
||||
device_mesh=None,
|
||||
tensor_idx=None,
|
||||
):
|
||||
device_mesh = self.device_mesh
|
||||
empty_param = self.empty_param
|
||||
rank = self.rank
|
||||
if param_type == "bias":
|
||||
parameter = get_tensor_shard(param, empty_param, device_mesh, rank, -1, tensor_idx)
|
||||
shard = [Shard(-1)]
|
||||
else:
|
||||
shard = [Shard(-2)]
|
||||
parameter = get_tensor_shard(param, empty_param, device_mesh, rank, -2, tensor_idx)
|
||||
parameter = parameter.to(param_casting_dtype)
|
||||
self.shard = shard
|
||||
return parameter, shard
|
||||
|
||||
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
|
||||
# colwise shard weight/bias to Shard(0), weight be Shard(-2) (0 if you have 1 dim only)
|
||||
# means Colwise as Linear is input * weight^T + bias, where
|
||||
# weight would become Shard(1)
|
||||
parameter, shard = self.shard_tensor(param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh)
|
||||
if param_type == "bias":
|
||||
parameter = get_tensor_shard(param, empty_param, device_mesh, rank, -1)
|
||||
shard = [Shard(-1)]
|
||||
else:
|
||||
shard = [Shard(-2)]
|
||||
parameter = get_tensor_shard(param, empty_param, device_mesh, rank, -2)
|
||||
|
||||
parameter = parameter.to(param_casting_dtype)
|
||||
if to_contiguous:
|
||||
parameter = parameter.contiguous()
|
||||
if self.use_dtensor:
|
||||
@ -726,21 +608,6 @@ class ColwiseParallel(TensorParallelLayer):
|
||||
|
||||
|
||||
class PackedColwiseParallel(ColwiseParallel):
|
||||
def shard_tensor(
|
||||
self,
|
||||
param,
|
||||
param_type=None,
|
||||
param_casting_dtype=None,
|
||||
to_contiguous=None,
|
||||
rank=None,
|
||||
device_mesh=None,
|
||||
tensor_idx=None,
|
||||
):
|
||||
device_mesh = device_mesh or self.device_mesh
|
||||
empty_param = self.empty_param
|
||||
rank = rank if rank is not None else self.rank
|
||||
return get_packed_weights(param, empty_param, device_mesh, rank, -2).to(param_casting_dtype), [Shard(-2)]
|
||||
|
||||
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
|
||||
# colwise shard weight/bias to Shard(0), weight be Shard(-2) (0 if you have 1 dim only)
|
||||
# means Colwise as Linear is input * weight^T + bias, where
|
||||
@ -775,41 +642,18 @@ class RowwiseParallel(TensorParallelLayer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
input_layouts: Placement | None = None,
|
||||
output_layouts: Placement | None = None,
|
||||
use_local_output: bool = True,
|
||||
use_dtensor=True,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
super().__init__()
|
||||
self.input_layouts = (input_layouts or Shard(-1),)
|
||||
self.output_layouts = (output_layouts or Replicate(),)
|
||||
self.use_local_output = use_local_output
|
||||
self.use_dtensor = use_dtensor
|
||||
|
||||
def shard_tensor(
|
||||
self,
|
||||
param,
|
||||
param_type=None,
|
||||
param_casting_dtype=None,
|
||||
to_contiguous=None,
|
||||
rank=None,
|
||||
device_mesh=None,
|
||||
tensor_idx=None,
|
||||
):
|
||||
device_mesh = device_mesh or self.device_mesh
|
||||
empty_param = self.empty_param
|
||||
rank = rank if rank is not None else self.rank
|
||||
if param_type == "bias":
|
||||
shard = [Replicate()]
|
||||
parameter = param[...]
|
||||
else:
|
||||
parameter = get_tensor_shard(param, empty_param, device_mesh, rank, -1, tensor_idx=tensor_idx)
|
||||
shard = [Shard(-1)]
|
||||
parameter = parameter.to(param_casting_dtype)
|
||||
self.shard = shard
|
||||
return parameter, shard
|
||||
|
||||
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
|
||||
# Rowwise shard weight to Shard(1), bias to Replicate(), weight be Shard(1)
|
||||
# means Rowwise as nn.Linear is input * weight^T + bias, where
|
||||
@ -881,21 +725,6 @@ class RowwiseParallel(TensorParallelLayer):
|
||||
|
||||
|
||||
class PackedRowwiseParallel(RowwiseParallel):
|
||||
def shard_tensor(
|
||||
self,
|
||||
param,
|
||||
param_type=None,
|
||||
param_casting_dtype=None,
|
||||
to_contiguous=None,
|
||||
rank=None,
|
||||
device_mesh=None,
|
||||
tensor_idx=None,
|
||||
):
|
||||
device_mesh = device_mesh or self.device_mesh
|
||||
empty_param = self.empty_param
|
||||
rank = rank if rank is not None else self.rank
|
||||
return get_packed_weights(param, empty_param, device_mesh, rank, -1), [Shard(-1)]
|
||||
|
||||
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
|
||||
# colwise shard weight/bias to Shard(0), weight be Shard(-2) (0 if you have 1 dim only)
|
||||
# means Colwise as Linear is input * weight^T + bias, where
|
||||
@ -954,8 +783,8 @@ class SequenceParallel(TensorParallelLayer):
|
||||
to ensure that they are replicated.
|
||||
"""
|
||||
|
||||
def __init__(self, sequence_dim: int = 1, use_local_output: bool = False, use_dtensor=False, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
def __init__(self, *, sequence_dim: int = 1, use_local_output: bool = False, use_dtensor=False):
|
||||
super().__init__()
|
||||
self.input_layouts = (Replicate(),)
|
||||
self.desired_input_layouts = (Shard(1),)
|
||||
self.output_layouts = (Replicate(),)
|
||||
@ -964,21 +793,6 @@ class SequenceParallel(TensorParallelLayer):
|
||||
self.sequence_sharding = (Shard(sequence_dim),)
|
||||
self.use_local_output = use_local_output
|
||||
|
||||
def shard_tensor(
|
||||
self,
|
||||
param,
|
||||
param_type=None,
|
||||
param_casting_dtype=None,
|
||||
to_contiguous=None,
|
||||
rank=None,
|
||||
device_mesh=None,
|
||||
tensor_idx=None,
|
||||
):
|
||||
parameter = param[...].to(param_casting_dtype)
|
||||
shard = [Replicate()]
|
||||
self.shard = shard
|
||||
return parameter, shard
|
||||
|
||||
@staticmethod
|
||||
def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh):
|
||||
input_tensor = inputs[0]
|
||||
@ -1013,34 +827,10 @@ class GroupedGemmParallel(TensorParallelLayer):
|
||||
Applies Expert Parallelism to MoE experts by loading the correct experts on each device.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.use_dtensor = False
|
||||
|
||||
def shard_tensor(
|
||||
self,
|
||||
param,
|
||||
param_type=None,
|
||||
param_casting_dtype=None,
|
||||
to_contiguous=None,
|
||||
rank=None,
|
||||
device_mesh=None,
|
||||
tensor_idx=None,
|
||||
):
|
||||
empty_param = self.empty_param
|
||||
ep_rank = self.rank
|
||||
device_mesh = self.device_mesh
|
||||
|
||||
global_num_experts = empty_param.shape[0]
|
||||
if global_num_experts % device_mesh.size() != 0:
|
||||
raise ValueError(
|
||||
f"Global number of experts must be divisible by number of devices: {global_num_experts} % {device_mesh.size()} != 0"
|
||||
)
|
||||
local_num_experts = global_num_experts // device_mesh.size()
|
||||
parameter = param[ep_rank * local_num_experts : (ep_rank + 1) * local_num_experts].to(param_casting_dtype)
|
||||
self.shard = None
|
||||
return parameter, None
|
||||
|
||||
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
|
||||
ep_rank = rank
|
||||
global_num_experts = empty_param.shape[0]
|
||||
@ -1061,8 +851,8 @@ class RouterParallel(TensorParallelLayer):
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
self.use_dtensor = False
|
||||
|
||||
@staticmethod
|
||||
@ -1127,20 +917,6 @@ class RouterParallel(TensorParallelLayer):
|
||||
) # masking class for one hot
|
||||
return router_scores, router_indices
|
||||
|
||||
def shard_tensor(
|
||||
self,
|
||||
param,
|
||||
param_type=None,
|
||||
param_casting_dtype=None,
|
||||
to_contiguous=None,
|
||||
rank=None,
|
||||
device_mesh=None,
|
||||
tensor_idx=None,
|
||||
):
|
||||
parameter = param[...].to(param_casting_dtype)
|
||||
self.shard = None
|
||||
return parameter, None
|
||||
|
||||
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
|
||||
# TODO: i'd like for this to be the default
|
||||
param = param[...].to(param_casting_dtype)
|
||||
@ -1283,9 +1059,6 @@ def shard_and_distribute_module(
|
||||
if current_shard_plan is not None:
|
||||
try:
|
||||
tp_layer = ALL_PARALLEL_STYLES[current_shard_plan]
|
||||
tp_layer.empty_param = empty_param
|
||||
tp_layer.device_mesh = device_mesh
|
||||
tp_layer.rank = rank
|
||||
param = tp_layer.partition_tensor(
|
||||
param, empty_param, param_type, param_casting_dtype, is_contiguous, rank, device_mesh
|
||||
)
|
||||
@ -1313,7 +1086,7 @@ def verify_tp_plan(expected_keys: list[str], tp_plan: dict[str, str] | None):
|
||||
if tp_plan is None:
|
||||
return
|
||||
|
||||
generic_keys = {replace_layer_number_by_wildcard(key) for key in expected_keys}
|
||||
generic_keys = {re.sub(r"\d+", "*", key) for key in expected_keys}
|
||||
unsharded_layers = set(generic_keys)
|
||||
unused_rules = tp_plan
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -37,6 +37,7 @@ from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, filter_out_non_signature_kwargs
|
||||
from ...utils.deprecation import deprecate_kwarg
|
||||
from ...utils.generic import check_model_inputs
|
||||
from .configuration_aimv2 import Aimv2Config, Aimv2TextConfig, Aimv2VisionConfig
|
||||
|
||||
@ -405,14 +406,13 @@ class Aimv2PreTrainedModel(PreTrainedModel):
|
||||
_supports_flash_attn = True
|
||||
_supports_flex_attn = True
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
super()._init_weights(module)
|
||||
if hasattr(module, "logit_scale"):
|
||||
if isinstance(module.logit_scale, nn.Parameter):
|
||||
module.logit_scale.fill_(math.log(1 / 0.07))
|
||||
module.logit_scale.data.fill_(math.log(1 / 0.07))
|
||||
elif isinstance(module, Aimv2AttentionPoolingHead):
|
||||
module.cls_token.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.cls_token.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
|
||||
|
||||
@auto_docstring(
|
||||
@ -445,11 +445,13 @@ class Aimv2VisionModel(Aimv2PreTrainedModel):
|
||||
def get_input_embeddings(self) -> nn.Module:
|
||||
return self.embeddings.patch_embed
|
||||
|
||||
@deprecate_kwarg("attention_mask", version="v4.58.0")
|
||||
@check_model_inputs(tie_last_hidden_states=False)
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
pixel_values,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> BaseModelOutputWithPooling:
|
||||
r"""
|
||||
|
||||
@ -32,6 +32,7 @@ from ...utils import (
|
||||
auto_docstring,
|
||||
can_return_tuple,
|
||||
)
|
||||
from ...utils.deprecation import deprecate_kwarg
|
||||
from ...utils.generic import check_model_inputs
|
||||
from ..clip.modeling_clip import CLIPModel, CLIPTextEmbeddings, _get_vector_norm
|
||||
from ..llama.modeling_llama import LlamaMLP, LlamaRMSNorm
|
||||
@ -448,14 +449,13 @@ class Aimv2PreTrainedModel(PreTrainedModel):
|
||||
_supports_flash_attn = True
|
||||
_supports_flex_attn = True
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
super()._init_weights(module)
|
||||
if hasattr(module, "logit_scale"):
|
||||
if isinstance(module.logit_scale, nn.Parameter):
|
||||
module.logit_scale.fill_(math.log(1 / 0.07))
|
||||
module.logit_scale.data.fill_(math.log(1 / 0.07))
|
||||
elif isinstance(module, Aimv2AttentionPoolingHead):
|
||||
module.cls_token.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.cls_token.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
|
||||
|
||||
@auto_docstring(
|
||||
@ -488,11 +488,13 @@ class Aimv2VisionModel(Aimv2PreTrainedModel):
|
||||
def get_input_embeddings(self) -> nn.Module:
|
||||
return self.embeddings.patch_embed
|
||||
|
||||
@deprecate_kwarg("attention_mask", version="v4.58.0")
|
||||
@check_model_inputs(tie_last_hidden_states=False)
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
pixel_values,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> BaseModelOutputWithPooling:
|
||||
r"""
|
||||
|
||||
@ -302,22 +302,21 @@ class AlbertPreTrainedModel(PreTrainedModel):
|
||||
"attentions": AlbertAttention,
|
||||
}
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights."""
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.zero_()
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.padding_idx is not None:
|
||||
module.weight[module.padding_idx].zero_()
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.zero_()
|
||||
module.weight.fill_(1.0)
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
elif isinstance(module, AlbertMLMHead):
|
||||
module.bias.zero_()
|
||||
module.bias.data.zero_()
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -426,10 +425,7 @@ class AlbertModel(AlbertPreTrainedModel):
|
||||
"""
|
||||
)
|
||||
class AlbertForPreTraining(AlbertPreTrainedModel):
|
||||
_tied_weights_keys = {
|
||||
"predictions.decoder.weight": "albert.embeddings.word_embeddings.weight",
|
||||
"predictions.decoder.bias": "predictions.bias",
|
||||
}
|
||||
_tied_weights_keys = ["predictions.decoder.bias", "predictions.decoder.weight"]
|
||||
|
||||
def __init__(self, config: AlbertConfig):
|
||||
super().__init__(config)
|
||||
@ -529,6 +525,7 @@ class AlbertMLMHead(nn.Module):
|
||||
self.dense = nn.Linear(config.hidden_size, config.embedding_size)
|
||||
self.decoder = nn.Linear(config.embedding_size, config.vocab_size)
|
||||
self.activation = ACT2FN[config.hidden_act]
|
||||
self.decoder.bias = self.bias
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.dense(hidden_states)
|
||||
@ -540,6 +537,14 @@ class AlbertMLMHead(nn.Module):
|
||||
|
||||
return prediction_scores
|
||||
|
||||
def _tie_weights(self) -> None:
|
||||
# For accelerate compatibility and to not break backward compatibility
|
||||
if self.decoder.bias.device.type == "meta":
|
||||
self.decoder.bias = self.bias
|
||||
else:
|
||||
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
|
||||
self.bias = self.decoder.bias
|
||||
|
||||
|
||||
class AlbertSOPHead(nn.Module):
|
||||
def __init__(self, config: AlbertConfig):
|
||||
@ -556,10 +561,7 @@ class AlbertSOPHead(nn.Module):
|
||||
|
||||
@auto_docstring
|
||||
class AlbertForMaskedLM(AlbertPreTrainedModel):
|
||||
_tied_weights_keys = {
|
||||
"predictions.decoder.weight": "albert.embeddings.word_embeddings.weight",
|
||||
"predictions.decoder.bias": "predictions.bias",
|
||||
}
|
||||
_tied_weights_keys = ["predictions.decoder.bias", "predictions.decoder.weight"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
@ -823,25 +823,24 @@ class AlignPreTrainedModel(PreTrainedModel):
|
||||
input_modalities = ["image", "text"]
|
||||
supports_gradient_checkpointing = True
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module: nn.Module):
|
||||
"""Initialize the weights"""
|
||||
std = self.config.initializer_range
|
||||
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
||||
module.weight.normal_(mean=0.0, std=std)
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.zero_()
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, AlignModel):
|
||||
nn.init.xavier_uniform_(module.text_projection.weight)
|
||||
module.text_projection.bias.zero_()
|
||||
module.temperature.fill_(self.config.temperature_init_value)
|
||||
module.text_projection.bias.data.zero_()
|
||||
module.temperature.data.fill_(self.config.temperature_init_value)
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.normal_(mean=0.0, std=std)
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight[module.padding_idx].zero_()
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
if isinstance(module, (nn.LayerNorm, nn.BatchNorm2d)):
|
||||
module.bias.zero_()
|
||||
module.weight.fill_(1.0)
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
|
||||
@auto_docstring(
|
||||
|
||||
@ -770,7 +770,6 @@ class AltCLIPPreTrainedModel(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_module = []
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
factor = self.config.initializer_factor
|
||||
@ -798,21 +797,23 @@ class AltCLIPPreTrainedModel(PreTrainedModel):
|
||||
module.text_projection.weight,
|
||||
std=module.text_embed_dim**-0.5 * self.config.initializer_factor,
|
||||
)
|
||||
module.text_projection._is_hf_initialized = True
|
||||
nn.init.normal_(
|
||||
module.visual_projection.weight,
|
||||
std=module.vision_embed_dim**-0.5 * self.config.initializer_factor,
|
||||
)
|
||||
module.visual_projection._is_hf_initialized = True
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.zero_()
|
||||
module.weight.fill_(1.0)
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
elif isinstance(module, nn.Linear):
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_factor)
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_factor)
|
||||
if module.bias is not None:
|
||||
module.bias.zero_()
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_factor)
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_factor)
|
||||
if module.padding_idx is not None:
|
||||
module.weight[module.padding_idx].zero_()
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
|
||||
|
||||
class AltCLIPVisionTransformer(nn.Module):
|
||||
|
||||
@ -17,6 +17,7 @@ Image/Text processor class for AltCLIP
|
||||
"""
|
||||
|
||||
from ...processing_utils import ProcessorMixin
|
||||
from ...utils.deprecation import deprecate_kwarg
|
||||
|
||||
|
||||
class AltCLIPProcessor(ProcessorMixin):
|
||||
@ -34,6 +35,7 @@ class AltCLIPProcessor(ProcessorMixin):
|
||||
The tokenizer is a required input.
|
||||
"""
|
||||
|
||||
@deprecate_kwarg(old_name="feature_extractor", version="5.0.0", new_name="image_processor")
|
||||
def __init__(self, image_processor=None, tokenizer=None):
|
||||
super().__init__(image_processor, tokenizer)
|
||||
|
||||
|
||||
@ -106,6 +106,7 @@ class ApertusConfig(PreTrainedConfig):
|
||||
"layers.*.self_attn.o_proj": "rowwise_rep", # we need to replicate here due to the added norm on q and k
|
||||
"layers.*.mlp.up_proj": "colwise",
|
||||
"layers.*.mlp.down_proj": "rowwise",
|
||||
"layers.*.mlp.gate_proj": "colwise",
|
||||
}
|
||||
base_model_pp_plan = {
|
||||
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||
|
||||
@ -429,7 +429,7 @@ class ApertusModel(ApertusPreTrainedModel):
|
||||
|
||||
@auto_docstring
|
||||
class ApertusForCausalLM(ApertusPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tp_plan = {"lm_head": "colwise_rep"}
|
||||
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
||||
|
||||
|
||||
@ -123,6 +123,7 @@ class ApertusConfig(LlamaConfig):
|
||||
"layers.*.self_attn.o_proj": "rowwise_rep", # we need to replicate here due to the added norm on q and k
|
||||
"layers.*.mlp.up_proj": "colwise",
|
||||
"layers.*.mlp.down_proj": "rowwise",
|
||||
"layers.*.mlp.gate_proj": "colwise",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
|
||||
@ -434,7 +434,7 @@ class ArceeModel(ArceePreTrainedModel):
|
||||
|
||||
@auto_docstring(checkpoint="arcee-ai/AFM-4.5B")
|
||||
class ArceeForCausalLM(ArceePreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tp_plan = {"lm_head": "colwise_rep"}
|
||||
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
||||
|
||||
|
||||
@ -99,14 +99,15 @@ class AriaTextConfig(PreTrainedConfig):
|
||||
|
||||
model_type = "aria_text"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
# Default tensor parallel plan for base model `AriaTextModel`
|
||||
base_model_tp_plan = {
|
||||
"layers.*.self_attn.q_proj": "colwise",
|
||||
"layers.*.self_attn.k_proj": "colwise",
|
||||
"layers.*.self_attn.v_proj": "colwise",
|
||||
"layers.*.self_attn.o_proj": "rowwise",
|
||||
"layers.*.mlp.shared_experts.gate_proj": "colwise",
|
||||
"layers.*.mlp.shared_experts.up_proj": "colwise",
|
||||
"layers.*.mlp.shared_experts.down_proj": "rowwise",
|
||||
"layers.*.mlp.gate_proj": "colwise",
|
||||
"layers.*.mlp.up_proj": "colwise",
|
||||
"layers.*.mlp.down_proj": "rowwise",
|
||||
}
|
||||
base_model_pp_plan = {
|
||||
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||
|
||||
@ -585,11 +585,10 @@ class AriaTextPreTrainedModel(PreTrainedModel):
|
||||
"attentions": AriaTextAttention,
|
||||
}
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
super()._init_weights(module)
|
||||
if isinstance(module, AriaGroupedExpertsGemm):
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
|
||||
|
||||
@auto_docstring
|
||||
@ -609,7 +608,6 @@ class AriaPreTrainedModel(PreTrainedModel):
|
||||
"attentions": AriaTextAttention,
|
||||
}
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
super()._init_weights(module)
|
||||
if isinstance(module, AriaProjector):
|
||||
@ -762,7 +760,7 @@ class AriaTextModel(AriaTextPreTrainedModel):
|
||||
|
||||
@auto_docstring
|
||||
class AriaTextForCausalLM(AriaTextPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tp_plan = {"lm_head": "colwise_rep"}
|
||||
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
||||
|
||||
@ -892,6 +890,8 @@ class AriaModelOutputWithPast(BaseModelOutputWithPast):
|
||||
"""
|
||||
)
|
||||
class AriaModel(AriaPreTrainedModel):
|
||||
_checkpoint_conversion_mapping = {"language_model.model": "language_model"}
|
||||
|
||||
def __init__(self, config: AriaConfig):
|
||||
super().__init__(config)
|
||||
self.vision_tower = AutoModel.from_config(config.vision_config)
|
||||
@ -1048,12 +1048,12 @@ class AriaModel(AriaPreTrainedModel):
|
||||
)
|
||||
class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
|
||||
_checkpoint_conversion_mapping = {
|
||||
r"^language_model.model": "model.language_model",
|
||||
r"^vision_tower": "model.vision_tower",
|
||||
r"^multi_modal_projector": "model.multi_modal_projector",
|
||||
r"^language_model.lm_head": "lm_head",
|
||||
"^language_model.model": "model.language_model",
|
||||
"^vision_tower": "model.vision_tower",
|
||||
"^multi_modal_projector": "model.multi_modal_projector",
|
||||
"^language_model.lm_head": "lm_head",
|
||||
}
|
||||
_tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"}
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
|
||||
def __init__(self, config: AriaConfig):
|
||||
super().__init__(config)
|
||||
|
||||
@ -169,15 +169,6 @@ class AriaTextConfig(LlamaConfig):
|
||||
|
||||
model_type = "aria_text"
|
||||
base_config_key = "text_config"
|
||||
base_model_tp_plan = {
|
||||
"layers.*.self_attn.q_proj": "colwise",
|
||||
"layers.*.self_attn.k_proj": "colwise",
|
||||
"layers.*.self_attn.v_proj": "colwise",
|
||||
"layers.*.self_attn.o_proj": "rowwise",
|
||||
"layers.*.mlp.shared_experts.gate_proj": "colwise",
|
||||
"layers.*.mlp.shared_experts.up_proj": "colwise",
|
||||
"layers.*.mlp.shared_experts.down_proj": "rowwise",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -1196,11 +1187,10 @@ class AriaTextPreTrainedModel(PreTrainedModel):
|
||||
"attentions": AriaTextAttention,
|
||||
}
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
super()._init_weights(module)
|
||||
if isinstance(module, AriaGroupedExpertsGemm):
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
|
||||
|
||||
class AriaPreTrainedModel(LlamaPreTrainedModel):
|
||||
@ -1209,7 +1199,6 @@ class AriaPreTrainedModel(LlamaPreTrainedModel):
|
||||
_can_compile_fullgraph = False # MoE models don't work with torch.compile (dynamic slicing)
|
||||
_supports_attention_backend = True
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
PreTrainedModel._init_weights(self, module)
|
||||
if isinstance(module, AriaProjector):
|
||||
@ -1227,7 +1216,7 @@ class AriaTextModel(LlamaModel):
|
||||
|
||||
|
||||
class AriaTextForCausalLM(AriaTextPreTrainedModel, LlamaForCausalLM):
|
||||
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
|
||||
def __init__(self, config: AriaTextConfig):
|
||||
super().__init__(config)
|
||||
@ -1366,8 +1355,6 @@ class AriaModel(LlavaModel):
|
||||
"""
|
||||
)
|
||||
class AriaForConditionalGeneration(LlavaForConditionalGeneration):
|
||||
_tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"}
|
||||
|
||||
def get_image_features(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
|
||||
@ -300,26 +300,23 @@ class ASTPreTrainedModel(PreTrainedModel):
|
||||
"attentions": ASTSelfAttention,
|
||||
}
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
|
||||
"""Initialize the weights"""
|
||||
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
||||
# Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
|
||||
# `trunc_normal_cpu` not implemented in `half` issues
|
||||
module.weight.copy_(
|
||||
nn.init.trunc_normal_(module.weight.to(torch.float32), mean=0.0, std=self.config.initializer_range).to(
|
||||
module.weight.dtype
|
||||
)
|
||||
)
|
||||
module.weight.data = nn.init.trunc_normal_(
|
||||
module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range
|
||||
).to(module.weight.dtype)
|
||||
if module.bias is not None:
|
||||
module.bias.zero_()
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.zero_()
|
||||
module.weight.fill_(1.0)
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
elif isinstance(module, ASTEmbeddings):
|
||||
module.cls_token.zero_()
|
||||
module.position_embeddings.zero_()
|
||||
module.distillation_token.zero_()
|
||||
module.cls_token.data.zero_()
|
||||
module.position_embeddings.data.zero_()
|
||||
module.distillation_token.data.zero_()
|
||||
|
||||
|
||||
@auto_docstring
|
||||
|
||||
@ -264,7 +264,6 @@ class AudioFlamingo3PreTrainedModel(PreTrainedModel):
|
||||
_supports_flash_attn = True
|
||||
_supports_sdpa = True
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
# important: this ported version of AudioFlamingo3 isn't meant for training from scratch - only
|
||||
# inference and fine-tuning - so the proper init weights code has been removed
|
||||
@ -275,16 +274,16 @@ class AudioFlamingo3PreTrainedModel(PreTrainedModel):
|
||||
)
|
||||
|
||||
if isinstance(module, (nn.Linear, nn.Conv1d)):
|
||||
module.weight.normal_(mean=0.0, std=std)
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.zero_()
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.weight.fill_(1.0)
|
||||
module.bias.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.normal_(mean=0.0, std=std)
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight[module.padding_idx].zero_()
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
|
||||
|
||||
@auto_docstring(
|
||||
@ -436,9 +435,10 @@ class AudioFlamingo3MultiModalProjector(nn.Module):
|
||||
"""
|
||||
)
|
||||
class AudioFlamingo3ForConditionalGeneration(AudioFlamingo3PreTrainedModel, GenerationMixin):
|
||||
_keep_in_fp32_modules_strict = None
|
||||
_tied_weights_keys = None
|
||||
_tp_plan = None
|
||||
_pp_plan = None
|
||||
_keep_in_fp32_modules_strict = None
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
@ -446,6 +446,9 @@ class AudioFlamingo3ForConditionalGeneration(AudioFlamingo3PreTrainedModel, Gene
|
||||
self.audio_tower = AutoModel.from_config(config.audio_config)
|
||||
self.language_model = AutoModelForCausalLM.from_config(config.text_config)
|
||||
self.multi_modal_projector = AudioFlamingo3MultiModalProjector(config)
|
||||
# Similar to Qwen2Audio
|
||||
if self.language_model._tied_weights_keys is not None:
|
||||
self._tied_weights_keys = [f"language_model.{k}" for k in self.language_model._tied_weights_keys]
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
@ -136,12 +136,16 @@ class AudioFlamingo3MultiModalProjector(VoxtralMultiModalProjector):
|
||||
"""
|
||||
)
|
||||
class AudioFlamingo3ForConditionalGeneration(VoxtralForConditionalGeneration):
|
||||
_tied_weights_keys = None
|
||||
_tp_plan = None
|
||||
_pp_plan = None
|
||||
_keep_in_fp32_modules_strict = None
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
# Similar to Qwen2Audio
|
||||
if self.language_model._tied_weights_keys is not None:
|
||||
self._tied_weights_keys = [f"language_model.{k}" for k in self.language_model._tied_weights_keys]
|
||||
|
||||
def get_audio_features(
|
||||
self, input_features: torch.FloatTensor, input_features_mask: torch.Tensor
|
||||
|
||||
@ -442,15 +442,6 @@ TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]](
|
||||
"GPT2TokenizerFast" if is_tokenizers_available() else None,
|
||||
),
|
||||
),
|
||||
(
|
||||
"ministral",
|
||||
(
|
||||
"MistralCommonTokenizer"
|
||||
if is_mistral_common_available()
|
||||
else ("LlamaTokenizer" if is_sentencepiece_available() else None),
|
||||
"LlamaTokenizerFast" if is_tokenizers_available() and not is_mistral_common_available() else None,
|
||||
),
|
||||
),
|
||||
(
|
||||
"mistral",
|
||||
(
|
||||
@ -460,15 +451,6 @@ TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]](
|
||||
"LlamaTokenizerFast" if is_tokenizers_available() and not is_mistral_common_available() else None,
|
||||
),
|
||||
),
|
||||
(
|
||||
"mistral3",
|
||||
(
|
||||
"MistralCommonTokenizer"
|
||||
if is_mistral_common_available()
|
||||
else ("LlamaTokenizer" if is_sentencepiece_available() else None),
|
||||
"LlamaTokenizerFast" if is_tokenizers_available() and not is_mistral_common_available() else None,
|
||||
),
|
||||
),
|
||||
(
|
||||
"mixtral",
|
||||
(
|
||||
|
||||
@ -826,22 +826,21 @@ class AutoformerPreTrainedModel(PreTrainedModel):
|
||||
main_input_name = "past_values"
|
||||
supports_gradient_checkpointing = True
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module: nn.Module):
|
||||
std = self.config.init_std
|
||||
if isinstance(module, (nn.Linear, nn.Conv1d)):
|
||||
module.weight.normal_(mean=0.0, std=std)
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.zero_()
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, AutoformerSinusoidalPositionalEmbedding):
|
||||
module._init_weight()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.normal_(mean=0.0, std=std)
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight[module.padding_idx].zero_()
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.weight.fill_(1.0)
|
||||
module.bias.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.data.zero_()
|
||||
|
||||
# copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask
|
||||
def _update_full_mask(
|
||||
|
||||
@ -90,6 +90,7 @@ class AyaVisionMultiModalProjector(nn.Module):
|
||||
@auto_docstring
|
||||
class AyaVisionPreTrainedModel(PreTrainedModel):
|
||||
config: AyaVisionConfig
|
||||
base_model_prefix = ""
|
||||
input_modalities = ["image", "text"]
|
||||
supports_gradient_checkpointing = True
|
||||
_skip_keys_device_placement = "past_key_values"
|
||||
@ -162,6 +163,8 @@ class AyaVisionModelOutputWithPast(BaseModelOutputWithPast):
|
||||
"""
|
||||
)
|
||||
class AyaVisionModel(AyaVisionPreTrainedModel):
|
||||
_checkpoint_conversion_mapping = {"language_model.model": "language_model"}
|
||||
|
||||
def __init__(self, config: AyaVisionConfig):
|
||||
super().__init__(config)
|
||||
self.vision_tower = AutoModel.from_config(config.vision_config)
|
||||
@ -330,12 +333,12 @@ class AyaVisionModel(AyaVisionPreTrainedModel):
|
||||
)
|
||||
class AyaVisionForConditionalGeneration(AyaVisionPreTrainedModel, GenerationMixin):
|
||||
_checkpoint_conversion_mapping = {
|
||||
r"^language_model.model": "model.language_model",
|
||||
r"^vision_tower": "model.vision_tower",
|
||||
r"^multi_modal_projector": "model.multi_modal_projector",
|
||||
r"^language_model.lm_head": "lm_head",
|
||||
"^language_model.model": "model.language_model",
|
||||
"^vision_tower": "model.vision_tower",
|
||||
"^multi_modal_projector": "model.multi_modal_projector",
|
||||
"^language_model.lm_head": "lm_head",
|
||||
}
|
||||
_tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"}
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
|
||||
def __init__(self, config: AyaVisionConfig):
|
||||
super().__init__(config)
|
||||
|
||||
@ -1126,13 +1126,12 @@ class BambaPreTrainedModel(PreTrainedModel):
|
||||
# Note: only supports HybridMambaAttentionDynamicCache
|
||||
_is_stateful = True
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
super()._init_weights(module)
|
||||
if isinstance(module, BambaMixer):
|
||||
module.dt_bias.fill_(1.0)
|
||||
module.A_log.copy_(torch.log(torch.arange(1, module.num_heads + 1)))
|
||||
module.D.fill_(1.0)
|
||||
module.dt_bias.data.fill_(1.0)
|
||||
module.A_log.data = torch.log(torch.arange(1, module.num_heads + 1))
|
||||
module.D.data.fill_(1.0)
|
||||
|
||||
|
||||
@auto_docstring
|
||||
@ -1384,7 +1383,7 @@ class BambaModel(BambaPreTrainedModel):
|
||||
|
||||
@auto_docstring
|
||||
class BambaForCausalLM(BambaPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tp_plan = {"lm_head": "colwise_rep"}
|
||||
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
||||
|
||||
|
||||
@ -800,13 +800,12 @@ class BambaPreTrainedModel(PreTrainedModel):
|
||||
# Note: only supports HybridMambaAttentionDynamicCache
|
||||
_is_stateful = True
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
super()._init_weights(module)
|
||||
if isinstance(module, BambaMixer):
|
||||
module.dt_bias.fill_(1.0)
|
||||
module.A_log.copy_(torch.log(torch.arange(1, module.num_heads + 1)))
|
||||
module.D.fill_(1.0)
|
||||
module.dt_bias.data.fill_(1.0)
|
||||
module.A_log.data = torch.log(torch.arange(1, module.num_heads + 1))
|
||||
module.D.data.fill_(1.0)
|
||||
|
||||
|
||||
@auto_docstring
|
||||
|
||||
@ -329,21 +329,19 @@ class BarkPreTrainedModel(PreTrainedModel):
|
||||
supports_gradient_checkpointing = False
|
||||
_supports_flash_attn = True
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights."""
|
||||
if isinstance(module, (nn.Linear,)):
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.zero_()
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.padding_idx is not None:
|
||||
module.weight[module.padding_idx].zero_()
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
if getattr(module, "bias", None) is not None:
|
||||
module.bias.zero_()
|
||||
module.weight.fill_(1.0)
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
def __init__(self, *inputs, **kwargs):
|
||||
super().__init__(*inputs, **kwargs)
|
||||
@ -912,9 +910,6 @@ class BarkFineModel(BarkPreTrainedModel):
|
||||
# non-causal gpt-like model with one embedding layer and one lm_head for each codebook of Encodec
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self._tied_weights_keys = {}
|
||||
for i in range(self.config.n_codes_total - self.config.n_codes_given):
|
||||
self._tied_weights_keys[f"lm_heads.{i}.weight"] = f"input_embeds_layers.{i + 1}.weight"
|
||||
|
||||
# initialize a modified non causal GPT-like model
|
||||
# note that for there is one embedding layer and one lm_head for each codebook of Encodec
|
||||
@ -1030,6 +1025,25 @@ class BarkFineModel(BarkPreTrainedModel):
|
||||
|
||||
return model_embeds
|
||||
|
||||
def _tie_weights(self):
|
||||
if getattr(self.config, "tie_word_embeddings", True):
|
||||
self._tied_weights_keys = []
|
||||
output_embeddings = self.get_output_embeddings()
|
||||
input_embeddings = self.get_input_embeddings()
|
||||
|
||||
for i in range(self.config.n_codes_total - self.config.n_codes_given):
|
||||
# self.input_embeds_layers[i + 1].weight = self.lm_heads[i].weight
|
||||
self._tie_embedding_weights(output_embeddings[i], input_embeddings[i + 1])
|
||||
self._tied_weights_keys.append(f"lm_heads.{i}.weight")
|
||||
|
||||
def tie_weights(self):
|
||||
"""
|
||||
Tie the weights between the input embeddings list and the output embeddings list.
|
||||
"""
|
||||
for module in self.modules():
|
||||
if hasattr(module, "_tie_weights"):
|
||||
module._tie_weights()
|
||||
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -1566,6 +1580,14 @@ class BarkModel(BarkPreTrainedModel, GenerationMixin):
|
||||
|
||||
return audio
|
||||
|
||||
def tie_weights(self):
|
||||
"""
|
||||
Tie the weights between the input embeddings list and the output embeddings list.
|
||||
"""
|
||||
for module in self.modules():
|
||||
if hasattr(module, "_tie_weights"):
|
||||
module._tie_weights()
|
||||
|
||||
|
||||
__all__ = [
|
||||
"BarkFineModel",
|
||||
|
||||
@ -164,7 +164,7 @@ class BartConfig(PreTrainedConfig):
|
||||
forced_eos_token_id=forced_eos_token_id,
|
||||
**kwargs,
|
||||
)
|
||||
self.tie_encoder_decoder = True
|
||||
|
||||
# ensure backward compatibility for BART CNN models
|
||||
if self.forced_bos_token_id is None and kwargs.get("force_bos_token_to_be_generated", False):
|
||||
self.forced_bos_token_id = self.bos_token_id
|
||||
|
||||
@ -476,20 +476,19 @@ class BartPreTrainedModel(PreTrainedModel):
|
||||
|
||||
_can_compile_fullgraph = True
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
std = self.config.init_std
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.normal_(mean=0.0, std=std)
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.zero_()
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.normal_(mean=0.0, std=std)
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight[module.padding_idx].zero_()
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.weight.fill_(1.0)
|
||||
module.bias.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.data.zero_()
|
||||
|
||||
@property
|
||||
def dummy_inputs(self):
|
||||
@ -528,7 +527,7 @@ class BartEncoder(BartPreTrainedModel):
|
||||
embed_tokens (nn.Embedding): output embedding
|
||||
"""
|
||||
|
||||
def __init__(self, config: BartConfig):
|
||||
def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = None):
|
||||
super().__init__(config)
|
||||
|
||||
self.dropout = config.dropout
|
||||
@ -539,9 +538,12 @@ class BartEncoder(BartPreTrainedModel):
|
||||
self.max_source_positions = config.max_position_embeddings
|
||||
embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
|
||||
|
||||
self.embed_tokens = BartScaledWordEmbedding(
|
||||
config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale
|
||||
)
|
||||
if embed_tokens is not None:
|
||||
self.embed_tokens = embed_tokens
|
||||
else:
|
||||
self.embed_tokens = BartScaledWordEmbedding(
|
||||
config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale
|
||||
)
|
||||
|
||||
self.embed_positions = BartLearnedPositionalEmbedding(
|
||||
config.max_position_embeddings,
|
||||
@ -672,7 +674,7 @@ class BartDecoder(BartPreTrainedModel):
|
||||
embed_tokens (nn.Embedding): output embedding
|
||||
"""
|
||||
|
||||
def __init__(self, config: BartConfig):
|
||||
def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = None):
|
||||
super().__init__(config)
|
||||
self.dropout = config.dropout
|
||||
self.layerdrop = config.decoder_layerdrop
|
||||
@ -680,9 +682,12 @@ class BartDecoder(BartPreTrainedModel):
|
||||
self.max_target_positions = config.max_position_embeddings
|
||||
embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
|
||||
|
||||
self.embed_tokens = BartScaledWordEmbedding(
|
||||
config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale
|
||||
)
|
||||
if embed_tokens is not None:
|
||||
self.embed_tokens = embed_tokens
|
||||
else:
|
||||
self.embed_tokens = BartScaledWordEmbedding(
|
||||
config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale
|
||||
)
|
||||
|
||||
self.embed_positions = BartLearnedPositionalEmbedding(
|
||||
config.max_position_embeddings,
|
||||
@ -894,10 +899,7 @@ class BartDecoder(BartPreTrainedModel):
|
||||
|
||||
@auto_docstring
|
||||
class BartModel(BartPreTrainedModel):
|
||||
_tied_weights_keys = {
|
||||
"decoder.embed_tokens.weight": "shared.weight",
|
||||
"encoder.embed_tokens.weight": "shared.weight",
|
||||
}
|
||||
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
|
||||
|
||||
def __init__(self, config: BartConfig):
|
||||
super().__init__(config)
|
||||
@ -906,12 +908,24 @@ class BartModel(BartPreTrainedModel):
|
||||
embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
|
||||
self.shared = BartScaledWordEmbedding(vocab_size, config.d_model, padding_idx, embed_scale=embed_scale)
|
||||
|
||||
self.encoder = BartEncoder(config)
|
||||
self.decoder = BartDecoder(config)
|
||||
self.encoder = BartEncoder(config, self.shared)
|
||||
self.decoder = BartDecoder(config, self.shared)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def _tie_weights(self):
|
||||
if self.config.tie_word_embeddings:
|
||||
# Some model checkpoints like "facebook/bart-large-cnn"'s embedding weight is in decoder.embed_tokens, need check here, see issue #36247
|
||||
if self.shared.weight.device == torch.device(
|
||||
"meta"
|
||||
) and self.decoder.embed_tokens.weight.device != torch.device("meta"):
|
||||
self._tie_embedding_weights(self.encoder.embed_tokens, self.decoder.embed_tokens)
|
||||
self._tie_embedding_weights(self.shared, self.decoder.embed_tokens)
|
||||
else:
|
||||
self._tie_embedding_weights(self.encoder.embed_tokens, self.shared)
|
||||
self._tie_embedding_weights(self.decoder.embed_tokens, self.shared)
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.shared
|
||||
|
||||
@ -1038,9 +1052,7 @@ class BartModel(BartPreTrainedModel):
|
||||
)
|
||||
class BartForConditionalGeneration(BartPreTrainedModel, GenerationMixin):
|
||||
base_model_prefix = "model"
|
||||
_tied_weights_keys = {
|
||||
"lm_head.weight": "model.shared.weight",
|
||||
}
|
||||
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
|
||||
_keys_to_ignore_on_load_missing = ["final_logits_bias"]
|
||||
|
||||
def __init__(self, config: BartConfig):
|
||||
@ -1074,6 +1086,11 @@ class BartForConditionalGeneration(BartPreTrainedModel, GenerationMixin):
|
||||
new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
|
||||
self.register_buffer("final_logits_bias", new_bias)
|
||||
|
||||
def _tie_weights(self):
|
||||
if self.config.tie_word_embeddings:
|
||||
self.model._tie_weights()
|
||||
self._tie_embedding_weights(self.lm_head, self.model.shared)
|
||||
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -1223,6 +1240,8 @@ class BartForConditionalGeneration(BartPreTrainedModel, GenerationMixin):
|
||||
"""
|
||||
)
|
||||
class BartForSequenceClassification(BartPreTrainedModel):
|
||||
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
|
||||
|
||||
def __init__(self, config: BartConfig, **kwargs):
|
||||
super().__init__(config, **kwargs)
|
||||
self.model = BartModel(config)
|
||||
@ -1355,6 +1374,8 @@ class BartForSequenceClassification(BartPreTrainedModel):
|
||||
|
||||
@auto_docstring
|
||||
class BartForQuestionAnswering(BartPreTrainedModel):
|
||||
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
@ -1492,9 +1513,7 @@ class BartDecoderWrapper(BartPreTrainedModel):
|
||||
"""
|
||||
)
|
||||
class BartForCausalLM(BartPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = {
|
||||
"lm_head.weight": "model.decoder.embed_tokens.weight",
|
||||
}
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
|
||||
def __init__(self, config):
|
||||
config.is_decoder = True
|
||||
|
||||
@ -16,6 +16,7 @@
|
||||
|
||||
import collections.abc
|
||||
import math
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Union
|
||||
|
||||
@ -162,7 +163,14 @@ class BeitEmbeddings(nn.Module):
|
||||
self,
|
||||
pixel_values: torch.Tensor,
|
||||
bool_masked_pos: Optional[torch.BoolTensor] = None,
|
||||
interpolate_pos_encoding: Optional[bool] = None,
|
||||
) -> torch.Tensor:
|
||||
if self.position_embeddings is not None and interpolate_pos_encoding is not None:
|
||||
warnings.warn(
|
||||
"`interpolate_pos_encoding` argument has no effect for BEiTEmbeddings, embeddings are always "
|
||||
"interpolated to the input image size. The argument will be removed in transformers v4.51.0."
|
||||
)
|
||||
|
||||
_, _, height, width = pixel_values.shape
|
||||
embeddings, (patch_height, patch_width) = self.patch_embeddings(pixel_values)
|
||||
batch_size, seq_len, _ = embeddings.size()
|
||||
@ -317,9 +325,19 @@ class BeitSdpaSelfAttention(BeitSelfAttention):
|
||||
) -> Union[tuple[torch.Tensor], tuple[torch.Tensor, torch.Tensor]]:
|
||||
if output_attentions:
|
||||
logger.warning_once(
|
||||
f"{self.__class__.__name__} does not support `output_attentions=True`. The returned attention weights will "
|
||||
"be `None`. If you want to get attention weights, please set `attn_implementation='eager'` when loading the model."
|
||||
"`BeitSdpaSelfAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not "
|
||||
"support `output_attentions=True`. Falling back to the manual attention implementation, "
|
||||
"but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. "
|
||||
'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
return super().forward(
|
||||
hidden_states=hidden_states,
|
||||
output_attentions=output_attentions,
|
||||
relative_position_bias=relative_position_bias,
|
||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||
resolution=resolution,
|
||||
)
|
||||
|
||||
batch_size, seq_length, _ = hidden_states.shape
|
||||
query_layer = (
|
||||
self.query(hidden_states)
|
||||
@ -674,32 +692,31 @@ class BeitPreTrainedModel(PreTrainedModel):
|
||||
_keys_to_ignore_on_load_unexpected = [r".*relative_position_index.*"]
|
||||
_supports_sdpa = True
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.zero_()
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.padding_idx is not None:
|
||||
module.weight[module.padding_idx].zero_()
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.zero_()
|
||||
module.weight.fill_(1.0)
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
elif isinstance(module, BeitEmbeddings):
|
||||
module.cls_token.zero_()
|
||||
module.cls_token.data.zero_()
|
||||
if module.mask_token is not None:
|
||||
module.mask_token.zero_()
|
||||
module.mask_token.data.zero_()
|
||||
if module.position_embeddings is not None:
|
||||
module.position_embeddings.zero_()
|
||||
module.position_embeddings.data.zero_()
|
||||
elif isinstance(module, BeitRelativePositionBias):
|
||||
module.relative_position_bias_table.zero_()
|
||||
module.relative_position_bias_table.data.zero_()
|
||||
elif isinstance(module, BeitLayer):
|
||||
if module.lambda_1 is not None:
|
||||
module.lambda_1.fill_(self.config.layer_scale_init_value)
|
||||
module.lambda_2.fill_(self.config.layer_scale_init_value)
|
||||
module.lambda_1.data.fill_(self.config.layer_scale_init_value)
|
||||
module.lambda_2.data.fill_(self.config.layer_scale_init_value)
|
||||
|
||||
|
||||
@auto_docstring
|
||||
|
||||
@ -506,9 +506,16 @@ class BertLMPredictionHead(nn.Module):
|
||||
|
||||
# The output weights are the same as the input embeddings, but there is
|
||||
# an output-only bias for each token.
|
||||
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=True)
|
||||
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
||||
|
||||
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
||||
self.decoder.bias = self.bias
|
||||
|
||||
def _tie_weights(self):
|
||||
self.decoder.bias = self.bias
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.transform(hidden_states)
|
||||
hidden_states = self.decoder(hidden_states)
|
||||
@ -562,22 +569,21 @@ class BertPreTrainedModel(PreTrainedModel):
|
||||
"cross_attentions": BertCrossAttention,
|
||||
}
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.zero_()
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.padding_idx is not None:
|
||||
module.weight[module.padding_idx].zero_()
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.zero_()
|
||||
module.weight.fill_(1.0)
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
elif isinstance(module, BertLMPredictionHead):
|
||||
module.bias.zero_()
|
||||
module.bias.data.zero_()
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -764,10 +770,7 @@ class BertModel(BertPreTrainedModel):
|
||||
"""
|
||||
)
|
||||
class BertForPreTraining(BertPreTrainedModel):
|
||||
_tied_weights_keys = {
|
||||
"cls.predictions.decoder.weight": "bert.embeddings.word_embeddings.weight",
|
||||
"cls.predictions.decoder.bias": "cls.predictions.bias",
|
||||
}
|
||||
_tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
@ -861,10 +864,7 @@ class BertForPreTraining(BertPreTrainedModel):
|
||||
"""
|
||||
)
|
||||
class BertLMHeadModel(BertPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = {
|
||||
"cls.predictions.decoder.weight": "bert.embeddings.word_embeddings.weight",
|
||||
"cls.predictions.decoder.bias": "cls.predictions.bias",
|
||||
}
|
||||
_tied_weights_keys = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
@ -948,10 +948,7 @@ class BertLMHeadModel(BertPreTrainedModel, GenerationMixin):
|
||||
|
||||
@auto_docstring
|
||||
class BertForMaskedLM(BertPreTrainedModel):
|
||||
_tied_weights_keys = {
|
||||
"cls.predictions.decoder.weight": "bert.embeddings.word_embeddings.weight",
|
||||
"cls.predictions.decoder.bias": "cls.predictions.bias",
|
||||
}
|
||||
_tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
@ -456,22 +456,21 @@ class BertGenerationPreTrainedModel(PreTrainedModel):
|
||||
"cross_attentions": BertGenerationCrossAttention,
|
||||
}
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.zero_()
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.padding_idx is not None:
|
||||
module.weight[module.padding_idx].zero_()
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.zero_()
|
||||
module.weight.fill_(1.0)
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
elif isinstance(module, BertGenerationOnlyLMHead):
|
||||
module.bias.zero_()
|
||||
module.bias.data.zero_()
|
||||
|
||||
|
||||
@auto_docstring(
|
||||
@ -630,11 +629,20 @@ class BertGenerationOnlyLMHead(nn.Module):
|
||||
super().__init__()
|
||||
self.decoder = nn.Linear(config.hidden_size, config.vocab_size)
|
||||
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
||||
self.decoder.bias = self.bias
|
||||
|
||||
def forward(self, hidden_states):
|
||||
logits = self.decoder(hidden_states)
|
||||
return logits
|
||||
|
||||
def _tie_weights(self):
|
||||
# For accelerate compatibility and to not break backward compatibility
|
||||
if self.decoder.bias.device.type == "meta":
|
||||
self.decoder.bias = self.bias
|
||||
else:
|
||||
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
|
||||
self.bias = self.decoder.bias
|
||||
|
||||
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
@ -642,10 +650,7 @@ class BertGenerationOnlyLMHead(nn.Module):
|
||||
"""
|
||||
)
|
||||
class BertGenerationDecoder(BertGenerationPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = {
|
||||
"lm_head.decoder.weight": "bert.embeddings.word_embeddings.weight",
|
||||
"lm_head.decoder.bias": "lm_head.bias",
|
||||
}
|
||||
_tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
@ -1464,9 +1464,16 @@ class BigBirdLMPredictionHead(nn.Module):
|
||||
|
||||
# The output weights are the same as the input embeddings, but there is
|
||||
# an output-only bias for each token.
|
||||
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=True)
|
||||
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
||||
|
||||
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
||||
self.decoder.bias = self.bias
|
||||
|
||||
def _tie_weights(self):
|
||||
self.decoder.bias = self.bias
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.transform(hidden_states)
|
||||
hidden_states = self.decoder(hidden_states)
|
||||
@ -1514,22 +1521,21 @@ class BigBirdPreTrainedModel(PreTrainedModel):
|
||||
base_model_prefix = "bert"
|
||||
supports_gradient_checkpointing = True
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.zero_()
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.padding_idx is not None:
|
||||
module.weight[module.padding_idx].zero_()
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.zero_()
|
||||
module.weight.fill_(1.0)
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
elif isinstance(module, BigBirdLMPredictionHead):
|
||||
module.bias.zero_()
|
||||
module.bias.data.zero_()
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -1893,10 +1899,7 @@ class BigBirdModel(BigBirdPreTrainedModel):
|
||||
|
||||
|
||||
class BigBirdForPreTraining(BigBirdPreTrainedModel):
|
||||
_tied_weights_keys = {
|
||||
"cls.predictions.decoder.bias": "cls.predictions.bias",
|
||||
"cls.predictions.decoder.weight": "bert.embeddings.word_embeddings.weight",
|
||||
}
|
||||
_tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
@ -1996,10 +1999,7 @@ class BigBirdForPreTraining(BigBirdPreTrainedModel):
|
||||
|
||||
@auto_docstring
|
||||
class BigBirdForMaskedLM(BigBirdPreTrainedModel):
|
||||
_tied_weights_keys = {
|
||||
"cls.predictions.decoder.bias": "cls.predictions.bias",
|
||||
"cls.predictions.decoder.weight": "bert.embeddings.word_embeddings.weight",
|
||||
}
|
||||
_tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
@ -2141,10 +2141,7 @@ class BigBirdForMaskedLM(BigBirdPreTrainedModel):
|
||||
"""
|
||||
)
|
||||
class BigBirdForCausalLM(BigBirdPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = {
|
||||
"cls.predictions.decoder.bias": "cls.predictions.bias",
|
||||
"cls.predictions.decoder.weight": "bert.embeddings.word_embeddings.weight",
|
||||
}
|
||||
_tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
@ -1539,20 +1539,19 @@ class BigBirdPegasusPreTrainedModel(PreTrainedModel):
|
||||
|
||||
_can_compile_fullgraph = True
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
std = self.config.init_std
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.normal_(mean=0.0, std=std)
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.zero_()
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.normal_(mean=0.0, std=std)
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight[module.padding_idx].zero_()
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.weight.fill_(1.0)
|
||||
module.bias.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.data.zero_()
|
||||
|
||||
@property
|
||||
def dummy_inputs(self):
|
||||
@ -1575,7 +1574,7 @@ class BigBirdPegasusEncoder(BigBirdPegasusPreTrainedModel):
|
||||
embed_tokens (nn.Embedding): output embedding
|
||||
"""
|
||||
|
||||
def __init__(self, config: BigBirdPegasusConfig):
|
||||
def __init__(self, config: BigBirdPegasusConfig, embed_tokens: Optional[nn.Embedding] = None):
|
||||
super().__init__(config)
|
||||
|
||||
self.attention_type = config.attention_type
|
||||
@ -1593,6 +1592,9 @@ class BigBirdPegasusEncoder(BigBirdPegasusPreTrainedModel):
|
||||
config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale
|
||||
)
|
||||
|
||||
if embed_tokens is not None:
|
||||
self.embed_tokens.weight = embed_tokens.weight
|
||||
|
||||
self.embed_positions = BigBirdPegasusLearnedPositionalEmbedding(
|
||||
config.max_position_embeddings,
|
||||
embed_dim,
|
||||
@ -1847,7 +1849,7 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel):
|
||||
embed_tokens (nn.Embedding): output embedding
|
||||
"""
|
||||
|
||||
def __init__(self, config: BigBirdPegasusConfig):
|
||||
def __init__(self, config: BigBirdPegasusConfig, embed_tokens: Optional[nn.Embedding] = None):
|
||||
super().__init__(config)
|
||||
self.dropout = config.dropout
|
||||
self.layerdrop = config.decoder_layerdrop
|
||||
@ -1859,6 +1861,9 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel):
|
||||
config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale
|
||||
)
|
||||
|
||||
if embed_tokens is not None:
|
||||
self.embed_tokens.weight = embed_tokens.weight
|
||||
|
||||
self.embed_positions = BigBirdPegasusLearnedPositionalEmbedding(
|
||||
config.max_position_embeddings,
|
||||
config.d_model,
|
||||
@ -2070,10 +2075,7 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel):
|
||||
|
||||
@auto_docstring
|
||||
class BigBirdPegasusModel(BigBirdPegasusPreTrainedModel):
|
||||
_tied_weights_keys = {
|
||||
"encoder.embed_tokens.weight": "shared.weight",
|
||||
"decoder.embed_tokens.weight": "shared.weight",
|
||||
}
|
||||
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
|
||||
|
||||
def __init__(self, config: BigBirdPegasusConfig):
|
||||
super().__init__(config)
|
||||
@ -2084,8 +2086,8 @@ class BigBirdPegasusModel(BigBirdPegasusPreTrainedModel):
|
||||
vocab_size, config.d_model, padding_idx, embed_scale=embed_scale
|
||||
)
|
||||
|
||||
self.encoder = BigBirdPegasusEncoder(config)
|
||||
self.decoder = BigBirdPegasusDecoder(config)
|
||||
self.encoder = BigBirdPegasusEncoder(config, self.shared)
|
||||
self.decoder = BigBirdPegasusDecoder(config, self.shared)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
@ -2098,6 +2100,11 @@ class BigBirdPegasusModel(BigBirdPegasusPreTrainedModel):
|
||||
self.encoder.embed_tokens = self.shared
|
||||
self.decoder.embed_tokens = self.shared
|
||||
|
||||
def _tie_weights(self):
|
||||
if self.config.tie_word_embeddings:
|
||||
self._tie_embedding_weights(self.encoder.embed_tokens, self.shared)
|
||||
self._tie_embedding_weights(self.decoder.embed_tokens, self.shared)
|
||||
|
||||
def get_encoder(self):
|
||||
return self.encoder
|
||||
|
||||
@ -2206,9 +2213,7 @@ class BigBirdPegasusModel(BigBirdPegasusPreTrainedModel):
|
||||
# Copied from transformers.models.bart.modeling_bart.BartForConditionalGeneration with Bart->BigBirdPegasus, BART->BIGBIRD_PEGASUS
|
||||
class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel, GenerationMixin):
|
||||
base_model_prefix = "model"
|
||||
_tied_weights_keys = {
|
||||
"lm_head.weight": "model.shared.weight",
|
||||
}
|
||||
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
|
||||
_keys_to_ignore_on_load_missing = ["final_logits_bias"]
|
||||
|
||||
def __init__(self, config: BigBirdPegasusConfig):
|
||||
@ -2242,6 +2247,11 @@ class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel, Gene
|
||||
new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
|
||||
self.register_buffer("final_logits_bias", new_bias)
|
||||
|
||||
def _tie_weights(self):
|
||||
if self.config.tie_word_embeddings:
|
||||
self.model._tie_weights()
|
||||
self._tie_embedding_weights(self.lm_head, self.model.shared)
|
||||
|
||||
@auto_docstring
|
||||
# Ignore copy
|
||||
def forward(
|
||||
@ -2364,6 +2374,8 @@ class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel, Gene
|
||||
"""
|
||||
)
|
||||
class BigBirdPegasusForSequenceClassification(BigBirdPegasusPreTrainedModel):
|
||||
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
|
||||
|
||||
def __init__(self, config: BigBirdPegasusConfig, **kwargs):
|
||||
super().__init__(config, **kwargs)
|
||||
self.model = BigBirdPegasusModel(config)
|
||||
@ -2485,6 +2497,8 @@ class BigBirdPegasusForSequenceClassification(BigBirdPegasusPreTrainedModel):
|
||||
|
||||
@auto_docstring
|
||||
class BigBirdPegasusForQuestionAnswering(BigBirdPegasusPreTrainedModel):
|
||||
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
@ -2607,6 +2621,8 @@ class BigBirdPegasusDecoderWrapper(BigBirdPegasusPreTrainedModel):
|
||||
|
||||
|
||||
class BigBirdPegasusForCausalLM(BigBirdPegasusPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
|
||||
def __init__(self, config):
|
||||
config.is_decoder = True
|
||||
config.is_encoder_decoder = False
|
||||
|
||||
@ -510,7 +510,7 @@ class BioGptModel(BioGptPreTrainedModel):
|
||||
"""
|
||||
)
|
||||
class BioGptForCausalLM(BioGptPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = {"output_projection.weight": "biogpt.embed_tokens.weight"}
|
||||
_tied_weights_keys = ["output_projection.weight"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
@ -332,7 +332,7 @@ class BioGptModel(BioGptPreTrainedModel):
|
||||
"""
|
||||
)
|
||||
class BioGptForCausalLM(BioGptPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = {"output_projection.weight": "biogpt.embed_tokens.weight"}
|
||||
_tied_weights_keys = ["output_projection.weight"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
@ -628,7 +628,6 @@ class BitPreTrainedModel(PreTrainedModel):
|
||||
main_input_name = "pixel_values"
|
||||
_no_split_modules = ["BitEmbeddings"]
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
if isinstance(module, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
|
||||
|
||||
@ -433,7 +433,7 @@ class BitNetModel(BitNetPreTrainedModel):
|
||||
|
||||
@auto_docstring
|
||||
class BitNetForCausalLM(BitNetPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tp_plan = None
|
||||
_pp_plan = None
|
||||
|
||||
|
||||
@ -114,7 +114,7 @@ class BitNetModel(LlamaModel):
|
||||
|
||||
|
||||
class BitNetForCausalLM(LlamaForCausalLM):
|
||||
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tp_plan = None
|
||||
_pp_plan = None
|
||||
|
||||
|
||||
@ -438,20 +438,19 @@ class BlenderbotPreTrainedModel(PreTrainedModel):
|
||||
|
||||
_can_compile_fullgraph = True
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
std = self.config.init_std
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.normal_(mean=0.0, std=std)
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.zero_()
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.normal_(mean=0.0, std=std)
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight[module.padding_idx].zero_()
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.weight.fill_(1.0)
|
||||
module.bias.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.data.zero_()
|
||||
|
||||
@property
|
||||
def dummy_inputs(self):
|
||||
@ -475,7 +474,7 @@ class BlenderbotEncoder(BlenderbotPreTrainedModel):
|
||||
embed_tokens (nn.Embedding): output embedding
|
||||
"""
|
||||
|
||||
def __init__(self, config: BlenderbotConfig):
|
||||
def __init__(self, config: BlenderbotConfig, embed_tokens: Optional[nn.Embedding] = None):
|
||||
super().__init__(config)
|
||||
|
||||
self.dropout = config.dropout
|
||||
@ -486,9 +485,12 @@ class BlenderbotEncoder(BlenderbotPreTrainedModel):
|
||||
self.max_source_positions = config.max_position_embeddings
|
||||
embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
|
||||
|
||||
self.embed_tokens = BlenderbotScaledWordEmbedding(
|
||||
config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale
|
||||
)
|
||||
if embed_tokens is not None:
|
||||
self.embed_tokens = embed_tokens
|
||||
else:
|
||||
self.embed_tokens = BlenderbotScaledWordEmbedding(
|
||||
config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale
|
||||
)
|
||||
|
||||
self.embed_positions = BlenderbotLearnedPositionalEmbedding(
|
||||
config.max_position_embeddings,
|
||||
@ -621,7 +623,7 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
|
||||
embed_tokens (nn.Embedding): output embedding
|
||||
"""
|
||||
|
||||
def __init__(self, config: BlenderbotConfig):
|
||||
def __init__(self, config: BlenderbotConfig, embed_tokens: Optional[nn.Embedding] = None):
|
||||
super().__init__(config)
|
||||
self.dropout = config.dropout
|
||||
self.layerdrop = config.decoder_layerdrop
|
||||
@ -629,9 +631,12 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
|
||||
self.max_target_positions = config.max_position_embeddings
|
||||
embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
|
||||
|
||||
self.embed_tokens = BlenderbotScaledWordEmbedding(
|
||||
config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale
|
||||
)
|
||||
if embed_tokens is not None:
|
||||
self.embed_tokens = embed_tokens
|
||||
else:
|
||||
self.embed_tokens = BlenderbotScaledWordEmbedding(
|
||||
config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale
|
||||
)
|
||||
|
||||
self.embed_positions = BlenderbotLearnedPositionalEmbedding(
|
||||
config.max_position_embeddings,
|
||||
@ -847,10 +852,7 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
|
||||
|
||||
@auto_docstring
|
||||
class BlenderbotModel(BlenderbotPreTrainedModel):
|
||||
_tied_weights_keys = {
|
||||
"encoder.embed_tokens.weight": "shared.weight",
|
||||
"decoder.embed_tokens.weight": "shared.weight",
|
||||
}
|
||||
_tied_weights_keys = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"]
|
||||
|
||||
def __init__(self, config: BlenderbotConfig):
|
||||
super().__init__(config)
|
||||
@ -858,8 +860,8 @@ class BlenderbotModel(BlenderbotPreTrainedModel):
|
||||
padding_idx, vocab_size = config.pad_token_id, config.vocab_size
|
||||
embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
|
||||
self.shared = BlenderbotScaledWordEmbedding(vocab_size, config.d_model, padding_idx, embed_scale=embed_scale)
|
||||
self.encoder = BlenderbotEncoder(config)
|
||||
self.decoder = BlenderbotDecoder(config)
|
||||
self.encoder = BlenderbotEncoder(config, self.shared)
|
||||
self.decoder = BlenderbotDecoder(config, self.shared)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
@ -999,9 +1001,7 @@ class BlenderbotModel(BlenderbotPreTrainedModel):
|
||||
class BlenderbotForConditionalGeneration(BlenderbotPreTrainedModel, GenerationMixin):
|
||||
base_model_prefix = "model"
|
||||
_keys_to_ignore_on_load_missing = ["final_logits_bias"]
|
||||
_tied_weights_keys = {
|
||||
"lm_head.weight": "model.shared.weight",
|
||||
}
|
||||
_tied_weights_keys = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight", "lm_head.weight"]
|
||||
|
||||
def __init__(self, config: BlenderbotConfig):
|
||||
super().__init__(config)
|
||||
@ -1184,9 +1184,7 @@ class BlenderbotDecoderWrapper(BlenderbotPreTrainedModel):
|
||||
|
||||
# Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->Blenderbot, facebook/bart-base->facebook/blenderbot-400M-distill
|
||||
class BlenderbotForCausalLM(BlenderbotPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = {
|
||||
"lm_head.weight": "model.decoder.embed_tokens.weight",
|
||||
}
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
|
||||
def __init__(self, config):
|
||||
config.is_decoder = True
|
||||
|
||||
@ -431,20 +431,19 @@ class BlenderbotSmallPreTrainedModel(PreTrainedModel):
|
||||
|
||||
_can_compile_fullgraph = True
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
std = self.config.init_std
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.normal_(mean=0.0, std=std)
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.zero_()
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.normal_(mean=0.0, std=std)
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight[module.padding_idx].zero_()
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.weight.fill_(1.0)
|
||||
module.bias.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.data.zero_()
|
||||
|
||||
@property
|
||||
def dummy_inputs(self):
|
||||
@ -468,7 +467,7 @@ class BlenderbotSmallEncoder(BlenderbotSmallPreTrainedModel):
|
||||
embed_tokens (nn.Embedding): output embedding
|
||||
"""
|
||||
|
||||
def __init__(self, config: BlenderbotSmallConfig):
|
||||
def __init__(self, config: BlenderbotSmallConfig, embed_tokens: Optional[nn.Embedding] = None):
|
||||
super().__init__(config)
|
||||
|
||||
self.dropout = config.dropout
|
||||
@ -479,7 +478,10 @@ class BlenderbotSmallEncoder(BlenderbotSmallPreTrainedModel):
|
||||
self.max_source_positions = config.max_position_embeddings
|
||||
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
|
||||
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)
|
||||
if embed_tokens is not None:
|
||||
self.embed_tokens = embed_tokens
|
||||
else:
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)
|
||||
|
||||
self.embed_positions = BlenderbotSmallLearnedPositionalEmbedding(
|
||||
config.max_position_embeddings,
|
||||
@ -610,7 +612,7 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
|
||||
embed_tokens (nn.Embedding): output embedding
|
||||
"""
|
||||
|
||||
def __init__(self, config: BlenderbotSmallConfig):
|
||||
def __init__(self, config: BlenderbotSmallConfig, embed_tokens: Optional[nn.Embedding] = None):
|
||||
super().__init__(config)
|
||||
self.dropout = config.dropout
|
||||
self.layerdrop = config.decoder_layerdrop
|
||||
@ -618,7 +620,10 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
|
||||
self.max_target_positions = config.max_position_embeddings
|
||||
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
|
||||
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
|
||||
if embed_tokens is not None:
|
||||
self.embed_tokens = embed_tokens
|
||||
else:
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
|
||||
|
||||
self.embed_positions = BlenderbotSmallLearnedPositionalEmbedding(
|
||||
config.max_position_embeddings,
|
||||
@ -833,10 +838,7 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
|
||||
|
||||
@auto_docstring
|
||||
class BlenderbotSmallModel(BlenderbotSmallPreTrainedModel):
|
||||
_tied_weights_keys = {
|
||||
"encoder.embed_tokens.weight": "shared.weight",
|
||||
"decoder.embed_tokens.weight": "shared.weight",
|
||||
}
|
||||
_tied_weights_keys = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"]
|
||||
|
||||
def __init__(self, config: BlenderbotSmallConfig):
|
||||
super().__init__(config)
|
||||
@ -844,8 +846,8 @@ class BlenderbotSmallModel(BlenderbotSmallPreTrainedModel):
|
||||
padding_idx, vocab_size = config.pad_token_id, config.vocab_size
|
||||
self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx)
|
||||
|
||||
self.encoder = BlenderbotSmallEncoder(config)
|
||||
self.decoder = BlenderbotSmallDecoder(config)
|
||||
self.encoder = BlenderbotSmallEncoder(config, self.shared)
|
||||
self.decoder = BlenderbotSmallDecoder(config, self.shared)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
@ -972,9 +974,7 @@ class BlenderbotSmallModel(BlenderbotSmallPreTrainedModel):
|
||||
class BlenderbotSmallForConditionalGeneration(BlenderbotSmallPreTrainedModel, GenerationMixin):
|
||||
base_model_prefix = "model"
|
||||
_keys_to_ignore_on_load_missing = ["final_logits_bias"]
|
||||
_tied_weights_keys = {
|
||||
"lm_head.weight": "model.shared.weight",
|
||||
}
|
||||
_tied_weights_keys = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight", "lm_head.weight"]
|
||||
|
||||
def __init__(self, config: BlenderbotSmallConfig):
|
||||
super().__init__(config)
|
||||
@ -1144,9 +1144,7 @@ class BlenderbotSmallDecoderWrapper(BlenderbotSmallPreTrainedModel):
|
||||
|
||||
# Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->BlenderbotSmall, facebook/bart-base->facebook/blenderbot_small-90M
|
||||
class BlenderbotSmallForCausalLM(BlenderbotSmallPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = {
|
||||
"lm_head.weight": "model.decoder.embed_tokens.weight",
|
||||
}
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
|
||||
def __init__(self, config):
|
||||
config.is_decoder = True
|
||||
|
||||
@ -419,14 +419,13 @@ class BlipPreTrainedModel(PreTrainedModel):
|
||||
_no_split_modules = ["BlipEncoderLayer", "BlipTextEmbeddings"]
|
||||
_skip_keys_device_placement = ["past_key_values"]
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
factor = self.config.initializer_range
|
||||
if isinstance(module, (nn.Conv2d, nn.Embedding, nn.Linear)):
|
||||
module.weight.normal_(mean=0.0, std=factor)
|
||||
module.weight.data.normal_(mean=0.0, std=factor)
|
||||
if hasattr(module, "bias") and module.bias is not None:
|
||||
module.bias.zero_()
|
||||
module.bias.data.zero_()
|
||||
|
||||
if isinstance(module, BlipVisionEmbeddings):
|
||||
if hasattr(self.config, "vision_config"):
|
||||
@ -444,10 +443,10 @@ class BlipPreTrainedModel(PreTrainedModel):
|
||||
)
|
||||
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.zero_()
|
||||
module.weight.fill_(1.0)
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
elif isinstance(module, nn.Linear) and module.bias is not None:
|
||||
module.bias.zero_()
|
||||
module.bias.data.zero_()
|
||||
|
||||
|
||||
class BlipEncoder(nn.Module):
|
||||
@ -798,11 +797,8 @@ class BlipModel(BlipPreTrainedModel):
|
||||
)
|
||||
class BlipForConditionalGeneration(BlipPreTrainedModel, GenerationMixin):
|
||||
config: BlipConfig
|
||||
_tied_weights_keys = ["text_decoder.cls.predictions.decoder.bias"]
|
||||
main_input_name = "pixel_values"
|
||||
_tied_weights_keys = {
|
||||
"text_decoder.cls.predictions.decoder.bias": "text_decoder.cls.predictions.bias",
|
||||
"text_decoder.cls.predictions.decoder.weight": "text_decoder.bert.embeddings.word_embeddings.weight",
|
||||
} # TODO @arthurzucker check why we need this when for other models, their subPreTrainedModel handle it themselves.
|
||||
|
||||
def __init__(self, config: BlipConfig):
|
||||
super().__init__(config)
|
||||
@ -967,10 +963,7 @@ class BlipForConditionalGeneration(BlipPreTrainedModel, GenerationMixin):
|
||||
)
|
||||
class BlipForQuestionAnswering(BlipPreTrainedModel, GenerationMixin):
|
||||
config: BlipConfig
|
||||
_tied_weights_keys = {
|
||||
"text_decoder.cls.predictions.decoder.bias": "text_decoder.cls.predictions.bias",
|
||||
"text_decoder.cls.predictions.decoder.weight": "text_decoder.bert.embeddings.word_embeddings.weight",
|
||||
}
|
||||
_tied_weights_keys = ["text_decoder.cls.predictions.decoder.bias"]
|
||||
|
||||
def __init__(self, config: BlipConfig):
|
||||
super().__init__(config)
|
||||
@ -978,6 +971,7 @@ class BlipForQuestionAnswering(BlipPreTrainedModel, GenerationMixin):
|
||||
self.vision_model = BlipVisionModel(config.vision_config)
|
||||
|
||||
self.text_encoder = BlipTextModel(config.text_config, add_pooling_layer=False)
|
||||
|
||||
self.text_decoder = BlipTextLMHeadModel(config.text_config)
|
||||
|
||||
self.decoder_pad_token_id = config.text_config.pad_token_id
|
||||
|
||||
@ -473,9 +473,16 @@ class BlipTextLMPredictionHead(nn.Module):
|
||||
|
||||
# The output weights are the same as the input embeddings, but there is
|
||||
# an output-only bias for each token.
|
||||
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=True)
|
||||
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
||||
|
||||
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
||||
self.decoder.bias = self.bias
|
||||
|
||||
def _tie_weights(self):
|
||||
self.decoder.bias = self.bias
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.transform(hidden_states)
|
||||
hidden_states = self.decoder(hidden_states)
|
||||
@ -504,16 +511,15 @@ class BlipTextPreTrainedModel(PreTrainedModel):
|
||||
base_model_prefix = "bert"
|
||||
_no_split_modules = []
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.zero_()
|
||||
module.weight.fill_(1.0)
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||
module.bias.zero_()
|
||||
module.bias.data.zero_()
|
||||
|
||||
|
||||
# Adapted from https://github.com/salesforce/BLIP/blob/3a29b7410476bf5f2ba0955827390eb6ea1f4f9d/models/med.py#L571
|
||||
@ -738,10 +744,7 @@ class BlipTextModel(BlipTextPreTrainedModel):
|
||||
|
||||
# Adapted from https://github.com/salesforce/BLIP/blob/main/models/med.py#L811
|
||||
class BlipTextLMHeadModel(BlipTextPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = {
|
||||
"cls.predictions.decoder.bias": "cls.predictions.bias",
|
||||
"cls.predictions.decoder.weight": "bert.embeddings.word_embeddings.weight",
|
||||
}
|
||||
_tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
@ -15,6 +15,7 @@
|
||||
"""PyTorch BLIP-2 model."""
|
||||
|
||||
import math
|
||||
import warnings
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, Union
|
||||
@ -408,20 +409,19 @@ class Blip2PreTrainedModel(PreTrainedModel):
|
||||
]
|
||||
_skip_keys_device_placement = "past_key_values"
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
factor = self.config.initializer_range
|
||||
|
||||
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
||||
module.weight.normal_(mean=0.0, std=factor)
|
||||
module.weight.data.normal_(mean=0.0, std=factor)
|
||||
if module.bias is not None:
|
||||
module.bias.zero_()
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.normal_(mean=0.0, std=factor)
|
||||
module.weight.data.normal_(mean=0.0, std=factor)
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.zero_()
|
||||
module.weight.fill_(1.0)
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
elif isinstance(module, Blip2VisionEmbeddings):
|
||||
nn.init.trunc_normal_(module.position_embedding, mean=0.0, std=factor)
|
||||
nn.init.trunc_normal_(module.class_embedding, mean=0.0, std=factor)
|
||||
@ -435,7 +435,7 @@ class Blip2PreTrainedModel(PreTrainedModel):
|
||||
Blip2ForImageTextRetrieval,
|
||||
),
|
||||
):
|
||||
module.query_tokens.zero_()
|
||||
module.query_tokens.data.zero_()
|
||||
|
||||
|
||||
# Copied from transformers.models.blip.modeling_blip.BlipEncoder with Blip->Blip2
|
||||
@ -1049,6 +1049,10 @@ class Blip2Model(Blip2PreTrainedModel):
|
||||
else:
|
||||
language_model = AutoModelForSeq2SeqLM.from_config(config.text_config)
|
||||
|
||||
# Update _tied_weights_keys using the base model used.
|
||||
if language_model._tied_weights_keys is not None:
|
||||
self._tied_weights_keys = [f"language_model.{k}" for k in language_model._tied_weights_keys]
|
||||
|
||||
self.language_model = language_model
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
@ -1072,6 +1076,11 @@ class Blip2Model(Blip2PreTrainedModel):
|
||||
def get_decoder(self):
|
||||
return self.language_model.get_decoder()
|
||||
|
||||
def _tie_weights(self):
|
||||
if not self.config.use_decoder_only_language_model:
|
||||
self.language_model.encoder.embed_tokens = self.language_model.shared
|
||||
self.language_model.decoder.embed_tokens = self.language_model.shared
|
||||
|
||||
@filter_out_non_signature_kwargs()
|
||||
@auto_docstring
|
||||
def get_text_features(
|
||||
@ -1081,6 +1090,7 @@ class Blip2Model(Blip2PreTrainedModel):
|
||||
decoder_input_ids: Optional[torch.Tensor] = None,
|
||||
decoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
labels: Optional[torch.Tensor] = None,
|
||||
legacy_output: bool = True,
|
||||
) -> Union[torch.FloatTensor, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
|
||||
@ -1099,10 +1109,12 @@ class Blip2Model(Blip2PreTrainedModel):
|
||||
decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
|
||||
Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
|
||||
be used by default.
|
||||
legacy_output (`bool`, *optional*, defaults to `True`):
|
||||
Whether to return a model output object or a tensor of features.
|
||||
|
||||
Returns:
|
||||
text_outputs (``torch.FloatTensor`):
|
||||
The language model's last hidden states.
|
||||
text_outputs (`CausalLMOutputWithPast` or `torch.FloatTensor`):
|
||||
The language model outputs. If `legacy_output=False`, the output is a `torch.FloatTensor`.
|
||||
|
||||
Examples:
|
||||
```python
|
||||
@ -1117,6 +1129,13 @@ class Blip2Model(Blip2PreTrainedModel):
|
||||
... text_features = model.get_text_features(**inputs)
|
||||
```"""
|
||||
|
||||
if legacy_output:
|
||||
warnings.warn(
|
||||
"Deprecation notice: In Transformers v4.59, the default return value of `get_text_features` will change. "
|
||||
"Currently, this method returns a model output object, but starting in v4.59, it will return a tensor instead. "
|
||||
"To opt in to the new behavior now, set `legacy_output=False`."
|
||||
)
|
||||
|
||||
if self.config.use_decoder_only_language_model:
|
||||
text_outputs: CausalLMOutputWithPast = self.language_model(
|
||||
input_ids=input_ids,
|
||||
@ -1134,7 +1153,7 @@ class Blip2Model(Blip2PreTrainedModel):
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
return text_outputs.logits
|
||||
return text_outputs if legacy_output else text_outputs.logits
|
||||
|
||||
@filter_out_non_signature_kwargs()
|
||||
@auto_docstring
|
||||
@ -1142,11 +1161,15 @@ class Blip2Model(Blip2PreTrainedModel):
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
legacy_output: bool = True,
|
||||
) -> Union[torch.FloatTensor, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
legacy_output (`bool`, *optional*, defaults to `True`):
|
||||
Whether to return a model output object or a tensor of features.
|
||||
|
||||
Returns:
|
||||
vision_outputs (`torch.FloatTensor`):
|
||||
The vision model's last layer pooled logits.
|
||||
vision_outputs (`BaseModelOutputWithPooling` or `torch.FloatTensor`):
|
||||
The vision model outputs. If `legacy_output=False`, the output is a `torch.FloatTensor`.
|
||||
|
||||
Examples:
|
||||
```python
|
||||
@ -1164,13 +1187,20 @@ class Blip2Model(Blip2PreTrainedModel):
|
||||
>>> with torch.inference_mode():
|
||||
... image_outputs = model.get_image_features(**inputs)
|
||||
```"""
|
||||
if legacy_output:
|
||||
warnings.warn(
|
||||
"Deprecation notice: In Transformers v4.59, the default return value of `get_text_features` will change. "
|
||||
"Currently, this method returns a model output object, but starting in v4.59, it will return a tensor instead. "
|
||||
"To opt in to the new behavior now, set `legacy_output=False`."
|
||||
)
|
||||
|
||||
vision_outputs = self.vision_model(
|
||||
pixel_values=pixel_values,
|
||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
return vision_outputs.pooler_output
|
||||
return vision_outputs if legacy_output else vision_outputs.pooler_output
|
||||
|
||||
@filter_out_non_signature_kwargs()
|
||||
@auto_docstring
|
||||
@ -1178,11 +1208,15 @@ class Blip2Model(Blip2PreTrainedModel):
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
legacy_output: bool = True,
|
||||
) -> Union[torch.FloatTensor, BaseModelOutputWithPooling]:
|
||||
r"""
|
||||
legacy_output (`bool`, *optional*, defaults to `True`):
|
||||
Whether to return a model output object or a tensor of features.
|
||||
|
||||
Returns:
|
||||
qformer_outputs (`torch.FloatTensor`):
|
||||
The Q-Former model's last layer hidden states.
|
||||
qformer_outputs (`BaseModelOutputWithPooling` or `torch.FloatTensor`):
|
||||
The Q-Former outputs. If `legacy_output=False`, the output is a `torch.FloatTensor`.
|
||||
|
||||
Examples:
|
||||
|
||||
@ -1201,6 +1235,14 @@ class Blip2Model(Blip2PreTrainedModel):
|
||||
>>> with torch.inference_mode():
|
||||
... qformer_outputs = model.get_qformer_features(**inputs)
|
||||
```"""
|
||||
|
||||
if legacy_output:
|
||||
warnings.warn(
|
||||
"Deprecation notice: In Transformers v4.59, the default return value of `get_qformer_features` will change. "
|
||||
"Currently, this method returns a model output object, but starting in v4.59, it will return a tensor instead. "
|
||||
"To opt in to the new behavior now, set `legacy_output=False`."
|
||||
)
|
||||
|
||||
vision_outputs: BaseModelOutputWithPooling = self.vision_model(
|
||||
pixel_values=pixel_values,
|
||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||
@ -1220,7 +1262,7 @@ class Blip2Model(Blip2PreTrainedModel):
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
return query_outputs.last_hidden_state
|
||||
return query_outputs if legacy_output else query_outputs.last_hidden_state
|
||||
|
||||
def get_placeholder_mask(self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor):
|
||||
"""
|
||||
@ -1570,6 +1612,10 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
|
||||
else:
|
||||
language_model = AutoModelForSeq2SeqLM.from_config(config.text_config)
|
||||
|
||||
# Update _tied_weights_keys using the base model used.
|
||||
if language_model._tied_weights_keys is not None:
|
||||
self._tied_weights_keys = [f"language_model.{k}" for k in language_model._tied_weights_keys]
|
||||
|
||||
self.language_model = language_model
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
@ -1593,6 +1639,11 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
|
||||
def get_decoder(self):
|
||||
return self.language_model.get_decoder()
|
||||
|
||||
def _tie_weights(self):
|
||||
if not self.config.use_decoder_only_language_model:
|
||||
self.language_model.encoder.embed_tokens = self.language_model.shared
|
||||
self.language_model.decoder.embed_tokens = self.language_model.shared
|
||||
|
||||
def _preprocess_accelerate(self):
|
||||
r"""
|
||||
Some pre-processing hacks to make the model `accelerate` compatible. Check
|
||||
|
||||
@ -15,6 +15,7 @@
|
||||
"""PyTorch BLOOM model."""
|
||||
|
||||
import math
|
||||
import warnings
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
@ -424,20 +425,19 @@ class BloomPreTrainedModel(PreTrainedModel):
|
||||
def __init__(self, *inputs, **kwargs):
|
||||
super().__init__(*inputs, **kwargs)
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module: nn.Module):
|
||||
"""Initialize the weights."""
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.zero_()
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.padding_idx is not None:
|
||||
module.weight[module.padding_idx].zero_()
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, LayerNorm):
|
||||
module.bias.zero_()
|
||||
module.weight.fill_(1.0)
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
|
||||
@auto_docstring
|
||||
@ -484,6 +484,7 @@ class BloomModel(BloomPreTrainedModel):
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**deprecated_arguments,
|
||||
) -> Union[tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
|
||||
r"""
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
|
||||
@ -498,6 +499,16 @@ class BloomModel(BloomPreTrainedModel):
|
||||
|
||||
[What are input IDs?](../glossary#input-ids)
|
||||
"""
|
||||
if deprecated_arguments.pop("position_ids", False) is not False:
|
||||
# `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
|
||||
warnings.warn(
|
||||
"`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
|
||||
" passing `position_ids`.",
|
||||
FutureWarning,
|
||||
)
|
||||
if len(deprecated_arguments) > 0:
|
||||
raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
@ -711,7 +722,7 @@ class BloomModel(BloomPreTrainedModel):
|
||||
"""
|
||||
)
|
||||
class BloomForCausalLM(BloomPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = {"lm_head.weight": "transformer.word_embeddings.weight"}
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
|
||||
def __init__(self, config: BloomConfig):
|
||||
super().__init__(config)
|
||||
@ -806,7 +817,7 @@ class BloomForCausalLM(BloomPreTrainedModel, GenerationMixin):
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**kwargs,
|
||||
**deprecated_arguments,
|
||||
) -> Union[tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
|
||||
r"""
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
|
||||
@ -825,6 +836,18 @@ class BloomForCausalLM(BloomPreTrainedModel, GenerationMixin):
|
||||
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
|
||||
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
|
||||
"""
|
||||
# Bloom has deprecated kwargs, so we need to pop num_items_in_batch explicitly
|
||||
num_items_in_batch = deprecated_arguments.pop("num_items_in_batch", None)
|
||||
if deprecated_arguments.pop("position_ids", False) is not False:
|
||||
# `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
|
||||
warnings.warn(
|
||||
"`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
|
||||
" passing `position_ids`.",
|
||||
FutureWarning,
|
||||
)
|
||||
if len(deprecated_arguments) > 0:
|
||||
raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
transformer_outputs = self.transformer(
|
||||
@ -850,7 +873,7 @@ class BloomForCausalLM(BloomPreTrainedModel, GenerationMixin):
|
||||
logits,
|
||||
labels,
|
||||
vocab_size=self.config.vocab_size,
|
||||
num_items_in_batch=kwargs.get("num_items_in_batch"),
|
||||
num_items_in_batch=num_items_in_batch,
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
@ -902,6 +925,7 @@ class BloomForSequenceClassification(BloomPreTrainedModel):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
**deprecated_arguments,
|
||||
) -> Union[tuple[torch.Tensor], SequenceClassifierOutputWithPast]:
|
||||
r"""
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
|
||||
@ -920,6 +944,16 @@ class BloomForSequenceClassification(BloomPreTrainedModel):
|
||||
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
||||
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
"""
|
||||
if deprecated_arguments.pop("position_ids", False) is not False:
|
||||
# `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
|
||||
warnings.warn(
|
||||
"`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
|
||||
" passing `position_ids`.",
|
||||
FutureWarning,
|
||||
)
|
||||
if len(deprecated_arguments) > 0:
|
||||
raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
transformer_outputs = self.transformer(
|
||||
@ -1025,6 +1059,7 @@ class BloomForTokenClassification(BloomPreTrainedModel):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
**deprecated_arguments,
|
||||
) -> Union[tuple[torch.Tensor], TokenClassifierOutput]:
|
||||
r"""
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
|
||||
@ -1043,6 +1078,16 @@ class BloomForTokenClassification(BloomPreTrainedModel):
|
||||
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
||||
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
"""
|
||||
if deprecated_arguments.pop("position_ids", False) is not False:
|
||||
# `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
|
||||
warnings.warn(
|
||||
"`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
|
||||
" passing `position_ids`.",
|
||||
FutureWarning,
|
||||
)
|
||||
if len(deprecated_arguments) > 0:
|
||||
raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
transformer_outputs = self.transformer(
|
||||
@ -1097,6 +1142,7 @@ class BloomForQuestionAnswering(BloomPreTrainedModel):
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
start_positions: Optional[torch.LongTensor] = None,
|
||||
end_positions: Optional[torch.LongTensor] = None,
|
||||
@ -1122,6 +1168,7 @@ class BloomForQuestionAnswering(BloomPreTrainedModel):
|
||||
outputs = self.transformer(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
|
||||
@ -447,6 +447,7 @@ class BltCrossAttention(nn.Module):
|
||||
@auto_docstring
|
||||
class BltPreTrainedModel(PreTrainedModel):
|
||||
config: BltConfig
|
||||
base_model_prefix = ""
|
||||
input_modalities = ["image", "text"]
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["BltTransformerLayer"]
|
||||
@ -1230,7 +1231,7 @@ class BltForCausalLM(BltPreTrainedModel, GenerationMixin):
|
||||
config: BltConfig
|
||||
_can_compile_fullgraph = False
|
||||
base_model_prefix = "model"
|
||||
_tied_weights_keys = {"model.local_encoder.embed_tokens.weight": "lm_head.weight"}
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
|
||||
def __init__(self, config: BltConfig):
|
||||
super().__init__(config.get_text_config())
|
||||
|
||||
@ -964,7 +964,7 @@ class BltForCausalLM(MllamaForCausalLM):
|
||||
config: BltConfig
|
||||
_can_compile_fullgraph = False
|
||||
base_model_prefix = "model"
|
||||
_tied_weights_keys = {"model.local_encoder.embed_tokens.weight": "lm_head.weight"}
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
|
||||
def __init__(self, config: BltConfig):
|
||||
super().__init__(config)
|
||||
|
||||
@ -175,6 +175,7 @@ class BridgeTowerTextConfig(PreTrainedConfig):
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
@ -297,7 +298,7 @@ class BridgeTowerConfig(PreTrainedConfig):
|
||||
|
||||
self.text_config = text_config
|
||||
self.vision_config = vision_config
|
||||
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
__all__ = ["BridgeTowerConfig", "BridgeTowerTextConfig", "BridgeTowerVisionConfig"]
|
||||
|
||||
@ -192,6 +192,9 @@ class BridgeTowerImageProcessor(BaseImageProcessor):
|
||||
do_pad: bool = True,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
if "pad_and_return_pixel_mask" in kwargs:
|
||||
do_pad = kwargs.pop("pad_and_return_pixel_mask")
|
||||
|
||||
super().__init__(**kwargs)
|
||||
size = size if size is not None else {"shortest_edge": 288}
|
||||
size = get_size_dict(size, default_to_square=False)
|
||||
@ -205,7 +208,7 @@ class BridgeTowerImageProcessor(BaseImageProcessor):
|
||||
self.do_normalize = do_normalize
|
||||
self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
|
||||
self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
|
||||
self.do_pad = kwargs.pop("pad_and_return_pixel_mask", do_pad)
|
||||
self.do_pad = do_pad
|
||||
self.do_center_crop = do_center_crop
|
||||
self.crop_size = crop_size
|
||||
|
||||
|
||||
@ -919,7 +919,6 @@ class BridgeTowerPreTrainedModel(PreTrainedModel):
|
||||
_no_split_modules = ["BridgeTowerSelfAttention", "BridgeTowerResidualAttention"]
|
||||
_skip_keys_device_placement = "past_key_values"
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module: nn.Module):
|
||||
std = self.config.initializer_factor
|
||||
if isinstance(module, BridgeTowerVisionTransformer):
|
||||
@ -928,7 +927,7 @@ class BridgeTowerPreTrainedModel(PreTrainedModel):
|
||||
fc_std = (2 * self.config.hidden_size) ** -0.5
|
||||
for block in module.transformer.resblocks:
|
||||
nn.init.normal_(block.attn.in_proj_weight, std=attn_std * std)
|
||||
block.attn.in_proj_bias.zero_()
|
||||
block.attn.in_proj_bias.data.zero_()
|
||||
nn.init.normal_(block.attn.out_proj.weight, std=proj_std * std)
|
||||
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std * std)
|
||||
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std * std)
|
||||
@ -936,15 +935,15 @@ class BridgeTowerPreTrainedModel(PreTrainedModel):
|
||||
nn.init.normal_(module.embeddings.class_embedding, std=attn_std * std)
|
||||
nn.init.normal_(module.embeddings.position_embedding.weight, std=attn_std * std)
|
||||
elif isinstance(module, (nn.Linear, nn.Conv2d, nn.Embedding)):
|
||||
module.weight.normal_(mean=0.0, std=0.05 * std)
|
||||
module.weight.data.normal_(mean=0.0, std=0.05 * std)
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.zero_()
|
||||
module.weight.fill_(1.0)
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
elif isinstance(module, BridgeTowerForContrastiveLearning):
|
||||
module.logit_scale.fill_(self.config.logit_scale_init_value)
|
||||
module.logit_scale.data.fill_(self.config.logit_scale_init_value)
|
||||
|
||||
if isinstance(module, (nn.Linear, BridgeTowerMLMHead)) and module.bias is not None:
|
||||
module.bias.zero_()
|
||||
module.bias.data.zero_()
|
||||
|
||||
|
||||
class BridgeTowerVisionModel(BridgeTowerPreTrainedModel):
|
||||
@ -1498,7 +1497,7 @@ class BridgeTowerITMHead(nn.Module):
|
||||
"""
|
||||
)
|
||||
class BridgeTowerForMaskedLM(BridgeTowerPreTrainedModel):
|
||||
_tied_weights_keys = {"mlm_score.decoder.weight": "bridgetower.text_model.embeddings.word_embeddings.weight"}
|
||||
_tied_weights_keys = ["mlm_score.decoder.weight"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
@ -514,21 +514,20 @@ class BrosPreTrainedModel(PreTrainedModel):
|
||||
config: BrosConfig
|
||||
base_model_prefix = "bros"
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module: nn.Module):
|
||||
"""Initialize the weights"""
|
||||
std = self.config.initializer_range
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.normal_(mean=0.0, std=std)
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.zero_()
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.normal_(mean=0.0, std=std)
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight[module.padding_idx].zero_()
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.zero_()
|
||||
module.weight.fill_(1.0)
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
elif isinstance(module, BrosRelationExtractor):
|
||||
nn.init.normal_(module.dummy_node, std=std)
|
||||
|
||||
|
||||
@ -383,6 +383,7 @@ class CamembertLMHead(nn.Module):
|
||||
|
||||
self.decoder = nn.Linear(config.hidden_size, config.vocab_size)
|
||||
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
||||
self.decoder.bias = self.bias
|
||||
|
||||
def forward(self, features, **kwargs):
|
||||
x = self.dense(features)
|
||||
@ -394,6 +395,14 @@ class CamembertLMHead(nn.Module):
|
||||
|
||||
return x
|
||||
|
||||
def _tie_weights(self):
|
||||
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
|
||||
# For accelerate compatibility and to not break backward compatibility
|
||||
if self.decoder.bias.device.type == "meta":
|
||||
self.decoder.bias = self.bias
|
||||
else:
|
||||
self.bias = self.decoder.bias
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class CamembertPreTrainedModel(PreTrainedModel):
|
||||
@ -410,22 +419,21 @@ class CamembertPreTrainedModel(PreTrainedModel):
|
||||
"cross_attentions": CamembertCrossAttention,
|
||||
}
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.zero_()
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.padding_idx is not None:
|
||||
module.weight[module.padding_idx].zero_()
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.zero_()
|
||||
module.weight.fill_(1.0)
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
elif isinstance(module, CamembertLMHead):
|
||||
module.bias.zero_()
|
||||
module.bias.data.zero_()
|
||||
|
||||
|
||||
class CamembertEmbeddings(nn.Module):
|
||||
@ -737,10 +745,7 @@ class CamembertModel(CamembertPreTrainedModel):
|
||||
|
||||
@auto_docstring
|
||||
class CamembertForMaskedLM(CamembertPreTrainedModel):
|
||||
_tied_weights_keys = {
|
||||
"lm_head.decoder.weight": "roberta.embeddings.word_embeddings.weight",
|
||||
"lm_head.decoder.bias": "lm_head.bias",
|
||||
}
|
||||
_tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
@ -1186,10 +1191,7 @@ class CamembertForQuestionAnswering(CamembertPreTrainedModel):
|
||||
"""
|
||||
)
|
||||
class CamembertForCausalLM(CamembertPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = {
|
||||
"lm_head.decoder.weight": "camembert.embeddings.word_embeddings.weight",
|
||||
"lm_head.decoder.bias": "lm_head.bias",
|
||||
}
|
||||
_tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
@ -53,11 +53,6 @@ class CamembertModel(RobertaModel):
|
||||
|
||||
|
||||
class CamembertForMaskedLM(RobertaForMaskedLM):
|
||||
_tied_weights_keys = {
|
||||
"lm_head.decoder.weight": "roberta.embeddings.word_embeddings.weight",
|
||||
"lm_head.decoder.bias": "lm_head.bias",
|
||||
}
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
del self.camembert
|
||||
|
||||
@ -688,11 +688,12 @@ class CanineLMPredictionHead(nn.Module):
|
||||
|
||||
# The output weights are the same as the input embeddings, but there is
|
||||
# an output-only bias for each token.
|
||||
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=True)
|
||||
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
||||
|
||||
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
||||
self.decoder.bias = self.bias
|
||||
|
||||
def forward(self, hidden_states: tuple[torch.FloatTensor]) -> torch.FloatTensor:
|
||||
hidden_states = self.transform(hidden_states)
|
||||
@ -719,20 +720,19 @@ class CaninePreTrainedModel(PreTrainedModel):
|
||||
base_model_prefix = "canine"
|
||||
supports_gradient_checkpointing = True
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
if isinstance(module, (nn.Linear, nn.Conv1d)):
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.zero_()
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.padding_idx is not None:
|
||||
module.weight[module.padding_idx].zero_()
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.zero_()
|
||||
module.weight.fill_(1.0)
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
|
||||
@auto_docstring
|
||||
|
||||
@ -1009,7 +1009,7 @@ class ChameleonModel(ChameleonPreTrainedModel):
|
||||
"""
|
||||
)
|
||||
class ChameleonForConditionalGeneration(ChameleonPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
@ -562,7 +562,6 @@ class ChineseCLIPPreTrainedModel(PreTrainedModel):
|
||||
input_modalities = ["image", "text"]
|
||||
supports_gradient_checkpointing = True
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
factor = self.config.initializer_factor
|
||||
@ -577,7 +576,7 @@ class ChineseCLIPPreTrainedModel(PreTrainedModel):
|
||||
nn.init.normal_(module.token_type_embeddings.weight, mean=0.0, std=self.config.initializer_range)
|
||||
for embedding in [module.word_embeddings, module.position_embeddings, module.token_type_embeddings]:
|
||||
if embedding.padding_idx is not None:
|
||||
embedding.weight[embedding.padding_idx].zero_()
|
||||
embedding.weight.data[embedding.padding_idx].zero_()
|
||||
elif isinstance(module, ChineseCLIPVisionAttention):
|
||||
factor = self.config.initializer_factor
|
||||
in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
|
||||
@ -603,12 +602,12 @@ class ChineseCLIPPreTrainedModel(PreTrainedModel):
|
||||
)
|
||||
|
||||
if isinstance(module, nn.LayerNorm):
|
||||
module.bias.zero_()
|
||||
module.weight.fill_(1.0)
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.zero_()
|
||||
module.bias.data.zero_()
|
||||
|
||||
|
||||
# Copied from transformers.models.align.modeling_align.AlignTextEncoder with Align->ChineseCLIP
|
||||
|
||||
@ -1308,29 +1308,28 @@ class ClapPreTrainedModel(PreTrainedModel):
|
||||
input_modalities = ["audio", "text"]
|
||||
supports_gradient_checkpointing = False
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module: nn.Module):
|
||||
"""Initialize the weights"""
|
||||
factor = self.config.initializer_factor
|
||||
|
||||
if isinstance(module, ClapTextEmbeddings):
|
||||
module.position_embeddings.weight.normal_(mean=0.0, std=factor * 0.02)
|
||||
module.token_type_embeddings.weight.normal_(mean=0.0, std=factor * 0.02)
|
||||
module.position_embeddings.weight.data.normal_(mean=0.0, std=factor * 0.02)
|
||||
module.token_type_embeddings.weight.data.normal_(mean=0.0, std=factor * 0.02)
|
||||
elif isinstance(module, ClapModel):
|
||||
module.logit_scale_a.fill_(math.log(self.config.logit_scale_init_value))
|
||||
module.logit_scale_t.fill_(math.log(self.config.logit_scale_init_value))
|
||||
module.logit_scale_a.data.fill_(math.log(self.config.logit_scale_init_value))
|
||||
module.logit_scale_t.data.fill_(math.log(self.config.logit_scale_init_value))
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.normal_(mean=0.0, std=factor * 0.02)
|
||||
module.weight.data.normal_(mean=0.0, std=factor * 0.02)
|
||||
elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d)):
|
||||
module.bias.zero_()
|
||||
module.weight.fill_(1.0)
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
elif isinstance(module, (nn.Conv2d, nn.Linear)):
|
||||
in_proj_std = (self.config.hidden_size**-0.5) * ((2 * self.config.num_hidden_layers) ** -0.5) * factor
|
||||
nn.init.normal_(module.weight, std=in_proj_std)
|
||||
if module.bias is not None:
|
||||
module.bias.zero_()
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, ClapAudioSelfAttention):
|
||||
module.relative_position_bias_table.zero_()
|
||||
module.relative_position_bias_table.data.zero_()
|
||||
|
||||
|
||||
class ClapAudioModel(ClapPreTrainedModel):
|
||||
@ -1373,7 +1372,7 @@ class ClapAudioModel(ClapPreTrainedModel):
|
||||
>>> model = ClapAudioModel.from_pretrained("laion/clap-htsat-fused")
|
||||
>>> processor = AutoProcessor.from_pretrained("laion/clap-htsat-fused")
|
||||
|
||||
>>> inputs = processor(audio=audio_sample, return_tensors="pt")
|
||||
>>> inputs = processor(audios=audio_sample, return_tensors="pt")
|
||||
|
||||
>>> outputs = model(**inputs)
|
||||
>>> last_hidden_state = outputs.last_hidden_state
|
||||
@ -1648,7 +1647,7 @@ class ClapModel(ClapPreTrainedModel):
|
||||
|
||||
>>> input_text = ["Sound of a dog", "Sound of vacuum cleaner"]
|
||||
|
||||
>>> inputs = processor(text=input_text, audio=audio_sample, return_tensors="pt", padding=True)
|
||||
>>> inputs = processor(text=input_text, audios=audio_sample, return_tensors="pt", padding=True)
|
||||
|
||||
>>> outputs = model(**inputs)
|
||||
>>> logits_per_audio = outputs.logits_per_audio # this is the audio-text similarity score
|
||||
@ -1820,7 +1819,7 @@ class ClapAudioModelWithProjection(ClapPreTrainedModel):
|
||||
>>> dataset = load_dataset("hf-internal-testing/ashraq-esc50-1-dog-example")
|
||||
>>> audio_sample = dataset["train"]["audio"][0]["array"]
|
||||
|
||||
>>> inputs = processor(audio=audio_sample, return_tensors="pt")
|
||||
>>> inputs = processor(audios=audio_sample, return_tensors="pt")
|
||||
>>> outputs = model(**inputs)
|
||||
>>> audio_embeds = outputs.audio_embeds
|
||||
```"""
|
||||
|
||||
@ -16,8 +16,13 @@
|
||||
Audio/Text processor class for CLAP
|
||||
"""
|
||||
|
||||
from ...processing_utils import ProcessorMixin
|
||||
from typing import Optional, Union
|
||||
|
||||
from ...audio_utils import AudioInput
|
||||
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
|
||||
from ...tokenization_utils_base import PreTokenizedInput, TextInput
|
||||
from ...utils import logging
|
||||
from ...utils.deprecation import deprecate_kwarg
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@ -40,5 +45,28 @@ class ClapProcessor(ProcessorMixin):
|
||||
def __init__(self, feature_extractor, tokenizer):
|
||||
super().__init__(feature_extractor, tokenizer)
|
||||
|
||||
@deprecate_kwarg("audios", version="v4.59.0", new_name="audio")
|
||||
def __call__(
|
||||
self,
|
||||
text: Optional[Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]]] = None,
|
||||
audios: Optional[AudioInput] = None,
|
||||
audio: Optional[AudioInput] = None,
|
||||
**kwargs: Unpack[ProcessingKwargs],
|
||||
):
|
||||
"""
|
||||
Forwards the `audio` and `sampling_rate` arguments to [`~ClapFeatureExtractor.__call__`] and the `text`
|
||||
argument to [`~RobertaTokenizerFast.__call__`]. Please refer to the docstring of the above two methods for more
|
||||
information.
|
||||
"""
|
||||
# The `deprecate_kwarg` will not work if the inputs are passed as arguments, so we check
|
||||
# again that the correct naming is used
|
||||
if audios is not None and audio is None:
|
||||
logger.warning(
|
||||
"Using `audios` keyword argument is deprecated when calling ClapProcessor, instead use `audio`."
|
||||
)
|
||||
audio = audios
|
||||
|
||||
return super().__call__(text=text, audio=audio, **kwargs)
|
||||
|
||||
|
||||
__all__ = ["ClapProcessor"]
|
||||
|
||||
@ -408,13 +408,12 @@ class CLIPPreTrainedModel(PreTrainedModel):
|
||||
"attentions": CLIPAttention,
|
||||
}
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
factor = self.config.initializer_factor
|
||||
if isinstance(module, CLIPTextEmbeddings):
|
||||
module.token_embedding.weight.normal_(mean=0.0, std=factor * 0.02)
|
||||
module.position_embedding.weight.normal_(mean=0.0, std=factor * 0.02)
|
||||
module.token_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
|
||||
module.position_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
|
||||
elif isinstance(module, CLIPVisionEmbeddings):
|
||||
factor = self.config.initializer_factor
|
||||
nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor)
|
||||
@ -460,10 +459,10 @@ class CLIPPreTrainedModel(PreTrainedModel):
|
||||
)
|
||||
|
||||
if isinstance(module, nn.LayerNorm):
|
||||
module.bias.zero_()
|
||||
module.weight.fill_(1.0)
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||
module.bias.zero_()
|
||||
module.bias.data.zero_()
|
||||
|
||||
|
||||
class CLIPEncoder(nn.Module):
|
||||
|
||||
@ -427,13 +427,12 @@ class CLIPSegPreTrainedModel(PreTrainedModel):
|
||||
input_modalities = ["image", "text"]
|
||||
supports_gradient_checkpointing = True
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
factor = self.config.initializer_factor
|
||||
if isinstance(module, CLIPSegTextEmbeddings):
|
||||
module.token_embedding.weight.normal_(mean=0.0, std=factor * 0.02)
|
||||
module.position_embedding.weight.normal_(mean=0.0, std=factor * 0.02)
|
||||
module.token_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
|
||||
module.position_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
|
||||
elif isinstance(module, CLIPSegVisionEmbeddings):
|
||||
factor = self.config.initializer_factor
|
||||
nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor)
|
||||
@ -464,10 +463,10 @@ class CLIPSegPreTrainedModel(PreTrainedModel):
|
||||
)
|
||||
|
||||
if isinstance(module, nn.LayerNorm):
|
||||
module.bias.zero_()
|
||||
module.weight.fill_(1.0)
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||
module.bias.zero_()
|
||||
module.bias.data.zero_()
|
||||
|
||||
|
||||
# Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoder with AltCLIP->CLIPSeg
|
||||
|
||||
@ -781,18 +781,17 @@ class ClvpPreTrainedModel(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
_skip_keys_device_placement = "past_key_values"
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module: nn.Module):
|
||||
"""Initialize the weights"""
|
||||
factor = self.config.initializer_factor
|
||||
if isinstance(module, nn.Embedding):
|
||||
module.weight.normal_(mean=0.0, std=factor * 0.02)
|
||||
module.weight.data.normal_(mean=0.0, std=factor * 0.02)
|
||||
elif isinstance(module, (nn.Linear, Conv1D, nn.Conv1d)):
|
||||
module.weight.normal_(mean=0.0, std=factor * 0.02)
|
||||
module.weight.data.normal_(mean=0.0, std=factor * 0.02)
|
||||
if module.bias is not None:
|
||||
module.bias.zero_()
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, ClvpRMSNorm):
|
||||
module.weight.fill_(1.0)
|
||||
module.weight.data.fill_(1.0)
|
||||
elif isinstance(module, ClvpEncoderMLP):
|
||||
in_proj_std = (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
|
||||
fc_std = (2 * module.config.hidden_size) ** -0.5 * factor
|
||||
@ -801,22 +800,22 @@ class ClvpPreTrainedModel(PreTrainedModel):
|
||||
elif isinstance(module, ClvpEncoder):
|
||||
config = self.config.get_text_config()
|
||||
factor = config.initializer_factor
|
||||
module.projection.weight.normal_(mean=0.0, std=factor * (config.hidden_size**-0.5))
|
||||
module.projection.weight.data.normal_(mean=0.0, std=factor * (config.hidden_size**-0.5))
|
||||
elif isinstance(module, ClvpConditioningEncoder):
|
||||
module.mel_conv.weight.normal_(mean=0.0, std=factor)
|
||||
module.mel_conv.bias.zero_()
|
||||
module.mel_conv.weight.data.normal_(mean=0.0, std=factor)
|
||||
module.mel_conv.bias.data.zero_()
|
||||
elif isinstance(module, ClvpForCausalLM):
|
||||
for name, p in module.named_parameters():
|
||||
if name == "c_proj.weight":
|
||||
p.normal_(
|
||||
p.data.normal_(
|
||||
mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.num_hidden_layers))
|
||||
)
|
||||
elif isinstance(module, ClvpModelForConditionalGeneration):
|
||||
module.logit_scale.fill_(self.config.logit_scale_init_value)
|
||||
module.logit_scale.data.fill_(self.config.logit_scale_init_value)
|
||||
|
||||
if isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
|
||||
module.bias.zero_()
|
||||
module.weight.fill_(1.0)
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
|
||||
class ClvpEncoder(ClvpPreTrainedModel):
|
||||
|
||||
@ -283,20 +283,19 @@ class CodeGenPreTrainedModel(PreTrainedModel):
|
||||
def __init__(self, *inputs, **kwargs):
|
||||
super().__init__(*inputs, **kwargs)
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights."""
|
||||
if isinstance(module, (nn.Linear,)):
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.zero_()
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.padding_idx is not None:
|
||||
module.weight[module.padding_idx].zero_()
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.zero_()
|
||||
module.weight.fill_(1.0)
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
|
||||
@auto_docstring
|
||||
@ -561,7 +560,7 @@ class CodeGenModel(CodeGenPreTrainedModel):
|
||||
"""
|
||||
)
|
||||
class CodeGenForCausalLM(CodeGenPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = {"lm_head.weight": "transformer.wte.weight"}
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user