mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-22 10:19:00 +08:00
Compare commits
105 Commits
trigger_vi
...
fix_sam_sa
Author | SHA1 | Date | |
---|---|---|---|
6b1a8c9446 | |||
9cde2f5d42 | |||
856f034f45 | |||
bb3c6426d8 | |||
2ad152f84c | |||
de70c8426e | |||
8ea61c4530 | |||
d34e21e7dd | |||
183fb3637c | |||
f022bf9322 | |||
0a52bd2403 | |||
555715f418 | |||
7a611f0afd | |||
3bd1c20149 | |||
dbc4b91db4 | |||
46a4b7c909 | |||
9ecee14378 | |||
f524439cc5 | |||
6e738411e1 | |||
9c500015c5 | |||
6f9da7649f | |||
7c9b0ca08c | |||
04282a9ef5 | |||
aef12349b6 | |||
9644acb7cb | |||
7d93f93f83 | |||
47f8578d96 | |||
6c6302817d | |||
003deb16f1 | |||
dbb9813dff | |||
656e2eab3f | |||
6bb6821d93 | |||
40a493c7ed | |||
ea29f61ed9 | |||
a4389494c7 | |||
0ba95564b7 | |||
d69945e5fc | |||
7b5e327c6e | |||
120935234f | |||
91f6fa00f4 | |||
5036ec8872 | |||
7f28da2850 | |||
01ad9f4b49 | |||
3ab47b6ce3 | |||
1e921a3a9c | |||
57a79f51b2 | |||
44fa04ae8d | |||
34c1e29cdd | |||
0f77ca72ca | |||
27ef46e846 | |||
fe9426f12d | |||
7caa57e85e | |||
b11b28cc4e | |||
0e0e5c1044 | |||
955e61b0da | |||
0173a99e73 | |||
e5a48785d9 | |||
4005e30c80 | |||
aa27fa75cd | |||
e021bf6bf8 | |||
ef27b2bc22 | |||
4a2decd192 | |||
935bbbc711 | |||
1b00966395 | |||
fe918d13b9 | |||
aaf224d570 | |||
9b5ce556aa | |||
b311a3f506 | |||
b499a14b17 | |||
e0f225cb10 | |||
342961f669 | |||
8771766a70 | |||
582d5e0e11 | |||
a5cc7a67d7 | |||
67b3d45eb6 | |||
07feaad8fb | |||
e40f301f1f | |||
e27d230ddd | |||
ab65ba47ad | |||
8fb60bf6be | |||
3ad35d0bca | |||
e3b70b0d1c | |||
4143f94d51 | |||
a63cb7578e | |||
e387821a96 | |||
f0e975c6cf | |||
31791b16a1 | |||
8ea72d12a2 | |||
5c85018072 | |||
7eaa90b87b | |||
4220039b29 | |||
8efe3a9d77 | |||
a5c6172c81 | |||
a31fa218ad | |||
716819b830 | |||
8f08318769 | |||
87e971e14d | |||
aaed2f5577 | |||
7f1a97bae3 | |||
9f9020fed3 | |||
23d79cea75 | |||
774dc274ac | |||
0010b41524 | |||
d498528800 | |||
66e696ee15 |
@ -43,8 +43,6 @@ jobs:
|
||||
parallelism: 1
|
||||
steps:
|
||||
- checkout
|
||||
- run: git branch
|
||||
- run: git log -n 1
|
||||
- run: python3 utils/extract_pr_number_from_circleci.py > pr_number.txt
|
||||
- run: echo $(cat pr_number.txt)
|
||||
- run: if [[ "$(cat pr_number.txt)" == "" && "$CIRCLE_BRANCH" != "main" && "$CIRCLE_BRANCH" != *-release ]]; then echo "Not a PR, not the main branch and not a release branch, skip test!"; circleci-agent step halt; fi
|
||||
|
@ -110,6 +110,7 @@ class CircleCIJob:
|
||||
print(f"Using {self.docker_image} docker image")
|
||||
if self.install_steps is None:
|
||||
self.install_steps = ["uv venv && uv pip install ."]
|
||||
self.install_steps.append("uv venv && uv pip install git+https://github.com/ydshieh/pytest.git@8.3.5-ydshieh git+https://github.com/ydshieh/pluggy.git@1.5.0-ydshieh")
|
||||
if self.pytest_options is None:
|
||||
self.pytest_options = {}
|
||||
if isinstance(self.tests_to_run, str):
|
||||
|
25
.github/workflows/change_pr_to_draft.yml
vendored
25
.github/workflows/change_pr_to_draft.yml
vendored
@ -1,25 +0,0 @@
|
||||
name: Change PR to draft
|
||||
|
||||
on:
|
||||
pull_request_target:
|
||||
types: [opened, reopened]
|
||||
|
||||
jobs:
|
||||
convert_pr_to_draft:
|
||||
runs-on: ubuntu-22.04
|
||||
name: Convert PR to draft
|
||||
permissions:
|
||||
pull-requests: write
|
||||
contents: write
|
||||
if: github.event.pull_request.draft == false
|
||||
steps:
|
||||
- name: Convert PR to draft
|
||||
shell: bash
|
||||
env:
|
||||
PR_NUMBER: ${{ github.event.number }}
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
REPO: ${{ github.repository }}
|
||||
run: |
|
||||
echo $PR_NUMBER
|
||||
gh pr ready $PR_NUMBER --repo $REPO --undo
|
||||
gh pr comment $PR_NUMBER --repo $REPO --body "Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. The CI will be paused while the PR is in draft mode. When it is ready for review, please click the \`Ready for review\` button (at the bottom of the PR page). This will assign reviewers and trigger CI."
|
19
.github/workflows/pr-style-bot.yml
vendored
Normal file
19
.github/workflows/pr-style-bot.yml
vendored
Normal file
@ -0,0 +1,19 @@
|
||||
# To run this bot, comment "@bot /style" on a PR
|
||||
name: Style Bot
|
||||
|
||||
on:
|
||||
issue_comment:
|
||||
types: [created]
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
|
||||
jobs:
|
||||
style:
|
||||
uses: huggingface/huggingface_hub/.github/workflows/style-bot-action.yml@main
|
||||
with:
|
||||
python_quality_dependencies: "[quality]"
|
||||
style_command_type: "default"
|
||||
secrets:
|
||||
bot_token: ${{ secrets.GITHUB_TOKEN }}
|
2
.github/workflows/self-comment-ci.yml
vendored
2
.github/workflows/self-comment-ci.yml
vendored
@ -29,7 +29,7 @@ jobs:
|
||||
runs-on: ubuntu-22.04
|
||||
name: Get PR number
|
||||
# For security: only allow team members to run
|
||||
if: ${{ github.event.issue.state == 'open' && contains(fromJSON('["ydshieh", "ArthurZucker", "zucchini-nlp", "qubvel", "molbap", "gante", "LysandreJik", "Cyrilvallez", "Rocketknight1", "SunMarc", "muellerzr", "eustlb", "MekkCyber"]'), github.actor) && (startsWith(github.event.comment.body, 'run-slow') || startsWith(github.event.comment.body, 'run slow') || startsWith(github.event.comment.body, 'run_slow')) }}
|
||||
if: ${{ github.event.issue.state == 'open' && contains(fromJSON('["ydshieh", "ArthurZucker", "zucchini-nlp", "qubvel", "molbap", "gante", "LysandreJik", "Cyrilvallez", "Rocketknight1", "SunMarc", "muellerzr", "eustlb", "MekkCyber", "manueldeprada"]'), github.actor) && (startsWith(github.event.comment.body, 'run-slow') || startsWith(github.event.comment.body, 'run slow') || startsWith(github.event.comment.body, 'run_slow')) }}
|
||||
outputs:
|
||||
PR_NUMBER: ${{ steps.set_pr_number.outputs.PR_NUMBER }}
|
||||
steps:
|
||||
|
16
.github/workflows/trigger_circleci.yml
vendored
16
.github/workflows/trigger_circleci.yml
vendored
@ -1,16 +0,0 @@
|
||||
name: Trigger CircleCI
|
||||
|
||||
on:
|
||||
pull_request_target:
|
||||
types: [ready_for_review]
|
||||
|
||||
jobs:
|
||||
trigger-circleci:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: trigger CircleCI pipeline via GitHub Actions
|
||||
uses: CircleCI-Public/trigger-circleci-pipeline-action@v1.0.5
|
||||
with:
|
||||
GHA_Meta: "Trigger via GitHub Actions"
|
||||
env:
|
||||
CCI_TOKEN: ${{ secrets.CIRCLECI_PAT }}
|
@ -98,7 +98,12 @@ Install Transformers from source if you want the latest changes in the library o
|
||||
```shell
|
||||
git clone https://github.com/huggingface/transformers.git
|
||||
cd transformers
|
||||
|
||||
# pip
|
||||
pip install .[torch]
|
||||
|
||||
# uv
|
||||
uv pip install .[torch]
|
||||
```
|
||||
|
||||
## Quickstart
|
||||
@ -120,7 +125,7 @@ To chat with a model, the usage pattern is the same. The only difference is you
|
||||
> [!TIP]
|
||||
> You can also chat with a model directly from the command line.
|
||||
> ```shell
|
||||
> transformers chat --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct
|
||||
> transformers chat Qwen/Qwen2.5-0.5B-Instruct
|
||||
> ```
|
||||
|
||||
```py
|
||||
|
@ -71,6 +71,9 @@ RUN python3 -m pip install --no-cache-dir g2p-en
|
||||
# For Some bitsandbytes tests
|
||||
RUN python3 -m pip install --no-cache-dir einops
|
||||
|
||||
# `kernels` may give different outputs (within 1e-5 range) even with the same model (weights) and the same inputs
|
||||
RUN python3 -m pip uninstall -y kernels
|
||||
|
||||
# When installing in editable mode, `transformers` is not recognized as a package.
|
||||
# this line must be added in order for python to be aware of transformers.
|
||||
RUN cd transformers && python3 setup.py develop
|
||||
|
@ -45,6 +45,9 @@ RUN python3 -m pip uninstall -y deepspeed
|
||||
# TODO: Find out why test fail.
|
||||
RUN DS_BUILD_CPU_ADAM=1 DS_BUILD_FUSED_ADAM=1 python3 -m pip install deepspeed --global-option="build_ext" --global-option="-j8" --no-cache -v --disable-pip-version-check 2>&1
|
||||
|
||||
# `kernels` may give different outputs (within 1e-5 range) even with the same model (weights) and the same inputs
|
||||
RUN python3 -m pip uninstall -y kernels
|
||||
|
||||
# When installing in editable mode, `transformers` is not recognized as a package.
|
||||
# this line must be added in order for python to be aware of transformers.
|
||||
RUN cd transformers && python3 setup.py develop
|
||||
|
@ -57,6 +57,9 @@ RUN python3 -m pip uninstall -y deepspeed
|
||||
#RUN git clone https://github.com/pytorch/TensorRT.git
|
||||
#RUN cd TensorRT/py && python3 setup.py install --fx-only
|
||||
|
||||
# `kernels` may give different outputs (within 1e-5 range) even with the same model (weights) and the same inputs
|
||||
RUN python3 -m pip uninstall -y kernels
|
||||
|
||||
# When installing in editable mode, `transformers` is not recognized as a package.
|
||||
# this line must be added in order for python to be aware of transformers.
|
||||
RUN cd transformers && python3 setup.py develop
|
||||
|
@ -28,6 +28,9 @@ RUN python3 -m pip uninstall -y tensorflow flax
|
||||
RUN python3 -m pip install --no-cache-dir git+https://github.com/facebookresearch/detectron2.git pytesseract
|
||||
RUN python3 -m pip install -U "itsdangerous<2.1.0"
|
||||
|
||||
# `kernels` may give different outputs (within 1e-5 range) even with the same model (weights) and the same inputs
|
||||
RUN python3 -m pip uninstall -y kernels
|
||||
|
||||
# When installing in editable mode, `transformers` is not recognized as a package.
|
||||
# this line must be added in order for python to be aware of transformers.
|
||||
RUN cd transformers && python3 setup.py develop
|
||||
|
@ -90,6 +90,9 @@ RUN python3 -m pip install --no-cache-dir "auto-round>=0.5.0"
|
||||
# Add transformers in editable mode
|
||||
RUN python3 -m pip install --no-cache-dir -e ./transformers[dev-torch]
|
||||
|
||||
# `kernels` may give different outputs (within 1e-5 range) even with the same model (weights) and the same inputs
|
||||
RUN python3 -m pip uninstall -y kernels
|
||||
|
||||
# When installing in editable mode, `transformers` is not recognized as a package.
|
||||
# this line must be added in order for python to be aware of transformers.
|
||||
RUN cd transformers && python3 setup.py develop
|
||||
|
@ -39,6 +39,8 @@
|
||||
title: Tokenizers
|
||||
- local: image_processors
|
||||
title: Image processors
|
||||
- local: video_processors
|
||||
title: Video processors
|
||||
- local: backbones
|
||||
title: Backbones
|
||||
- local: feature_extractors
|
||||
@ -362,7 +364,9 @@
|
||||
title: Feature Extractor
|
||||
- local: main_classes/image_processor
|
||||
title: Image Processor
|
||||
title: Main classes
|
||||
- local: main_classes/video_processor
|
||||
title: Video Processor
|
||||
title: Main Classes
|
||||
- sections:
|
||||
- sections:
|
||||
- local: model_doc/albert
|
||||
|
@ -27,7 +27,7 @@ This guide shows you how to quickly start chatting with Transformers from the co
|
||||
|
||||
## transformers CLI
|
||||
|
||||
Chat with a model directly from the command line as shown below. It launches an interactive session with a model. Enter `clear` to reset the conversation, `exit` to terminate the session, and `help` to display all the command options.
|
||||
After you've [installed Transformers](./installation.md), chat with a model directly from the command line as shown below. It launches an interactive session with a model, with a few base commands listed at the start of the session.
|
||||
|
||||
```bash
|
||||
transformers chat Qwen/Qwen2.5-0.5B-Instruct
|
||||
@ -37,6 +37,12 @@ transformers chat Qwen/Qwen2.5-0.5B-Instruct
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/transformers-chat-cli.png"/>
|
||||
</div>
|
||||
|
||||
You can launch the CLI with arbitrary `generate` flags, with the format `arg_1=value_1 arg_2=value_2 ...`
|
||||
|
||||
```bash
|
||||
transformers chat Qwen/Qwen2.5-0.5B-Instruct do_sample=False max_new_tokens=10
|
||||
```
|
||||
|
||||
For a full list of options, run the command below.
|
||||
|
||||
```bash
|
||||
|
@ -20,11 +20,15 @@ A decoding strategy informs how a model should select the next generated token.
|
||||
|
||||
This guide will help you understand the different decoding strategies available in Transformers and how and when to use them.
|
||||
|
||||
## Greedy search
|
||||
## Basic decoding methods
|
||||
|
||||
Greedy search is the default decoding strategy. It selects the next most likely token at each step. Unless specified in [`GenerationConfig`], this strategy generates a maximum of 20 tokens.
|
||||
These are well established decoding methods, and should be your starting point for text generation tasks.
|
||||
|
||||
Greedy search works well for tasks with relatively short outputs. However, it breaks down when generating longer sequences because it begins to repeat itself.
|
||||
### Greedy search
|
||||
|
||||
Greedy search is the default decoding strategy. It selects the next most likely token at each step. Unless specified in [`GenerationConfig`], this strategy generates a maximum of 20 new tokens.
|
||||
|
||||
Greedy search works well for tasks with relatively short outputs where creativity is not a priority. However, it breaks down when generating longer sequences because it begins to repeat itself.
|
||||
|
||||
```py
|
||||
import torch
|
||||
@ -40,11 +44,11 @@ tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
'Hugging Face is an open-source company that provides a suite of tools and services for building, deploying, and maintaining natural language processing'
|
||||
```
|
||||
|
||||
## Contrastive search
|
||||
### Sampling
|
||||
|
||||
[Contrastive search](https://huggingface.co/papers/2202.06417) is a decoding strategy that aims to reduce repetition even while generating longer sequences. This strategy compares how similar a generated token is against previous tokens, and if they're more similar, a penalty is applied.
|
||||
Sampling, or multinomial sampling, randomly selects a token based on the probability distribution over the entire model's vocabulary (as opposed to the most likely token, as in greedy search). This means every token with a non-zero probability has a chance to be selected. Sampling strategies reduce repetition and can generate more creative and diverse outputs.
|
||||
|
||||
Enable contrastive search with the `penalty_alpha` and `top_k` parameters. The `penalty_alpha` manages the penalty applied and `top_k` is the number of most likely tokens to return.
|
||||
Enable multinomial sampling with `do_sample=True` and `num_beams=1`.
|
||||
|
||||
```py
|
||||
import torch
|
||||
@ -55,14 +59,14 @@ inputs = tokenizer("Hugging Face is an open-source company", return_tensors="pt"
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", torch_dtype=torch.float16).to("cuda")
|
||||
# explicitly set to 100 because Llama2 generation length is 4096
|
||||
outputs = model.generate(**inputs, max_new_tokens=100, penalty_alpha=0.6, top_k=4)
|
||||
outputs = model.generate(**inputs, max_new_tokens=50, do_sample=True, num_beams=1)
|
||||
tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
'Hugging Face is an open-source company that provides a platform for building and deploying AI models.\nHugging Face is an open-source company that provides a platform for building and deploying AI models. The platform allows developers to build and deploy AI models, as well as collaborate with other developers.\nHugging Face was founded in 2019 by Thibault Wittemberg and Clément Delangue. The company is based in Paris, France.\nHugging Face has'
|
||||
'Hugging Face is an open-source company 🤗\nWe are open-source and believe that open-source is the best way to build technology. Our mission is to make AI accessible to everyone, and we believe that open-source is the best way to achieve that.'
|
||||
```
|
||||
|
||||
## Beam search
|
||||
### Beam search
|
||||
|
||||
Beam search keeps track of several generated sequences (beams) at each time step. After a certain number of steps, it selects the sequence with the highest *overall* probability. Unlike greedy search, this strategy can "look ahead" and pick a sequence with a higher probability overall even if the initial tokens have a lower probability.
|
||||
Beam search keeps track of several generated sequences (beams) at each time step. After a certain number of steps, it selects the sequence with the highest *overall* probability. Unlike greedy search, this strategy can "look ahead" and pick a sequence with a higher probability overall even if the initial tokens have a lower probability. It is best suited for input-grounded tasks, like describing an image or speech recognition. You can also use `do_sample=True` with beam search to sample at each step, but beam search will still greedily prune out low probability sequences between steps.
|
||||
|
||||
> [!TIP]
|
||||
> Check out the [beam search visualizer](https://huggingface.co/spaces/m-ric/beam_search_visualizer) to see how beam search works.
|
||||
@ -83,66 +87,11 @@ tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
"['Hugging Face is an open-source company that develops and maintains the Hugging Face platform, which is a collection of tools and libraries for building and deploying natural language processing (NLP) models. Hugging Face was founded in 2018 by Thomas Wolf']"
|
||||
```
|
||||
|
||||
## Diverse beam search
|
||||
## Advanced decoding methods
|
||||
|
||||
[Diverse beam search](https://hf.co/papers/1610.02424) is a variant of beam search that produces more diverse output candidates to choose from. This strategy measures the dissimilarity of sequences and a penalty is applied if sequences are too similar. To avoid high computation costs, the number of beams is divided into groups.
|
||||
Advanced decoding methods aim at either tackling specific generation quality issues (e.g. repetition) or at improving the generation throughput in certain situations. These techniques are more complex, and may not work correctly with all models.
|
||||
|
||||
Enable diverse beam search with the `num_beams`, `num_beam_groups` and `diversity_penalty` parameters (the `num_beams` parameter should be divisible by `num_beam_groups`).
|
||||
|
||||
```py
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
|
||||
inputs = tokenizer("Hugging Face is an open-source company", return_tensors="pt").to("cuda")
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", torch_dtype=torch.float16).to("cuda")
|
||||
# explicitly set to 100 because Llama2 generation length is 4096
|
||||
outputs = model.generate(**inputs, max_new_tokens=50, num_beams=6, num_beam_groups=3, diversity_penalty=1.0, do_sample=False)
|
||||
tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
'Hugging Face is an open-source company 🤗\nWe are an open-source company. Our mission is to democratize AI and make it accessible to everyone. We believe that AI should be used for the benefit of humanity, not for the benefit of a'
|
||||
```
|
||||
|
||||
## Multinomial sampling
|
||||
|
||||
Search methods selects the most likely tokens. Sampling, or multinomial sampling, randomly selects a token based on the probability distribution over the entire models vocabulary. This means every token with a non-zero probability has a chance to be selected. Sampling strategies reduce repetition and can generate more creative and diverse outputs.
|
||||
|
||||
Enable multinomial sampling with `do_sample=True` and `num_beams=1`.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
|
||||
inputs = tokenizer("Hugging Face is an open-source company", return_tensors="pt").to("cuda")
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", torch_dtype=torch.float16).to("cuda")
|
||||
# explicitly set to 100 because Llama2 generation length is 4096
|
||||
outputs = model.generate(**inputs, max_new_tokens=50, do_sample=True, num_beams=1)
|
||||
tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
'Hugging Face is an open-source company 🤗\nWe are open-source and believe that open-source is the best way to build technology. Our mission is to make AI accessible to everyone, and we believe that open-source is the best way to achieve that.'
|
||||
```
|
||||
|
||||
## Beam search multinomial sampling
|
||||
|
||||
This decoding strategy is a combination of beam search and multinomial sampling. It generates multiple beams and uses a sampling strategy for each beam.
|
||||
|
||||
Enable beam search multinomial sampling by setting `num_beams` to a value greater than 1 and `do_sample=True`.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
|
||||
inputs = tokenizer("Hugging Face is an open-source company", return_tensors="pt").to("cuda")
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", torch_dtype=torch.float16).to("cuda")
|
||||
# explicitly set to 100 because Llama2 generation length is 4096
|
||||
outputs = model.generate(**inputs, max_new_tokens=50, do_sample=True, num_beams=4)
|
||||
'Hugging Face is an open-source company 100% dedicated to making AI more accessible. We believe that AI should be available to everyone, and we’re working hard to make that a reality.\nWe’re a team of passionate engineers, designers,'
|
||||
```
|
||||
|
||||
## Speculative decoding
|
||||
### Speculative decoding
|
||||
|
||||
[Speculative](https://hf.co/papers/2211.17192) or assistive decoding isn't a search or sampling strategy. Instead, speculative decoding adds a second smaller model to generate candidate tokens. The main model verifies the candidate tokens in a single `forward` pass, which speeds up the decoding process overall. This method is especially useful for LLMs where it can be more costly and slower to generate tokens. Refer to the [speculative decoding](./llm_optims#speculative-decoding) guide to learn more.
|
||||
|
||||
@ -203,7 +152,7 @@ tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
### Prompt lookup decoding
|
||||
#### Prompt lookup decoding
|
||||
|
||||
[Prompt lookup decoding](./llm_optims#prompt-lookup-decoding) is a variant of speculative decoding that uses overlapping n-grams as the candidate tokens. It works well for input-grounded tasks such as summarization. Refer to the [prompt lookup decoding](./llm_optims#prompt-lookup-decoding) guide to learn more.
|
||||
|
||||
@ -245,7 +194,7 @@ outputs = model.generate(**inputs, assistant_early_exit=4, do_sample=False, max_
|
||||
tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
```
|
||||
|
||||
### Universal assisted decoding
|
||||
#### Universal assisted decoding
|
||||
|
||||
Universal assisted decoding (UAD) enables the main and assistant models to use different tokenizers. The main models input tokens are re-encoded into assistant model tokens. Candidate tokens are generated in the assistant encoding which are re-encoded into the main model candidate tokens. The candidate tokens are verified as explained in [speculative decoding](#speculative-decoding).
|
||||
|
||||
@ -269,7 +218,27 @@ tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
['Alice and Bob are sitting in a bar. Alice is drinking a beer and Bob is drinking a']
|
||||
```
|
||||
|
||||
## DoLa
|
||||
### Contrastive search
|
||||
|
||||
[Contrastive search](https://huggingface.co/papers/2202.06417) is a decoding strategy that aims to reduce repetition even while generating longer sequences. This strategy compares how similar a generated token is against previous tokens, and if they're more similar, a penalty is applied.
|
||||
|
||||
Enable contrastive search with the `penalty_alpha` and `top_k` parameters. The `penalty_alpha` manages the penalty applied and `top_k` is the number of most likely tokens to return.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
|
||||
inputs = tokenizer("Hugging Face is an open-source company", return_tensors="pt").to("cuda")
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", torch_dtype=torch.float16).to("cuda")
|
||||
# explicitly set to 100 because Llama2 generation length is 4096
|
||||
outputs = model.generate(**inputs, max_new_tokens=100, penalty_alpha=0.6, top_k=4)
|
||||
tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
'Hugging Face is an open-source company that provides a platform for building and deploying AI models.\nHugging Face is an open-source company that provides a platform for building and deploying AI models. The platform allows developers to build and deploy AI models, as well as collaborate with other developers.\nHugging Face was founded in 2019 by Thibault Wittemberg and Clément Delangue. The company is based in Paris, France.\nHugging Face has'
|
||||
```
|
||||
|
||||
### DoLa
|
||||
|
||||
[Decoding by Contrasting Layers (DoLa)](https://hf.co/papers/2309.03883) is a contrastive decoding strategy for improving factuality and reducing hallucination. This strategy works by contrasting the logit differences between the final and early layers. As a result, factual knowledge localized to particular layers are amplified. DoLa is not recommended for smaller models like GPT-2.
|
||||
|
||||
@ -325,6 +294,210 @@ tokenizer.batch_decode(outputs[:, inputs.input_ids.shape[-1]:], skip_special_tok
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
### Diverse beam search
|
||||
|
||||
[Diverse beam search](https://hf.co/papers/1610.02424) is a variant of beam search that produces more diverse output candidates to choose from. This strategy measures the dissimilarity of sequences and a penalty is applied if sequences are too similar. To avoid high computation costs, the number of beams is divided into groups.
|
||||
|
||||
Enable diverse beam search with the `num_beams`, `num_beam_groups` and `diversity_penalty` parameters (the `num_beams` parameter should be divisible by `num_beam_groups`).
|
||||
|
||||
```py
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
|
||||
inputs = tokenizer("Hugging Face is an open-source company", return_tensors="pt").to("cuda")
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", torch_dtype=torch.float16).to("cuda")
|
||||
# explicitly set to 100 because Llama2 generation length is 4096
|
||||
outputs = model.generate(**inputs, max_new_tokens=50, num_beams=6, num_beam_groups=3, diversity_penalty=1.0, do_sample=False)
|
||||
tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
'Hugging Face is an open-source company 🤗\nWe are an open-source company. Our mission is to democratize AI and make it accessible to everyone. We believe that AI should be used for the benefit of humanity, not for the benefit of a'
|
||||
```
|
||||
|
||||
|
||||
## Custom decoding methods
|
||||
|
||||
Custom decoding methods enable specialized generation behavior such as the following:
|
||||
- have the model continue thinking if it is uncertain;
|
||||
- roll back generation if the model gets stuck;
|
||||
- handle special tokens with custom logic;
|
||||
- enhanced input preparation for advanced models;
|
||||
|
||||
We enable custom decoding methods through model repositories, assuming a specific model tag and file structure (see subsection below). This feature is an extension of [custom modeling code](./models.md#custom-models) and, like such, requires setting `trust_remote_code=True`.
|
||||
|
||||
If a model repository holds a custom decoding method, the easiest way to try it out is to load the model and generate with it:
|
||||
|
||||
<!-- TODO before merging: 1) better repo name (use a `generate-community` org?) 2) prettify the repo -->
|
||||
```py
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
# `transformers-community/custom_generate_example` holds a copy of `Qwen/Qwen2.5-0.5B-Instruct`, but
|
||||
# with custom generation code -> calling `generate` uses the custom decoding method!
|
||||
tokenizer = AutoTokenizer.from_pretrained("transformers-community/custom_generate_example")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"transformers-community/custom_generate_example", device_map="auto", trust_remote_code=True
|
||||
)
|
||||
|
||||
inputs = tokenizer(["The quick brown"], return_tensors="pt").to(model.device)
|
||||
# The custom decoding method is a minimal greedy decoding implementation. It also prints a custom message at run time.
|
||||
gen_out = model.generate(**inputs)
|
||||
# you should now see its custom message, "✨ using a custom generation method ✨"
|
||||
print(tokenizer.batch_decode(gen_out, skip_special_tokens=True))
|
||||
'The quick brown fox jumps over a lazy dog, and the dog is a type of animal. Is'
|
||||
```
|
||||
|
||||
Model repositories with custom decoding methods have a special property: their decoding method can be loaded from **any** model through [`~GenerationMixin.generate`]'s `custom_generate` argument. This means anyone can create and share their custom generation method to potentially work with any Transformers model, without requiring users to install additional Python packages.
|
||||
|
||||
```py
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
|
||||
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct", device_map="auto")
|
||||
|
||||
inputs = tokenizer(["The quick brown"], return_tensors="pt").to(model.device)
|
||||
# `custom_generate` replaces the original `generate` by the custom decoding method defined in
|
||||
# `transformers-community/custom_generate_example`
|
||||
gen_out = model.generate(**inputs, custom_generate="transformers-community/custom_generate_example", trust_remote_code=True)
|
||||
print(tokenizer.batch_decode(gen_out, skip_special_tokens=True)[0])
|
||||
'The quick brown fox jumps over a lazy dog, and the dog is a type of animal. Is'
|
||||
```
|
||||
|
||||
You should read the `README.md` file of the repository containing the custom generation strategy to see what the new arguments and output type differences are, if they exist. Otherwise, you can assume it works like the base [`~GenerationMixin.generate`] method.
|
||||
|
||||
> [!TIP]
|
||||
> You can find all custom decoding methods by [searching for their custom tag.](https://huggingface.co/models?other=custom_generate), `custom_generate`
|
||||
|
||||
Consider the Hub repository [transformers-community/custom_generate_example](https://huggingface.co/transformers-community/custom_generate_example) as an example. The `README.md` states that it has an additional input argument, `left_padding`, which adds a number of padding tokens before the prompt.
|
||||
|
||||
```py
|
||||
gen_out = model.generate(
|
||||
**inputs, custom_generate="transformers-community/custom_generate_example", trust_remote_code=True, left_padding=5
|
||||
)
|
||||
print(tokenizer.batch_decode(gen_out)[0])
|
||||
'<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>The quick brown fox jumps over the lazy dog.\n\nThe sentence "The quick'
|
||||
```
|
||||
|
||||
If the custom method has pinned Python requirements that your environment doesn't meet, you'll get an exception about missing requirements. For instance, [transformers-community/custom_generate_bad_requirements](https://huggingface.co/transformers-community/custom_generate_bad_requirements) has an impossible set of requirements defined in its `custom_generate/requirements.txt` file, and you'll see the error message below if you try to run it.
|
||||
|
||||
```
|
||||
ImportError: Missing requirements in your local environment for `transformers-community/custom_generate_bad_requirements`:
|
||||
foo (installed: None)
|
||||
bar==0.0.0 (installed: None)
|
||||
torch>=99.0 (installed: 2.6.0)
|
||||
```
|
||||
|
||||
Updating your Python requirements accordingly will remove this error message.
|
||||
|
||||
### Creating a custom decoding method
|
||||
|
||||
To create a new decoding method, you need to create a new [**Model**](https://huggingface.co/new) repository and push a few files into it.
|
||||
1. The model you've designed your decoding method with.
|
||||
2. `custom_generate/generate.py`, which contains all the logic for your custom decoding method.
|
||||
3. `custom_generate/requirements.txt`, used to optionally add new Python requirements and/or lock specific versions to correctly use your method.
|
||||
4. `README.md`, where you should add the `custom_generate` tag and document any new arguments or output type differences of your custom method here.
|
||||
|
||||
After you've added all required files, your repository should look like this
|
||||
|
||||
```
|
||||
your_repo/
|
||||
├── README.md # include the 'custom_generate' tag
|
||||
├── config.json
|
||||
├── ...
|
||||
└── custom_generate/
|
||||
├── generate.py
|
||||
└── requirements.txt
|
||||
```
|
||||
|
||||
#### Adding the base model
|
||||
|
||||
The starting point for your custom decoding method is a model repository just like any other. The model to add to this repository should be the model you've designed your method with, and it is meant to be part of a working self-contained model-generate pair. When the model in this repository is loaded, your custom decoding method will override `generate`. Don't worry -- your decoding method can still be loaded with any other Transformers model, as explained in the section above.
|
||||
|
||||
If you simply want to copy an existing model, you can do
|
||||
|
||||
```py
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("source/model_repo")
|
||||
model = AutoModelForCausalLM.from_pretrained("source/model_repo")
|
||||
tokenizer.save_pretrained("your/decoding_method", push_to_hub=True)
|
||||
model.save_pretrained("your/decoding_method", push_to_hub=True)
|
||||
```
|
||||
|
||||
#### generate.py
|
||||
|
||||
This is the core of your decoding method. It *must* contain a method named `generate`, and this method *must* contain a `model` argument as its first argument. `model` is the model instance, which means you have access to all attributes and methods in the model, including the ones defined in [`GenerationMixin`] (like the base `generate` method).
|
||||
|
||||
> [!WARNING]
|
||||
> `generate.py` must be placed in a folder named `custom_generate`, and not at the root level of the repository. The file paths for this feature are hardcoded.
|
||||
|
||||
Under the hood, when the base [`~GenerationMixin.generate`] method is called with a `custom_generate` argument, it first checks its Python requirements (if any), then locates the custom `generate` method in `generate.py`, and finally calls the custom `generate`. All received arguments and `model` are forwarded to your custom `generate` method.
|
||||
|
||||
This means your `generate` can have a mix of original and custom arguments (as well as a different output type) as shown below.
|
||||
|
||||
```py
|
||||
import torch
|
||||
|
||||
def generate(model, input_ids, generation_config=None, left_padding=None, **kwargs):
|
||||
generation_config = generation_config or model.generation_config # default to the model generation config
|
||||
cur_length = input_ids.shape[1]
|
||||
max_length = generation_config.max_length or cur_length + generation_config.max_new_tokens
|
||||
|
||||
# Example of custom argument: add `left_padding` (integer) pad tokens before the prompt
|
||||
if left_padding is not None:
|
||||
if not isinstance(left_padding, int) or left_padding < 0:
|
||||
raise ValueError(f"left_padding must be an integer larger than 0, but is {left_padding}")
|
||||
|
||||
pad_token = kwargs.pop("pad_token", None) or generation_config.pad_token_id or model.config.pad_token_id
|
||||
if pad_token is None:
|
||||
raise ValueError("pad_token is not defined")
|
||||
batch_size = input_ids.shape[0]
|
||||
pad_tensor = torch.full(size=(batch_size, left_padding), fill_value=pad_token).to(input_ids.device)
|
||||
input_ids = torch.cat((pad_tensor, input_ids), dim=1)
|
||||
cur_length = input_ids.shape[1]
|
||||
|
||||
# Simple greedy decoding loop
|
||||
while cur_length < max_length:
|
||||
logits = model(input_ids).logits
|
||||
next_token_logits = logits[:, -1, :]
|
||||
next_tokens = torch.argmax(next_token_logits, dim=-1)
|
||||
input_ids = torch.cat((input_ids, next_tokens[:, None]), dim=-1)
|
||||
cur_length += 1
|
||||
|
||||
return input_ids
|
||||
```
|
||||
|
||||
Follow the recommended practices below to ensure your custom decoding method works as expected.
|
||||
- Feel free to reuse the logic for validation and input preparation in the original [`~GenerationMixin.generate`].
|
||||
- Pin the `transformers` version in the requirements if you use any private method/attribute in `model`.
|
||||
- You can add other files in the `custom_generate` folder, and use relative imports.
|
||||
- Consider adding model validation, input validation, or even a separate test file to help users sanity-check your code in their environment.
|
||||
|
||||
#### requirements.txt
|
||||
|
||||
You can optionally specify additional Python requirements in a `requirements.txt` file inside the `custom_generate` folder. These are checked at runtime and an exception will be thrown if they're missing, nudging users to update their environment accordingly.
|
||||
|
||||
#### README.md
|
||||
|
||||
The root level `README.md` in the model repository usually describes the model therein. However, since the focus of the repository is the custom decoding method, we highly recommend to shift its focus towards describing the custom decoding method. In addition to a description of the method, we recommend documenting any input and/or output differences to the original [`~GenerationMixin.generate`]. This way, users can focus on what's new, and rely on Transformers docs for generic implementation details.
|
||||
|
||||
For discoverability, we highly recommend you to add the `custom_generate` tag to your repository. To do so, the top of your `README.md` file should look like the example below. After you push the file, you should see the tag in your repository!
|
||||
|
||||
```
|
||||
---
|
||||
library_name: transformers
|
||||
tags:
|
||||
- custom_generate
|
||||
---
|
||||
|
||||
(your markdown content here)
|
||||
```
|
||||
|
||||
Recommended practices:
|
||||
- Document input and output differences in [`~GenerationMixin.generate`].
|
||||
- Add self-contained examples to enable quick experimentation.
|
||||
- Describe soft-requirements such as if the method only works well with a certain family of models.
|
||||
|
||||
|
||||
## Resources
|
||||
|
||||
Read the [How to generate text: using different decoding methods for language generation with Transformers](https://huggingface.co/blog/how-to-generate) blog post for an explanation of how common decoding strategies work.
|
||||
|
@ -90,11 +90,6 @@ class SamVisionAttentionSplit(SamVisionAttention, nn.Module):
|
||||
|
||||
attn_weights = (query * self.scale) @ key.transpose(-2, -1)
|
||||
|
||||
if self.use_rel_pos:
|
||||
attn_weights = self.add_decomposed_rel_pos(
|
||||
attn_weights, query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)
|
||||
)
|
||||
|
||||
attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype)
|
||||
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
||||
attn_output = (attn_probs @ value).reshape(batch_size, self.num_attention_heads, height, width, -1)
|
||||
@ -114,13 +109,14 @@ Load the model with [`~PreTrainedModel.from_pretrained`].
|
||||
|
||||
```py
|
||||
from transformers import SamModel
|
||||
from transformers.models.sam import modeling_sam
|
||||
|
||||
# replace the attention class in the modeling_sam module
|
||||
modeling_sam.SamVisionAttention = SamVisionAttentionSplit
|
||||
|
||||
# load the pretrained SAM model
|
||||
model = SamModel.from_pretrained("facebook/sam-vit-base")
|
||||
|
||||
# replace the attention class in the vision_encoder module
|
||||
for layer in model.vision_encoder.layers:
|
||||
if hasattr(layer, "attn"):
|
||||
layer.attn = SamVisionAttentionSplit(model.config.vision_config, model.config.vision_config.window_size)
|
||||
```
|
||||
|
||||
## LoRA
|
||||
@ -138,7 +134,7 @@ config = LoraConfig(
|
||||
# apply LoRA to q and v
|
||||
target_modules=["q", "v"],
|
||||
lora_dropout=0.1,
|
||||
task_type="mask-generation"
|
||||
task_type="FEATURE_EXTRACTION"
|
||||
)
|
||||
```
|
||||
|
||||
@ -152,5 +148,5 @@ Call [print_trainable_parameters](https://huggingface.co/docs/peft/package_refer
|
||||
|
||||
```py
|
||||
model.print_trainable_parameters()
|
||||
"trainable params: 608,256 || all params: 94,343,728 || trainable%: 0.6447"
|
||||
"trainable params: 589,824 || all params: 94,274,096 || trainable%: 0.6256"
|
||||
```
|
@ -16,7 +16,7 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
# Image processors
|
||||
|
||||
Image processors converts images into pixel values, tensors that represent image colors and size. The pixel values are inputs to a vision or video model. To ensure a pretrained model receives the correct input, an image processor can perform the following operations to make sure an image is exactly like the images a model was pretrained on.
|
||||
Image processors converts images into pixel values, tensors that represent image colors and size. The pixel values are inputs to a vision model. To ensure a pretrained model receives the correct input, an image processor can perform the following operations to make sure an image is exactly like the images a model was pretrained on.
|
||||
|
||||
- [`~BaseImageProcessor.center_crop`] to resize an image
|
||||
- [`~BaseImageProcessor.normalize`] or [`~BaseImageProcessor.rescale`] pixel values
|
||||
|
@ -84,6 +84,19 @@ class Trainer:
|
||||
|
||||
Backends that can be added here are all the backends that are available in the `import_utils.py` module.
|
||||
|
||||
Additionally, specific versions can be specified in each backend. For example, this is how you would specify
|
||||
a requirement on torch>=2.6 on the `Trainer` class:
|
||||
|
||||
```python
|
||||
from .utils.import_utils import requires
|
||||
|
||||
@requires(backends=("torch>=2.6", "accelerate"))
|
||||
class Trainer:
|
||||
...
|
||||
```
|
||||
|
||||
You can specify the following operators: `==`, `>`, `>=`, `<`, `<=`, `!=`.
|
||||
|
||||
## Methods
|
||||
|
||||
[[autodoc]] utils.import_utils.define_import_structure
|
||||
|
@ -20,9 +20,13 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
Text generation is the most popular application for large language models (LLMs). A LLM is trained to generate the next word (token) given some initial text (prompt) along with its own generated outputs up to a predefined length or when it reaches an end-of-sequence (`EOS`) token.
|
||||
|
||||
In Transformers, the [`~GenerationMixin.generate`] API handles text generation, and it is available for all models with generative capabilities.
|
||||
In Transformers, the [`~GenerationMixin.generate`] API handles text generation, and it is available for all models with generative capabilities. This guide will show you the basics of text generation with [`~GenerationMixin.generate`] and some common pitfalls to avoid.
|
||||
|
||||
This guide will show you the basics of text generation with [`~GenerationMixin.generate`] and some common pitfalls to avoid.
|
||||
> [!TIP]
|
||||
> You can also chat with a model directly from the command line. ([reference](./conversations.md#transformers-cli))
|
||||
> ```shell
|
||||
> transformers chat Qwen/Qwen2.5-0.5B-Instruct
|
||||
> ```
|
||||
|
||||
## Default generate
|
||||
|
||||
@ -134,6 +138,20 @@ outputs = model.generate(**inputs, generation_config=generation_config)
|
||||
print(tokenizer.batch_decode(outputs, skip_special_tokens=True))
|
||||
```
|
||||
|
||||
## Common Options
|
||||
|
||||
[`~GenerationMixin.generate`] is a powerful tool that can be heavily customized. This can be daunting for a new users. This section contains a list of popular generation options that you can define in most text generation tools in Transformers: [`~GenerationMixin.generate`], [`GenerationConfig`], `pipelines`, the `chat` CLI, ...
|
||||
|
||||
| Option name | Type | Simplified description |
|
||||
|---|---|---|
|
||||
| `max_new_tokens` | `int` | Controls the maximum generation length. Be sure to define it, as it usually defaults to a small value. |
|
||||
| `do_sample` | `bool` | Defines whether generation will sample the next token (`True`), or is greedy instead (`False`). Most use cases should set this flag to `True`. Check [this guide](./generation_strategies.md) for more information. |
|
||||
| `temperature` | `float` | How unpredictable the next selected token will be. High values (`>0.8`) are good for creative tasks, low values (e.g. `<0.4`) for tasks that require "thinking". Requires `do_sample=True`. |
|
||||
| `num_beams` | `int` | When set to `>1`, activates the beam search algorithm. Beam search is good on input-grounded tasks. Check [this guide](./generation_strategies.md) for more information. |
|
||||
| `repetition_penalty` | `float` | Set it to `>1.0` if you're seeing the model repeat itself often. Larger values apply a larger penalty. |
|
||||
| `eos_token_id` | `List[int]` | The token(s) that will cause generation to stop. The default value is usually good, but you can specify a different token. |
|
||||
|
||||
|
||||
## Pitfalls
|
||||
|
||||
The section below covers some common issues you may encounter during text generation and how to solve them.
|
||||
@ -286,4 +304,4 @@ Take a look below for some more specific and specialized text generation librari
|
||||
- [SynCode](https://github.com/uiuc-focal-lab/syncode): a library for context-free grammar guided generation (JSON, SQL, Python).
|
||||
- [Text Generation Inference](https://github.com/huggingface/text-generation-inference): a production-ready server for LLMs.
|
||||
- [Text generation web UI](https://github.com/oobabooga/text-generation-webui): a Gradio web UI for text generation.
|
||||
- [logits-processor-zoo](https://github.com/NVIDIA/logits-processor-zoo): additional logits processors for controlling text generation.
|
||||
- [logits-processor-zoo](https://github.com/NVIDIA/logits-processor-zoo): additional logits processors for controlling text generation.
|
||||
|
55
docs/source/en/main_classes/video_processor.md
Normal file
55
docs/source/en/main_classes/video_processor.md
Normal file
@ -0,0 +1,55 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
|
||||
# Video Processor
|
||||
|
||||
A **Video Processor** is a utility responsible for preparing input features for video models, as well as handling the post-processing of their outputs. It provides transformations such as resizing, normalization, and conversion into PyTorch.
|
||||
|
||||
The video processor extends the functionality of image processors by allowing Vision Large Language Models (VLMs) to handle videos with a distinct set of arguments compared to images. It serves as the bridge between raw video data and the model, ensuring that input features are optimized for the VLM.
|
||||
|
||||
When adding a new VLM or updating an existing one to enable distinct video preprocessing, saving and reloading the processor configuration will store the video related arguments in a dedicated file named `video_preprocessing_config.json`. Don't worry if you haven't upadted your VLM, the processor will try to load video related configurations from a file named `preprocessing_config.json`.
|
||||
|
||||
|
||||
### Usage Example
|
||||
Here's an example of how to load a video processor with [`llava-hf/llava-onevision-qwen2-0.5b-ov-hf`](https://huggingface.co/llava-hf/llava-onevision-qwen2-0.5b-ov-hf) model:
|
||||
|
||||
```python
|
||||
from transformers import AutoVideoProcessor
|
||||
|
||||
processor = AutoVideoProcessor.from_pretrained("llava-hf/llava-onevision-qwen2-0.5b-ov-hf")
|
||||
```
|
||||
|
||||
Currently, if using base image processor for videos, it processes video data by treating each frame as an individual image and applying transformations frame-by-frame. While functional, this approach is not highly efficient. Using `AutoVideoProcessor` allows us to take advantage of **fast video processors**, leveraging the [torchvision](https://pytorch.org/vision/stable/index.html) library. Fast processors handle the whole batch of videos at once, without iterating over each video or frame. These updates introduce GPU acceleration and significantly enhance processing speed, especially for tasks requiring high throughput.
|
||||
|
||||
Fast video processors are available for all models and are loaded by default when an `AutoVideoProcessor` is initialized. When using a fast video processor, you can also set the `device` argument to specify the device on which the processing should be done. By default, the processing is done on the same device as the inputs if the inputs are tensors, or on the CPU otherwise. For even more speed improvement, we can compile the processor when using 'cuda' as device.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from transformers.video_utils import load_video
|
||||
from transformers import AutoVideoProcessor
|
||||
|
||||
video = load_video("video.mp4")
|
||||
processor = AutoVideoProcessor.from_pretrained("llava-hf/llava-onevision-qwen2-0.5b-ov-hf", device="cuda")
|
||||
processor = torch.compile(processor)
|
||||
processed_video = processor(video, return_tensors="pt")
|
||||
```
|
||||
|
||||
|
||||
## BaseVideoProcessor
|
||||
|
||||
[[autodoc]] video_processing_utils.BaseVideoProcessor
|
||||
|
@ -57,6 +57,7 @@ This model was contributed by [lysandre](https://huggingface.co/lysandre). This
|
||||
- Embedding size E is different from hidden size H justified because the embeddings are context independent (one embedding vector represents one token), whereas hidden states are context dependent (one hidden state represents a sequence of tokens) so it's more logical to have H >> E. Also, the embedding matrix is large since it's V x E (V being the vocab size). If E < H, it has less parameters.
|
||||
- Layers are split in groups that share parameters (to save memory).
|
||||
Next sentence prediction is replaced by a sentence ordering prediction: in the inputs, we have two sentences A and B (that are consecutive) and we either feed A followed by B or B followed by A. The model must predict if they have been swapped or not.
|
||||
- The `head_mask` argument is ignored when using all attention implementation other than "eager". If you have a `head_mask` and want it to have effect, load the model with `XXXModel.from_pretrained(model_id, attn_implementation="eager")`
|
||||
|
||||
### Using Scaled Dot Product Attention (SDPA)
|
||||
|
||||
|
@ -74,6 +74,10 @@ Likewise, if your `NewModel` is a subclass of [`PreTrainedModel`], make sure its
|
||||
|
||||
[[autodoc]] AutoImageProcessor
|
||||
|
||||
## AutoVideoProcessor
|
||||
|
||||
[[autodoc]] AutoVideoProcessor
|
||||
|
||||
## AutoProcessor
|
||||
|
||||
[[autodoc]] AutoProcessor
|
||||
|
@ -55,6 +55,7 @@ This model was contributed by [sshleifer](https://huggingface.co/sshleifer). The
|
||||
* mask a span of k tokens with a single mask token (a span of 0 tokens is an insertion of a mask token)
|
||||
* permute sentences
|
||||
* rotate the document to make it start at a specific token
|
||||
- The `head_mask` argument is ignored when using all attention implementation other than "eager". If you have a `head_mask` and want it to have effect, load the model with `XXXModel.from_pretrained(model_id, attn_implementation="eager")`
|
||||
|
||||
## Implementation Notes
|
||||
|
||||
|
@ -150,6 +150,7 @@ If you're interested in submitting a resource to be included here, please feel f
|
||||
[[autodoc]] BeitImageProcessor
|
||||
- preprocess
|
||||
- post_process_semantic_segmentation
|
||||
|
||||
## BeitImageProcessorFast
|
||||
|
||||
[[autodoc]] BeitImageProcessorFast
|
||||
|
@ -36,6 +36,7 @@ This model was contributed by [kamalkraj](https://huggingface.co/kamalkraj). The
|
||||
- BioGPT is a model with absolute position embeddings so it's usually advised to pad the inputs on the right rather than the left.
|
||||
- BioGPT was trained with a causal language modeling (CLM) objective and is therefore powerful at predicting the next token in a sequence. Leveraging this feature allows BioGPT to generate syntactically coherent text as it can be observed in the run_generation.py example script.
|
||||
- The model can take the `past_key_values` (for PyTorch) as input, which is the previously computed key/value attention pairs. Using this (past_key_values or past) value prevents the model from re-computing pre-computed values in the context of text generation. For PyTorch, see past_key_values argument of the BioGptForCausalLM.forward() method for more information on its usage.
|
||||
- The `head_mask` argument is ignored when using all attention implementation other than "eager". If you have a `head_mask` and want it to have effect, load the model with `XXXModel.from_pretrained(model_id, attn_implementation="eager")`
|
||||
|
||||
### Using Scaled Dot Product Attention (SDPA)
|
||||
|
||||
|
@ -53,6 +53,7 @@ The original code for vision can be found [here](https://github.com/facebookrese
|
||||
- For Data2VecAudio, preprocessing is identical to [`Wav2Vec2Model`], including feature extraction
|
||||
- For Data2VecText, preprocessing is identical to [`RobertaModel`], including tokenization.
|
||||
- For Data2VecVision, preprocessing is identical to [`BeitModel`], including feature extraction.
|
||||
- The `head_mask` argument is ignored when using all attention implementation other than "eager". If you have a `head_mask` and want it to have effect, load the model with `XXXModel.from_pretrained(model_id, attn_implementation="eager")`
|
||||
|
||||
### Using Scaled Dot Product Attention (SDPA)
|
||||
|
||||
|
@ -46,8 +46,12 @@ The main differences compared to GPT2.
|
||||
- Merge the key and value caches into one (this changes the format of layer_past/ present, does it risk creating problems?)
|
||||
- Use the memory layout (self.num_heads, 3, self.head_dim) instead of `(3, self.num_heads, self.head_dim)` for the QKV tensor with MHA. (prevents an overhead with the merged key and values, but makes the checkpoints incompatible with the original openai-community/gpt2 model).
|
||||
|
||||
|
||||
You can read more about the optimizations in the [original pull request](https://github.com/huggingface/transformers/pull/22575)
|
||||
|
||||
> [!NOTE]
|
||||
> The `head_mask` argument is ignored when using all attention implementation other than "eager". If you have a `head_mask` and want it to have effect, load the model with `XXXModel.from_pretrained(model_id, attn_implementation="eager")`
|
||||
|
||||
## Combining Starcoder and Flash Attention 2
|
||||
|
||||
First, make sure to install the latest version of Flash Attention 2 to include the sliding window attention feature.
|
||||
|
@ -50,7 +50,7 @@ This model was contributed by [patrickvonplaten](https://huggingface.co/patrickv
|
||||
- Hubert is a speech model that accepts a float array corresponding to the raw waveform of the speech signal.
|
||||
- Hubert model was fine-tuned using connectionist temporal classification (CTC) so the model output has to be decoded
|
||||
using [`Wav2Vec2CTCTokenizer`].
|
||||
|
||||
- The `head_mask` argument is ignored when using all attention implementation other than "eager". If you have a `head_mask` and want it to have effect, load the model with `XXXModel.from_pretrained(model_id, attn_implementation="eager")`
|
||||
|
||||
## Using Flash Attention 2
|
||||
|
||||
|
@ -58,6 +58,12 @@ The attributes can be obtained from model config, as `model.config.num_query_tok
|
||||
|
||||
[[autodoc]] InstructBlipVideoProcessor
|
||||
|
||||
|
||||
## InstructBlipVideoVideoProcessor
|
||||
|
||||
[[autodoc]] InstructBlipVideoVideoProcessor
|
||||
- preprocess
|
||||
|
||||
## InstructBlipVideoImageProcessor
|
||||
|
||||
[[autodoc]] InstructBlipVideoImageProcessor
|
||||
|
@ -353,3 +353,7 @@ This example showcases how to handle a batch of chat conversations with interlea
|
||||
## InternVLProcessor
|
||||
|
||||
[[autodoc]] InternVLProcessor
|
||||
|
||||
## InternVLVideoProcessor
|
||||
|
||||
[[autodoc]] InternVLVideoProcessor
|
||||
|
@ -262,6 +262,10 @@ model = LlavaNextVideoForConditionalGeneration.from_pretrained(
|
||||
|
||||
[[autodoc]] LlavaNextVideoImageProcessor
|
||||
|
||||
## LlavaNextVideoVideoProcessor
|
||||
|
||||
[[autodoc]] LlavaNextVideoVideoProcessor
|
||||
|
||||
## LlavaNextVideoModel
|
||||
|
||||
[[autodoc]] LlavaNextVideoModel
|
||||
|
@ -303,6 +303,7 @@ model = LlavaOnevisionForConditionalGeneration.from_pretrained(
|
||||
## LlavaOnevisionImageProcessor
|
||||
|
||||
[[autodoc]] LlavaOnevisionImageProcessor
|
||||
- preprocess
|
||||
|
||||
## LlavaOnevisionImageProcessorFast
|
||||
|
||||
@ -313,6 +314,10 @@ model = LlavaOnevisionForConditionalGeneration.from_pretrained(
|
||||
|
||||
[[autodoc]] LlavaOnevisionVideoProcessor
|
||||
|
||||
## LlavaOnevisionVideoProcessor
|
||||
|
||||
[[autodoc]] LlavaOnevisionVideoProcessor
|
||||
|
||||
## LlavaOnevisionModel
|
||||
|
||||
[[autodoc]] LlavaOnevisionModel
|
||||
|
@ -51,6 +51,9 @@ multilingual it expects the sequences in a certain format: A special language id
|
||||
source and target text. The source text format is `[lang_code] X [eos]`, where `lang_code` is source language
|
||||
id for source text and target language id for target text, with `X` being the source or target text.
|
||||
|
||||
> [!NOTE]
|
||||
> The `head_mask` argument is ignored when using all attention implementation other than "eager". If you have a `head_mask` and want it to have effect, load the model with `XXXModel.from_pretrained(model_id, attn_implementation="eager")`
|
||||
|
||||
The [`M2M100Tokenizer`] depends on `sentencepiece` so be sure to install it before running the
|
||||
examples. To install `sentencepiece` run `pip install sentencepiece`.
|
||||
|
||||
|
@ -35,6 +35,9 @@ You can find all the original mBART checkpoints under the [AI at Meta](https://h
|
||||
> [!TIP]
|
||||
> Click on the mBART models in the right sidebar for more examples of applying mBART to different language tasks.
|
||||
|
||||
> [!NOTE]
|
||||
> The `head_mask` argument is ignored when using all attention implementation other than "eager". If you have a `head_mask` and want it to have effect, load the model with `XXXModel.from_pretrained(model_id, attn_implementation="eager")`
|
||||
|
||||
The example below demonstrates how to translate text with [`Pipeline`] or the [`AutoModel`] class.
|
||||
|
||||
<hfoptions id="usage">
|
||||
|
@ -62,6 +62,9 @@ python src/transformers/models/musicgen/convert_musicgen_transformers.py \
|
||||
--checkpoint small --pytorch_dump_folder /output/path --safe_serialization
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
> The `head_mask` argument is ignored when using all attention implementation other than "eager". If you have a `head_mask` and want it to have effect, load the model with `XXXModel.from_pretrained(model_id, attn_implementation="eager")`
|
||||
|
||||
## Generation
|
||||
|
||||
MusicGen is compatible with two generation modes: greedy and sampling. In practice, sampling leads to significantly
|
||||
|
@ -44,6 +44,9 @@ There are two key differences with MusicGen:
|
||||
1. The audio prompt is used here as a conditional signal for the generated audio sample, whereas it's used for audio continuation in [MusicGen](https://huggingface.co/docs/transformers/main/en/model_doc/musicgen).
|
||||
2. Conditional text and audio signals are concatenated to the decoder's hidden states instead of being used as a cross-attention signal, as in MusicGen.
|
||||
|
||||
> [!NOTE]
|
||||
> The `head_mask` argument is ignored when using all attention implementation other than "eager". If you have a `head_mask` and want it to have effect, load the model with `XXXModel.from_pretrained(model_id, attn_implementation="eager")`
|
||||
|
||||
## Generation
|
||||
|
||||
MusicGen Melody is compatible with two generation modes: greedy and sampling. In practice, sampling leads to significantly better results than greedy, thus we encourage sampling mode to be used where possible. Sampling is enabled by default, and can be explicitly specified by setting `do_sample=True` in the call to [`MusicgenMelodyForConditionalGeneration.generate`], or by overriding the model's generation config (see below).
|
||||
|
@ -41,6 +41,9 @@ Tips:
|
||||
- OPT has the same architecture as [`BartDecoder`].
|
||||
- Contrary to GPT2, OPT adds the EOS token `</s>` to the beginning of every prompt.
|
||||
|
||||
> [!NOTE]
|
||||
> The `head_mask` argument is ignored when using all attention implementation other than "eager". If you have a `head_mask` and want it to have effect, load the model with `XXXModel.from_pretrained(model_id, attn_implementation="eager")`
|
||||
|
||||
## Resources
|
||||
|
||||
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with OPT. If you're
|
||||
|
@ -40,6 +40,9 @@ The abstract from the paper is the following:
|
||||
|
||||
`Qwen2-Audio-7B` and `Qwen2-Audio-7B-Instruct` can be found on the [Huggingface Hub](https://huggingface.co/Qwen)
|
||||
|
||||
> [!NOTE]
|
||||
> The `head_mask` argument is ignored when using all attention implementation other than "eager". If you have a `head_mask` and want it to have effect, load the model with `XXXModel.from_pretrained(model_id, attn_implementation="eager")`
|
||||
|
||||
### Inference
|
||||
|
||||
```python
|
||||
|
@ -287,6 +287,11 @@ model = Qwen2VLForConditionalGeneration.from_pretrained(
|
||||
[[autodoc]] Qwen2VLImageProcessor
|
||||
- preprocess
|
||||
|
||||
## Qwen2VLVideoProcessor
|
||||
|
||||
[[autodoc]] Qwen2VLVideoProcessor
|
||||
- preprocess
|
||||
|
||||
## Qwen2VLImageProcessorFast
|
||||
|
||||
[[autodoc]] Qwen2VLImageProcessorFast
|
||||
|
@ -23,6 +23,7 @@ rendered properly in your Markdown viewer.
|
||||
">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
|
||||
## Overview
|
||||
|
||||
The RoBERTa model was proposed in [RoBERTa: A Robustly Optimized BERT Pretraining Approach](https://arxiv.org/abs/1907.11692) by Yinhan Liu, [Myle Ott](https://huggingface.co/myleott), Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer
|
||||
|
@ -43,8 +43,8 @@ import requests
|
||||
from transformers import SamHQModel, SamHQProcessor
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
model = SamHQModel.from_pretrained("sushmanth/sam_hq_vit_b").to(device)
|
||||
processor = SamHQProcessor.from_pretrained("sushmanth/sam_hq_vit_b")
|
||||
model = SamHQModel.from_pretrained("syscv-community/sam-hq-vit-base").to(device)
|
||||
processor = SamHQProcessor.from_pretrained("syscv-community/sam-hq-vit-base")
|
||||
|
||||
img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
|
||||
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
|
||||
@ -69,8 +69,8 @@ import requests
|
||||
from transformers import SamHQModel, SamHQProcessor
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
model = SamHQModel.from_pretrained("sushmanth/sam_hq_vit_b").to(device)
|
||||
processor = SamHQProcessor.from_pretrained("sushmanth/sam_hq_vit_b")
|
||||
model = SamHQModel.from_pretrained("syscv-community/sam-hq-vit-base").to(device)
|
||||
processor = SamHQProcessor.from_pretrained("syscv-community/sam-hq-vit-base")
|
||||
|
||||
img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
|
||||
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
|
||||
|
@ -46,6 +46,9 @@ This model was contributed by [anton-l](https://huggingface.co/anton-l).
|
||||
- SEWForCTC is fine-tuned using connectionist temporal classification (CTC) so the model output has to be decoded using
|
||||
[`Wav2Vec2CTCTokenizer`].
|
||||
|
||||
> [!NOTE]
|
||||
> The `head_mask` argument is ignored when using all attention implementation other than "eager". If you have a `head_mask` and want it to have effect, load the model with `XXXModel.from_pretrained(model_id, attn_implementation="eager")`
|
||||
|
||||
## Resources
|
||||
|
||||
- [Audio classification task guide](../tasks/audio_classification)
|
||||
|
@ -197,6 +197,9 @@ print(generated_texts[0])
|
||||
[[autodoc]] SmolVLMImageProcessor
|
||||
- preprocess
|
||||
|
||||
## SmolVLMVideoProcessor
|
||||
[[autodoc]] SmolVLMVideoProcessor
|
||||
- preprocess
|
||||
|
||||
## SmolVLMProcessor
|
||||
[[autodoc]] SmolVLMProcessor
|
||||
|
@ -54,6 +54,9 @@ found [here](https://github.com/microsoft/UniSpeech/tree/main/UniSpeech-SAT).
|
||||
decoded using [`Wav2Vec2CTCTokenizer`].
|
||||
- UniSpeechSat performs especially well on speaker verification, speaker identification, and speaker diarization tasks.
|
||||
|
||||
> [!NOTE]
|
||||
> The `head_mask` argument is ignored when using all attention implementation other than "eager". If you have a `head_mask` and want it to have effect, load the model with `XXXModel.from_pretrained(model_id, attn_implementation="eager")`
|
||||
|
||||
## Resources
|
||||
|
||||
- [Audio classification task guide](../tasks/audio_classification)
|
||||
|
@ -49,6 +49,9 @@ found [here](https://github.com/microsoft/UniSpeech/tree/main/UniSpeech).
|
||||
- UniSpeech model can be fine-tuned using connectionist temporal classification (CTC) so the model output has to be
|
||||
decoded using [`Wav2Vec2CTCTokenizer`].
|
||||
|
||||
> [!NOTE]
|
||||
> The `head_mask` argument is ignored when using all attention implementation other than "eager". If you have a `head_mask` and want it to have effect, load the model with `XXXModel.from_pretrained(model_id, attn_implementation="eager")`
|
||||
|
||||
## Resources
|
||||
|
||||
- [Audio classification task guide](../tasks/audio_classification)
|
||||
|
@ -211,6 +211,11 @@ model = VideoLlavaForConditionalGeneration.from_pretrained(
|
||||
|
||||
[[autodoc]] VideoLlavaImageProcessor
|
||||
|
||||
|
||||
## VideoLlavaVideoProcessor
|
||||
|
||||
[[autodoc]] VideoLlavaVideoProcessor
|
||||
|
||||
## VideoLlavaProcessor
|
||||
|
||||
[[autodoc]] VideoLlavaProcessor
|
||||
|
@ -72,6 +72,11 @@ This model was contributed by [nielsr](https://huggingface.co/nielsr). The origi
|
||||
[[autodoc]] ViltImageProcessor
|
||||
- preprocess
|
||||
|
||||
## ViltImageProcessorFast
|
||||
|
||||
[[autodoc]] ViltImageProcessorFast
|
||||
- preprocess
|
||||
|
||||
## ViltProcessor
|
||||
|
||||
[[autodoc]] ViltProcessor
|
||||
|
@ -50,6 +50,9 @@ Note: Meta (FAIR) released a new version of [Wav2Vec2-BERT 2.0](https://huggingf
|
||||
- Wav2Vec2 model was trained using connectionist temporal classification (CTC) so the model output has to be decoded
|
||||
using [`Wav2Vec2CTCTokenizer`].
|
||||
|
||||
> [!NOTE]
|
||||
> The `head_mask` argument is ignored when using all attention implementation other than "eager". If you have a `head_mask` and want it to have effect, load the model with `XXXModel.from_pretrained(model_id, attn_implementation="eager")`
|
||||
|
||||
## Using Flash Attention 2
|
||||
|
||||
Flash Attention 2 is an faster, optimized version of the model.
|
||||
|
@ -32,6 +32,9 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
You can find all the original Whisper checkpoints under the [Whisper](https://huggingface.co/collections/openai/whisper-release-6501bba2cf999715fd953013) collection.
|
||||
|
||||
> [!NOTE]
|
||||
> The `head_mask` argument is ignored when using all attention implementation other than "eager". If you have a `head_mask` and want it to have effect, load the model with `XXXModel.from_pretrained(model_id, attn_implementation="eager")`
|
||||
|
||||
> [!TIP]
|
||||
> Click on the Whisper models in the right sidebar for more examples of how to apply Whisper to different audio tasks.
|
||||
|
||||
|
@ -54,8 +54,8 @@ For each model type, there is a separate class for each machine learning framewo
|
||||
from transformers import AutoModelForCausalLM, MistralForCausalLM
|
||||
|
||||
# load with AutoClass or model-specific class
|
||||
model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", , torch_dtype="auto", device_map="auto")
|
||||
model = MistralForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", , torch_dtype="auto", device_map="auto")
|
||||
model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", torch_dtype="auto", device_map="auto")
|
||||
model = MistralForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", torch_dtype="auto", device_map="auto")
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
@ -272,6 +272,7 @@ Explicitly set the [torch_dtype](https://pytorch.org/docs/stable/tensor_attribut
|
||||
<hfoption id="specific dtype">
|
||||
|
||||
```py
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
gemma = AutoModelForCausalLM.from_pretrained("google/gemma-7b", torch_dtype=torch.float16)
|
||||
|
@ -13,9 +13,15 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# Distributed GPU inference
|
||||
# Tensor parallelism in transformers
|
||||
|
||||
[Tensor parallelism](./perf_train_gpu_many#tensor-parallelism) shards a model onto multiple GPUs and parallelizes computations such as matrix multiplication. It enables fitting larger model sizes into memory and is faster because each GPU can process a tensor slice.
|
||||
This document assumes that you are already familiar with the basics of tensor parallelism. If you are not, please refer to the [Ultra-Scale Playbook](https://huggingface.co/spaces/nanotron/ultrascale-playbook?section=tensor_parallelism) section on tensor parallelism.
|
||||
|
||||
> [!TIP]
|
||||
> Tensor parallelism is very communication intensive, therefore it is reccomended to use it on a single machine with multiple GPUs, utilizing fast intra-node communication. For multi-node training, methods as pipeline or data parallelism are more efficient (depending on your use case).
|
||||
|
||||
Tensor parallelism requires slight changes to the model parameters, therefore in transformers, we support some of the popular models out of the box.
|
||||
|
||||
> [!TIP]
|
||||
> Expand the list below to see which models support tensor parallelism. Open a GitHub issue or pull request to add support for a model not currently below.
|
||||
@ -37,9 +43,218 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
</details>
|
||||
|
||||
Set `tp_plan="auto"` in [`~AutoModel.from_pretrained`] to enable tensor parallelism for inference.
|
||||
## Using 🤗 transformers
|
||||
|
||||
```py
|
||||
Transformers provides a simple interface to use for tensor parallelism. We provide multiple classes implementing different partitioning
|
||||
strategies and a simple entrypoint to parallelize `nn.Module` instance. You won't have to interact with this interface directly, everything is done in `PretrainedModel.from_pretrained` method for you. This section will first talk about the partitioning strategies
|
||||
we support, then the user interface you will be interacting with, and finally it will teach you how to extend it with your own partitioning
|
||||
strategies.
|
||||
|
||||
### Partitioning strategies
|
||||
|
||||
In transformers, partitioning strategies reside in a class `ParallelInterface` which works like a mapping from string to the strategy implementation.
|
||||
|
||||
|
||||
```python
|
||||
class ParallelInterface(MutableMapping):
|
||||
"""
|
||||
Dict-like object keeping track of allowed attention functions. You can easily add a new attention function
|
||||
with a call to `register()`. If a model needs to locally overwrite an existing attention function, say `sdpa`,
|
||||
it needs to declare a new instance of this class inside the `modeling_<model>.py`, and declare it on that instance.
|
||||
"""
|
||||
_global_mapping = {
|
||||
"colwise": ColwiseParallel(),
|
||||
"rowwise": RowwiseParallel(),
|
||||
"colwise_rep": ColwiseParallel(output_layouts=Replicate()),
|
||||
"rowwise_rep": RowwiseParallel(input_layouts=Replicate()),
|
||||
"local_colwise": ColwiseParallel(use_dtensor=False),
|
||||
"local_rowwise": RowwiseParallel(use_dtensor=False),
|
||||
"local": IsolatedParallel(),
|
||||
"gather": GatherParallel(),
|
||||
"local_packed_rowwise": PackedRowwiseParallel(use_dtensor=False),
|
||||
"sequence_parallel": SequenceParallel(),
|
||||
"replicate": ReplicateParallel(),
|
||||
}
|
||||
```
|
||||
|
||||
We support the following strategies:
|
||||
|
||||
- `ColwiseParallel` - A simple column-wise partitioning, being able to handle both weights and biases, does exactly what we've discussed before.
|
||||
- `RowwiseParallel` - Again, row-wise partitioning as dicussed before, supports weights and biases, on top of that it also supports `nn.Embedding` modules.
|
||||
- `SequenceParallel` - Sequence parallel implementation, for support of `LayerNorm` and `Dropout` layers. Also supports Python implementation of `RMSNorm` (see [this](https://github.com/facebookresearch/llama/blob/main/llama/model.py#L34))
|
||||
- `PackedColwiseParallel` - A variant of column-wise partitioning, however it works on packed weights (i.e. `up_proj` and `gate_proj` being packed together). For more details, see [this comment](https://github.com/huggingface/transformers/blob/main/src/transformers/integrations/tensor_parallel.py#L79-#L108)
|
||||
- `PackedRowwiseParallel` - A variant of row-wise partitioning, works on packed weights, for more details check the comment linked above.
|
||||
- `GatherParallel` - A very simple class, that only makes the outputs of the module to be gathered across devices.
|
||||
- `IsolatedParallel` - This is a special case, where we want to *isolate* the module from the rest of the devices (world). This is used for Experts in MoE layers, basically creating Expert parallelism of sorts.
|
||||
- `ReplicateParallel` - Many `torch.distributed` APIs break if model is partially sharded, so this class is used to replicate the module across all devices.
|
||||
|
||||
### Sharding a model
|
||||
|
||||
We provide two ways to shard a model, first one is to use `auto` tensor parallelism plan, which will automatically shard the model based on our predefined configuration. This requires the model to have predefined tensor parallel plan in transformers.
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
# model_id = "meta-llama/Meta-Llama-3-8B-Instruct" # better for smaller number of GPUs
|
||||
model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct" # better to visualize all the possible strategies
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, tp_plan="auto")
|
||||
|
||||
print(model._tp_plan)
|
||||
```
|
||||
|
||||
> [!TIP]
|
||||
> For a list of models that support tensor parallelism, see the [Supported models](#supported-models) section above.
|
||||
|
||||
The second way is to manually specify your own partitioning plan.
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
tp_plan = {
|
||||
"model.layers.*.self_attn.q_proj": "colwise",
|
||||
"model.layers.*.self_attn.k_proj": "colwise",
|
||||
"model.layers.*.self_attn.v_proj": "colwise",
|
||||
"model.layers.*.self_attn.o_proj": "rowwise",
|
||||
...
|
||||
}
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, tp_plan=tp_plan)
|
||||
|
||||
print(model._tp_plan)
|
||||
```
|
||||
|
||||
You might have noticed that there are some special cases in the `ParallelInterface` mapping, let's now talk about them. This will help you understand their purpose and help with extending to other strategies.
|
||||
|
||||
### PackedRowwiseParallel
|
||||
This class is a special case of `RowwiseParallel`, it's used to shard packed weights. Weight packing is a common technique used in models. It's a technique where we pack multiple linear layers into a single, bigger one.
|
||||
|
||||
For example in `Llama4` model, we pack `up_proj` and `gate_proj` into a single `gate_up_proj` module.
|
||||
```python
|
||||
class Llama4TextExperts(nn.Module):
|
||||
...
|
||||
self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, 2 * self.expert_dim))
|
||||
```
|
||||
|
||||
Then in forward, we can use batch matrix multiplication to compute the output of the `gate_up_proj` module.
|
||||
|
||||
```python
|
||||
def forward(self, hidden_states):
|
||||
...
|
||||
gate_up = torch.bmm(hidden_states, self.gate_up_proj) # Compute the output of the gate_up_proj module
|
||||
gate, up = gate_up.chunk(2, dim=-1) # Split the output into gate and up
|
||||
```
|
||||
|
||||
In this case, we need to use the `PackedRowwiseParallel` strategy to shard the `gate_up_proj` module, as using a simple `RowwiseParallel` will shard the layers wrongly.
|
||||
|
||||
> [!TIP]
|
||||
> If this is a bit difficult to wrap your head around, check out [this comment](https://github.com/huggingface/transformers/blob/main/src/transformers/integrations/tensor_parallel.py#L79-#L108) for an amazing visual representation of why `Packed*` needs to be used.
|
||||
|
||||
|
||||
### `local*` strategies
|
||||
|
||||
You could have noticed that there are `local*` strategies, which use the same layers as `*` strategy, but don't use `DTensor` at all.
|
||||
This is because `DTensor` is not supported for some of the operations: such as `torch.chunk`. Therefore, sometimes we need to use the `local*` strategies, which use vanilla `torch.Tensor` and do some of the distributed logic manually.
|
||||
|
||||
<!---
|
||||
Readd this when I get the exact error message
|
||||
> [!TIP]
|
||||
> If you are using a custom partitioning strategy, and it's not working with `... is not supported` error, try using the `local*` strategies to see if they work better.
|
||||
-->
|
||||
|
||||
> [!WARNING]
|
||||
> Manually specifying your own partitiong plan requires a good understanding of the model architecture and how the partitioning strategies interact together. If you are not sure about this, the resulting model can be very slow, even failing or incorrect. Again, refer to the [Ultra-Scale Playbook](https://huggingface.co/spaces/nanotron/ultrascale-playbook?section=tensor_parallelism) which can teach you everything required.
|
||||
|
||||
### Extending the interface with your own partitioning strategies
|
||||
|
||||
This is a very advanced topic, which requires a good understanding of distributed collectives and the model architecture.
|
||||
Your custom partitioning strategy should inherit from `TensorParallelLayer` defined in [integrations/tensor_parallel.py](https://github.com/huggingface/transformers/blob/main/src/transformers/integrations/tensor_parallel.py) and implement: `partition_tensor`, `_prepare_input_fn` and `_prepare_output_fn`. Then it should be registered in the `ParallelInterface` mapping, so our dispatching logic can find it when specified in the `tp_plan`.
|
||||
|
||||
Let's go through this workflow step by step, on an already existing example: `ColwiseParallel`.
|
||||
|
||||
1. Inherit from `TensorParallelLayer` and initialization
|
||||
|
||||
```python
|
||||
class ColwiseParallel(TensorParallelLayer):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
input_layouts: Optional[Placement] = None, # The input layout coming from the previous layer
|
||||
output_layouts: Optional[Placement] = None, # The output layout we want to achieve
|
||||
use_local_output: bool = True, # Whether to use local output or not
|
||||
use_dtensor=True, # Whether to use DTensor or not
|
||||
):
|
||||
self.input_layouts = (input_layouts or Replicate(),) # The input sharding coming from the previous layer
|
||||
self.output_layouts = (output_layouts or Shard(-1),) # Desired output sharding
|
||||
self.desired_input_layouts = (Replicate(),) # Desired input sharding, inputs should be replicated across GPUs
|
||||
self.use_local_output = use_local_output
|
||||
self.use_dtensor = use_dtensor
|
||||
```
|
||||
|
||||
In the `__init__` method, we define these attributes, where `input_layouts` and `output_layouts` describing, how the input and output tensors should be placed on the devices. `desired_input_layouts` is used to specify, how the input *SHOULD* be placed on the devices.
|
||||
|
||||
2a. Implement `partition_tensor` method
|
||||
|
||||
```python
|
||||
def partition_tensor(
|
||||
self,
|
||||
param, # Full tensor of the parameter
|
||||
empty_param, # Empty tensor of the parameter, will be filled with the partitioned tensor
|
||||
param_type, # Type of the parameter, `bias` or `weight`
|
||||
param_casting_dtype, # The type to cast the parameter to
|
||||
to_contiguous, # Whether to convert the tensor to a contiguous memory layout
|
||||
rank, # The rank of the current device
|
||||
device_mesh, # The device mesh
|
||||
) -> nn.Parameter: # Return the partitioned parameter
|
||||
...
|
||||
```
|
||||
|
||||
This method is used to partition the tensor, and fill the `empty_param` with the partitioned tensor.
|
||||
We provide some utility functions to help you with this, such as `get_tensor_shard` which will get you the correct shard of the original parameter for this rank or `get_packed_weights` to help with packed weights.
|
||||
|
||||
2b. Implement `_prepare_input_fn` and `_prepare_output_fn` methods
|
||||
|
||||
These methods are used as [`pre-forward`](https://docs.pytorch.org/docs/stable/generated/torch.nn.modules.module.register_module_forward_pre_hook.html) and [`forward`](https://docs.pytorch.org/docs/stable/generated/torch.nn.modules.module.register_module_forward_hook.html) hooks respectively. Their purpose is to re-distribute the inputs and outputs to the desired layout, passed in the `__init__` method.
|
||||
|
||||
```python
|
||||
def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh):
|
||||
...
|
||||
# Do some custom logic, cast to DTensor etc.
|
||||
...
|
||||
return inputs.redistribute(placements=desired_input_layouts, device_mesh=device_mesh)
|
||||
|
||||
def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
|
||||
...
|
||||
# Do some custom logic, cast to DTensor etc.
|
||||
...
|
||||
return outputs.redistribute(placements=output_layouts, device_mesh=device_mesh)
|
||||
```
|
||||
|
||||
3. Register the strategy
|
||||
Congratulations! You've implemented your own partitioning strategy. Now, to use it with your own `tp_plan`, you need to register it in the `ParallelInterface` mapping.
|
||||
|
||||
```python
|
||||
from transformers.integrations.tensor_parallel import ParallelInterface
|
||||
|
||||
ParallelInterface.register_strategy("colwise_custom", ColwiseParallel)
|
||||
```
|
||||
|
||||
And now you can use it in your `tp_plan` as such:
|
||||
|
||||
```python
|
||||
tp_plan = {
|
||||
"model.layers.*.self_attn.q_proj": "colwise_custom",
|
||||
...
|
||||
}
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, tp_plan=tp_plan)
|
||||
```
|
||||
|
||||
|
||||
## Full example
|
||||
|
||||
Let's go through a full example of inference with tensor parallelism.
|
||||
```python
|
||||
import os
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
@ -66,17 +281,49 @@ Launch the inference script above on [torchrun](https://pytorch.org/docs/stable/
|
||||
torchrun --nproc-per-node 4 demo.py
|
||||
```
|
||||
|
||||
For CPU, please binding different socket on each rank. For example, if you are using Intel 4th Gen Xeon:
|
||||
```bash
|
||||
export OMP_NUM_THREADS=56
|
||||
numactl -C 0-55 -m 0 torchrun --nnodes=2 --node_rank=0 --master_addr="127.0.0.1" --master_port=29500 --nproc-per-node 1 demo.py & numactl -C 56-111 -m 1 torchrun --nnodes=2 --node_rank=1 --master_addr="127.0.0.1" --master_port=29500 --nproc-per-node 1 demo.py & wait
|
||||
```
|
||||
The CPU benchmark data will be released soon.
|
||||
|
||||
You can benefit from considerable speed ups for inference, especially for inputs with large batch size or long sequences.
|
||||
|
||||
For a single forward pass on [Llama](./model_doc/llama) with a sequence length of 512 and various batch sizes, you can expect the following speed ups.
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/Meta-Llama-3-8B-Instruct%2C%20seqlen%20%3D%20512%2C%20python%2C%20w_%20compile.png">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/Meta-Llama-3-8B-Instruct%2C%20seqlen%20%3D%20512%2C%20python%2C%20w_%20compile.png">
|
||||
</div>
|
||||
|
||||
## Tensor parallelism in-depth
|
||||
Our implementation of tensor parallelism is framework-agnostic in design, but the specific implementations we've developed rely on the torch.distributed package. We heavily utilize abstractions such as `DeviceMesh` or `DTensor` to provide a simple and extensible interface to the user.
|
||||
|
||||
### DeviceMesh
|
||||
Imagine `DeviceMesh` as a multi-dimensional grid of devices that communicate together. Different parallelization strategies require different types of communication patterns, therefore we can create a `DeviceMesh` with multiple submeshes:
|
||||
```python
|
||||
from torch.distributed.device_mesh import init_device_mesh
|
||||
|
||||
# Create a 1D mesh of 4 GPUs
|
||||
device_mesh = init_device_mesh("cuda", (4,), mesh_dim_names=["tp"])
|
||||
```
|
||||
Then, most of the `torch.distributed` defined parallelization strategies can be applied to a mesh itself, or its submesh, automatically handling the communication patterns.
|
||||
|
||||
### DTensor
|
||||
|
||||
Abbreviation for Distributed Tensor, `DTensor` is a tensor subclass that handles the distributed logic on-top of the usual tensor operations. Most of the model weights in case of tensor parallelism are stored as `DTensor`s (with some exceptions, more on that later).
|
||||
The most important part of DTensor, that is crucial to understand, is the `placement` attribute. It's an attribute that tells PyTorch how is the tensor placed on the devices of the `DeviceMesh`.
|
||||
|
||||
It can have the following values:
|
||||
|
||||
- `Shard(dimension)` - Annotates that this `DTensor` is sharded across a given dimension, over the `DeviceMesh` it was constructed under. For example, if we would like to shard weights for column-wise partitioning, we would do:
|
||||
```python
|
||||
weight = ...
|
||||
weight = DTensor.from_local(weight, device_mesh["tp"], placements=[Shard(0)]) # Shard across the 1st (column-wise) dimension
|
||||
bias = ...
|
||||
bias = DTensor.from_local(bias, device_mesh["tp"], placements=[Shard(-1)]) # Shard across the ONLY dimension
|
||||
```
|
||||
|
||||
To give another example, for row-wise partitioning, we would do:
|
||||
```python
|
||||
weight = ...
|
||||
weight = DTensor.from_local(weight, device_mesh["tp"], placements=[Shard(1)]) # Shard across the 2nd (row-wise) dimension
|
||||
bias = ...
|
||||
bias = DTensor.from_local(bias, device_mesh["tp"], placements=[Replicate()]) # Replicate bias across all GPUs
|
||||
```
|
||||
|
||||
- `Replicate()` - Annotates that this `DTensor` is replicated across the `DeviceMesh`. Very straight-forward, only creates a full copy of the tensor on each device.
|
||||
- `Partial()` - This placement is mostly of no interest to us, it's used to annotate that this tensor is pending a reduction operation.
|
||||
|
@ -106,6 +106,8 @@ dataset[0]["text"]
|
||||
Remember to resample the sampling rate to match the pretrained models required sampling rate.
|
||||
|
||||
```py
|
||||
from datasets import Audio
|
||||
|
||||
dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
|
||||
```
|
||||
|
||||
|
@ -372,14 +372,14 @@ accelerate launch \
|
||||
|
||||
### torch.compile
|
||||
|
||||
[torch.compile](./perf_torch_compile) can significantly speed up training and reduce computational overhead. Configure your torch.compile settings in [`TrainingArguments`]. Set `torch.compile` to `True`, and select a backend and compile mode.
|
||||
[torch.compile](./perf_torch_compile) can significantly speed up training and reduce computational overhead. Configure your torch.compile settings in [`TrainingArguments`]. Set `torch_compile` to `True`, and select a backend and compile mode.
|
||||
|
||||
```py
|
||||
from transformers import TrainingArguments
|
||||
|
||||
training_args = TrainingArguments(
|
||||
torch.compile=True,
|
||||
torch.compile_backend="inductor",
|
||||
torch_compile=True,
|
||||
torch_compile_backend="inductor",
|
||||
torch_compile_mode="default",
|
||||
...,
|
||||
)
|
||||
|
49
docs/source/en/video_processors.md
Normal file
49
docs/source/en/video_processors.md
Normal file
@ -0,0 +1,49 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
|
||||
# Video Processor
|
||||
|
||||
A **Video Processor** is a utility responsible for preparing input features for video models, as well as handling the post-processing of their outputs. It provides transformations such as resizing, normalization, and conversion into PyTorch.
|
||||
|
||||
The video processor extends the functionality of image processors by allowing the models to handle videos with a distinct set of arguments compared to images. It serves as the bridge between raw video data and the model, ensuring that input features are optimized for the VLM.
|
||||
|
||||
Use [`~BaseVideoProcessor.from_pretrained`] to load a video processors configuration (image size, whether to normalize and rescale, etc.) from a video model on the Hugging Face [Hub](https://hf.co) or local directory. The configuration for each pretrained model should be saved in a [video_preprocessor_config.json] file but older models might have the config saved in [preprocessor_config.json](https://huggingface.co/llava-hf/llava-onevision-qwen2-0.5b-ov-hf/blob/main/preprocessor_config.json) file. Note that the latter is less preferred and will be removed in the future.
|
||||
|
||||
|
||||
### Usage Example
|
||||
Here's an example of how to load a video processor with [`llava-hf/llava-onevision-qwen2-0.5b-ov-hf`](https://huggingface.co/llava-hf/llava-onevision-qwen2-0.5b-ov-hf) model:
|
||||
|
||||
```python
|
||||
from transformers import AutoVideoProcessor
|
||||
|
||||
processor = AutoVideoProcessor.from_pretrained("llava-hf/llava-onevision-qwen2-0.5b-ov-hf")
|
||||
```
|
||||
|
||||
Currently, if using base image processor for videos, it processes video data by treating each frame as an individual image and applying transformations frame-by-frame. While functional, this approach is not highly efficient. Using `AutoVideoProcessor` allows us to take advantage of **fast video processors**, leveraging the [torchvision](https://pytorch.org/vision/stable/index.html) library. Fast processors handle the whole batch of videos at once, without iterating over each video or frame. These updates introduce GPU acceleration and significantly enhance processing speed, especially for tasks requiring high throughput.
|
||||
|
||||
Fast video processors are available for all models and are loaded by default when an `AutoVideoProcessor` is initialized. When using a fast video processor, you can also set the `device` argument to specify the device on which the processing should be done. By default, the processing is done on the same device as the inputs if the inputs are tensors, or on the CPU otherwise. For even more speed improvement, we can compile the processor when using 'cuda' as device.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from transformers.video_utils import load_video
|
||||
from transformers import AutoVideoProcessor
|
||||
|
||||
video = load_video("video.mp4")
|
||||
processor = AutoVideoProcessor.from_pretrained("llava-hf/llava-onevision-qwen2-0.5b-ov-hf", device="cuda")
|
||||
processor = torch.compile(processor)
|
||||
processed_video = processor(video, return_tensors="pt")
|
||||
```
|
@ -105,6 +105,7 @@ BEiT の使用を開始するのに役立つ公式 Hugging Face およびコミ
|
||||
|
||||
[[autodoc]] BeitImageProcessor
|
||||
- preprocess
|
||||
- post_process_semantic_segmentation
|
||||
|
||||
## BeitImageProcessorFast
|
||||
|
||||
|
@ -157,5 +157,8 @@
|
||||
title: 通用工具
|
||||
- local: internal/time_series_utils
|
||||
title: 时序数据工具
|
||||
- sections:
|
||||
- local: model_doc/bert
|
||||
title: BERT
|
||||
title: 内部辅助工具
|
||||
title: 应用程序接口 (API)
|
||||
title: 应用程序接口 (API)
|
258
docs/source/zh/model_doc/bert.md
Normal file
258
docs/source/zh/model_doc/bert.md
Normal file
@ -0,0 +1,258 @@
|
||||
<!--Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
<div style="float: right;">
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="TensorFlow" src="https://img.shields.io/badge/TensorFlow-FF6F00?style=flat&logo=tensorflow&logoColor=white">
|
||||
<img alt="Flax" src="https://img.shields.io/badge/Flax-29a79b.svg?style=flat&logo=
|
||||
">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
</div>
|
||||
|
||||
# BERT
|
||||
|
||||
[BERT](https://huggingface.co/papers/1810.04805) 是一个在无标签的文本数据上预训练的双向 transformer,用于预测句子中被掩码的(masked) token,以及预测一个句子是否跟随在另一个句子之后。其主要思想是,在预训练过程中,通过随机掩码一些 token,让模型利用左右上下文的信息预测它们,从而获得更全面深入的理解。此外,BERT 具有很强的通用性,其学习到的语言表示可以通过额外的层或头进行微调,从而适配其他下游 NLP 任务。
|
||||
|
||||
你可以在 [BERT](https://huggingface.co/collections/google/bert-release-64ff5e7a4be99045d1896dbc) 集合下找到 BERT 的所有原始 checkpoint。
|
||||
|
||||
> [!TIP]
|
||||
> 点击右侧边栏中的 BERT 模型,以查看将 BERT 应用于不同语言任务的更多示例。
|
||||
|
||||
下面的示例演示了如何使用 [`Pipeline`], [`AutoModel`] 和命令行预测 `[MASK]` token。
|
||||
|
||||
<hfoptions id="usage">
|
||||
<hfoption id="Pipeline">
|
||||
|
||||
```py
|
||||
import torch
|
||||
from transformers import pipeline
|
||||
|
||||
pipeline = pipeline(
|
||||
task="fill-mask",
|
||||
model="google-bert/bert-base-uncased",
|
||||
torch_dtype=torch.float16,
|
||||
device=0
|
||||
)
|
||||
pipeline("Plants create [MASK] through a process known as photosynthesis.")
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="AutoModel">
|
||||
|
||||
```py
|
||||
import torch
|
||||
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"google-bert/bert-base-uncased",
|
||||
)
|
||||
model = AutoModelForMaskedLM.from_pretrained(
|
||||
"google-bert/bert-base-uncased",
|
||||
torch_dtype=torch.float16,
|
||||
device_map="auto",
|
||||
attn_implementation="sdpa"
|
||||
)
|
||||
inputs = tokenizer("Plants create [MASK] through a process known as photosynthesis.", return_tensors="pt").to("cuda")
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
predictions = outputs.logits
|
||||
|
||||
masked_index = torch.where(inputs['input_ids'] == tokenizer.mask_token_id)[1]
|
||||
predicted_token_id = predictions[0, masked_index].argmax(dim=-1)
|
||||
predicted_token = tokenizer.decode(predicted_token_id)
|
||||
|
||||
print(f"The predicted token is: {predicted_token}")
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="transformers-cli">
|
||||
|
||||
```bash
|
||||
echo -e "Plants create [MASK] through a process known as photosynthesis." | transformers-cli run --task fill-mask --model google-bert/bert-base-uncased --device 0
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
## 注意
|
||||
|
||||
- 输入内容应在右侧进行填充,因为 BERT 使用绝对位置嵌入。
|
||||
## BertConfig
|
||||
|
||||
[[autodoc]] BertConfig
|
||||
- all
|
||||
|
||||
## BertTokenizer
|
||||
|
||||
[[autodoc]] BertTokenizer
|
||||
- build_inputs_with_special_tokens
|
||||
- get_special_tokens_mask
|
||||
- create_token_type_ids_from_sequences
|
||||
- save_vocabulary
|
||||
|
||||
## BertTokenizerFast
|
||||
|
||||
[[autodoc]] BertTokenizerFast
|
||||
|
||||
## BertModel
|
||||
|
||||
[[autodoc]] BertModel
|
||||
- forward
|
||||
|
||||
## BertForPreTraining
|
||||
|
||||
[[autodoc]] BertForPreTraining
|
||||
- forward
|
||||
|
||||
## BertLMHeadModel
|
||||
|
||||
[[autodoc]] BertLMHeadModel
|
||||
- forward
|
||||
|
||||
## BertForMaskedLM
|
||||
|
||||
[[autodoc]] BertForMaskedLM
|
||||
- forward
|
||||
|
||||
## BertForNextSentencePrediction
|
||||
|
||||
[[autodoc]] BertForNextSentencePrediction
|
||||
- forward
|
||||
|
||||
## BertForSequenceClassification
|
||||
|
||||
[[autodoc]] BertForSequenceClassification
|
||||
- forward
|
||||
|
||||
## BertForMultipleChoice
|
||||
|
||||
[[autodoc]] BertForMultipleChoice
|
||||
- forward
|
||||
|
||||
## BertForTokenClassification
|
||||
|
||||
[[autodoc]] BertForTokenClassification
|
||||
- forward
|
||||
|
||||
## BertForQuestionAnswering
|
||||
|
||||
[[autodoc]] BertForQuestionAnswering
|
||||
- forward
|
||||
|
||||
## TFBertTokenizer
|
||||
|
||||
[[autodoc]] TFBertTokenizer
|
||||
|
||||
## TFBertModel
|
||||
|
||||
[[autodoc]] TFBertModel
|
||||
- call
|
||||
|
||||
## TFBertForPreTraining
|
||||
|
||||
[[autodoc]] TFBertForPreTraining
|
||||
- call
|
||||
|
||||
## TFBertModelLMHeadModel
|
||||
|
||||
[[autodoc]] TFBertLMHeadModel
|
||||
- call
|
||||
|
||||
## TFBertForMaskedLM
|
||||
|
||||
[[autodoc]] TFBertForMaskedLM
|
||||
- call
|
||||
|
||||
## TFBertForNextSentencePrediction
|
||||
|
||||
[[autodoc]] TFBertForNextSentencePrediction
|
||||
- call
|
||||
|
||||
## TFBertForSequenceClassification
|
||||
|
||||
[[autodoc]] TFBertForSequenceClassification
|
||||
- call
|
||||
|
||||
## TFBertForMultipleChoice
|
||||
|
||||
[[autodoc]] TFBertForMultipleChoice
|
||||
- call
|
||||
|
||||
## TFBertForTokenClassification
|
||||
|
||||
[[autodoc]] TFBertForTokenClassification
|
||||
- call
|
||||
|
||||
## TFBertForQuestionAnswering
|
||||
|
||||
[[autodoc]] TFBertForQuestionAnswering
|
||||
- call
|
||||
|
||||
## FlaxBertModel
|
||||
|
||||
[[autodoc]] FlaxBertModel
|
||||
- __call__
|
||||
|
||||
## FlaxBertForPreTraining
|
||||
|
||||
[[autodoc]] FlaxBertForPreTraining
|
||||
- __call__
|
||||
|
||||
## FlaxBertForCausalLM
|
||||
|
||||
[[autodoc]] FlaxBertForCausalLM
|
||||
- __call__
|
||||
|
||||
## FlaxBertForMaskedLM
|
||||
|
||||
[[autodoc]] FlaxBertForMaskedLM
|
||||
- __call__
|
||||
|
||||
## FlaxBertForNextSentencePrediction
|
||||
|
||||
[[autodoc]] FlaxBertForNextSentencePrediction
|
||||
- __call__
|
||||
|
||||
## FlaxBertForSequenceClassification
|
||||
|
||||
[[autodoc]] FlaxBertForSequenceClassification
|
||||
- __call__
|
||||
|
||||
## FlaxBertForMultipleChoice
|
||||
|
||||
[[autodoc]] FlaxBertForMultipleChoice
|
||||
- __call__
|
||||
|
||||
## FlaxBertForTokenClassification
|
||||
|
||||
[[autodoc]] FlaxBertForTokenClassification
|
||||
- __call__
|
||||
|
||||
## FlaxBertForQuestionAnswering
|
||||
|
||||
[[autodoc]] FlaxBertForQuestionAnswering
|
||||
- __call__
|
||||
|
||||
## Bert specific outputs
|
||||
|
||||
[[autodoc]] models.bert.modeling_bert.BertForPreTrainingOutput
|
||||
|
||||
[[autodoc]] models.bert.modeling_tf_bert.TFBertForPreTrainingOutput
|
||||
|
||||
[[autodoc]] models.bert.modeling_flax_bert.FlaxBertForPreTrainingOutput
|
@ -21,6 +21,7 @@ from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
import aiohttp
|
||||
import datasets
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
@ -454,6 +455,7 @@ def main():
|
||||
split=train_split_name,
|
||||
cache_dir=args.cache_dir,
|
||||
trust_remote_code=args.trust_remote_code,
|
||||
storage_options={"client_kwargs": {"timeout": aiohttp.ClientTimeout(total=60 * 60)}},
|
||||
)
|
||||
datasets_splits.append(dataset_split)
|
||||
|
||||
|
@ -34,7 +34,6 @@ from transformers import (
|
||||
GPT2Tokenizer,
|
||||
GPTJForCausalLM,
|
||||
LlamaForCausalLM,
|
||||
LlamaTokenizer,
|
||||
OpenAIGPTLMHeadModel,
|
||||
OpenAIGPTTokenizer,
|
||||
OPTForCausalLM,
|
||||
@ -63,7 +62,7 @@ MODEL_CLASSES = {
|
||||
"xlm": (XLMWithLMHeadModel, XLMTokenizer),
|
||||
"gptj": (GPTJForCausalLM, AutoTokenizer),
|
||||
"bloom": (BloomForCausalLM, BloomTokenizerFast),
|
||||
"llama": (LlamaForCausalLM, LlamaTokenizer),
|
||||
"llama": (LlamaForCausalLM, AutoTokenizer),
|
||||
"opt": (OPTForCausalLM, GPT2Tokenizer),
|
||||
}
|
||||
|
||||
|
@ -276,6 +276,7 @@ _import_structure = {
|
||||
"TorchAoConfig",
|
||||
"VptqConfig",
|
||||
],
|
||||
"video_utils": [],
|
||||
}
|
||||
|
||||
# tokenizers-backed objects
|
||||
@ -334,6 +335,7 @@ except OptionalDependencyNotAvailable:
|
||||
]
|
||||
else:
|
||||
_import_structure["image_processing_utils_fast"] = ["BaseImageProcessorFast"]
|
||||
_import_structure["video_processing_utils"] = ["BaseVideoProcessor"]
|
||||
|
||||
# PyTorch-backed objects
|
||||
try:
|
||||
@ -809,6 +811,7 @@ if TYPE_CHECKING:
|
||||
from .utils.dummy_torchvision_objects import *
|
||||
else:
|
||||
from .image_processing_utils_fast import BaseImageProcessorFast
|
||||
from .video_processing_utils import BaseVideoProcessor
|
||||
|
||||
try:
|
||||
if not (is_torchvision_available() and is_timm_available()):
|
||||
|
@ -21,6 +21,104 @@ if is_hqq_available():
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# Utility functions for static/sliding cache update logic
|
||||
def _static_cache_update(
|
||||
k_cache: torch.Tensor,
|
||||
v_cache: torch.Tensor,
|
||||
key_states: torch.Tensor,
|
||||
value_states: torch.Tensor,
|
||||
cache_position: Optional[torch.LongTensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Updates the static cache tensors in place.
|
||||
|
||||
Args:
|
||||
k_cache (`torch.Tensor`): The key cache tensor to update.
|
||||
v_cache (`torch.Tensor`): The value cache tensor to update.
|
||||
key_states (`torch.Tensor`): The new key states to add.
|
||||
value_states (`torch.Tensor`): The new value states to add.
|
||||
cache_position (`Optional[torch.LongTensor]`): The position indices where the new states should be inserted.
|
||||
If None, the entire cache is overwritten (prefill).
|
||||
|
||||
Returns:
|
||||
Tuple[`torch.Tensor`, `torch.Tensor`]: The updated key and value cache tensors (modified in-place).
|
||||
"""
|
||||
if cache_position is None:
|
||||
# Prefill phase where seq_len potentially equals max_cache_len. Directly copy.
|
||||
k_cache.copy_(key_states)
|
||||
v_cache.copy_(value_states)
|
||||
else:
|
||||
# Generation phase. Update specific positions.
|
||||
# Use index_copy_ for in-place update (compile-friendly).
|
||||
try:
|
||||
k_cache.index_copy_(2, cache_position, key_states)
|
||||
v_cache.index_copy_(2, cache_position, value_states)
|
||||
except NotImplementedError:
|
||||
# Fallback for devices like MPS where index_copy_ might not be supported.
|
||||
k_cache[:, :, cache_position] = key_states
|
||||
v_cache[:, :, cache_position] = value_states
|
||||
return k_cache, v_cache
|
||||
|
||||
|
||||
def _sliding_cache_update(
|
||||
k_cache: torch.Tensor,
|
||||
v_cache: torch.Tensor,
|
||||
key_states: torch.Tensor,
|
||||
value_states: torch.Tensor,
|
||||
cache_position: torch.LongTensor,
|
||||
max_cache_len: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Updates the sliding window cache tensors, returning the potentially modified tensors.
|
||||
|
||||
Args:
|
||||
k_cache (`torch.Tensor`): The key cache tensor to update.
|
||||
v_cache (`torch.Tensor`): The value cache tensor to update.
|
||||
key_states (`torch.Tensor`): The new key states to add.
|
||||
value_states (`torch.Tensor`): The new value states to add.
|
||||
cache_position (`torch.LongTensor`): The position indices where the new states should be inserted.
|
||||
max_cache_len (`int`): The maximum length of the sliding window cache.
|
||||
|
||||
Returns:
|
||||
Tuple[`torch.Tensor`, `torch.Tensor`]: The key and value tensors representing the cache state after the update.
|
||||
For prefill > window, these are the full input states.
|
||||
Otherwise, they are the updated cache tensors.
|
||||
"""
|
||||
# Handle prefill phase when prompt length > sliding_window_size
|
||||
if cache_position.shape[0] > max_cache_len:
|
||||
new_k = key_states[:, :, -max_cache_len:, :]
|
||||
new_v = value_states[:, :, -max_cache_len:, :]
|
||||
k_cache.copy_(new_k)
|
||||
v_cache.copy_(new_v)
|
||||
return key_states, value_states
|
||||
|
||||
# Sliding window logic for generation phase or prefill < window
|
||||
slicing = torch.arange(max_cache_len, device=value_states.device)
|
||||
current_seq_len = cache_position[-1] + 1 # Use last position to determine current length
|
||||
to_shift = current_seq_len > max_cache_len
|
||||
indices = (slicing + to_shift.sum()) % max_cache_len
|
||||
|
||||
k_out_shifted = k_cache[:, :, indices]
|
||||
v_out_shifted = v_cache[:, :, indices]
|
||||
|
||||
# Clamp cache_position to determine the *target index* within the shifted cache view
|
||||
update_position = cache_position.clamp(min=0, max=max_cache_len - 1)
|
||||
|
||||
try:
|
||||
k_out_updated = k_out_shifted.index_copy(2, update_position, key_states)
|
||||
v_out_updated = v_out_shifted.index_copy(2, update_position, value_states)
|
||||
except NotImplementedError:
|
||||
# Fallback for MPS: clone and modify the clone
|
||||
k_out_updated = k_out_shifted.clone()
|
||||
v_out_updated = v_out_shifted.clone()
|
||||
k_out_updated[:, :, update_position] = key_states
|
||||
v_out_updated[:, :, update_position] = value_states
|
||||
|
||||
k_cache.copy_(k_out_updated)
|
||||
v_cache.copy_(v_out_updated)
|
||||
return k_out_updated, v_out_updated
|
||||
|
||||
|
||||
class Cache:
|
||||
"""
|
||||
Base, abstract class for all caches. The actual data structure is specific to each subclass.
|
||||
@ -464,7 +562,7 @@ class DynamicCache(Cache):
|
||||
"""Returns the maximum sequence length of the cache object. DynamicCache does not have a maximum length."""
|
||||
return None
|
||||
|
||||
def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
|
||||
def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""Converts the `DynamicCache` instance into the its equivalent in the legacy cache format. Used for
|
||||
backward compatibility."""
|
||||
legacy_cache = ()
|
||||
@ -473,7 +571,9 @@ class DynamicCache(Cache):
|
||||
return legacy_cache
|
||||
|
||||
@classmethod
|
||||
def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache":
|
||||
def from_legacy_cache(
|
||||
cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor, torch.FloatTensor]]] = None
|
||||
) -> "DynamicCache":
|
||||
"""Converts a cache in the legacy cache format into an equivalent `DynamicCache`. Used for
|
||||
backward compatibility."""
|
||||
cache = cls()
|
||||
@ -1262,28 +1362,16 @@ class StaticCache(Cache):
|
||||
"""
|
||||
if cache_kwargs is None:
|
||||
cache_kwargs = {}
|
||||
cache_position = cache_kwargs.get("cache_position")
|
||||
k_out = self.key_cache[layer_idx]
|
||||
v_out = self.value_cache[layer_idx]
|
||||
key_states = key_states.to(k_out.dtype)
|
||||
value_states = value_states.to(v_out.dtype)
|
||||
|
||||
if cache_position is None:
|
||||
k_out.copy_(key_states)
|
||||
v_out.copy_(value_states)
|
||||
else:
|
||||
# Note: here we use `tensor.index_copy_(dim, index, tensor)` that is equivalent to
|
||||
# `tensor[:, :, index] = tensor`, but the first one is compile-friendly and it does explicitly an in-place
|
||||
# operation, that avoids copies and uses less memory.
|
||||
try:
|
||||
k_out.index_copy_(2, cache_position, key_states)
|
||||
v_out.index_copy_(2, cache_position, value_states)
|
||||
except NotImplementedError:
|
||||
# The operator 'aten::index_copy.out' is not currently implemented for the MPS device.
|
||||
k_out[:, :, cache_position] = key_states
|
||||
v_out[:, :, cache_position] = value_states
|
||||
|
||||
return k_out, v_out
|
||||
key_states = key_states.to(self.key_cache[layer_idx].dtype)
|
||||
value_states = value_states.to(self.value_cache[layer_idx].dtype)
|
||||
return _static_cache_update(
|
||||
self.key_cache[layer_idx],
|
||||
self.value_cache[layer_idx],
|
||||
key_states,
|
||||
value_states,
|
||||
cache_kwargs.get("cache_position"),
|
||||
)
|
||||
|
||||
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
||||
"""Returns the sequence length of the cached states that were seen by the model."""
|
||||
@ -1312,7 +1400,7 @@ class SlidingWindowCache(StaticCache):
|
||||
|
||||
The `to_shift` is only true once we are above sliding_window. Thus with `sliding_window==64`:
|
||||
|
||||
indices = (slicing + to_shift[-1].int()-1) % self.config.sliding_window
|
||||
indices = (slicing + to_shift[-1].sum()-1) % self.config.sliding_window
|
||||
tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
|
||||
19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36,
|
||||
37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54,
|
||||
@ -1396,46 +1484,21 @@ class SlidingWindowCache(StaticCache):
|
||||
if cache_kwargs is None:
|
||||
cache_kwargs = {}
|
||||
cache_position = cache_kwargs.get("cache_position")
|
||||
k_out = self.key_cache[layer_idx]
|
||||
v_out = self.value_cache[layer_idx]
|
||||
key_states = key_states.to(k_out.dtype)
|
||||
value_states = value_states.to(v_out.dtype)
|
||||
|
||||
# assume this only happens in prefill phase when prompt length > sliding_window_size (= max_cache_len)
|
||||
if cache_position.shape[0] > self.max_cache_len:
|
||||
k_out = key_states[:, :, -self.max_cache_len :, :]
|
||||
v_out = value_states[:, :, -self.max_cache_len :, :]
|
||||
# Assumption: caches are all zeros at this point, `+=` is equivalent to `=` but compile-friendly
|
||||
self.key_cache[layer_idx] += k_out
|
||||
self.value_cache[layer_idx] += v_out
|
||||
# we should return the whole states instead of k_out, v_out to take the whole prompt
|
||||
# into consideration when building kv cache instead of just throwing away tokens outside of the window
|
||||
return key_states, value_states
|
||||
if cache_position is None:
|
||||
raise ValueError("`cache_position` must be provided for SlidingWindowCache.")
|
||||
|
||||
slicing = torch.ones(self.max_cache_len, dtype=torch.long, device=value_states.device).cumsum(0)
|
||||
cache_position = cache_position.clamp(0, self.max_cache_len - 1)
|
||||
to_shift = cache_position > self.max_cache_len - 1
|
||||
indices = (slicing + to_shift[-1].int() - 1) % self.max_cache_len
|
||||
key_states = key_states.to(self.key_cache[layer_idx].dtype)
|
||||
value_states = value_states.to(self.value_cache[layer_idx].dtype)
|
||||
|
||||
k_out = k_out[:, :, indices]
|
||||
v_out = v_out[:, :, indices]
|
||||
|
||||
try:
|
||||
k_out.index_copy_(2, cache_position, key_states)
|
||||
v_out.index_copy_(2, cache_position, value_states)
|
||||
except NotImplementedError:
|
||||
# The operator 'aten::index_copy.out' is not currently implemented for the MPS device.
|
||||
k_out[:, :, cache_position] = key_states
|
||||
v_out[:, :, cache_position] = value_states
|
||||
|
||||
# `_.zero()` followed by `+=` is equivalent `=`, but compile-friendly (without graph breaks due to assignment)
|
||||
self.key_cache[layer_idx].zero_()
|
||||
self.value_cache[layer_idx].zero_()
|
||||
|
||||
self.key_cache[layer_idx] += k_out
|
||||
self.value_cache[layer_idx] += v_out
|
||||
|
||||
return k_out, v_out
|
||||
return _sliding_cache_update(
|
||||
self.key_cache[layer_idx],
|
||||
self.value_cache[layer_idx],
|
||||
key_states,
|
||||
value_states,
|
||||
cache_position,
|
||||
self.max_cache_len,
|
||||
)
|
||||
|
||||
def get_max_cache_shape(self) -> Optional[int]:
|
||||
return self.max_cache_len
|
||||
@ -1505,8 +1568,8 @@ class EncoderDecoderCache(Cache):
|
||||
"""
|
||||
return len(self.self_attention_cache)
|
||||
|
||||
def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
|
||||
"""Converts the `EncoderDecoderCache` instance into its equivalent in the legacy cache format."""
|
||||
def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor]]:
|
||||
"""Converts the `EncoderDecoderCache` instance into its equivalent in the legacy cache format."""
|
||||
legacy_cache = ()
|
||||
if len(self.cross_attention_cache) > 0:
|
||||
for self_attn, cross_attn in zip(
|
||||
@ -1678,12 +1741,13 @@ class HybridCache(Cache):
|
||||
super().__init__()
|
||||
if not hasattr(config, "sliding_window") or config.sliding_window is None:
|
||||
raise ValueError(
|
||||
"Setting `cache_implementation` to 'sliding_window' requires the model config supporting "
|
||||
"Setting `cache_implementation` to 'hybrid' requires the model config supporting "
|
||||
"sliding window attention, please check if there is a `sliding_window` field in the model "
|
||||
"config and it's not set to None."
|
||||
)
|
||||
self.max_cache_len = max_cache_len
|
||||
self._sliding_window_max_len = min(config.sliding_window, max_cache_len)
|
||||
self.max_cache_len = max_cache_len if max_cache_len is not None else config.max_position_embeddings
|
||||
# Sliding layers can't be larger than the overall max cache len
|
||||
self.sliding_window_len = min(config.sliding_window, self.max_cache_len)
|
||||
self.max_batch_size = max_batch_size
|
||||
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
|
||||
self.head_dim = (
|
||||
@ -1692,22 +1756,17 @@ class HybridCache(Cache):
|
||||
|
||||
self._dtype = dtype
|
||||
self.num_key_value_heads = (
|
||||
config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
|
||||
config.num_attention_heads
|
||||
if getattr(config, "num_key_value_heads", None) is None
|
||||
else config.num_key_value_heads
|
||||
)
|
||||
|
||||
layer_switch = config.sliding_window_pattern if hasattr(config, "sliding_window_pattern") else 2 # 2 is for BC
|
||||
self.is_sliding = torch.tensor(
|
||||
[bool((i + 1) % layer_switch) for i in range(config.num_hidden_layers)], dtype=torch.bool
|
||||
)
|
||||
self.is_sliding_list = [bool((i + 1) % layer_switch) for i in range(config.num_hidden_layers)]
|
||||
self.key_cache: List[torch.Tensor] = []
|
||||
self.value_cache: List[torch.Tensor] = []
|
||||
global_cache_shape = (self.max_batch_size, self.num_key_value_heads, max_cache_len, self.head_dim)
|
||||
sliding_cache_shape = (
|
||||
self.max_batch_size,
|
||||
self.num_key_value_heads,
|
||||
self._sliding_window_max_len,
|
||||
self.head_dim,
|
||||
)
|
||||
global_cache_shape = (self.max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
|
||||
sliding_cache_shape = (self.max_batch_size, self.num_key_value_heads, self.sliding_window_len, self.head_dim)
|
||||
device = torch.device(device) if device is not None else None
|
||||
for i in range(config.num_hidden_layers):
|
||||
if layer_device_map is not None:
|
||||
@ -1716,7 +1775,7 @@ class HybridCache(Cache):
|
||||
layer_device = device
|
||||
# Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
|
||||
# breaks when updating the cache.
|
||||
cache_shape = global_cache_shape if not self.is_sliding[i] else sliding_cache_shape
|
||||
cache_shape = sliding_cache_shape if self.is_sliding_list[i] else global_cache_shape
|
||||
new_layer_key_cache = torch.zeros(cache_shape, dtype=self._dtype, device=layer_device)
|
||||
new_layer_value_cache = torch.zeros(cache_shape, dtype=self._dtype, device=layer_device)
|
||||
torch._dynamo.mark_static_address(new_layer_key_cache)
|
||||
@ -1724,42 +1783,6 @@ class HybridCache(Cache):
|
||||
self.key_cache.append(new_layer_key_cache)
|
||||
self.value_cache.append(new_layer_value_cache)
|
||||
|
||||
def _sliding_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len):
|
||||
if cache_position.shape[0] > max_cache_len:
|
||||
k_out = key_states[:, :, -max_cache_len:, :]
|
||||
v_out = value_states[:, :, -max_cache_len:, :]
|
||||
# Assumption: caches are all zeros at this point, `+=` is equivalent to `=` but compile-friendly
|
||||
self.key_cache[layer_idx] += k_out
|
||||
self.value_cache[layer_idx] += v_out
|
||||
# we should return the whole states instead of k_out, v_out to take the whole prompt
|
||||
# into consideration when building kv cache instead of just throwing away tokens outside of the window
|
||||
return key_states, value_states
|
||||
|
||||
slicing = torch.ones(max_cache_len, dtype=torch.long, device=value_states.device).cumsum(0)
|
||||
cache_position = cache_position.clamp(0, max_cache_len - 1)
|
||||
to_shift = cache_position > max_cache_len - 1
|
||||
indices = (slicing + to_shift[-1].int() - 1) % max_cache_len
|
||||
k_out = k_out[:, :, indices]
|
||||
v_out = v_out[:, :, indices]
|
||||
|
||||
k_out[:, :, cache_position] = key_states
|
||||
v_out[:, :, cache_position] = value_states
|
||||
# `_.zero()` followed by `+=` is equivalent `=`, but compile-friendly (without graph breaks due to assignment)
|
||||
self.key_cache[layer_idx].zero_()
|
||||
self.value_cache[layer_idx].zero_()
|
||||
|
||||
self.key_cache[layer_idx] += k_out
|
||||
self.value_cache[layer_idx] += v_out
|
||||
return k_out, v_out
|
||||
|
||||
def _static_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len):
|
||||
k_out[:, :, cache_position] = key_states
|
||||
v_out[:, :, cache_position] = value_states
|
||||
|
||||
self.key_cache[layer_idx] = k_out
|
||||
self.value_cache[layer_idx] = v_out
|
||||
return k_out, v_out
|
||||
|
||||
def update(
|
||||
self,
|
||||
key_states: torch.Tensor,
|
||||
@ -1770,7 +1793,10 @@ class HybridCache(Cache):
|
||||
if cache_kwargs is None:
|
||||
cache_kwargs = {}
|
||||
cache_position = cache_kwargs.get("cache_position")
|
||||
sliding_window = cache_kwargs.get("sliding_window")
|
||||
if cache_position is None:
|
||||
raise ValueError("`cache_position` must be provided for HybridCache.")
|
||||
|
||||
is_sliding_layer = self.is_sliding_list[layer_idx]
|
||||
|
||||
# These two `if` blocks are only reached in multigpu and if `layer_device_map` is not passed. They are used
|
||||
# when the cache is initialized in the forward pass (e.g. Gemma2)
|
||||
@ -1779,25 +1805,22 @@ class HybridCache(Cache):
|
||||
if self.value_cache[layer_idx].device != value_states.device:
|
||||
self.value_cache[layer_idx] = self.value_cache[layer_idx].to(value_states.device)
|
||||
|
||||
k_out = self.key_cache[layer_idx]
|
||||
v_out = self.value_cache[layer_idx]
|
||||
key_states = key_states.to(k_out.dtype)
|
||||
value_states = value_states.to(v_out.dtype)
|
||||
k_cache = self.key_cache[layer_idx]
|
||||
v_cache = self.value_cache[layer_idx]
|
||||
key_states = key_states.to(k_cache.dtype)
|
||||
value_states = value_states.to(v_cache.dtype)
|
||||
|
||||
if sliding_window:
|
||||
update_fn = self._sliding_update
|
||||
if is_sliding_layer:
|
||||
return _sliding_cache_update(
|
||||
k_cache,
|
||||
v_cache,
|
||||
key_states,
|
||||
value_states,
|
||||
cache_position,
|
||||
k_cache.shape[2], # Use actual cache dim as max cache len
|
||||
)
|
||||
else:
|
||||
update_fn = self._static_update
|
||||
|
||||
return update_fn(
|
||||
cache_position,
|
||||
layer_idx,
|
||||
key_states,
|
||||
value_states,
|
||||
k_out,
|
||||
v_out,
|
||||
k_out.shape[2],
|
||||
)
|
||||
return _static_cache_update(k_cache, v_cache, key_states, value_states, cache_position)
|
||||
|
||||
def get_max_cache_shape(self) -> Optional[int]:
|
||||
return self.max_cache_len
|
||||
@ -2031,7 +2054,7 @@ class OffloadedHybridCache(HybridChunkedCache):
|
||||
|
||||
# TODO (joao): to enable this cache on multiple devicesuse the pattern from `OffloadedCache`, which keeps
|
||||
# track of the original device of each layer
|
||||
unique_devices = set(layer_device_map.values())
|
||||
unique_devices = set(layer_device_map.values()) if layer_device_map else set()
|
||||
if len(unique_devices) > 1:
|
||||
raise ValueError(f"OffloadedHybridCache does not support multiple devices. Got devices: {unique_devices}")
|
||||
|
||||
@ -2290,7 +2313,7 @@ class OffloadedStaticCache(StaticCache):
|
||||
|
||||
# TODO (joao): to enable this cache on multiple devicesuse the pattern from `OffloadedCache`, which keeps
|
||||
# track of the original device of each layer
|
||||
unique_devices = set(layer_device_map.values())
|
||||
unique_devices = set(layer_device_map.values()) if layer_device_map else set()
|
||||
if len(unique_devices) > 1:
|
||||
raise ValueError(f"OffloadedStaticCache does not support multiple devices. Got devices: {unique_devices}")
|
||||
|
||||
@ -2367,6 +2390,9 @@ class OffloadedStaticCache(StaticCache):
|
||||
A tuple containing the updated key and value states.
|
||||
"""
|
||||
|
||||
key_states = key_states.to(self.key_cache[layer_idx].dtype)
|
||||
value_states = value_states.to(self.value_cache[layer_idx].dtype)
|
||||
|
||||
if layer_idx == 0:
|
||||
# Update seen tokens.
|
||||
# TODO(gante): Remove this.
|
||||
|
@ -13,12 +13,12 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
import platform
|
||||
import string
|
||||
import time
|
||||
import warnings
|
||||
from argparse import ArgumentParser, Namespace
|
||||
from dataclasses import dataclass, field
|
||||
from threading import Thread
|
||||
@ -42,7 +42,13 @@ if is_rich_available():
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
BitsAndBytesConfig,
|
||||
GenerationConfig,
|
||||
TextIteratorStreamer,
|
||||
)
|
||||
|
||||
|
||||
ALLOWED_KEY_CHARS = set(string.ascii_letters + string.whitespace)
|
||||
@ -64,25 +70,16 @@ DEFAULT_EXAMPLES = {
|
||||
"socks": {"text": "Why is it important to eat socks after meditating?"},
|
||||
}
|
||||
|
||||
SUPPORTED_GENERATION_KWARGS = [
|
||||
"max_new_tokens",
|
||||
"do_sample",
|
||||
"num_beams",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"top_k",
|
||||
"repetition_penalty",
|
||||
]
|
||||
|
||||
# Printed at the start of a chat session
|
||||
HELP_STRING_MINIMAL = """
|
||||
|
||||
**TRANSFORMERS CHAT INTERFACE**
|
||||
|
||||
Chat interface to try out a model. Besides chatting with the model, here are some basic commands:
|
||||
- **help**: shows all available commands
|
||||
- **clear**: clears the current conversation and starts a new one
|
||||
- **exit**: closes the interface
|
||||
- **!help**: shows all available commands
|
||||
- **!status**: shows the current status of the model and generation settings
|
||||
- **!clear**: clears the current conversation and starts a new one
|
||||
- **!exit**: closes the interface
|
||||
"""
|
||||
|
||||
|
||||
@ -92,18 +89,32 @@ HELP_STRING = f"""
|
||||
**TRANSFORMERS CHAT INTERFACE HELP**
|
||||
|
||||
Full command list:
|
||||
- **help**: shows this help message
|
||||
- **clear**: clears the current conversation and starts a new one
|
||||
- **example {{NAME}}**: loads example named `{{NAME}}` from the config and uses it as the user input. Available example
|
||||
names: `{"`, `".join(DEFAULT_EXAMPLES.keys())}`
|
||||
- **set {{SETTING_NAME}}={{SETTING_VALUE}};**: changes the system prompt or generation settings (multiple settings are
|
||||
separated by a ';'). Available settings: `{"`, `".join(SUPPORTED_GENERATION_KWARGS)}`
|
||||
- **reset**: same as clear but also resets the generation configs to defaults if they have been changed by **set**
|
||||
- **save {{SAVE_NAME}} (optional)**: saves the current chat and settings to file by default to
|
||||
- **!help**: shows this help message
|
||||
- **!clear**: clears the current conversation and starts a new one
|
||||
- **!status**: shows the current status of the model and generation settings
|
||||
- **!example {{NAME}}**: loads example named `{{NAME}}` from the config and uses it as the user input.
|
||||
Available example names: `{"`, `".join(DEFAULT_EXAMPLES.keys())}`
|
||||
- **!set {{ARG_1}}={{VALUE_1}} {{ARG_2}}={{VALUE_2}}** ...: changes the system prompt or generation settings (multiple
|
||||
settings are separated by a space). Accepts the same flags and format as the `generate_flags` CLI argument.
|
||||
If you're a new user, check this basic flag guide: https://huggingface.co/docs/transformers/llm_tutorial#common-options
|
||||
- **!save {{SAVE_NAME}} (optional)**: saves the current chat and settings to file by default to
|
||||
`./chat_history/{{MODEL_NAME}}/chat_{{DATETIME}}.yaml` or `{{SAVE_NAME}}` if provided
|
||||
- **exit**: closes the interface
|
||||
- **!exit**: closes the interface
|
||||
"""
|
||||
|
||||
# format: (optional CLI arg being deprecated, its current default, corresponding `generate` flag)
|
||||
_DEPRECATION_MAP = [
|
||||
("max_new_tokens", 256, "max_new_tokens"),
|
||||
("do_sample", True, "do_sample"),
|
||||
("num_beams", 1, "num_beams"),
|
||||
("temperature", 1.0, "temperature"),
|
||||
("top_k", 50, "top_k"),
|
||||
("top_p", 1.0, "top_p"),
|
||||
("repetition_penalty", 1.0, "repetition_penalty"),
|
||||
("eos_tokens", None, "eos_token_id"),
|
||||
("eos_token_ids", None, "eos_token_id"),
|
||||
]
|
||||
|
||||
|
||||
class RichInterface:
|
||||
def __init__(self, model_name: Optional[str] = None, user_name: Optional[str] = None):
|
||||
@ -181,6 +192,14 @@ class RichInterface:
|
||||
self._console.print(Markdown(HELP_STRING_MINIMAL if minimal else HELP_STRING))
|
||||
self._console.print()
|
||||
|
||||
def print_status(self, model_name: str, generation_config: GenerationConfig, model_kwargs: dict):
|
||||
"""Prints the status of the model and generation settings to the console."""
|
||||
self._console.print(f"[bold blue]Model: {model_name}\n")
|
||||
if model_kwargs:
|
||||
self._console.print(f"[bold blue]Model kwargs: {model_kwargs}")
|
||||
self._console.print(f"[bold blue]{generation_config}")
|
||||
self._console.print()
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatArguments:
|
||||
@ -207,6 +226,17 @@ class ChatArguments:
|
||||
examples_path: Optional[str] = field(default=None, metadata={"help": "Path to a yaml file with examples."})
|
||||
|
||||
# Generation settings
|
||||
generation_config: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
"Path to a local generation config file or to a HuggingFace repo containing a "
|
||||
"`generation_config.json` file. Other generation settings passed as CLI arguments will be applied on "
|
||||
"top of this generation config."
|
||||
),
|
||||
},
|
||||
)
|
||||
# Deprecated CLI args start here
|
||||
max_new_tokens: int = field(default=256, metadata={"help": "Maximum number of tokens to generate."})
|
||||
do_sample: bool = field(default=True, metadata={"help": "Whether to sample outputs during generation."})
|
||||
num_beams: int = field(default=1, metadata={"help": "Number of beams for beam search."})
|
||||
@ -222,6 +252,7 @@ class ChatArguments:
|
||||
default=None,
|
||||
metadata={"help": "EOS token IDs to stop the generation. If multiple they should be comma separated."},
|
||||
)
|
||||
# Deprecated CLI args end here
|
||||
|
||||
# Model loading
|
||||
model_revision: str = field(
|
||||
@ -280,23 +311,66 @@ class ChatCommand(BaseTransformersCLICommand):
|
||||
|
||||
group = chat_parser.add_argument_group("Positional arguments")
|
||||
group.add_argument(
|
||||
"model_name_or_path_positional", type=str, nargs="?", default=None, help="Name of the pre-trained model."
|
||||
"model_name_or_path_positional", type=str, default=None, help="Name of the pre-trained model."
|
||||
)
|
||||
group.add_argument(
|
||||
"generate_flags",
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
"Flags to pass to `generate`, using a space as a separator between flags. Accepts booleans, numbers, "
|
||||
"and lists of integers, more advanced parameterization should be set through --generation-config. "
|
||||
"Example: `transformers chat <model_repo> max_new_tokens=100 do_sample=False eos_token_id=[1,2]`. "
|
||||
"If you're a new user, check this basic flag guide: https://huggingface.co/docs/transformers/llm_tutorial#common-options"
|
||||
),
|
||||
nargs="*",
|
||||
)
|
||||
|
||||
chat_parser.set_defaults(func=chat_command_factory)
|
||||
|
||||
def __init__(self, args):
|
||||
args.model_name_or_path = args.model_name_or_path_positional or args.model_name_or_path
|
||||
args = self._handle_deprecated_args(args)
|
||||
self.args = args
|
||||
|
||||
if args.model_name_or_path is None:
|
||||
def _handle_deprecated_args(self, args: ChatArguments) -> ChatArguments:
|
||||
"""
|
||||
Handles deprecated arguments and their deprecation cycle. To be removed after we fully migrated to the new
|
||||
args.
|
||||
"""
|
||||
has_warnings = False
|
||||
|
||||
# 1. Model as a positional argument
|
||||
args.model_name_or_path_positional = args.model_name_or_path_positional or args.model_name_or_path
|
||||
if args.model_name_or_path_positional is None:
|
||||
raise ValueError(
|
||||
"One of the following must be provided:"
|
||||
"\n- The positional argument containing the model repo;"
|
||||
"\n- the optional --model_name_or_path argument, containing the model repo"
|
||||
"\ne.g. transformers chat <model_repo> or transformers chat --model_name_or_path <model_repo>"
|
||||
"\n- The positional argument containing the model repo, e.g. `transformers chat <model_repo>`"
|
||||
"\n- the optional --model_name_or_path argument, containing the model repo (deprecated)"
|
||||
)
|
||||
elif args.model_name_or_path is not None:
|
||||
has_warnings = True
|
||||
warnings.warn(
|
||||
"The --model_name_or_path argument is deprecated will be removed in v4.54.0. Use the positional "
|
||||
"argument instead, e.g. `transformers chat <model_repo>`.",
|
||||
FutureWarning,
|
||||
)
|
||||
# 2. Named generate option args
|
||||
for deprecated_arg, default_value, new_arg in _DEPRECATION_MAP:
|
||||
value = getattr(args, deprecated_arg)
|
||||
if value != default_value:
|
||||
has_warnings = True
|
||||
warnings.warn(
|
||||
f"The --{deprecated_arg} argument is deprecated will be removed in v4.54.0. There are two "
|
||||
"alternative solutions to specify this generation option: \n"
|
||||
"1. Pass `--generation-config <path_to_file/Hub repo>` to specify a generation config.\n"
|
||||
"2. Pass `generate` flags through positional arguments, e.g. `transformers chat <model_repo> "
|
||||
f"{new_arg}={value}`",
|
||||
FutureWarning,
|
||||
)
|
||||
|
||||
self.args = args
|
||||
if has_warnings:
|
||||
print("\n(Press enter to continue)")
|
||||
input()
|
||||
return args
|
||||
|
||||
# -----------------------------------------------------------------------------------------------------------------
|
||||
# Chat session methods
|
||||
@ -319,7 +393,7 @@ class ChatCommand(BaseTransformersCLICommand):
|
||||
|
||||
if filename is None:
|
||||
time_str = time.strftime("%Y-%m-%d_%H-%M-%S")
|
||||
filename = f"{args.model_name_or_path}/chat_{time_str}.json"
|
||||
filename = f"{args.model_name_or_path_positional}/chat_{time_str}.json"
|
||||
filename = os.path.join(folder, filename)
|
||||
|
||||
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
||||
@ -338,50 +412,95 @@ class ChatCommand(BaseTransformersCLICommand):
|
||||
|
||||
# -----------------------------------------------------------------------------------------------------------------
|
||||
# Input parsing methods
|
||||
@staticmethod
|
||||
def parse_settings(
|
||||
user_input: str, current_args: ChatArguments, interface: RichInterface
|
||||
) -> tuple[ChatArguments, bool]:
|
||||
"""Parses the settings from the user input into the CLI arguments."""
|
||||
settings = user_input[4:].strip().split(";")
|
||||
settings = [(setting.split("=")[0], setting[len(setting.split("=")[0]) + 1 :]) for setting in settings]
|
||||
settings = dict(settings)
|
||||
error = False
|
||||
def parse_generate_flags(self, generate_flags: list[str]) -> dict:
|
||||
"""Parses the generate flags from the user input into a dictionary of `generate` kwargs."""
|
||||
if len(generate_flags) == 0:
|
||||
return {}
|
||||
|
||||
for name in settings:
|
||||
if hasattr(current_args, name):
|
||||
try:
|
||||
if isinstance(getattr(current_args, name), bool):
|
||||
if settings[name] == "True":
|
||||
settings[name] = True
|
||||
elif settings[name] == "False":
|
||||
settings[name] = False
|
||||
else:
|
||||
raise ValueError
|
||||
else:
|
||||
settings[name] = type(getattr(current_args, name))(settings[name])
|
||||
except ValueError:
|
||||
error = True
|
||||
interface.print_color(
|
||||
text=f"Cannot cast setting {name} (={settings[name]}) to {type(getattr(current_args, name))}.",
|
||||
color="red",
|
||||
)
|
||||
else:
|
||||
interface.print_color(text=f"There is no '{name}' setting.", color="red")
|
||||
# Assumption: `generate_flags` is a list of strings, each string being a `flag=value` pair, that can be parsed
|
||||
# into a json string if we:
|
||||
# 1. Add quotes around each flag name
|
||||
generate_flags_as_dict = {'"' + flag.split("=")[0] + '"': flag.split("=")[1] for flag in generate_flags}
|
||||
|
||||
if error:
|
||||
interface.print_color(
|
||||
text="There was an issue parsing the settings. No settings have been changed.",
|
||||
color="red",
|
||||
# 2. Handle types:
|
||||
# 2. a. booleans should be lowercase, None should be null
|
||||
generate_flags_as_dict = {
|
||||
k: v.lower() if v.lower() in ["true", "false"] else v for k, v in generate_flags_as_dict.items()
|
||||
}
|
||||
generate_flags_as_dict = {k: "null" if v == "None" else v for k, v in generate_flags_as_dict.items()}
|
||||
|
||||
# 2. b. strings should be quoted
|
||||
def is_number(s: str) -> bool:
|
||||
return s.replace(".", "", 1).isdigit()
|
||||
|
||||
generate_flags_as_dict = {k: f'"{v}"' if not is_number(v) else v for k, v in generate_flags_as_dict.items()}
|
||||
# 2. c. [no processing needed] lists are lists of ints because `generate` doesn't take lists of strings :)
|
||||
# We also mention in the help message that we only accept lists of ints for now.
|
||||
|
||||
# 3. Join the the result into a comma separated string
|
||||
generate_flags_string = ", ".join([f"{k}: {v}" for k, v in generate_flags_as_dict.items()])
|
||||
|
||||
# 4. Add the opening/closing brackets
|
||||
generate_flags_string = "{" + generate_flags_string + "}"
|
||||
|
||||
# 5. Remove quotes around boolean/null and around lists
|
||||
generate_flags_string = generate_flags_string.replace('"null"', "null")
|
||||
generate_flags_string = generate_flags_string.replace('"true"', "true")
|
||||
generate_flags_string = generate_flags_string.replace('"false"', "false")
|
||||
generate_flags_string = generate_flags_string.replace('"[', "[")
|
||||
generate_flags_string = generate_flags_string.replace(']"', "]")
|
||||
|
||||
# 6. Replace the `=` with `:`
|
||||
generate_flags_string = generate_flags_string.replace("=", ":")
|
||||
|
||||
try:
|
||||
processed_generate_flags = json.loads(generate_flags_string)
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(
|
||||
"Failed to convert `generate_flags` into a valid JSON object."
|
||||
"\n`generate_flags` = {generate_flags}"
|
||||
"\nConverted JSON string = {generate_flags_string}"
|
||||
)
|
||||
return processed_generate_flags
|
||||
|
||||
def get_generation_parameterization(
|
||||
self, args: ChatArguments, tokenizer: AutoTokenizer
|
||||
) -> tuple[GenerationConfig, dict]:
|
||||
"""
|
||||
Returns a GenerationConfig object holding the generation parameters for the CLI command.
|
||||
"""
|
||||
# No generation config arg provided -> use base generation config, apply CLI defaults
|
||||
if args.generation_config is None:
|
||||
generation_config = GenerationConfig()
|
||||
# Apply deprecated CLI args on top of the default generation config
|
||||
pad_token_id, eos_token_ids = self.parse_eos_tokens(tokenizer, args.eos_tokens, args.eos_token_ids)
|
||||
deprecated_kwargs = {
|
||||
"max_new_tokens": args.max_new_tokens,
|
||||
"do_sample": args.do_sample,
|
||||
"num_beams": args.num_beams,
|
||||
"temperature": args.temperature,
|
||||
"top_k": args.top_k,
|
||||
"top_p": args.top_p,
|
||||
"repetition_penalty": args.repetition_penalty,
|
||||
"pad_token_id": pad_token_id,
|
||||
"eos_token_id": eos_token_ids,
|
||||
}
|
||||
generation_config.update(**deprecated_kwargs)
|
||||
# generation config arg provided -> use it as the base parameterization
|
||||
else:
|
||||
for name in settings:
|
||||
setattr(current_args, name, settings[name])
|
||||
interface.print_color(text=f"Set {name} to {settings[name]}.", color="green")
|
||||
if ".json" in args.generation_config: # is a local file
|
||||
dirname = os.path.dirname(args.generation_config)
|
||||
filename = os.path.basename(args.generation_config)
|
||||
generation_config = GenerationConfig.from_pretrained(dirname, filename)
|
||||
else:
|
||||
generation_config = GenerationConfig.from_pretrained(args.generation_config)
|
||||
|
||||
time.sleep(1.5) # so the user has time to read the changes
|
||||
|
||||
return current_args, not error
|
||||
# Finally: parse and apply `generate_flags`
|
||||
parsed_generate_flags = self.parse_generate_flags(args.generate_flags)
|
||||
model_kwargs = generation_config.update(**parsed_generate_flags)
|
||||
# `model_kwargs` contain non-generation flags in `parsed_generate_flags` that should be passed directly to
|
||||
# `generate`
|
||||
return generation_config, model_kwargs
|
||||
|
||||
@staticmethod
|
||||
def parse_eos_tokens(
|
||||
@ -406,36 +525,6 @@ class ChatCommand(BaseTransformersCLICommand):
|
||||
|
||||
return pad_token_id, all_eos_token_ids
|
||||
|
||||
@staticmethod
|
||||
def is_valid_setting_command(s: str) -> bool:
|
||||
# First check the basic structure
|
||||
if not s.startswith("set ") or "=" not in s:
|
||||
return False
|
||||
|
||||
# Split into individual assignments
|
||||
assignments = [a.strip() for a in s[4:].split(";") if a.strip()]
|
||||
|
||||
for assignment in assignments:
|
||||
# Each assignment should have exactly one '='
|
||||
if assignment.count("=") != 1:
|
||||
return False
|
||||
|
||||
key, value = assignment.split("=", 1)
|
||||
key = key.strip()
|
||||
value = value.strip()
|
||||
if not key or not value:
|
||||
return False
|
||||
|
||||
# Keys can only have alphabetic characters, spaces and underscores
|
||||
if not set(key).issubset(ALLOWED_KEY_CHARS):
|
||||
return False
|
||||
|
||||
# Values can have just about anything that isn't a semicolon
|
||||
if not set(value).issubset(ALLOWED_VALUE_CHARS):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
# -----------------------------------------------------------------------------------------------------------------
|
||||
# Model loading and performance automation methods
|
||||
@staticmethod
|
||||
@ -460,7 +549,7 @@ class ChatCommand(BaseTransformersCLICommand):
|
||||
|
||||
def load_model_and_tokenizer(self, args: ChatArguments) -> tuple[AutoModelForCausalLM, AutoTokenizer]:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
args.model_name_or_path,
|
||||
args.model_name_or_path_positional,
|
||||
revision=args.model_revision,
|
||||
trust_remote_code=args.trust_remote_code,
|
||||
)
|
||||
@ -475,7 +564,7 @@ class ChatCommand(BaseTransformersCLICommand):
|
||||
"quantization_config": quantization_config,
|
||||
}
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
args.model_name_or_path, trust_remote_code=args.trust_remote_code, **model_kwargs
|
||||
args.model_name_or_path_positional, trust_remote_code=args.trust_remote_code, **model_kwargs
|
||||
)
|
||||
|
||||
if getattr(model, "hf_device_map", None) is None:
|
||||
@ -483,6 +572,88 @@ class ChatCommand(BaseTransformersCLICommand):
|
||||
|
||||
return model, tokenizer
|
||||
|
||||
# -----------------------------------------------------------------------------------------------------------------
|
||||
# User commands
|
||||
def handle_non_exit_user_commands(
|
||||
self,
|
||||
user_input: str,
|
||||
args: ChatArguments,
|
||||
interface: RichInterface,
|
||||
examples: dict[str, dict[str, str]],
|
||||
generation_config: GenerationConfig,
|
||||
model_kwargs: dict,
|
||||
chat: list[dict],
|
||||
) -> tuple[list[dict], GenerationConfig, dict]:
|
||||
"""
|
||||
Handles all user commands except for `!exit`. May update the chat history (e.g. reset it) or the
|
||||
generation config (e.g. set a new flag).
|
||||
"""
|
||||
|
||||
if user_input == "!clear":
|
||||
chat = self.clear_chat_history(args.system_prompt)
|
||||
interface.clear()
|
||||
|
||||
elif user_input == "!help":
|
||||
interface.print_help()
|
||||
|
||||
elif user_input.startswith("!save") and len(user_input.split()) < 2:
|
||||
split_input = user_input.split()
|
||||
|
||||
if len(split_input) == 2:
|
||||
filename = split_input[1]
|
||||
else:
|
||||
filename = None
|
||||
filename = self.save_chat(chat, args, filename)
|
||||
interface.print_color(text=f"Chat saved in {filename}!", color="green")
|
||||
|
||||
elif user_input.startswith("!set"):
|
||||
# splits the new args into a list of strings, each string being a `flag=value` pair (same format as
|
||||
# `generate_flags`)
|
||||
new_generate_flags = user_input[4:].strip()
|
||||
new_generate_flags = new_generate_flags.split()
|
||||
# sanity check: each member in the list must have an =
|
||||
for flag in new_generate_flags:
|
||||
if "=" not in flag:
|
||||
interface.print_color(
|
||||
text=(
|
||||
f"Invalid flag format, missing `=` after `{flag}`. Please use the format "
|
||||
"`arg_1=value_1 arg_2=value_2 ...`."
|
||||
),
|
||||
color="red",
|
||||
)
|
||||
break
|
||||
else:
|
||||
# parses the new args into a dictionary of `generate` kwargs, and updates the corresponding variables
|
||||
parsed_new_generate_flags = self.parse_generate_flags(new_generate_flags)
|
||||
new_model_kwargs = generation_config.update(**parsed_new_generate_flags)
|
||||
model_kwargs.update(**new_model_kwargs)
|
||||
|
||||
elif user_input.startswith("!example") and len(user_input.split()) == 2:
|
||||
example_name = user_input.split()[1]
|
||||
if example_name in examples:
|
||||
interface.clear()
|
||||
chat = []
|
||||
interface.print_user_message(examples[example_name]["text"])
|
||||
chat.append({"role": "user", "content": examples[example_name]["text"]})
|
||||
else:
|
||||
example_error = (
|
||||
f"Example {example_name} not found in list of available examples: {list(examples.keys())}."
|
||||
)
|
||||
interface.print_color(text=example_error, color="red")
|
||||
|
||||
elif user_input == "!status":
|
||||
interface.print_status(
|
||||
model_name=args.model_name_or_path_positional,
|
||||
generation_config=generation_config,
|
||||
model_kwargs=model_kwargs,
|
||||
)
|
||||
|
||||
else:
|
||||
interface.print_color(text=f"'{user_input}' is not a valid command. Showing help message.", color="red")
|
||||
interface.print_help()
|
||||
|
||||
return chat, generation_config, model_kwargs
|
||||
|
||||
# -----------------------------------------------------------------------------------------------------------------
|
||||
# Main logic
|
||||
def run(self):
|
||||
@ -498,8 +669,6 @@ class ChatCommand(BaseTransformersCLICommand):
|
||||
with open(args.examples_path) as f:
|
||||
examples = yaml.safe_load(f)
|
||||
|
||||
current_args = copy.deepcopy(args)
|
||||
|
||||
if args.user is None:
|
||||
user = self.get_username()
|
||||
else:
|
||||
@ -507,12 +676,11 @@ class ChatCommand(BaseTransformersCLICommand):
|
||||
|
||||
model, tokenizer = self.load_model_and_tokenizer(args)
|
||||
generation_streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True, skip_prompt=True)
|
||||
generation_config, model_kwargs = self.get_generation_parameterization(args, tokenizer)
|
||||
|
||||
pad_token_id, eos_token_ids = self.parse_eos_tokens(tokenizer, args.eos_tokens, args.eos_token_ids)
|
||||
|
||||
interface = RichInterface(model_name=args.model_name_or_path, user_name=user)
|
||||
interface = RichInterface(model_name=args.model_name_or_path_positional, user_name=user)
|
||||
interface.clear()
|
||||
chat = self.clear_chat_history(current_args.system_prompt)
|
||||
chat = self.clear_chat_history(args.system_prompt)
|
||||
|
||||
# Starts the session with a minimal help message at the top, so that a user doesn't get stuck
|
||||
interface.print_help(minimal=True)
|
||||
@ -520,57 +688,26 @@ class ChatCommand(BaseTransformersCLICommand):
|
||||
try:
|
||||
user_input = interface.input()
|
||||
|
||||
if user_input == "clear":
|
||||
chat = self.clear_chat_history(current_args.system_prompt)
|
||||
interface.clear()
|
||||
continue
|
||||
|
||||
if user_input == "help":
|
||||
interface.print_help()
|
||||
continue
|
||||
|
||||
if user_input == "exit":
|
||||
break
|
||||
|
||||
if user_input == "reset":
|
||||
interface.clear()
|
||||
current_args = copy.deepcopy(args)
|
||||
chat = self.clear_chat_history(current_args.system_prompt)
|
||||
continue
|
||||
|
||||
if user_input.startswith("save") and len(user_input.split()) < 2:
|
||||
split_input = user_input.split()
|
||||
|
||||
if len(split_input) == 2:
|
||||
filename = split_input[1]
|
||||
# User commands
|
||||
if user_input.startswith("!"):
|
||||
# `!exit` is special, it breaks the loop
|
||||
if user_input == "!exit":
|
||||
break
|
||||
else:
|
||||
filename = None
|
||||
filename = self.save_chat(chat, current_args, filename)
|
||||
interface.print_color(text=f"Chat saved in {filename}!", color="green")
|
||||
continue
|
||||
|
||||
if self.is_valid_setting_command(user_input):
|
||||
current_args, success = self.parse_settings(user_input, current_args, interface)
|
||||
if success:
|
||||
chat = []
|
||||
interface.clear()
|
||||
continue
|
||||
|
||||
if user_input.startswith("example") and len(user_input.split()) == 2:
|
||||
example_name = user_input.split()[1]
|
||||
if example_name in examples:
|
||||
interface.clear()
|
||||
chat = []
|
||||
interface.print_user_message(examples[example_name]["text"])
|
||||
user_input = examples[example_name]["text"]
|
||||
else:
|
||||
example_error = (
|
||||
f"Example {example_name} not found in list of available examples: {list(examples.keys())}."
|
||||
chat, generation_config, model_kwargs = self.handle_non_exit_user_commands(
|
||||
user_input=user_input,
|
||||
args=args,
|
||||
interface=interface,
|
||||
examples=examples,
|
||||
generation_config=generation_config,
|
||||
model_kwargs=model_kwargs,
|
||||
chat=chat,
|
||||
)
|
||||
interface.print_color(text=example_error, color="red")
|
||||
# `!example` sends a user message to the model
|
||||
if not user_input.startswith("!example"):
|
||||
continue
|
||||
|
||||
chat.append({"role": "user", "content": user_input})
|
||||
else:
|
||||
chat.append({"role": "user", "content": user_input})
|
||||
|
||||
inputs = tokenizer.apply_chat_template(chat, return_tensors="pt", add_generation_prompt=True).to(
|
||||
model.device
|
||||
@ -580,15 +717,8 @@ class ChatCommand(BaseTransformersCLICommand):
|
||||
"inputs": inputs,
|
||||
"attention_mask": attention_mask,
|
||||
"streamer": generation_streamer,
|
||||
"max_new_tokens": current_args.max_new_tokens,
|
||||
"do_sample": current_args.do_sample,
|
||||
"num_beams": current_args.num_beams,
|
||||
"temperature": current_args.temperature,
|
||||
"top_k": current_args.top_k,
|
||||
"top_p": current_args.top_p,
|
||||
"repetition_penalty": current_args.repetition_penalty,
|
||||
"pad_token_id": pad_token_id,
|
||||
"eos_token_id": eos_token_ids,
|
||||
"generation_config": generation_config,
|
||||
**model_kwargs,
|
||||
}
|
||||
|
||||
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
||||
|
@ -408,6 +408,10 @@ class PretrainedConfig(PushToHubMixin):
|
||||
repo_id = self._create_repo(repo_id, **kwargs)
|
||||
files_timestamps = self._get_files_timestamps(save_directory)
|
||||
|
||||
# This attribute is important to know on load, but should not be serialized on save.
|
||||
if "transformers_weights" in self:
|
||||
delattr(self, "transformers_weights")
|
||||
|
||||
# If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be
|
||||
# loaded from the Hub.
|
||||
if self._auto_class is not None:
|
||||
|
@ -1584,7 +1584,9 @@ class TikTokenConverter:
|
||||
self.pattern = pattern
|
||||
self.add_prefix_space = add_prefix_space
|
||||
self.additional_special_tokens = (
|
||||
additional_special_tokens.keys() if type(additional_special_tokens) is dict else additional_special_tokens
|
||||
additional_special_tokens.keys()
|
||||
if isinstance(additional_special_tokens, dict)
|
||||
else additional_special_tokens
|
||||
)
|
||||
|
||||
def extract_vocab_merges_from_model(self, tiktoken_url: str):
|
||||
|
@ -17,6 +17,7 @@ import ast
|
||||
import filecmp
|
||||
import hashlib
|
||||
import importlib
|
||||
import importlib.metadata
|
||||
import importlib.util
|
||||
import os
|
||||
import re
|
||||
@ -30,6 +31,7 @@ from types import ModuleType
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from huggingface_hub import try_to_load_from_cache
|
||||
from packaging import version
|
||||
|
||||
from .utils import (
|
||||
HF_MODULES_CACHE,
|
||||
@ -39,6 +41,7 @@ from .utils import (
|
||||
is_offline_mode,
|
||||
logging,
|
||||
)
|
||||
from .utils.import_utils import VersionComparison, split_package_version
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
@ -383,7 +386,7 @@ def get_cached_module_file(
|
||||
new_files.append(module_file)
|
||||
|
||||
except OSError:
|
||||
logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.")
|
||||
logger.info(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.")
|
||||
raise
|
||||
|
||||
# Check we have all the requirements in our environment
|
||||
@ -417,7 +420,8 @@ def get_cached_module_file(
|
||||
# benefit of versioning.
|
||||
submodule_path = submodule_path / commit_hash
|
||||
full_submodule = full_submodule + os.path.sep + commit_hash
|
||||
create_dynamic_module(full_submodule)
|
||||
full_submodule_module_file_path = os.path.join(full_submodule, module_file)
|
||||
create_dynamic_module(Path(full_submodule_module_file_path).parent)
|
||||
|
||||
if not (submodule_path / module_file).exists():
|
||||
shutil.copy(resolved_module_file, submodule_path / module_file)
|
||||
@ -663,7 +667,33 @@ def _raise_timeout_error(signum, frame):
|
||||
TIME_OUT_REMOTE_CODE = 15
|
||||
|
||||
|
||||
def resolve_trust_remote_code(trust_remote_code, model_name, has_local_code, has_remote_code):
|
||||
def resolve_trust_remote_code(trust_remote_code, model_name, has_local_code, has_remote_code, error_message=None):
|
||||
"""
|
||||
Resolves the `trust_remote_code` argument. If there is remote code to be loaded, the user must opt-in to loading
|
||||
it.
|
||||
|
||||
Args:
|
||||
trust_remote_code (`bool` or `None`):
|
||||
User-defined `trust_remote_code` value.
|
||||
model_name (`str`):
|
||||
The name of the model repository in huggingface.co.
|
||||
has_local_code (`bool`):
|
||||
Whether the model has local code.
|
||||
has_remote_code (`bool`):
|
||||
Whether the model has remote code.
|
||||
error_message (`str`, *optional*):
|
||||
Custom error message to display if there is remote code to load and the user didn't opt-in. If unset, the error
|
||||
message will be regarding loading a model with custom code.
|
||||
|
||||
Returns:
|
||||
The resolved `trust_remote_code` value.
|
||||
"""
|
||||
# Originally, `trust_remote_code` was used to load models with custom code.
|
||||
error_message = (
|
||||
error_message
|
||||
or f"The repository `{model_name}` contains custom code which must be executed to correctly load the model."
|
||||
)
|
||||
|
||||
if trust_remote_code is None:
|
||||
if has_local_code:
|
||||
trust_remote_code = False
|
||||
@ -674,8 +704,7 @@ def resolve_trust_remote_code(trust_remote_code, model_name, has_local_code, has
|
||||
signal.alarm(TIME_OUT_REMOTE_CODE)
|
||||
while trust_remote_code is None:
|
||||
answer = input(
|
||||
f"The repository for {model_name} contains custom code which must be executed to correctly "
|
||||
f"load the model. You can inspect the repository content at https://hf.co/{model_name}.\n"
|
||||
f"{error_message} You can inspect the repository content at https://hf.co/{model_name}.\n"
|
||||
f"You can avoid this prompt in future by passing the argument `trust_remote_code=True`.\n\n"
|
||||
f"Do you wish to run the custom code? [y/N] "
|
||||
)
|
||||
@ -687,8 +716,7 @@ def resolve_trust_remote_code(trust_remote_code, model_name, has_local_code, has
|
||||
except Exception:
|
||||
# OS which does not support signal.SIGALRM
|
||||
raise ValueError(
|
||||
f"The repository for {model_name} contains custom code which must be executed to correctly "
|
||||
f"load the model. You can inspect the repository content at https://hf.co/{model_name}.\n"
|
||||
f"{error_message} You can inspect the repository content at https://hf.co/{model_name}.\n"
|
||||
f"Please pass the argument `trust_remote_code=True` to allow custom code to be run."
|
||||
)
|
||||
finally:
|
||||
@ -701,9 +729,64 @@ def resolve_trust_remote_code(trust_remote_code, model_name, has_local_code, has
|
||||
|
||||
if has_remote_code and not has_local_code and not trust_remote_code:
|
||||
raise ValueError(
|
||||
f"Loading {model_name} requires you to execute the configuration file in that"
|
||||
" repo on your local machine. Make sure you have read the code there to avoid malicious use, then"
|
||||
" set the option `trust_remote_code=True` to remove this error."
|
||||
f"{error_message} You can inspect the repository content at https://hf.co/{model_name}.\n"
|
||||
f"Please pass the argument `trust_remote_code=True` to allow custom code to be run."
|
||||
)
|
||||
|
||||
return trust_remote_code
|
||||
|
||||
|
||||
def check_python_requirements(path_or_repo_id, requirements_file="requirements.txt", **kwargs):
|
||||
"""
|
||||
Tries to locate `requirements_file` in a local folder or repo, and confirms that the environment has all the
|
||||
python dependencies installed.
|
||||
|
||||
Args:
|
||||
path_or_repo_id (`str` or `os.PathLike`):
|
||||
This can be either:
|
||||
- a string, the *model id* of a model repo on huggingface.co.
|
||||
- a path to a *directory* potentially containing the file.
|
||||
kwargs (`Dict[str, Any]`, *optional*):
|
||||
Additional arguments to pass to `cached_file`.
|
||||
"""
|
||||
failed = [] # error messages regarding requirements
|
||||
try:
|
||||
requirements = cached_file(path_or_repo_id=path_or_repo_id, filename=requirements_file, **kwargs)
|
||||
with open(requirements, "r") as f:
|
||||
requirements = f.readlines()
|
||||
|
||||
for requirement in requirements:
|
||||
requirement = requirement.strip()
|
||||
if not requirement or requirement.startswith("#"): # skip empty lines and comments
|
||||
continue
|
||||
|
||||
try:
|
||||
# e.g. "torch>2.6.0" -> "torch", ">", "2.6.0"
|
||||
package_name, delimiter, version_number = split_package_version(requirement)
|
||||
except ValueError: # e.g. "torch", as opposed to "torch>2.6.0"
|
||||
package_name = requirement
|
||||
delimiter, version_number = None, None
|
||||
|
||||
try:
|
||||
local_package_version = importlib.metadata.version(package_name)
|
||||
except importlib.metadata.PackageNotFoundError:
|
||||
failed.append(f"{requirement} (installed: None)")
|
||||
continue
|
||||
|
||||
if delimiter is not None and version_number is not None:
|
||||
is_satisfied = VersionComparison.from_string(delimiter)(
|
||||
version.parse(local_package_version), version.parse(version_number)
|
||||
)
|
||||
else:
|
||||
is_satisfied = True
|
||||
|
||||
if not is_satisfied:
|
||||
failed.append(f"{requirement} (installed: {local_package_version})")
|
||||
|
||||
except OSError: # no requirements.txt
|
||||
pass
|
||||
|
||||
if failed:
|
||||
raise ImportError(
|
||||
f"Missing requirements in your local environment for `{path_or_repo_id}`:\n" + "\n".join(failed)
|
||||
)
|
||||
|
@ -28,7 +28,7 @@ from ..utils import is_sklearn_available
|
||||
if is_sklearn_available():
|
||||
from sklearn.metrics import roc_curve
|
||||
|
||||
from ..cache_utils import DynamicCache
|
||||
from ..cache_utils import Cache
|
||||
from ..pytorch_utils import isin_mps_friendly
|
||||
from .logits_process import LogitsProcessorList, MinLengthLogitsProcessor, SuppressTokensLogitsProcessor
|
||||
|
||||
@ -1183,7 +1183,9 @@ class EarlyExitCandidateGenerator(AssistedCandidateGenerator):
|
||||
def _crop_past_key_values(model, past_key_values, max_length):
|
||||
"""Crops the past key values up to a certain maximum length."""
|
||||
new_past = []
|
||||
if model.config.is_encoder_decoder:
|
||||
if isinstance(past_key_values, Cache):
|
||||
past_key_values.crop(max_length)
|
||||
elif model.config.is_encoder_decoder:
|
||||
for idx in range(len(past_key_values)):
|
||||
new_past.append(
|
||||
(
|
||||
@ -1204,8 +1206,6 @@ def _crop_past_key_values(model, past_key_values, max_length):
|
||||
else:
|
||||
for idx in range(len(past_key_values)):
|
||||
past_key_values[idx] = past_key_values[idx][:, :, :max_length, :]
|
||||
elif isinstance(past_key_values, DynamicCache):
|
||||
past_key_values.crop(max_length)
|
||||
elif past_key_values is not None:
|
||||
for idx in range(len(past_key_values)):
|
||||
if past_key_values[idx] != ([], []):
|
||||
|
@ -35,6 +35,7 @@ from ..utils import (
|
||||
is_torch_available,
|
||||
logging,
|
||||
)
|
||||
from ..utils.deprecation import deprecate_kwarg
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -514,7 +515,7 @@ class GenerationConfig(PushToHubMixin):
|
||||
raise err
|
||||
|
||||
# Validate the values of the attributes
|
||||
self.validate(is_init=True)
|
||||
self.validate()
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.to_json_string(ignore_metadata=True))
|
||||
@ -576,9 +577,10 @@ class GenerationConfig(PushToHubMixin):
|
||||
if generation_mode in ("greedy_search", "sample"):
|
||||
generation_mode = GenerationMode.ASSISTED_GENERATION
|
||||
else:
|
||||
raise ValueError(
|
||||
logger.warning(
|
||||
"You've set `assistant_model`, which triggers assisted generate. Currently, assisted generate "
|
||||
"is only supported with Greedy Search and Sample."
|
||||
"is only supported with Greedy Search and Sample. However, the base decoding mode (based on "
|
||||
f"current flags) is {generation_mode} -- some of the set flags will be ignored."
|
||||
)
|
||||
|
||||
# DoLa generation may extend some generation modes
|
||||
@ -586,13 +588,15 @@ class GenerationConfig(PushToHubMixin):
|
||||
if generation_mode in ("greedy_search", "sample"):
|
||||
generation_mode = GenerationMode.DOLA_GENERATION
|
||||
else:
|
||||
raise ValueError(
|
||||
logger.warning(
|
||||
"You've set `dola_layers`, which triggers DoLa generate. Currently, DoLa generate "
|
||||
"is only supported with Greedy Search and Sample."
|
||||
"is only supported with Greedy Search and Sample. However, the base decoding mode (based on "
|
||||
f"current flags) is {generation_mode} -- some of the set flags will be ignored."
|
||||
)
|
||||
return generation_mode
|
||||
|
||||
def validate(self, is_init=False):
|
||||
@deprecate_kwarg("is_init", version="4.54.0")
|
||||
def validate(self, strict=False):
|
||||
"""
|
||||
Validates the values of the attributes of the [`GenerationConfig`] instance. Raises exceptions in the presence
|
||||
of parameterization that can be detected as incorrect from the configuration instance alone.
|
||||
@ -600,174 +604,24 @@ class GenerationConfig(PushToHubMixin):
|
||||
Note that some parameters not validated here are best validated at generate runtime, as they may depend on
|
||||
other inputs and/or the model, such as parameters related to the generation length.
|
||||
|
||||
Arg:
|
||||
is_init (`bool`, *optional*, defaults to `False`):
|
||||
Whether the validation is performed during the initialization of the instance.
|
||||
Args:
|
||||
strict (bool): If True, raise an exception for any issues found. If False, only log issues.
|
||||
"""
|
||||
minor_issues = {} # format: {attribute_name: issue_description}
|
||||
|
||||
# Validation of individual attributes
|
||||
# 1. Validation of individual attributes
|
||||
# 1.1. Decoding attributes
|
||||
if self.early_stopping not in {True, False, "never"}:
|
||||
raise ValueError(f"`early_stopping` must be a boolean or 'never', but is {self.early_stopping}.")
|
||||
if self.max_new_tokens is not None and self.max_new_tokens <= 0:
|
||||
raise ValueError(f"`max_new_tokens` must be greater than 0, but is {self.max_new_tokens}.")
|
||||
if self.pad_token_id is not None and self.pad_token_id < 0:
|
||||
warnings.warn(
|
||||
minor_issues["pad_token_id"] = (
|
||||
f"`pad_token_id` should be positive but got {self.pad_token_id}. This will cause errors when batch "
|
||||
"generating, if there is padding. Please set `pad_token_id` explicitly as "
|
||||
"`model.generation_config.pad_token_id=PAD_TOKEN_ID` to avoid errors in generation"
|
||||
)
|
||||
|
||||
# Validation of attribute relations:
|
||||
fix_location = ""
|
||||
if is_init:
|
||||
fix_location = (
|
||||
" This was detected when initializing the generation config instance, which means the corresponding "
|
||||
"file may hold incorrect parameterization and should be fixed."
|
||||
)
|
||||
|
||||
# 1. detect sampling-only parameterization when not in sampling mode
|
||||
if self.do_sample is False:
|
||||
greedy_wrong_parameter_msg = (
|
||||
"`do_sample` is set to `False`. However, `{flag_name}` is set to `{flag_value}` -- this flag is only "
|
||||
"used in sample-based generation modes. You should set `do_sample=True` or unset `{flag_name}`."
|
||||
+ fix_location
|
||||
)
|
||||
if self.temperature is not None and self.temperature != 1.0:
|
||||
warnings.warn(
|
||||
greedy_wrong_parameter_msg.format(flag_name="temperature", flag_value=self.temperature),
|
||||
UserWarning,
|
||||
)
|
||||
if self.top_p is not None and self.top_p != 1.0:
|
||||
warnings.warn(
|
||||
greedy_wrong_parameter_msg.format(flag_name="top_p", flag_value=self.top_p),
|
||||
UserWarning,
|
||||
)
|
||||
if self.min_p is not None:
|
||||
warnings.warn(
|
||||
greedy_wrong_parameter_msg.format(flag_name="min_p", flag_value=self.min_p),
|
||||
UserWarning,
|
||||
)
|
||||
if self.typical_p is not None and self.typical_p != 1.0:
|
||||
warnings.warn(
|
||||
greedy_wrong_parameter_msg.format(flag_name="typical_p", flag_value=self.typical_p),
|
||||
UserWarning,
|
||||
)
|
||||
if (
|
||||
self.top_k is not None and self.top_k != 50 and self.penalty_alpha is None
|
||||
): # contrastive search uses top_k
|
||||
warnings.warn(
|
||||
greedy_wrong_parameter_msg.format(flag_name="top_k", flag_value=self.top_k),
|
||||
UserWarning,
|
||||
)
|
||||
if self.epsilon_cutoff is not None and self.epsilon_cutoff != 0.0:
|
||||
warnings.warn(
|
||||
greedy_wrong_parameter_msg.format(flag_name="epsilon_cutoff", flag_value=self.epsilon_cutoff),
|
||||
UserWarning,
|
||||
)
|
||||
if self.eta_cutoff is not None and self.eta_cutoff != 0.0:
|
||||
warnings.warn(
|
||||
greedy_wrong_parameter_msg.format(flag_name="eta_cutoff", flag_value=self.eta_cutoff),
|
||||
UserWarning,
|
||||
)
|
||||
|
||||
# 2. detect beam-only parameterization when not in beam mode
|
||||
if self.num_beams is None:
|
||||
warnings.warn("`num_beams` is set to None - defaulting to 1.", UserWarning)
|
||||
self.num_beams = 1
|
||||
|
||||
if self.num_beams == 1:
|
||||
single_beam_wrong_parameter_msg = (
|
||||
"`num_beams` is set to 1. However, `{flag_name}` is set to `{flag_value}` -- this flag is only used "
|
||||
"in beam-based generation modes. You should set `num_beams>1` or unset `{flag_name}`." + fix_location
|
||||
)
|
||||
if self.early_stopping is not False:
|
||||
warnings.warn(
|
||||
single_beam_wrong_parameter_msg.format(flag_name="early_stopping", flag_value=self.early_stopping),
|
||||
UserWarning,
|
||||
)
|
||||
if self.num_beam_groups is not None and self.num_beam_groups != 1:
|
||||
warnings.warn(
|
||||
single_beam_wrong_parameter_msg.format(
|
||||
flag_name="num_beam_groups", flag_value=self.num_beam_groups
|
||||
),
|
||||
UserWarning,
|
||||
)
|
||||
if self.diversity_penalty is not None and self.diversity_penalty != 0.0:
|
||||
warnings.warn(
|
||||
single_beam_wrong_parameter_msg.format(
|
||||
flag_name="diversity_penalty", flag_value=self.diversity_penalty
|
||||
),
|
||||
UserWarning,
|
||||
)
|
||||
if self.length_penalty is not None and self.length_penalty != 1.0:
|
||||
warnings.warn(
|
||||
single_beam_wrong_parameter_msg.format(flag_name="length_penalty", flag_value=self.length_penalty),
|
||||
UserWarning,
|
||||
)
|
||||
if self.constraints is not None:
|
||||
warnings.warn(
|
||||
single_beam_wrong_parameter_msg.format(flag_name="constraints", flag_value=self.constraints),
|
||||
UserWarning,
|
||||
)
|
||||
|
||||
# 3. detect incorrect parameterization specific to advanced beam modes
|
||||
else:
|
||||
# constrained beam search
|
||||
if self.constraints is not None or self.force_words_ids is not None:
|
||||
constrained_wrong_parameter_msg = (
|
||||
"one of `constraints`, `force_words_ids` is not `None`, triggering constrained beam search. However, "
|
||||
"`{flag_name}` is set to `{flag_value}`, which is incompatible with this generation mode. Set "
|
||||
"`constraints` and `force_words_ids` to `None` or unset `{flag_name}` to continue." + fix_location
|
||||
)
|
||||
if self.do_sample is True:
|
||||
raise ValueError(
|
||||
constrained_wrong_parameter_msg.format(flag_name="do_sample", flag_value=self.do_sample)
|
||||
)
|
||||
if self.num_beam_groups is not None and self.num_beam_groups != 1:
|
||||
raise ValueError(
|
||||
constrained_wrong_parameter_msg.format(
|
||||
flag_name="num_beam_groups", flag_value=self.num_beam_groups
|
||||
)
|
||||
)
|
||||
# group beam search
|
||||
if self.diversity_penalty != 0.0 or self.num_beam_groups != 1:
|
||||
group_error_prefix = (
|
||||
"`diversity_penalty` is not 0.0 or `num_beam_groups` is not 1, triggering group beam search. In "
|
||||
"this generation mode, "
|
||||
)
|
||||
if self.do_sample is True:
|
||||
raise ValueError(group_error_prefix + "`do_sample` must be set to `False`")
|
||||
if self.num_beams % self.num_beam_groups != 0:
|
||||
raise ValueError(group_error_prefix + "`num_beams` should be divisible by `num_beam_groups`")
|
||||
if self.diversity_penalty == 0.0:
|
||||
raise ValueError(
|
||||
group_error_prefix
|
||||
+ "`diversity_penalty` should be greater than `0.0`, otherwise your groups will be identical."
|
||||
)
|
||||
# DoLa generation
|
||||
if self.dola_layers is not None and (self.repetition_penalty is None or self.repetition_penalty < 1.2):
|
||||
warnings.warn(
|
||||
"`dola_layers` is set to trigger DoLa decoding, but `repetition_penalty` is set to a value of "
|
||||
f"{self.repetition_penalty}, which could induce unwanted repetition. The recommended value for "
|
||||
"DoLa decoding is `repetition_penalty>=1.2`.",
|
||||
UserWarning,
|
||||
)
|
||||
|
||||
# 4. check `num_return_sequences`
|
||||
if self.num_return_sequences != 1:
|
||||
if self.num_beams == 1:
|
||||
if self.do_sample is False:
|
||||
raise ValueError(
|
||||
"Greedy methods without beam search do not support `num_return_sequences` different than 1 "
|
||||
f"(got {self.num_return_sequences})."
|
||||
)
|
||||
elif self.num_return_sequences > self.num_beams:
|
||||
raise ValueError(
|
||||
f"`num_return_sequences` ({self.num_return_sequences}) has to be smaller or equal to `num_beams` "
|
||||
f"({self.num_beams})."
|
||||
)
|
||||
|
||||
# 5. check cache-related arguments
|
||||
# 1.2. Cache attributes
|
||||
if self.cache_implementation is not None and self.cache_implementation not in ALL_CACHE_IMPLEMENTATIONS:
|
||||
raise ValueError(
|
||||
f"Invalid `cache_implementation` ({self.cache_implementation}). Choose one of: "
|
||||
@ -784,6 +638,141 @@ class GenerationConfig(PushToHubMixin):
|
||||
if not isinstance(self.cache_config, cache_class):
|
||||
self.cache_config = cache_class.from_dict(self.cache_config)
|
||||
self.cache_config.validate()
|
||||
# 1.3. Performance attributes
|
||||
if self.compile_config is not None and not isinstance(self.compile_config, CompileConfig):
|
||||
raise ValueError(
|
||||
f"You provided `compile_config` as an instance of {type(self.compile_config)}, but it must be an "
|
||||
"instance of `CompileConfig`."
|
||||
)
|
||||
# 1.4. Watermarking attributes
|
||||
if self.watermarking_config is not None:
|
||||
if not (
|
||||
isinstance(self.watermarking_config, WatermarkingConfig)
|
||||
or isinstance(self.watermarking_config, SynthIDTextWatermarkingConfig)
|
||||
):
|
||||
minor_issues["watermarking_config"] = (
|
||||
"`watermarking_config` as a dict is deprecated and will be removed in v4.54.0. Please construct "
|
||||
"`watermarking_config` object with `WatermarkingConfig` or `SynthIDTextWatermarkingConfig` class."
|
||||
)
|
||||
self.watermarking_config = WatermarkingConfig.from_dict(self.watermarking_config)
|
||||
self.watermarking_config.validate()
|
||||
|
||||
# 2. Validation of attribute combinations
|
||||
# 2.1. detect sampling-only parameterization when not in sampling mode
|
||||
if self.do_sample is False:
|
||||
greedy_wrong_parameter_msg = (
|
||||
"`do_sample` is set to `False`. However, `{flag_name}` is set to `{flag_value}` -- this flag is only "
|
||||
"used in sample-based generation modes. You should set `do_sample=True` or unset `{flag_name}`."
|
||||
)
|
||||
if self.temperature is not None and self.temperature != 1.0:
|
||||
minor_issues["temperature"] = greedy_wrong_parameter_msg.format(
|
||||
flag_name="temperature", flag_value=self.temperature
|
||||
)
|
||||
if self.top_p is not None and self.top_p != 1.0:
|
||||
minor_issues["top_p"] = greedy_wrong_parameter_msg.format(flag_name="top_p", flag_value=self.top_p)
|
||||
if self.min_p is not None:
|
||||
minor_issues["min_p"] = greedy_wrong_parameter_msg.format(flag_name="min_p", flag_value=self.min_p)
|
||||
if self.typical_p is not None and self.typical_p != 1.0:
|
||||
minor_issues["typical_p"] = greedy_wrong_parameter_msg.format(
|
||||
flag_name="typical_p", flag_value=self.typical_p
|
||||
)
|
||||
if (
|
||||
self.top_k is not None and self.top_k != 50 and self.penalty_alpha is None
|
||||
): # contrastive search uses top_k
|
||||
minor_issues["top_k"] = greedy_wrong_parameter_msg.format(flag_name="top_k", flag_value=self.top_k)
|
||||
if self.epsilon_cutoff is not None and self.epsilon_cutoff != 0.0:
|
||||
minor_issues["epsilon_cutoff"] = greedy_wrong_parameter_msg.format(
|
||||
flag_name="epsilon_cutoff", flag_value=self.epsilon_cutoff
|
||||
)
|
||||
if self.eta_cutoff is not None and self.eta_cutoff != 0.0:
|
||||
minor_issues["eta_cutoff"] = greedy_wrong_parameter_msg.format(
|
||||
flag_name="eta_cutoff", flag_value=self.eta_cutoff
|
||||
)
|
||||
|
||||
# 2.2. detect beam-only parameterization when not in beam mode
|
||||
if self.num_beams == 1:
|
||||
single_beam_wrong_parameter_msg = (
|
||||
"`num_beams` is set to 1. However, `{flag_name}` is set to `{flag_value}` -- this flag is only used "
|
||||
"in beam-based generation modes. You should set `num_beams>1` or unset `{flag_name}`."
|
||||
)
|
||||
if self.early_stopping is not False:
|
||||
minor_issues["early_stopping"] = single_beam_wrong_parameter_msg.format(
|
||||
flag_name="early_stopping", flag_value=self.early_stopping
|
||||
)
|
||||
if self.num_beam_groups is not None and self.num_beam_groups != 1:
|
||||
minor_issues["num_beam_groups"] = single_beam_wrong_parameter_msg.format(
|
||||
flag_name="num_beam_groups", flag_value=self.num_beam_groups
|
||||
)
|
||||
if self.diversity_penalty is not None and self.diversity_penalty != 0.0:
|
||||
minor_issues["diversity_penalty"] = single_beam_wrong_parameter_msg.format(
|
||||
flag_name="diversity_penalty", flag_value=self.diversity_penalty
|
||||
)
|
||||
if self.length_penalty is not None and self.length_penalty != 1.0:
|
||||
minor_issues["length_penalty"] = single_beam_wrong_parameter_msg.format(
|
||||
flag_name="length_penalty", flag_value=self.length_penalty
|
||||
)
|
||||
if self.constraints is not None:
|
||||
minor_issues["constraints"] = single_beam_wrong_parameter_msg.format(
|
||||
flag_name="constraints", flag_value=self.constraints
|
||||
)
|
||||
# DoLa generation needs num_beams == 1
|
||||
if self.dola_layers is not None and (self.repetition_penalty is None or self.repetition_penalty < 1.2):
|
||||
minor_issues["repetition_penalty"] = (
|
||||
"`dola_layers` is set to trigger DoLa decoding, but `repetition_penalty` is set to a value of "
|
||||
f"{self.repetition_penalty}, which could induce unwanted repetition. The recommended value for "
|
||||
"DoLa decoding is `repetition_penalty>=1.2`.",
|
||||
)
|
||||
|
||||
# 2.3. detect incorrect parameterization specific to advanced beam modes
|
||||
else:
|
||||
# constrained beam search
|
||||
if self.constraints is not None or self.force_words_ids is not None:
|
||||
constrained_wrong_parameter_msg = (
|
||||
"one of `constraints`, `force_words_ids` is not `None`, triggering constrained beam search. "
|
||||
"However, `{flag_name}` is set to `{flag_value}`, which is incompatible with this generation "
|
||||
"mode. Set `constraints` and `force_words_ids` to `None` or unset `{flag_name}` to continue."
|
||||
)
|
||||
if self.do_sample is True:
|
||||
raise ValueError(
|
||||
constrained_wrong_parameter_msg.format(flag_name="do_sample", flag_value=self.do_sample)
|
||||
)
|
||||
if self.num_beam_groups is not None and self.num_beam_groups != 1:
|
||||
raise ValueError(
|
||||
constrained_wrong_parameter_msg.format(
|
||||
flag_name="num_beam_groups", flag_value=self.num_beam_groups
|
||||
)
|
||||
)
|
||||
# group beam search
|
||||
elif self.diversity_penalty != 0.0 or self.num_beam_groups != 1:
|
||||
group_error_prefix = (
|
||||
"`diversity_penalty` is not 0.0 or `num_beam_groups` is not 1, triggering group beam search. In "
|
||||
"this generation mode, "
|
||||
)
|
||||
if self.do_sample is True:
|
||||
raise ValueError(group_error_prefix + "`do_sample` must be set to `False`")
|
||||
if self.num_beams % self.num_beam_groups != 0:
|
||||
raise ValueError(group_error_prefix + "`num_beams` should be divisible by `num_beam_groups`")
|
||||
if self.diversity_penalty == 0.0:
|
||||
raise ValueError(
|
||||
group_error_prefix
|
||||
+ "`diversity_penalty` should be greater than `0.0`, otherwise your groups will be identical."
|
||||
)
|
||||
|
||||
# 2.4. check `num_return_sequences`
|
||||
if self.num_return_sequences != 1:
|
||||
if self.num_beams == 1:
|
||||
if self.do_sample is False:
|
||||
raise ValueError(
|
||||
"Greedy methods without beam search do not support `num_return_sequences` different than 1 "
|
||||
f"(got {self.num_return_sequences})."
|
||||
)
|
||||
elif self.num_return_sequences > self.num_beams:
|
||||
raise ValueError(
|
||||
f"`num_return_sequences` ({self.num_return_sequences}) has to be smaller or equal to `num_beams` "
|
||||
f"({self.num_beams})."
|
||||
)
|
||||
|
||||
# 2.5. check cache-related arguments
|
||||
if self.use_cache is False:
|
||||
# In this case, all cache-related arguments should be unset. However, since `use_cache=False` is often used
|
||||
# passed to `generate` directly to hot-fix cache issues, let's raise a warning instead of an error
|
||||
@ -794,42 +783,20 @@ class GenerationConfig(PushToHubMixin):
|
||||
)
|
||||
for arg_name in ("cache_implementation", "cache_config", "return_legacy_cache"):
|
||||
if getattr(self, arg_name) is not None:
|
||||
logger.warning_once(
|
||||
no_cache_warning.format(cache_arg=arg_name, cache_arg_value=getattr(self, arg_name))
|
||||
minor_issues[arg_name] = no_cache_warning.format(
|
||||
cache_arg=arg_name, cache_arg_value=getattr(self, arg_name)
|
||||
)
|
||||
|
||||
# 6. check watermarking arguments
|
||||
if self.watermarking_config is not None:
|
||||
if not (
|
||||
isinstance(self.watermarking_config, WatermarkingConfig)
|
||||
or isinstance(self.watermarking_config, SynthIDTextWatermarkingConfig)
|
||||
):
|
||||
warnings.warn(
|
||||
"`watermarking_config` as a dict is deprecated. Please construct `watermarking_config` object with "
|
||||
"`WatermarkingConfig` or `SynthIDTextWatermarkingConfig` class.",
|
||||
FutureWarning,
|
||||
)
|
||||
self.watermarking_config = WatermarkingConfig.from_dict(self.watermarking_config)
|
||||
self.watermarking_config.validate()
|
||||
|
||||
# 7. performances arguments
|
||||
if self.compile_config is not None and not isinstance(self.compile_config, CompileConfig):
|
||||
raise ValueError(
|
||||
f"You provided `compile_config` as an instance of {type(self.compile_config)}, but it must be an "
|
||||
"instance of `CompileConfig`."
|
||||
)
|
||||
|
||||
# 8. other incorrect combinations
|
||||
# 2.6. other incorrect combinations
|
||||
if self.return_dict_in_generate is not True:
|
||||
for extra_output_flag in self.extra_output_flags:
|
||||
if getattr(self, extra_output_flag) is True:
|
||||
warnings.warn(
|
||||
minor_issues[extra_output_flag] = (
|
||||
f"`return_dict_in_generate` is NOT set to `True`, but `{extra_output_flag}` is. When "
|
||||
f"`return_dict_in_generate` is not `True`, `{extra_output_flag}` is ignored.",
|
||||
UserWarning,
|
||||
f"`return_dict_in_generate` is not `True`, `{extra_output_flag}` is ignored."
|
||||
)
|
||||
|
||||
# 8. check common issue: passing `generate` arguments inside the generation config
|
||||
# 3. Check common issue: passing `generate` arguments inside the generation config
|
||||
generate_arguments = (
|
||||
"logits_processor",
|
||||
"stopping_criteria",
|
||||
@ -839,6 +806,7 @@ class GenerationConfig(PushToHubMixin):
|
||||
"streamer",
|
||||
"negative_prompt_ids",
|
||||
"negative_prompt_attention_mask",
|
||||
"use_model_defaults",
|
||||
)
|
||||
for arg in generate_arguments:
|
||||
if hasattr(self, arg):
|
||||
@ -847,6 +815,30 @@ class GenerationConfig(PushToHubMixin):
|
||||
"`generate()` (or a pipeline) directly."
|
||||
)
|
||||
|
||||
# Finally, handle caught minor issues. With default parameterization, we will throw a minimal warning.
|
||||
if len(minor_issues) > 0:
|
||||
# Full list of issues with potential fixes
|
||||
info_message = []
|
||||
for attribute_name, issue_description in minor_issues.items():
|
||||
info_message.append(f"- `{attribute_name}`: {issue_description}")
|
||||
info_message = "\n".join(info_message)
|
||||
info_message += (
|
||||
"\nIf you're using a pretrained model, note that some of these attributes may be set through the "
|
||||
"model's `generation_config.json` file."
|
||||
)
|
||||
|
||||
if strict:
|
||||
raise ValueError("GenerationConfig is invalid: \n" + info_message)
|
||||
else:
|
||||
attributes_with_issues = list(minor_issues.keys())
|
||||
warning_message = (
|
||||
f"The following generation flags are not valid and may be ignored: {attributes_with_issues}."
|
||||
)
|
||||
if logger.getEffectiveLevel() >= logging.WARNING:
|
||||
warning_message += " Set `TRANSFORMERS_VERBOSITY=info` for more details."
|
||||
logger.warning(warning_message)
|
||||
logger.info(info_message)
|
||||
|
||||
def save_pretrained(
|
||||
self,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
@ -871,18 +863,13 @@ class GenerationConfig(PushToHubMixin):
|
||||
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
||||
"""
|
||||
|
||||
# At save time, validate the instance -- if any warning/exception is thrown, we refuse to save the instance.
|
||||
# At save time, validate the instance enforcing strictness -- if any warning/exception would be thrown, we
|
||||
# refuse to save the instance.
|
||||
# This strictness is enforced to prevent bad configurations from being saved and re-used.
|
||||
try:
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
self.validate()
|
||||
if len(caught_warnings) > 0:
|
||||
raise ValueError(str([w.message for w in caught_warnings]))
|
||||
self.validate(strict=True)
|
||||
except ValueError as exc:
|
||||
raise ValueError(
|
||||
"The generation config instance is invalid -- `.validate()` throws warnings and/or exceptions. "
|
||||
"Fix these issues to save the configuration.\n\nThrown during validation:\n" + str(exc)
|
||||
)
|
||||
raise ValueError(str(exc) + "\n\nFix these issues to save the configuration.")
|
||||
|
||||
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||
|
||||
|
@ -37,8 +37,9 @@ STOPPING_CRITERIA_INPUTS_DOCSTRING = r"""
|
||||
Additional stopping criteria specific kwargs.
|
||||
|
||||
Return:
|
||||
`torch.BoolTensor`. (`torch.BoolTensor` of shape `(batch_size, 1)`), where `True` indicates we stop generation
|
||||
for a particular row, `True` indicates we should continue.
|
||||
`torch.BoolTensor`. (`torch.BoolTensor` of shape `(batch_size, 1)`):
|
||||
`True` indicates we stop generation for a particular row.
|
||||
`False` indicates we should continue.
|
||||
|
||||
"""
|
||||
|
||||
|
@ -23,12 +23,11 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Un
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from huggingface_hub import file_exists
|
||||
from packaging import version
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from transformers.generation.candidate_generator import AssistantVocabTranslatorCache
|
||||
|
||||
from ..cache_utils import (
|
||||
Cache,
|
||||
DynamicCache,
|
||||
@ -39,6 +38,12 @@ from ..cache_utils import (
|
||||
QuantizedCacheConfig,
|
||||
)
|
||||
from ..configuration_utils import PretrainedConfig
|
||||
from ..dynamic_module_utils import (
|
||||
check_python_requirements,
|
||||
get_cached_module_file,
|
||||
get_class_in_module,
|
||||
resolve_trust_remote_code,
|
||||
)
|
||||
from ..integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||
from ..integrations.fsdp import is_fsdp_managed_module
|
||||
from ..modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput
|
||||
@ -55,6 +60,7 @@ from ..utils import (
|
||||
from .beam_constraints import DisjunctiveConstraint, PhrasalConstraint
|
||||
from .beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
|
||||
from .candidate_generator import (
|
||||
AssistantVocabTranslatorCache,
|
||||
AssistedCandidateGenerator,
|
||||
AssistedCandidateGeneratorDifferentTokenizers,
|
||||
CandidateGenerator,
|
||||
@ -376,6 +382,73 @@ class GenerationMixin:
|
||||
To learn more about decoding strategies refer to the [text generation strategies guide](../generation_strategies).
|
||||
"""
|
||||
|
||||
def load_custom_generate(
|
||||
self,
|
||||
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
|
||||
trust_remote_code: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> Callable:
|
||||
"""
|
||||
Loads and returns a custom generate function, given a model repo.
|
||||
|
||||
Args:
|
||||
pretrained_model_name_or_path (`str` or `os.PathLike`):
|
||||
Can be either:
|
||||
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
|
||||
- A path to a *directory* containing model weights saved using
|
||||
[`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
|
||||
trust_remote_code (`bool`, *optional*):
|
||||
Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
|
||||
should only be set to `True` for repositories you trust and in which you have read the code, as it will
|
||||
execute code present on the Hub on your local machine.
|
||||
**kwargs:
|
||||
Additional keyword arguments for remote code loading.
|
||||
|
||||
Raises:
|
||||
OSError: If `pretrained_model_name_or_path` does not contain a `custom_generate` subdirectory.
|
||||
|
||||
Returns:
|
||||
A callable that can be used to generate text.
|
||||
"""
|
||||
# Does `pretrained_model_name_or_path` have a `custom_generate` subdirectory? If not -> OSError
|
||||
is_local_code = os.path.exists(pretrained_model_name_or_path)
|
||||
has_custom_generate_folder = True
|
||||
if is_local_code:
|
||||
if not os.path.exists(os.path.join(pretrained_model_name_or_path, "custom_generate/generate.py")):
|
||||
has_custom_generate_folder = False
|
||||
else:
|
||||
if not file_exists(pretrained_model_name_or_path, "custom_generate/generate.py"):
|
||||
has_custom_generate_folder = False
|
||||
|
||||
if not has_custom_generate_folder:
|
||||
raise OSError(
|
||||
f"`{pretrained_model_name_or_path}` does not contain a `custom_generate` subdirectory with a "
|
||||
"`generate.py` file, can't load the custom generate function."
|
||||
)
|
||||
|
||||
# Handle opt-in `trust_remote_code` and related exceptions
|
||||
error_message = (
|
||||
f"The repository `{pretrained_model_name_or_path}` contains custom generation code that will override "
|
||||
"the default `generate` method."
|
||||
)
|
||||
resolve_trust_remote_code(
|
||||
trust_remote_code,
|
||||
pretrained_model_name_or_path,
|
||||
has_local_code=is_local_code,
|
||||
has_remote_code=not is_local_code,
|
||||
error_message=error_message,
|
||||
)
|
||||
|
||||
# Load the custom generate function
|
||||
check_python_requirements(
|
||||
pretrained_model_name_or_path, requirements_file="custom_generate/requirements.txt", **kwargs
|
||||
)
|
||||
module = get_cached_module_file(
|
||||
pretrained_model_name_or_path, module_file="custom_generate/generate.py", **kwargs
|
||||
)
|
||||
custom_generate_function = get_class_in_module("generate", module)
|
||||
return custom_generate_function
|
||||
|
||||
def _cache_dependant_input_preparation(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
@ -1679,16 +1752,21 @@ class GenerationMixin:
|
||||
use_model_defaults is None and model_base_version >= version.parse("4.50.0")
|
||||
):
|
||||
modified_values = {}
|
||||
default_generation_config = GenerationConfig()
|
||||
for key, default_value in default_generation_config.__dict__.items():
|
||||
global_default_generation_config = GenerationConfig()
|
||||
model_generation_config = self.generation_config
|
||||
# we iterate over the model's generation config: it may hold custom keys, which we'll want to copy
|
||||
for key, model_gen_config_value in model_generation_config.__dict__.items():
|
||||
if key.startswith("_") or key == "transformers_version": # metadata
|
||||
continue
|
||||
custom_gen_config_value = getattr(generation_config, key)
|
||||
model_gen_config_value = getattr(self.generation_config, key)
|
||||
if custom_gen_config_value == default_value and model_gen_config_value != default_value:
|
||||
global_default_value = getattr(global_default_generation_config, key, None)
|
||||
custom_gen_config_value = getattr(generation_config, key, None)
|
||||
if (
|
||||
custom_gen_config_value == global_default_value
|
||||
and model_gen_config_value != global_default_value
|
||||
):
|
||||
modified_values[key] = model_gen_config_value
|
||||
setattr(generation_config, key, model_gen_config_value)
|
||||
if len(modified_values) > 0:
|
||||
if use_model_defaults is None and len(modified_values) > 0:
|
||||
logger.warning_once(
|
||||
f"`generation_config` default values have been modified to match model-specific defaults: "
|
||||
f"{modified_values}. If this is not desired, please set these values explicitly."
|
||||
@ -1711,6 +1789,8 @@ class GenerationMixin:
|
||||
def _get_initial_cache_position(self, seq_length, device, model_kwargs):
|
||||
"""Calculates `cache_position` for the pre-fill stage based on `input_ids` and optionally past length"""
|
||||
# `torch.compile`-friendly `torch.arange` from a shape -- the lines below are equivalent to `torch.arange`
|
||||
if "cache_position" in model_kwargs and model_kwargs["cache_position"]:
|
||||
return model_kwargs
|
||||
if "inputs_embeds" in model_kwargs and not self.config.is_encoder_decoder:
|
||||
cache_position = torch.ones_like(model_kwargs["inputs_embeds"][0, :, 0], dtype=torch.int64).cumsum(0) - 1
|
||||
elif "decoder_inputs_embeds" in model_kwargs and self.config.is_encoder_decoder:
|
||||
@ -2156,6 +2236,7 @@ class GenerationMixin:
|
||||
negative_prompt_ids: Optional[torch.Tensor] = None,
|
||||
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
use_model_defaults: Optional[bool] = None,
|
||||
custom_generate: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> Union[GenerateOutput, torch.LongTensor]:
|
||||
r"""
|
||||
@ -2225,6 +2306,11 @@ class GenerationMixin:
|
||||
generation configuration (`model.generation_config`), as opposed to the global defaults
|
||||
(`GenerationConfig()`). If unset, models saved starting from `v4.50` will consider this flag to be
|
||||
`True`.
|
||||
custom_generate (`str`, *optional*):
|
||||
A string containing the name of a huggingface.co repository. If provided, the custom `generate`
|
||||
function defined in that reposity's `custom_generate/generate.py` file will be executed instead of the
|
||||
standard `generate` method. Note that the logic is for generation is entirely defined in that
|
||||
repository, and the return type may be different from the standard `generate` method.
|
||||
kwargs (`Dict[str, Any]`, *optional*):
|
||||
Ad hoc parametrization of `generation_config` and/or additional model-specific kwargs that will be
|
||||
forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
|
||||
@ -2246,6 +2332,20 @@ class GenerationMixin:
|
||||
- [`~generation.GenerateEncoderDecoderOutput`],
|
||||
- [`~generation.GenerateBeamEncoderDecoderOutput`]
|
||||
"""
|
||||
# 0. If requested, load an arbitrary generation recipe from the Hub and run it instead
|
||||
if custom_generate is not None:
|
||||
trust_remote_code = kwargs.pop("trust_remote_code", None)
|
||||
# Get all `generate` arguments in a single variable. Custom functions are responsible for handling them:
|
||||
# they receive the same inputs as `generate`, only with `model` instead of `self`. They can access to
|
||||
# methods from `GenerationMixin` through `model`.
|
||||
global_keys_to_exclude = {"self", "kwargs"}
|
||||
generate_arguments = {key: value for key, value in locals().items() if key not in global_keys_to_exclude}
|
||||
generate_arguments.update(kwargs)
|
||||
|
||||
custom_generate_function = self.load_custom_generate(
|
||||
custom_generate, trust_remote_code=trust_remote_code, **kwargs
|
||||
)
|
||||
return custom_generate_function(model=self, **generate_arguments)
|
||||
|
||||
# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
|
||||
tokenizer = kwargs.pop("tokenizer", None) # Pull this out first, we only use it for stopping criteria
|
||||
@ -4905,9 +5005,14 @@ class GenerationMixin:
|
||||
input_chunks = torch.split(input_ids[:, :-1], chunk_size, dim=-1)
|
||||
|
||||
if "past_key_values" not in model_kwargs:
|
||||
raise ValueError("Cannot use prefill chunkink without a cache")
|
||||
raise ValueError("Cannot use prefill chunking without a cache")
|
||||
|
||||
model_forward = self.forward
|
||||
|
||||
compile_forward = self._valid_auto_compile_criteria(model_kwargs, generation_config)
|
||||
if compile_forward:
|
||||
model_forward = self.get_compiled_call(generation_config.compile_config)
|
||||
|
||||
model_forward = self.get_compiled_call(generation_config.compile_config)
|
||||
attention_mask = model_kwargs.pop("attention_mask", None)
|
||||
|
||||
past_length = 0
|
||||
|
@ -18,11 +18,7 @@ from typing import Any, Optional, TypedDict, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .image_processing_utils import (
|
||||
BaseImageProcessor,
|
||||
BatchFeature,
|
||||
get_size_dict,
|
||||
)
|
||||
from .image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
|
||||
from .image_transforms import (
|
||||
convert_to_rgb,
|
||||
get_resize_output_image_size,
|
||||
@ -233,6 +229,9 @@ class BaseImageProcessorFast(BaseImageProcessor):
|
||||
else:
|
||||
setattr(self, key, getattr(self, key, None))
|
||||
|
||||
# get valid kwargs names
|
||||
self._valid_kwargs_names = list(self.valid_kwargs.__annotations__.keys())
|
||||
|
||||
def resize(
|
||||
self,
|
||||
image: "torch.Tensor",
|
||||
@ -249,7 +248,7 @@ class BaseImageProcessorFast(BaseImageProcessor):
|
||||
Image to resize.
|
||||
size (`SizeDict`):
|
||||
Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
|
||||
resample (`InterpolationMode`, *optional*, defaults to `InterpolationMode.BILINEAR`):
|
||||
interpolation (`InterpolationMode`, *optional*, defaults to `InterpolationMode.BILINEAR`):
|
||||
`InterpolationMode` filter to use when resizing the image e.g. `InterpolationMode.BICUBIC`.
|
||||
|
||||
Returns:
|
||||
@ -566,12 +565,16 @@ class BaseImageProcessorFast(BaseImageProcessor):
|
||||
data_format=data_format,
|
||||
)
|
||||
|
||||
def __call__(self, images: ImageInput, *args, **kwargs: Unpack[DefaultFastImageProcessorKwargs]) -> BatchFeature:
|
||||
return self.preprocess(images, *args, **kwargs)
|
||||
|
||||
@auto_docstring
|
||||
def preprocess(self, images: ImageInput, **kwargs: Unpack[DefaultFastImageProcessorKwargs]) -> BatchFeature:
|
||||
validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self.valid_kwargs.__annotations__.keys())
|
||||
def preprocess(self, images: ImageInput, *args, **kwargs: Unpack[DefaultFastImageProcessorKwargs]) -> BatchFeature:
|
||||
# args are not validated, but their order in the `preprocess` and `_preprocess` signatures must be the same
|
||||
validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_kwargs_names)
|
||||
# Set default kwargs from self. This ensures that if a kwarg is not provided
|
||||
# by the user, it gets its default value from the instance, or is set to None.
|
||||
for kwarg_name in self.valid_kwargs.__annotations__:
|
||||
for kwarg_name in self._valid_kwargs_names:
|
||||
kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None))
|
||||
|
||||
# Extract parameters that are only used for preparing the input images
|
||||
@ -603,7 +606,7 @@ class BaseImageProcessorFast(BaseImageProcessor):
|
||||
kwargs.pop("default_to_square")
|
||||
kwargs.pop("data_format")
|
||||
|
||||
return self._preprocess(images=images, **kwargs)
|
||||
return self._preprocess(images, *args, **kwargs)
|
||||
|
||||
def _preprocess(
|
||||
self,
|
||||
@ -651,6 +654,7 @@ class BaseImageProcessorFast(BaseImageProcessor):
|
||||
def to_dict(self):
|
||||
encoder_dict = super().to_dict()
|
||||
encoder_dict.pop("_valid_processor_keys", None)
|
||||
encoder_dict.pop("_valid_kwargs_names", None)
|
||||
return encoder_dict
|
||||
|
||||
|
||||
|
@ -56,7 +56,9 @@ def to_channel_dimension_format(
|
||||
input_channel_dim: Optional[Union[ChannelDimension, str]] = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Converts `image` to the channel dimension format specified by `channel_dim`.
|
||||
Converts `image` to the channel dimension format specified by `channel_dim`. The input
|
||||
can have arbitrary number of leading dimensions. Only last three dimension will be permuted
|
||||
to format the `image`.
|
||||
|
||||
Args:
|
||||
image (`numpy.ndarray`):
|
||||
@ -80,9 +82,11 @@ def to_channel_dimension_format(
|
||||
return image
|
||||
|
||||
if target_channel_dim == ChannelDimension.FIRST:
|
||||
image = image.transpose((2, 0, 1))
|
||||
axes = list(range(image.ndim - 3)) + [image.ndim - 1, image.ndim - 3, image.ndim - 2]
|
||||
image = image.transpose(axes)
|
||||
elif target_channel_dim == ChannelDimension.LAST:
|
||||
image = image.transpose((1, 2, 0))
|
||||
axes = list(range(image.ndim - 3)) + [image.ndim - 2, image.ndim - 1, image.ndim - 3]
|
||||
image = image.transpose(axes)
|
||||
else:
|
||||
raise ValueError(f"Unsupported channel dimension format: {channel_dim}")
|
||||
|
||||
|
@ -15,11 +15,9 @@
|
||||
import base64
|
||||
import os
|
||||
from collections.abc import Iterable
|
||||
from contextlib import redirect_stdout
|
||||
from dataclasses import dataclass
|
||||
from io import BytesIO
|
||||
from typing import Callable, Optional, Union
|
||||
from urllib.parse import urlparse
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
@ -27,9 +25,6 @@ from packaging import version
|
||||
|
||||
from .utils import (
|
||||
ExplicitEnum,
|
||||
is_av_available,
|
||||
is_cv2_available,
|
||||
is_decord_available,
|
||||
is_jax_tensor,
|
||||
is_numpy_array,
|
||||
is_tf_tensor,
|
||||
@ -37,7 +32,6 @@ from .utils import (
|
||||
is_torch_tensor,
|
||||
is_torchvision_available,
|
||||
is_vision_available,
|
||||
is_yt_dlp_available,
|
||||
logging,
|
||||
requires_backends,
|
||||
to_numpy,
|
||||
@ -62,7 +56,6 @@ if is_vision_available():
|
||||
PILImageResampling = PIL.Image
|
||||
|
||||
if is_torchvision_available():
|
||||
from torchvision import io as torchvision_io
|
||||
from torchvision.transforms import InterpolationMode
|
||||
|
||||
pil_torch_interpolation_mapping = {
|
||||
@ -89,18 +82,6 @@ ImageInput = Union[
|
||||
] # noqa
|
||||
|
||||
|
||||
VideoInput = Union[
|
||||
list["PIL.Image.Image"],
|
||||
"np.ndarray",
|
||||
"torch.Tensor",
|
||||
list["np.ndarray"],
|
||||
list["torch.Tensor"],
|
||||
list[list["PIL.Image.Image"]],
|
||||
list[list["np.ndarray"]],
|
||||
list[list["torch.Tensor"]],
|
||||
] # noqa
|
||||
|
||||
|
||||
class ChannelDimension(ExplicitEnum):
|
||||
FIRST = "channels_first"
|
||||
LAST = "channels_last"
|
||||
@ -116,14 +97,6 @@ class AnnotionFormat(ExplicitEnum):
|
||||
COCO_PANOPTIC = AnnotationFormat.COCO_PANOPTIC.value
|
||||
|
||||
|
||||
@dataclass
|
||||
class VideoMetadata:
|
||||
total_num_frames: int
|
||||
fps: float
|
||||
duration: float
|
||||
video_backend: str
|
||||
|
||||
|
||||
AnnotationType = dict[str, Union[int, str, list[dict]]]
|
||||
|
||||
|
||||
@ -309,37 +282,6 @@ def make_nested_list_of_images(
|
||||
raise ValueError("Invalid input type. Must be a single image, a list of images, or a list of batches of images.")
|
||||
|
||||
|
||||
def make_batched_videos(videos) -> VideoInput:
|
||||
"""
|
||||
Ensure that the input is a list of videos.
|
||||
Args:
|
||||
videos (`VideoInput`):
|
||||
Video or videos to turn into a list of videos.
|
||||
Returns:
|
||||
list: A list of videos.
|
||||
"""
|
||||
if isinstance(videos, (list, tuple)) and isinstance(videos[0], (list, tuple)) and is_valid_image(videos[0][0]):
|
||||
# case 1: nested batch of videos so we flatten it
|
||||
if not is_pil_image(videos[0][0]) and videos[0][0].ndim == 4:
|
||||
videos = [[video for batch_list in batched_videos for video in batch_list] for batched_videos in videos]
|
||||
# case 2: list of videos represented as list of video frames
|
||||
return videos
|
||||
|
||||
elif isinstance(videos, (list, tuple)) and is_valid_image(videos[0]):
|
||||
if is_pil_image(videos[0]) or videos[0].ndim == 3:
|
||||
return [videos]
|
||||
elif videos[0].ndim == 4:
|
||||
return [list(video) for video in videos]
|
||||
|
||||
elif is_valid_image(videos):
|
||||
if is_pil_image(videos) or videos.ndim == 3:
|
||||
return [[videos]]
|
||||
elif videos.ndim == 4:
|
||||
return [list(videos)]
|
||||
|
||||
raise ValueError(f"Could not make batched video from {videos}")
|
||||
|
||||
|
||||
def to_numpy_array(img) -> np.ndarray:
|
||||
if not is_valid_image(img):
|
||||
raise ValueError(f"Invalid image type: {type(img)}")
|
||||
@ -371,6 +313,8 @@ def infer_channel_dimension_format(
|
||||
first_dim, last_dim = 0, 2
|
||||
elif image.ndim == 4:
|
||||
first_dim, last_dim = 1, 3
|
||||
elif image.ndim == 5:
|
||||
first_dim, last_dim = 2, 4
|
||||
else:
|
||||
raise ValueError(f"Unsupported number of image dimensions: {image.ndim}")
|
||||
|
||||
@ -548,348 +492,6 @@ def load_image(image: Union[str, "PIL.Image.Image"], timeout: Optional[float] =
|
||||
return image
|
||||
|
||||
|
||||
def default_sample_indices_fn(metadata: VideoMetadata, num_frames=None, fps=None, **kwargs):
|
||||
"""
|
||||
A default sampling function that replicates the logic used in get_uniform_frame_indices,
|
||||
while optionally handling `fps` if `num_frames` is not provided.
|
||||
|
||||
Args:
|
||||
metadata (`VideoMetadata`):
|
||||
`VideoMetadata` object containing metadata about the video, such as "total_num_frames" or "fps".
|
||||
num_frames (`int`, *optional*):
|
||||
Number of frames to sample uniformly.
|
||||
fps (`int`, *optional*):
|
||||
Desired frames per second. Takes priority over num_frames if both are provided.
|
||||
|
||||
Returns:
|
||||
`np.ndarray`: Array of frame indices to sample.
|
||||
"""
|
||||
total_num_frames = metadata.total_num_frames
|
||||
video_fps = metadata.fps
|
||||
|
||||
# If num_frames is not given but fps is, calculate num_frames from fps
|
||||
if num_frames is None and fps is not None:
|
||||
num_frames = int(total_num_frames / video_fps * fps)
|
||||
if num_frames > total_num_frames:
|
||||
raise ValueError(
|
||||
f"When loading the video with fps={fps}, we computed num_frames={num_frames} "
|
||||
f"which exceeds total_num_frames={total_num_frames}. Check fps or video metadata."
|
||||
)
|
||||
|
||||
if num_frames is not None:
|
||||
indices = np.arange(0, total_num_frames, total_num_frames / num_frames, dtype=int)
|
||||
else:
|
||||
indices = np.arange(0, total_num_frames, dtype=int)
|
||||
return indices
|
||||
|
||||
|
||||
def read_video_opencv(
|
||||
video_path: str,
|
||||
sample_indices_fn: Callable,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Decode a video using the OpenCV backend.
|
||||
|
||||
Args:
|
||||
video_path (`str`):
|
||||
Path to the video file.
|
||||
sample_indices_fn (`Callable`):
|
||||
A callable function that will return indices at which the video should be sampled. If the video has to be loaded using
|
||||
by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`.
|
||||
If not provided, simple uniform sampling with fps is performed.
|
||||
Example:
|
||||
def sample_indices_fn(metadata, **kwargs):
|
||||
return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int)
|
||||
|
||||
Returns:
|
||||
Tuple[`np.array`, `VideoMetadata`]: A tuple containing:
|
||||
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
|
||||
- `VideoMetadata` object.
|
||||
"""
|
||||
# Lazy import cv2
|
||||
requires_backends(read_video_opencv, ["cv2"])
|
||||
import cv2
|
||||
|
||||
video = cv2.VideoCapture(video_path)
|
||||
total_num_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
video_fps = video.get(cv2.CAP_PROP_FPS)
|
||||
duration = total_num_frames / video_fps if video_fps else 0
|
||||
metadata = VideoMetadata(
|
||||
total_num_frames=int(total_num_frames), fps=float(video_fps), duration=float(duration), video_backend="opencv"
|
||||
)
|
||||
indices = sample_indices_fn(metadata=metadata, **kwargs)
|
||||
|
||||
index = 0
|
||||
frames = []
|
||||
while video.isOpened():
|
||||
success, frame = video.read()
|
||||
if not success:
|
||||
break
|
||||
if index in indices:
|
||||
height, width, channel = frame.shape
|
||||
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
frames.append(frame[0:height, 0:width, 0:channel])
|
||||
if success:
|
||||
index += 1
|
||||
if index >= total_num_frames:
|
||||
break
|
||||
|
||||
video.release()
|
||||
metadata.frames_indices = indices
|
||||
return np.stack(frames), metadata
|
||||
|
||||
|
||||
def read_video_decord(
|
||||
video_path: str,
|
||||
sample_indices_fn: Optional[Callable] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Decode a video using the Decord backend.
|
||||
|
||||
Args:
|
||||
video_path (`str`):
|
||||
Path to the video file.
|
||||
sample_indices_fn (`Callable`, *optional*):
|
||||
A callable function that will return indices at which the video should be sampled. If the video has to be loaded using
|
||||
by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`.
|
||||
If not provided, simple uniform sampling with fps is performed.
|
||||
Example:
|
||||
def sample_indices_fn(metadata, **kwargs):
|
||||
return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int)
|
||||
|
||||
Returns:
|
||||
Tuple[`np.array`, `VideoMetadata`]: A tuple containing:
|
||||
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
|
||||
- `VideoMetadata` object.
|
||||
"""
|
||||
# Lazy import from decord
|
||||
requires_backends(read_video_decord, ["decord"])
|
||||
from decord import VideoReader, cpu
|
||||
|
||||
vr = VideoReader(uri=video_path, ctx=cpu(0)) # decord has problems with gpu
|
||||
video_fps = vr.get_avg_fps()
|
||||
total_num_frames = len(vr)
|
||||
duration = total_num_frames / video_fps if video_fps else 0
|
||||
metadata = VideoMetadata(
|
||||
total_num_frames=int(total_num_frames), fps=float(video_fps), duration=float(duration), video_backend="decord"
|
||||
)
|
||||
|
||||
indices = sample_indices_fn(metadata=metadata, **kwargs)
|
||||
|
||||
frames = vr.get_batch(indices).asnumpy()
|
||||
metadata.frames_indices = indices
|
||||
return frames, metadata
|
||||
|
||||
|
||||
def read_video_pyav(
|
||||
video_path: str,
|
||||
sample_indices_fn: Callable,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Decode the video with PyAV decoder.
|
||||
|
||||
Args:
|
||||
video_path (`str`):
|
||||
Path to the video file.
|
||||
sample_indices_fn (`Callable`, *optional*):
|
||||
A callable function that will return indices at which the video should be sampled. If the video has to be loaded using
|
||||
by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`.
|
||||
If not provided, simple uniform sampling with fps is performed.
|
||||
Example:
|
||||
def sample_indices_fn(metadata, **kwargs):
|
||||
return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int)
|
||||
|
||||
Returns:
|
||||
Tuple[`np.array`, `VideoMetadata`]: A tuple containing:
|
||||
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
|
||||
- `VideoMetadata` object.
|
||||
"""
|
||||
# Lazy import av
|
||||
requires_backends(read_video_pyav, ["av"])
|
||||
import av
|
||||
|
||||
container = av.open(video_path)
|
||||
total_num_frames = container.streams.video[0].frames
|
||||
video_fps = container.streams.video[0].average_rate # should we better use `av_guess_frame_rate`?
|
||||
duration = total_num_frames / video_fps if video_fps else 0
|
||||
metadata = VideoMetadata(
|
||||
total_num_frames=int(total_num_frames), fps=float(video_fps), duration=float(duration), video_backend="pyav"
|
||||
)
|
||||
indices = sample_indices_fn(metadata=metadata, **kwargs)
|
||||
|
||||
frames = []
|
||||
container.seek(0)
|
||||
end_index = indices[-1]
|
||||
for i, frame in enumerate(container.decode(video=0)):
|
||||
if i > end_index:
|
||||
break
|
||||
if i >= 0 and i in indices:
|
||||
frames.append(frame)
|
||||
|
||||
video = np.stack([x.to_ndarray(format="rgb24") for x in frames])
|
||||
metadata.frames_indices = indices
|
||||
return video, metadata
|
||||
|
||||
|
||||
def read_video_torchvision(
|
||||
video_path: str,
|
||||
sample_indices_fn: Callable,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Decode the video with torchvision decoder.
|
||||
|
||||
Args:
|
||||
video_path (`str`):
|
||||
Path to the video file.
|
||||
sample_indices_fn (`Callable`, *optional*):
|
||||
A callable function that will return indices at which the video should be sampled. If the video has to be loaded using
|
||||
by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`.
|
||||
If not provided, simple uniform sampling with fps is performed.
|
||||
Example:
|
||||
def sample_indices_fn(metadata, **kwargs):
|
||||
return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int)
|
||||
|
||||
Returns:
|
||||
Tuple[`np.array`, `VideoMetadata`]: A tuple containing:
|
||||
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
|
||||
- `VideoMetadata` object.
|
||||
"""
|
||||
video, _, info = torchvision_io.read_video(
|
||||
video_path,
|
||||
start_pts=0.0,
|
||||
end_pts=None,
|
||||
pts_unit="sec",
|
||||
output_format="THWC",
|
||||
)
|
||||
video_fps = info["video_fps"]
|
||||
total_num_frames = video.size(0)
|
||||
duration = total_num_frames / video_fps if video_fps else 0
|
||||
metadata = VideoMetadata(
|
||||
total_num_frames=int(total_num_frames),
|
||||
fps=float(video_fps),
|
||||
duration=float(duration),
|
||||
video_backend="torchvision",
|
||||
)
|
||||
|
||||
indices = sample_indices_fn(metadata=metadata, **kwargs)
|
||||
|
||||
video = video[indices].contiguous().numpy()
|
||||
metadata.frames_indices = indices
|
||||
return video, metadata
|
||||
|
||||
|
||||
VIDEO_DECODERS = {
|
||||
"decord": read_video_decord,
|
||||
"opencv": read_video_opencv,
|
||||
"pyav": read_video_pyav,
|
||||
"torchvision": read_video_torchvision,
|
||||
}
|
||||
|
||||
|
||||
def load_video(
|
||||
video: Union[str, "VideoInput"],
|
||||
num_frames: Optional[int] = None,
|
||||
fps: Optional[int] = None,
|
||||
backend: str = "opencv",
|
||||
sample_indices_fn: Optional[Callable] = None,
|
||||
**kwargs,
|
||||
) -> np.array:
|
||||
"""
|
||||
Loads `video` to a numpy array.
|
||||
|
||||
Args:
|
||||
video (`str` or `VideoInput`):
|
||||
The video to convert to the numpy array format. Can be a link to video or local path.
|
||||
num_frames (`int`, *optional*):
|
||||
Number of frames to sample uniformly. If not passed, the whole video is loaded.
|
||||
fps (`int`, *optional*):
|
||||
Number of frames to sample per second. Should be passed only when `num_frames=None`.
|
||||
If not specified and `num_frames==None`, all frames are sampled.
|
||||
backend (`str`, *optional*, defaults to `"opencv"`):
|
||||
The backend to use when loading the video. Can be any of ["decord", "pyav", "opencv", "torchvision"]. Defaults to "opencv".
|
||||
sample_indices_fn (`Callable`, *optional*):
|
||||
A callable function that will return indices at which the video should be sampled. If the video has to be loaded using
|
||||
by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`.
|
||||
If not provided, simple uniformt sampling with fps is performed, otherwise `sample_indices_fn` has priority over other args.
|
||||
The function expects at input the all args along with all kwargs passed to `load_video` and should output valid
|
||||
indices at which the video should be sampled. For example:
|
||||
|
||||
Example:
|
||||
def sample_indices_fn(metadata, **kwargs):
|
||||
return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int)
|
||||
|
||||
Returns:
|
||||
Tuple[`np.array`, Dict]: A tuple containing:
|
||||
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
|
||||
- Metadata dictionary.
|
||||
"""
|
||||
|
||||
# If `sample_indices_fn` is given, we can accept any args as those might be needed by custom `sample_indices_fn`
|
||||
if fps is not None and num_frames is not None and sample_indices_fn is None:
|
||||
raise ValueError(
|
||||
"`num_frames`, `fps`, and `sample_indices_fn` are mutually exclusive arguments, please use only one!"
|
||||
)
|
||||
|
||||
# If user didn't pass a sampling function, create one on the fly with default logic
|
||||
if sample_indices_fn is None:
|
||||
|
||||
def sample_indices_fn_func(metadata, **fn_kwargs):
|
||||
return default_sample_indices_fn(metadata, num_frames=num_frames, fps=fps, **fn_kwargs)
|
||||
|
||||
sample_indices_fn = sample_indices_fn_func
|
||||
|
||||
if urlparse(video).netloc in ["www.youtube.com", "youtube.com"]:
|
||||
if not is_yt_dlp_available():
|
||||
raise ImportError("To load a video from YouTube url you have to install `yt_dlp` first.")
|
||||
# Lazy import from yt_dlp
|
||||
requires_backends(load_video, ["yt_dlp"])
|
||||
from yt_dlp import YoutubeDL
|
||||
|
||||
buffer = BytesIO()
|
||||
with redirect_stdout(buffer), YoutubeDL() as f:
|
||||
f.download([video])
|
||||
bytes_obj = buffer.getvalue()
|
||||
file_obj = BytesIO(bytes_obj)
|
||||
elif video.startswith("http://") or video.startswith("https://"):
|
||||
file_obj = BytesIO(requests.get(video).content)
|
||||
elif os.path.isfile(video):
|
||||
file_obj = video
|
||||
elif is_valid_image(video) or (isinstance(video, (list, tuple)) and is_valid_image(video[0])):
|
||||
file_obj = None
|
||||
else:
|
||||
raise TypeError("Incorrect format used for video. Should be an url linking to an video or a local path.")
|
||||
|
||||
# can also load with decord, but not cv2/torchvision
|
||||
# both will fail in case of url links
|
||||
video_is_url = video.startswith("http://") or video.startswith("https://")
|
||||
if video_is_url and backend in ["opencv", "torchvision"]:
|
||||
raise ValueError(
|
||||
"If you are trying to load a video from URL, you can decode the video only with `pyav` or `decord` as backend"
|
||||
)
|
||||
|
||||
if file_obj is None:
|
||||
return video
|
||||
|
||||
if (
|
||||
(not is_decord_available() and backend == "decord")
|
||||
or (not is_av_available() and backend == "pyav")
|
||||
or (not is_cv2_available() and backend == "opencv")
|
||||
or (not is_torchvision_available() and backend == "torchvision")
|
||||
):
|
||||
raise ImportError(
|
||||
f"You chose backend={backend} for loading the video but the required library is not found in your environment "
|
||||
f"Make sure to install {backend} before loading the video."
|
||||
)
|
||||
|
||||
video_decoder = VIDEO_DECODERS[backend]
|
||||
video, metadata = video_decoder(file_obj, sample_indices_fn, **kwargs)
|
||||
return video, metadata
|
||||
|
||||
|
||||
def load_images(
|
||||
images: Union[list, tuple, str, "PIL.Image.Image"], timeout: Optional[float] = None
|
||||
) -> Union["PIL.Image.Image", list["PIL.Image.Image"], list[list["PIL.Image.Image"]]]:
|
||||
|
@ -124,7 +124,16 @@ def unpack_weights(packed: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
|
||||
|
||||
|
||||
class BitLinear(nn.Module):
|
||||
def __init__(self, in_features: int, out_features: int, bias: bool, device=None, dtype=None):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool,
|
||||
device=None,
|
||||
dtype=None,
|
||||
use_rms_norm: bool = False,
|
||||
rms_norm_eps: float = 1e-6,
|
||||
):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
self.in_features = in_features
|
||||
@ -150,6 +159,13 @@ class BitLinear(nn.Module):
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
# Optional RMSNorm (applied on the activations before quantization).
|
||||
self.rms_norm = None
|
||||
if use_rms_norm:
|
||||
from ..models.llama.modeling_llama import LlamaRMSNorm
|
||||
|
||||
self.rms_norm = LlamaRMSNorm(in_features, eps=rms_norm_eps)
|
||||
|
||||
@torch.compile
|
||||
def activation_quant(self, input, num_bits=8):
|
||||
"""
|
||||
@ -180,6 +196,10 @@ class BitLinear(nn.Module):
|
||||
return out
|
||||
|
||||
def forward(self, input):
|
||||
# Apply RMSNorm on the input if requested.
|
||||
if self.rms_norm is not None:
|
||||
input = self.rms_norm(input)
|
||||
|
||||
w = self.weight
|
||||
w_quant = unpack_weights(w, dtype=self.dtype)
|
||||
input_quant, input_scale = self.activation_quant(input)
|
||||
@ -245,9 +265,17 @@ class AutoBitLinear(nn.Linear):
|
||||
device=None,
|
||||
dtype=None,
|
||||
online_quant: bool = False,
|
||||
use_rms_norm: bool = False,
|
||||
rms_norm_eps: float = 1e-6,
|
||||
):
|
||||
super().__init__(in_features, out_features, bias)
|
||||
self.online_quant = online_quant
|
||||
# Optional RMSNorm
|
||||
self.rms_norm = None
|
||||
if use_rms_norm:
|
||||
from ..models.llama.modeling_llama import LlamaRMSNorm
|
||||
|
||||
self.rms_norm = LlamaRMSNorm(in_features, eps=rms_norm_eps)
|
||||
if not online_quant:
|
||||
self.register_buffer(
|
||||
"weight_scale",
|
||||
@ -271,6 +299,10 @@ class AutoBitLinear(nn.Linear):
|
||||
return state_dict
|
||||
|
||||
def forward(self, input):
|
||||
# Optional RMSNorm on activations prior to quantization.
|
||||
if self.rms_norm is not None:
|
||||
input = self.rms_norm(input)
|
||||
|
||||
if self.online_quant:
|
||||
weight = WeightQuant.apply(self.weight)
|
||||
else:
|
||||
@ -318,6 +350,8 @@ def _replace_with_bitnet_linear(
|
||||
device=module.weight.device,
|
||||
dtype=module.weight.dtype,
|
||||
online_quant=(quantization_config.quantization_mode == "online"),
|
||||
use_rms_norm=quantization_config.use_rms_norm,
|
||||
rms_norm_eps=quantization_config.rms_norm_eps,
|
||||
)
|
||||
if quantization_config.quantization_mode == "offline":
|
||||
model._modules[name].requires_grad_(False)
|
||||
@ -328,6 +362,8 @@ def _replace_with_bitnet_linear(
|
||||
bias=module.bias is not None,
|
||||
device=module.weight.device,
|
||||
dtype=module.weight.dtype,
|
||||
use_rms_norm=quantization_config.use_rms_norm,
|
||||
rms_norm_eps=quantization_config.rms_norm_eps,
|
||||
)
|
||||
model._modules[name].requires_grad_(False)
|
||||
has_been_replaced = True
|
||||
@ -363,7 +399,7 @@ def replace_with_bitnet_linear(
|
||||
model (`torch.nn.Module`):
|
||||
Input model or `torch.nn.Module` as the function is run recursively.
|
||||
modules_to_not_convert (`List[`str`]`, *optional*, defaults to `["lm_head"]`):
|
||||
Names of the modules to not convert in `EetqLinear`. In practice we keep the `lm_head` in full precision
|
||||
Names of the modules to not convert in `BitLinear`. In practice we keep the `lm_head` in full precision
|
||||
for numerical stability reasons.
|
||||
current_key_name (`List[`str`]`, *optional*):
|
||||
An array to track the current key of the recursion. This is used to check whether the current key (part of
|
||||
|
@ -15,7 +15,7 @@
|
||||
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from ..utils import is_accelerate_available, is_torch_available, logging
|
||||
from ..utils import is_accelerate_available, is_torch_accelerator_available, is_torch_available, logging
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
@ -332,8 +332,10 @@ class FP8Linear(nn.Linear):
|
||||
if self.weight.element_size() > 1:
|
||||
return F.linear(input, self.weight, self.bias)
|
||||
else:
|
||||
# Context manager used to switch among the available cuda devices
|
||||
with torch.cuda.device(input.device):
|
||||
# Context manager used to switch among the available accelerators
|
||||
device_type = torch.accelerator.current_accelerator().type if is_torch_accelerator_available() else "cuda"
|
||||
torch_accelerator_module = getattr(torch, device_type, torch.cuda)
|
||||
with torch_accelerator_module.device(input.device):
|
||||
qinput, scale = act_quant(input, self.block_size[1])
|
||||
output = w8a8_block_fp8_matmul_triton(
|
||||
qinput,
|
||||
@ -343,9 +345,9 @@ class FP8Linear(nn.Linear):
|
||||
self.block_size,
|
||||
output_dtype=input.dtype,
|
||||
)
|
||||
# Blocks the CPU until all CUDA operations on the specified device are complete. It is used to ensure that the results of the
|
||||
# Blocks the CPU until all accelerator operations on the specified device are complete. It is used to ensure that the results of the
|
||||
# preceding operations are ready before proceeding
|
||||
torch.cuda.synchronize()
|
||||
torch_accelerator_module.synchronize()
|
||||
if self.bias is not None:
|
||||
output = output + self.bias
|
||||
return output.to(dtype=input.dtype)
|
||||
|
@ -15,6 +15,7 @@
|
||||
Integrations with other Python libraries.
|
||||
"""
|
||||
|
||||
import copy
|
||||
import functools
|
||||
import importlib.metadata
|
||||
import importlib.util
|
||||
@ -33,7 +34,7 @@ from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Union
|
||||
import numpy as np
|
||||
import packaging.version
|
||||
|
||||
from .. import PreTrainedModel, TFPreTrainedModel
|
||||
from .. import PreTrainedModel, TFPreTrainedModel, TrainingArguments
|
||||
from .. import __version__ as version
|
||||
from ..utils import (
|
||||
PushToHubMixin,
|
||||
@ -929,13 +930,17 @@ class WandbCallback(TrainerCallback):
|
||||
if not self._initialized:
|
||||
self.setup(args, state, model, **kwargs)
|
||||
|
||||
def on_train_end(self, args, state, control, model=None, processing_class=None, **kwargs):
|
||||
def on_train_end(self, args: TrainingArguments, state, control, model=None, processing_class=None, **kwargs):
|
||||
if self._wandb is None:
|
||||
return
|
||||
if self._log_model.is_enabled and self._initialized and state.is_world_process_zero:
|
||||
from ..trainer import Trainer
|
||||
|
||||
fake_trainer = Trainer(args=args, model=model, processing_class=processing_class, eval_dataset=["fake"])
|
||||
args_for_fake = copy.deepcopy(args)
|
||||
args_for_fake.deepspeed = None
|
||||
fake_trainer = Trainer(
|
||||
args=args_for_fake, model=model, processing_class=processing_class, eval_dataset=["fake"]
|
||||
)
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
fake_trainer.save_model(temp_dir)
|
||||
metadata = (
|
||||
|
@ -61,6 +61,22 @@ def _blocks_to_block_sizes(total_size: int, blocks: Union[int, List[int]]) -> Li
|
||||
return [single_size] * blocks
|
||||
|
||||
|
||||
def _get_parameter_tp_plan(parameter_name: str, tp_plan: dict[str, str]) -> Optional[str]:
|
||||
"""
|
||||
Get the TP style for a parameter from the TP plan.
|
||||
|
||||
The TP plan is a dictionary that maps parameter names to TP styles.
|
||||
The parameter name can be a generic name with wildcards (e.g. "*.weight") or a specific name (e.g. "layer_1.weight").
|
||||
"""
|
||||
generic_param_name = re.sub(r"\d+", "*", parameter_name)
|
||||
if generic_param_name in tp_plan:
|
||||
return tp_plan[generic_param_name]
|
||||
elif "." in generic_param_name and generic_param_name.rsplit(".", 1)[0] in tp_plan:
|
||||
return tp_plan[generic_param_name.rsplit(".", 1)[0]]
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
str_to_torch_dtype = {
|
||||
"BOOL": torch.bool,
|
||||
"U8": torch.uint8,
|
||||
@ -138,6 +154,71 @@ def get_packed_weights(param, empty_param, device_mesh, rank, dim):
|
||||
return tensor.to(str_to_torch_dtype[slice_dtype])
|
||||
|
||||
|
||||
def repack_weights(
|
||||
packed_parameter: torch.Tensor,
|
||||
sharded_dim: int, # The dimension index in the global tensor that was sharded
|
||||
world_size: int,
|
||||
num_blocks: int = 2,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Reorders a tensor that was reconstructed from sharded packed weights into its canonical packed format.
|
||||
|
||||
For example, if a weight was packed (e.g., gate_proj and up_proj) and then sharded,
|
||||
DTensor.full_tensor() might produce an interleaved layout like [G0, U0, G1, U1, ...]
|
||||
along the sharded dimension. This function reorders it to [G0, G1, ..., U0, U1, ...].
|
||||
This is an inverse operation to get_packed_weights.
|
||||
|
||||
Args:
|
||||
reconstructed_tensor: The tensor reconstructed from DTensor (e.g., via .full_tensor().contiguous()).
|
||||
sharded_dim: The dimension index in the reconstructed_tensor that was originally sharded.
|
||||
world_size: The tensor parallel world size.
|
||||
num_packed_projs: The number of projections that were packed together (e.g., 2 for gate_up_proj).
|
||||
|
||||
Returns:
|
||||
The reordered tensor in canonical packed format.
|
||||
"""
|
||||
|
||||
if num_blocks != 2:
|
||||
raise ValueError(
|
||||
"Num blocks different from 2 is not supported yet. This is most likely a bug in your implementation as we only pack gate and up projections together."
|
||||
)
|
||||
|
||||
actual_sharded_dim = sharded_dim if sharded_dim >= 0 else sharded_dim + packed_parameter.ndim
|
||||
total_size_on_sharded_dim = packed_parameter.shape[actual_sharded_dim]
|
||||
original_block_size_on_dim = total_size_on_sharded_dim // num_blocks
|
||||
shard_chunk_size = original_block_size_on_dim // world_size
|
||||
|
||||
prefix_shape = packed_parameter.shape[:actual_sharded_dim]
|
||||
suffix_shape = packed_parameter.shape[actual_sharded_dim + 1 :]
|
||||
|
||||
tensor_view = packed_parameter.view(
|
||||
*prefix_shape,
|
||||
world_size,
|
||||
num_blocks,
|
||||
shard_chunk_size,
|
||||
*suffix_shape,
|
||||
)
|
||||
|
||||
# Permute to bring num_packed_projs first, then world_size, then shard_chunk_size
|
||||
# This groups all chunks of G together, then all chunks of U together.
|
||||
# Target order of these middle dimensions: (num_packed_projs, world_size, shard_chunk_size)
|
||||
# Current order of view's middle dimensions: (world_size, num_packed_projs, shard_chunk_size)
|
||||
# Absolute indices of the dimensions to be permuted (world_size, num_packed_projs)
|
||||
axis_ws_abs = len(prefix_shape)
|
||||
axis_npp_abs = len(prefix_shape) + 1
|
||||
|
||||
permute_order = list(range(tensor_view.ndim))
|
||||
permute_order[axis_ws_abs], permute_order[axis_npp_abs] = permute_order[axis_npp_abs], permute_order[axis_ws_abs]
|
||||
|
||||
tensor_permuted = tensor_view.permute(*permute_order)
|
||||
|
||||
# Reshape back to the original tensor's ndim, with the sharded dimension now correctly ordered as [G_all, U_all].
|
||||
# The final shape should be the same as reconstructed_tensor.
|
||||
final_ordered_tensor = tensor_permuted.reshape_as(packed_parameter)
|
||||
|
||||
return final_ordered_tensor
|
||||
|
||||
|
||||
def get_tensor_shard(param, empty_param, device_mesh, rank, dim):
|
||||
if dim == 0:
|
||||
size_ = empty_param.shape[0]
|
||||
@ -578,6 +659,49 @@ def translate_to_torch_parallel_style(style: str):
|
||||
raise ValueError(f"Unsupported parallel style value: {style}")
|
||||
|
||||
|
||||
def convert_local_tensor_to_dtensor(
|
||||
parameter: torch.Tensor, parameter_name: str, device_mesh, tp_plan: dict[str, str]
|
||||
) -> DTensor:
|
||||
"""
|
||||
Converts a local variant of weights to a DTensor with corresponding placements. Shouldn't be done ever except of before saving the model.
|
||||
"""
|
||||
_, param_type = parameter_name.rsplit(".", 1) if "." in parameter_name else parameter_name
|
||||
tp_style = _get_parameter_tp_plan(parameter_name, tp_plan)
|
||||
if not tp_style:
|
||||
return parameter
|
||||
|
||||
if tp_style not in ["local_packed_rowwise", "local_rowwise", "local_colwise"]:
|
||||
return parameter
|
||||
# TODO: this logic should be wrapped in a function, this is copied from corresponding tp classes.
|
||||
if tp_style == "local_packed_rowwise":
|
||||
placements = [Shard(-1)]
|
||||
elif tp_style == "local_rowwise":
|
||||
if param_type == "bias":
|
||||
placements = [Replicate()]
|
||||
else:
|
||||
placements = [Shard(-1)]
|
||||
elif tp_style == "local_colwise":
|
||||
if param_type == "bias":
|
||||
placements = [Shard(-1)]
|
||||
else:
|
||||
placements = [Shard(-2)]
|
||||
return DTensor.from_local(parameter, device_mesh, placements, run_check=False)
|
||||
|
||||
|
||||
def replace_state_dict_local_with_dtensor(
|
||||
state_dict: dict[str, torch.Tensor],
|
||||
tp_plan: dict[str, str],
|
||||
device_mesh,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Replaces all tensors that were sharded with `local_*` strategy with DTensor to make determining their proper size possible.
|
||||
"""
|
||||
for key, value in state_dict.items():
|
||||
if isinstance(value, torch.Tensor) and not isinstance(value, DTensor):
|
||||
state_dict[key] = convert_local_tensor_to_dtensor(value, key, device_mesh, tp_plan)
|
||||
return state_dict
|
||||
|
||||
|
||||
def add_tensor_parallel_hooks_to_module(model, module, tp_plan, layer_name, current_module_plan, device_mesh):
|
||||
"""
|
||||
Add hooks to the module holding the layer. Meaning:
|
||||
@ -632,13 +756,9 @@ def shard_and_distribute_module(
|
||||
param_name, param_type = parameter_name.rsplit(".", 1) if "." in parameter_name else parameter_name
|
||||
tp_plan = model._tp_plan
|
||||
module_to_tp = model.get_submodule(param_name)
|
||||
current_module_plan = None
|
||||
rank = int(rank)
|
||||
generic_param_name = re.sub(r"\d+", "*", parameter_name)
|
||||
if generic_param_name in tp_plan:
|
||||
current_module_plan = tp_plan[generic_param_name]
|
||||
elif "." in generic_param_name and generic_param_name.rsplit(".", 1)[0] in tp_plan:
|
||||
current_module_plan = tp_plan[generic_param_name.rsplit(".", 1)[0]]
|
||||
|
||||
current_module_plan = _get_parameter_tp_plan(parameter_name, tp_plan)
|
||||
|
||||
# Add hooks to the module if not done yet
|
||||
# add_tensor_parallel_hooks_to_module(model, module_to_tp, tp_plan, param_name, current_module_plan, device_mesh)
|
||||
@ -670,3 +790,34 @@ def shard_and_distribute_module(
|
||||
setattr(module_to_tp, param_type, param)
|
||||
# module_to_tp.load_state_dict({param_type: param}, strict=False, assign=True)
|
||||
return param
|
||||
|
||||
|
||||
def verify_tp_plan(expected_keys: list[str], tp_plan: Optional[dict[str, str]]):
|
||||
"""
|
||||
Verify the TP plan of the model, log a warning if the layers that were not sharded and the rules that were not applied.
|
||||
"""
|
||||
|
||||
if tp_plan is None:
|
||||
return
|
||||
|
||||
generic_keys = {re.sub(r"\d+", "*", key) for key in expected_keys}
|
||||
unsharded_layers = set(generic_keys)
|
||||
unused_rules = tp_plan
|
||||
|
||||
for key in generic_keys:
|
||||
param_name, _ = key.rsplit(".", 1) if "." in key else key
|
||||
generic_param_name = re.sub(r"\d+", "*", param_name)
|
||||
|
||||
if generic_param_name in tp_plan:
|
||||
unused_rules.pop(generic_param_name)
|
||||
unsharded_layers.discard(key)
|
||||
elif "." in generic_param_name and (parent_param_name := generic_param_name.rsplit(".", 1)[0]) in tp_plan:
|
||||
unused_rules.pop(parent_param_name)
|
||||
unsharded_layers.discard(key)
|
||||
else:
|
||||
pass # we couldn't find the rule for this parameter, so it's not sharded
|
||||
|
||||
if len(unused_rules) > 0:
|
||||
logger.warning(f"The following TP rules were not applied on any of the layers: {unused_rules}")
|
||||
if len(unsharded_layers) > 0:
|
||||
logger.warning(f"The following layers were not sharded: {', '.join(unsharded_layers)}")
|
||||
|
@ -66,7 +66,7 @@ def load_pytorch_checkpoint_in_flax_state_dict(
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
logger.error(
|
||||
"Loading a PyTorch model in Flax, requires both PyTorch and Flax to be installed. Please see"
|
||||
" https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation"
|
||||
" https://pytorch.org/ and https://flax.readthedocs.io/en/latest/index.html#installation for installation"
|
||||
" instructions."
|
||||
)
|
||||
raise
|
||||
@ -360,7 +360,7 @@ def load_flax_weights_in_pytorch_model(pt_model, flax_state):
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
logger.error(
|
||||
"Loading a Flax weights in PyTorch, requires both PyTorch and Flax to be installed. Please see"
|
||||
" https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation"
|
||||
" https://pytorch.org/ and https://flax.readthedocs.io/en/latest/index.html#installation for installation"
|
||||
" instructions."
|
||||
)
|
||||
raise
|
||||
|
@ -18,6 +18,7 @@ from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from .cache_utils import Cache, EncoderDecoderCache
|
||||
from .utils import ModelOutput
|
||||
|
||||
|
||||
@ -131,11 +132,8 @@ class BaseModelOutputWithPast(ModelOutput):
|
||||
|
||||
If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
|
||||
hidden_size)` is output.
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
|
||||
`config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
|
||||
encoder_sequence_length, embed_size_per_head)`.
|
||||
past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
|
||||
`config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
|
||||
@ -154,7 +152,7 @@ class BaseModelOutputWithPast(ModelOutput):
|
||||
"""
|
||||
|
||||
last_hidden_state: Optional[torch.FloatTensor] = None
|
||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
past_key_values: Optional[Cache] = None
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
|
||||
@ -222,11 +220,8 @@ class BaseModelOutputWithPoolingAndCrossAttentions(ModelOutput):
|
||||
|
||||
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
|
||||
weighted average in the cross-attention heads.
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
|
||||
`config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
|
||||
encoder_sequence_length, embed_size_per_head)`.
|
||||
past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
|
||||
`config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
|
||||
@ -236,7 +231,7 @@ class BaseModelOutputWithPoolingAndCrossAttentions(ModelOutput):
|
||||
last_hidden_state: Optional[torch.FloatTensor] = None
|
||||
pooler_output: Optional[torch.FloatTensor] = None
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
past_key_values: Optional[Cache] = None
|
||||
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
|
||||
@ -252,11 +247,8 @@ class BaseModelOutputWithPastAndCrossAttentions(ModelOutput):
|
||||
|
||||
If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
|
||||
hidden_size)` is output.
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
|
||||
`config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
|
||||
encoder_sequence_length, embed_size_per_head)`.
|
||||
past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
|
||||
`config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
|
||||
@ -281,7 +273,7 @@ class BaseModelOutputWithPastAndCrossAttentions(ModelOutput):
|
||||
"""
|
||||
|
||||
last_hidden_state: Optional[torch.FloatTensor] = None
|
||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
past_key_values: Optional[Cache] = None
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
@ -298,9 +290,8 @@ class MoECausalLMOutputWithPast(ModelOutput):
|
||||
Language modeling loss (for next-token prediction).
|
||||
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
|
||||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
|
||||
past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
|
||||
`past_key_values` input) to speed up sequential decoding.
|
||||
@ -328,7 +319,7 @@ class MoECausalLMOutputWithPast(ModelOutput):
|
||||
|
||||
loss: Optional[torch.FloatTensor] = None
|
||||
logits: Optional[torch.FloatTensor] = None
|
||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
past_key_values: Optional[Cache] = None
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
z_loss: Optional[torch.FloatTensor] = None
|
||||
@ -376,11 +367,8 @@ class MoeModelOutputWithPast(ModelOutput):
|
||||
Args:
|
||||
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
||||
Sequence of hidden-states at the output of the last layer of the model.
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
|
||||
`config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
|
||||
encoder_sequence_length, embed_size_per_head)`.
|
||||
past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
|
||||
`config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
|
||||
@ -404,7 +392,7 @@ class MoeModelOutputWithPast(ModelOutput):
|
||||
"""
|
||||
|
||||
last_hidden_state: Optional[torch.FloatTensor] = None
|
||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
past_key_values: Optional[Cache] = None
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
router_logits: Optional[Tuple[torch.FloatTensor]] = None
|
||||
@ -431,9 +419,8 @@ class MoeCausalLMOutputWithPast(ModelOutput):
|
||||
Raw router logtis (post-softmax) that are computed by MoE routers, these terms are used to compute the auxiliary
|
||||
loss for Mixture of Experts models.
|
||||
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
|
||||
past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
|
||||
`past_key_values` input) to speed up sequential decoding.
|
||||
@ -453,7 +440,7 @@ class MoeCausalLMOutputWithPast(ModelOutput):
|
||||
loss: Optional[torch.FloatTensor] = None
|
||||
aux_loss: Optional[torch.FloatTensor] = None
|
||||
logits: Optional[torch.FloatTensor] = None
|
||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
past_key_values: Optional[Cache] = None
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
router_logits: Optional[Tuple[torch.FloatTensor]] = None
|
||||
@ -471,11 +458,8 @@ class MoEModelOutputWithPastAndCrossAttentions(ModelOutput):
|
||||
|
||||
If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
|
||||
hidden_size)` is output.
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
|
||||
`config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
|
||||
encoder_sequence_length, embed_size_per_head)`.
|
||||
past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
|
||||
`config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
|
||||
@ -505,7 +489,7 @@ class MoEModelOutputWithPastAndCrossAttentions(ModelOutput):
|
||||
"""
|
||||
|
||||
last_hidden_state: Optional[torch.FloatTensor] = None
|
||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
past_key_values: Optional[Cache] = None
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
@ -524,10 +508,8 @@ class Seq2SeqModelOutput(ModelOutput):
|
||||
|
||||
If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
|
||||
hidden_size)` is output.
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
|
||||
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
|
||||
past_key_values (`EncoderDecoderCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
It is a [`~cache_utils.EncoderDecoderCache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
||||
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
||||
@ -564,7 +546,7 @@ class Seq2SeqModelOutput(ModelOutput):
|
||||
"""
|
||||
|
||||
last_hidden_state: Optional[torch.FloatTensor] = None
|
||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
past_key_values: Optional[EncoderDecoderCache] = None
|
||||
decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
@ -585,10 +567,8 @@ class Seq2SeqMoEModelOutput(ModelOutput):
|
||||
|
||||
If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
|
||||
hidden_size)` is output.
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
|
||||
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
|
||||
past_key_values (`EncoderDecoderCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
It is a [`~cache_utils.EncoderDecoderCache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
||||
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
||||
@ -634,7 +614,7 @@ class Seq2SeqMoEModelOutput(ModelOutput):
|
||||
"""
|
||||
|
||||
last_hidden_state: Optional[torch.FloatTensor] = None
|
||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
past_key_values: Optional[EncoderDecoderCache] = None
|
||||
decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
decoder_router_logits: Optional[Tuple[torch.FloatTensor]] = None
|
||||
@ -684,9 +664,8 @@ class CausalLMOutputWithPast(ModelOutput):
|
||||
Language modeling loss (for next-token prediction).
|
||||
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
|
||||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
|
||||
past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
|
||||
`past_key_values` input) to speed up sequential decoding.
|
||||
@ -705,7 +684,7 @@ class CausalLMOutputWithPast(ModelOutput):
|
||||
|
||||
loss: Optional[torch.FloatTensor] = None
|
||||
logits: Optional[torch.FloatTensor] = None
|
||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
past_key_values: Optional[Cache] = None
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
|
||||
@ -737,10 +716,8 @@ class CausalLMOutputWithCrossAttentions(ModelOutput):
|
||||
|
||||
Cross attentions weights after the attention softmax, used to compute the weighted average in the
|
||||
cross-attention heads.
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
Tuple of `torch.FloatTensor` tuples of length `config.n_layers`, with each tuple containing the cached key,
|
||||
value states of the self-attention and the cross-attention layers if model is used in encoder-decoder
|
||||
setting. Only relevant if `config.is_decoder = True`.
|
||||
past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
|
||||
`past_key_values` input) to speed up sequential decoding.
|
||||
@ -748,7 +725,7 @@ class CausalLMOutputWithCrossAttentions(ModelOutput):
|
||||
|
||||
loss: Optional[torch.FloatTensor] = None
|
||||
logits: Optional[torch.FloatTensor] = None
|
||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
past_key_values: Optional[Cache] = None
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
@ -764,9 +741,8 @@ class SequenceClassifierOutputWithPast(ModelOutput):
|
||||
Classification (or regression if config.num_labels==1) loss.
|
||||
logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
|
||||
Classification (or regression if config.num_labels==1) scores (before SoftMax).
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
|
||||
past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
|
||||
`past_key_values` input) to speed up sequential decoding.
|
||||
@ -785,7 +761,7 @@ class SequenceClassifierOutputWithPast(ModelOutput):
|
||||
|
||||
loss: Optional[torch.FloatTensor] = None
|
||||
logits: Optional[torch.FloatTensor] = None
|
||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
past_key_values: Optional[Cache] = None
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
|
||||
@ -829,10 +805,8 @@ class Seq2SeqLMOutput(ModelOutput):
|
||||
Language modeling loss.
|
||||
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
|
||||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
|
||||
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
|
||||
past_key_values (`EncoderDecoderCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
It is a [`~cache_utils.EncoderDecoderCache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
||||
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
||||
@ -870,7 +844,7 @@ class Seq2SeqLMOutput(ModelOutput):
|
||||
|
||||
loss: Optional[torch.FloatTensor] = None
|
||||
logits: Optional[torch.FloatTensor] = None
|
||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
past_key_values: Optional[EncoderDecoderCache] = None
|
||||
decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
@ -889,10 +863,8 @@ class Seq2SeqMoEOutput(ModelOutput):
|
||||
Language modeling loss.
|
||||
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
|
||||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
|
||||
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
|
||||
past_key_values (`EncoderDecoderCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
It is a [`~cache_utils.EncoderDecoderCache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
||||
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
||||
@ -943,7 +915,7 @@ class Seq2SeqMoEOutput(ModelOutput):
|
||||
decoder_z_loss: Optional[torch.FloatTensor] = None
|
||||
encoder_aux_loss: Optional[torch.FloatTensor] = None
|
||||
decoder_aux_loss: Optional[torch.FloatTensor] = None
|
||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
past_key_values: Optional[EncoderDecoderCache] = None
|
||||
decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
decoder_router_logits: Optional[Tuple[torch.FloatTensor]] = None
|
||||
@ -1023,10 +995,8 @@ class Seq2SeqSequenceClassifierOutput(ModelOutput):
|
||||
Classification (or regression if config.num_labels==1) loss.
|
||||
logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
|
||||
Classification (or regression if config.num_labels==1) scores (before SoftMax).
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
|
||||
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
|
||||
past_key_values (`EncoderDecoderCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
It is a [`~cache_utils.EncoderDecoderCache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
||||
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
||||
@ -1064,7 +1034,7 @@ class Seq2SeqSequenceClassifierOutput(ModelOutput):
|
||||
|
||||
loss: Optional[torch.FloatTensor] = None
|
||||
logits: Optional[torch.FloatTensor] = None
|
||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
past_key_values: Optional[EncoderDecoderCache] = None
|
||||
decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
@ -1177,10 +1147,8 @@ class Seq2SeqQuestionAnsweringModelOutput(ModelOutput):
|
||||
Span-start scores (before SoftMax).
|
||||
end_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
|
||||
Span-end scores (before SoftMax).
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
|
||||
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
|
||||
past_key_values (`EncoderDecoderCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
It is a [`~cache_utils.EncoderDecoderCache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
||||
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
||||
@ -1219,7 +1187,7 @@ class Seq2SeqQuestionAnsweringModelOutput(ModelOutput):
|
||||
loss: Optional[torch.FloatTensor] = None
|
||||
start_logits: Optional[torch.FloatTensor] = None
|
||||
end_logits: Optional[torch.FloatTensor] = None
|
||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
past_key_values: Optional[EncoderDecoderCache] = None
|
||||
decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
@ -1508,10 +1476,8 @@ class Seq2SeqSpectrogramOutput(ModelOutput):
|
||||
Spectrogram generation loss.
|
||||
spectrogram (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_bins)`):
|
||||
The predicted spectrogram.
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
|
||||
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
|
||||
past_key_values (`EncoderDecoderCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
It is a [`~cache_utils.EncoderDecoderCache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
||||
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
||||
@ -1549,7 +1515,7 @@ class Seq2SeqSpectrogramOutput(ModelOutput):
|
||||
|
||||
loss: Optional[torch.FloatTensor] = None
|
||||
spectrogram: Optional[torch.FloatTensor] = None
|
||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
past_key_values: Optional[EncoderDecoderCache] = None
|
||||
decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
@ -1570,10 +1536,8 @@ class Seq2SeqTSModelOutput(ModelOutput):
|
||||
|
||||
If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
|
||||
hidden_size)` is output.
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
|
||||
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
|
||||
past_key_values (`EncoderDecoderCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
It is a [`~cache_utils.EncoderDecoderCache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
||||
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
||||
@ -1618,7 +1582,7 @@ class Seq2SeqTSModelOutput(ModelOutput):
|
||||
"""
|
||||
|
||||
last_hidden_state: Optional[torch.FloatTensor] = None
|
||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
past_key_values: Optional[EncoderDecoderCache] = None
|
||||
decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
@ -1641,10 +1605,8 @@ class Seq2SeqTSPredictionOutput(ModelOutput):
|
||||
Distributional loss.
|
||||
params (`torch.FloatTensor` of shape `(batch_size, num_samples, num_params)`):
|
||||
Parameters of the chosen distribution.
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
|
||||
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
|
||||
past_key_values (`EncoderDecoderCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
It is a [`~cache_utils.EncoderDecoderCache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
||||
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
||||
@ -1690,7 +1652,7 @@ class Seq2SeqTSPredictionOutput(ModelOutput):
|
||||
|
||||
loss: Optional[torch.FloatTensor] = None
|
||||
params: Optional[Tuple[torch.FloatTensor]] = None
|
||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
past_key_values: Optional[EncoderDecoderCache] = None
|
||||
decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
|
@ -585,8 +585,8 @@ def load_tf2_state_dict_in_pytorch_model(pt_model, tf_state_dict, allow_missing_
|
||||
loaded_pt_weights_data_ptr = {}
|
||||
missing_keys_pt = []
|
||||
for pt_weight_name, pt_weight in current_pt_params_dict.items():
|
||||
# Handle PyTorch shared weight ()not duplicated in TF 2.0
|
||||
if pt_weight.data_ptr() in loaded_pt_weights_data_ptr:
|
||||
# Handle PyTorch shared weight not duplicated in TF 2.0
|
||||
if pt_weight.data_ptr() in loaded_pt_weights_data_ptr and pt_weight.data_ptr() != 0:
|
||||
new_pt_params_dict[pt_weight_name] = loaded_pt_weights_data_ptr[pt_weight.data_ptr()]
|
||||
continue
|
||||
|
||||
|
@ -63,7 +63,11 @@ from .integrations.flex_attention import flex_attention_forward
|
||||
from .integrations.sdpa_attention import sdpa_attention_forward
|
||||
from .integrations.tensor_parallel import (
|
||||
SUPPORTED_TP_STYLES,
|
||||
_get_parameter_tp_plan,
|
||||
repack_weights,
|
||||
replace_state_dict_local_with_dtensor,
|
||||
shard_and_distribute_module,
|
||||
verify_tp_plan,
|
||||
)
|
||||
from .loss.loss_utils import LOSS_MAPPING
|
||||
from .pytorch_utils import ( # noqa: F401
|
||||
@ -122,6 +126,7 @@ from .utils import (
|
||||
from .utils.hub import create_and_tag_model_card, get_checkpoint_shard_files
|
||||
from .utils.import_utils import (
|
||||
ENV_VARS_TRUE_VALUES,
|
||||
is_huggingface_hub_greater_or_equal,
|
||||
is_sagemaker_mp_enabled,
|
||||
is_torch_fx_proxy,
|
||||
is_torchdynamo_compiling,
|
||||
@ -167,6 +172,9 @@ _is_quantized = False
|
||||
_is_ds_init_called = False
|
||||
_torch_distributed_available = torch.distributed.is_available()
|
||||
|
||||
if _torch_distributed_available and is_torch_greater_or_equal("2.5"):
|
||||
from torch.distributed.tensor import DTensor
|
||||
|
||||
|
||||
def is_fsdp_enabled():
|
||||
return (
|
||||
@ -219,22 +227,19 @@ TORCH_INIT_FUNCTIONS = {
|
||||
# DO NOT MODIFY, KEPT FOR BC ONLY
|
||||
VLMS = [
|
||||
"aria",
|
||||
"aya_vision",
|
||||
"ayavision",
|
||||
"emu3",
|
||||
"fuyu",
|
||||
"got_ocr2",
|
||||
"gotocr2",
|
||||
"gemma3",
|
||||
"internvl",
|
||||
"llava",
|
||||
"llava_next",
|
||||
"llava_next_video",
|
||||
"llava_onevision",
|
||||
"llava", # all llava prefixed models fall under this check
|
||||
"mistral3",
|
||||
"mllama",
|
||||
"paligemma",
|
||||
"qwen2_vl",
|
||||
"qwem2_5_vl",
|
||||
"video_llava",
|
||||
"qwen2vl",
|
||||
"qwen2_5_vl",
|
||||
"videollava",
|
||||
"vipllava",
|
||||
]
|
||||
|
||||
@ -883,6 +888,7 @@ def _get_resolved_checkpoint_files(
|
||||
user_agent: dict,
|
||||
revision: str,
|
||||
commit_hash: Optional[str],
|
||||
transformers_explicit_filename: Optional[str] = None,
|
||||
) -> Tuple[Optional[List[str]], Optional[Dict]]:
|
||||
"""Get all the checkpoint filenames based on `pretrained_model_name_or_path`, and optional metadata if the
|
||||
checkpoints are sharded.
|
||||
@ -894,7 +900,11 @@ def _get_resolved_checkpoint_files(
|
||||
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
||||
is_local = os.path.isdir(pretrained_model_name_or_path)
|
||||
if is_local:
|
||||
if from_tf and os.path.isfile(
|
||||
if transformers_explicit_filename is not None:
|
||||
# If the filename is explicitly defined, load this by default.
|
||||
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, transformers_explicit_filename)
|
||||
is_sharded = transformers_explicit_filename.endswith(".safetensors.index.json")
|
||||
elif from_tf and os.path.isfile(
|
||||
os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index")
|
||||
):
|
||||
# Load from a TF 1.0 checkpoint in priority if from_tf
|
||||
@ -982,7 +992,10 @@ def _get_resolved_checkpoint_files(
|
||||
resolved_archive_file = download_url(pretrained_model_name_or_path)
|
||||
else:
|
||||
# set correct filename
|
||||
if from_tf:
|
||||
if transformers_explicit_filename is not None:
|
||||
filename = transformers_explicit_filename
|
||||
is_sharded = transformers_explicit_filename.endswith(".safetensors.index.json")
|
||||
elif from_tf:
|
||||
filename = TF2_WEIGHTS_NAME
|
||||
elif from_flax:
|
||||
filename = FLAX_WEIGHTS_NAME
|
||||
@ -3407,6 +3420,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
if safe_serialization and not is_safetensors_available():
|
||||
raise ImportError("`safe_serialization` requires the `safetensors library: `pip install safetensors`.")
|
||||
|
||||
# we need to check against tp_size, not tp_plan, as tp_plan is substituted to the class one
|
||||
if self._tp_size is not None and not is_huggingface_hub_greater_or_equal("0.31.4"):
|
||||
raise ImportError(
|
||||
"Saving a model with tensor parallelism requires `huggingface_hub` version 0.31.4 or higher."
|
||||
)
|
||||
|
||||
if os.path.isfile(save_directory):
|
||||
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
||||
return
|
||||
@ -3534,6 +3553,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
# Rename state_dict keys before saving to file. Do nothing unless overridden in a particular model.
|
||||
# (initially introduced with TimmWrapperModel to remove prefix and make checkpoints compatible with timm)
|
||||
state_dict = self._fix_state_dict_keys_on_save(state_dict)
|
||||
# If model was sharded, we cannot properly determine sizes of tensors that `local_*` strategy was used,
|
||||
# therefore we replace them with DTensors that are equivalently sharded
|
||||
if self._tp_size is not None:
|
||||
state_dict = replace_state_dict_local_with_dtensor(state_dict, self._tp_plan, self._device_mesh)
|
||||
|
||||
if safe_serialization:
|
||||
# Safetensors does not allow tensor aliasing.
|
||||
@ -3542,7 +3565,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
for name, tensor in state_dict.items():
|
||||
# Sometimes in the state_dict we have non-tensor objects.
|
||||
# e.g. in bitsandbytes we have some `str` objects in the state_dict
|
||||
if isinstance(tensor, torch.Tensor):
|
||||
if isinstance(tensor, torch.Tensor) or isinstance(tensor, DTensor):
|
||||
ptrs[id_tensor_storage(tensor)].append(name)
|
||||
else:
|
||||
# In the non-tensor case, fall back to the pointer of the object itself
|
||||
@ -3652,7 +3675,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
for shard_file, tensors in filename_to_tensors:
|
||||
shard = {}
|
||||
for tensor in tensors:
|
||||
shard[tensor] = state_dict[tensor].contiguous()
|
||||
if isinstance(state_dict[tensor], DTensor):
|
||||
full_tensor = state_dict[tensor].full_tensor()
|
||||
# to get the correctly ordered tensor we need to repack if packed
|
||||
if _get_parameter_tp_plan(tensor, self._tp_plan) in ("local_packed_rowwise",):
|
||||
full_tensor = repack_weights(full_tensor, -1, self._tp_size, 2)
|
||||
shard[tensor] = full_tensor.contiguous() # only do contiguous after it's permuted correctly
|
||||
else:
|
||||
shard[tensor] = state_dict[tensor].contiguous()
|
||||
# delete reference, see https://github.com/huggingface/transformers/pull/34890
|
||||
del state_dict[tensor]
|
||||
|
||||
@ -4107,6 +4137,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
gguf_file = kwargs.pop("gguf_file", None)
|
||||
tp_plan = kwargs.pop("tp_plan", None)
|
||||
tp_size = kwargs.pop("tp_size", None)
|
||||
trust_remote_code = kwargs.pop("trust_remote_code", None)
|
||||
|
||||
# Load models with hardcoded key mapping on class for VLMs only, to keep BC and standardize model
|
||||
if any(allowed_name in cls.__name__.lower() for allowed_name in VLMS):
|
||||
@ -4116,7 +4147,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
|
||||
# Not used anymore -- remove them from the kwargs
|
||||
_ = kwargs.pop("resume_download", None)
|
||||
_ = kwargs.pop("trust_remote_code", None)
|
||||
_ = kwargs.pop("mirror", None)
|
||||
_ = kwargs.pop("_fast_init", True)
|
||||
_ = kwargs.pop("low_cpu_mem_usage", None)
|
||||
@ -4364,6 +4394,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
|
||||
model_kwargs = kwargs
|
||||
|
||||
transformers_explicit_filename = getattr(config, "transformers_weights", None)
|
||||
|
||||
if transformers_explicit_filename is not None:
|
||||
if not transformers_explicit_filename.endswith(
|
||||
".safetensors"
|
||||
) and not transformers_explicit_filename.endswith(".safetensors.index.json"):
|
||||
raise ValueError(
|
||||
"The transformers file in the config seems to be incorrect: it is neither a safetensors file "
|
||||
"(*.safetensors) nor a safetensors index file (*.safetensors.index.json): "
|
||||
f"{transformers_explicit_filename}"
|
||||
)
|
||||
|
||||
pre_quantized = hasattr(config, "quantization_config")
|
||||
if pre_quantized and not AutoHfQuantizer.supports_quant_method(config.quantization_config):
|
||||
pre_quantized = False
|
||||
@ -4432,6 +4474,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
user_agent=user_agent,
|
||||
revision=revision,
|
||||
commit_hash=commit_hash,
|
||||
transformers_explicit_filename=transformers_explicit_filename,
|
||||
)
|
||||
|
||||
is_sharded = sharded_metadata is not None
|
||||
@ -4587,6 +4630,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
|
||||
# record tp degree the model sharded to
|
||||
model._tp_size = tp_size
|
||||
model._device_mesh = device_mesh
|
||||
|
||||
# make sure token embedding weights are still tied if needed
|
||||
model.tie_weights()
|
||||
@ -4594,30 +4638,44 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
# Set model in evaluation mode to deactivate DropOut modules by default
|
||||
model.eval()
|
||||
|
||||
# If it is a model with generation capabilities, attempt to load the generation config
|
||||
# If it is a model with generation capabilities, attempt to load generation files (generation config,
|
||||
# custom generate function)
|
||||
if model.can_generate() and generation_config is not None:
|
||||
logger.info("The user-defined `generation_config` will be used to override the default generation config.")
|
||||
model.generation_config = model.generation_config.from_dict(generation_config.to_dict())
|
||||
elif model.can_generate() and pretrained_model_name_or_path is not None:
|
||||
repo_loading_kwargs = {
|
||||
"cache_dir": cache_dir,
|
||||
"force_download": force_download,
|
||||
"proxies": proxies,
|
||||
"local_files_only": local_files_only,
|
||||
"token": token,
|
||||
"revision": revision,
|
||||
"subfolder": subfolder,
|
||||
**kwargs,
|
||||
}
|
||||
# Load generation config
|
||||
try:
|
||||
model.generation_config = GenerationConfig.from_pretrained(
|
||||
pretrained_model_name_or_path,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
_from_auto=from_auto_class,
|
||||
_from_pipeline=from_pipeline,
|
||||
**kwargs,
|
||||
**repo_loading_kwargs,
|
||||
)
|
||||
except OSError:
|
||||
logger.info(
|
||||
"Generation config file not found, using a generation config created from the model config."
|
||||
)
|
||||
pass
|
||||
# Load custom generate function if `pretrained_model_name_or_path` defines it (and override `generate`)
|
||||
if hasattr(model, "load_custom_generate"):
|
||||
try:
|
||||
custom_generate = model.load_custom_generate(
|
||||
pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **repo_loading_kwargs
|
||||
)
|
||||
model.generate = functools.partial(custom_generate, model=model)
|
||||
except OSError: # there is no custom generate function
|
||||
pass
|
||||
|
||||
# Dispatch model with hooks on all devices if necessary (not needed with a tp_plan, so we skip it as it slightly
|
||||
# harm performances)
|
||||
@ -4963,6 +5021,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
if hf_quantizer is not None:
|
||||
expected_keys = hf_quantizer.update_expected_keys(model_to_load, expected_keys, checkpoint_keys)
|
||||
|
||||
if logger.level >= logging.WARNING:
|
||||
verify_tp_plan(expected_keys, getattr(model_to_load, "_tp_plan", None))
|
||||
|
||||
# Warmup cuda to load the weights much faster on devices
|
||||
if device_map is not None and not is_hqq_or_quark:
|
||||
expanded_device_map = expand_device_map(device_map, expected_keys)
|
||||
|
@ -367,15 +367,15 @@ class AlbertSdpaAttention(AlbertAttention):
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
output_attentions: bool = False,
|
||||
) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
|
||||
if self.position_embedding_type != "absolute" or output_attentions or head_mask is not None:
|
||||
if self.position_embedding_type != "absolute" or output_attentions:
|
||||
logger.warning(
|
||||
"AlbertSdpaAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support "
|
||||
"non-absolute `position_embedding_type` or `output_attentions=True` or `head_mask`. Falling back to "
|
||||
"non-absolute `position_embedding_type` or `output_attentions=True` . Falling back to "
|
||||
"the eager attention implementation, but specifying the eager implementation will be required from "
|
||||
"Transformers version v5.0.0 onwards. This warning can be removed using the argument "
|
||||
'`attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
return super().forward(hidden_states, attention_mask, head_mask, output_attentions)
|
||||
return super().forward(hidden_states, attention_mask, output_attentions=output_attentions)
|
||||
|
||||
batch_size, seq_len, _ = hidden_states.size()
|
||||
query_layer = self.transpose_for_scores(self.query(hidden_states))
|
||||
|
@ -1197,7 +1197,7 @@ class AriaModel(AriaPreTrainedModel):
|
||||
def get_image_features(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
pixel_mask: torch.FloatTensor = None,
|
||||
pixel_mask: Optional[torch.FloatTensor] = None,
|
||||
vision_feature_layer: int = -1,
|
||||
):
|
||||
"""
|
||||
@ -1208,13 +1208,16 @@ class AriaModel(AriaPreTrainedModel):
|
||||
The tensors corresponding to the input images.
|
||||
pixel_mask (`torch.FloatTensor]`, *optional*):
|
||||
The tensors corresponding to the input image mask.
|
||||
vision_feature_layer (`Union[int, List[int]]`):
|
||||
vision_feature_layer (`Union[int, List[int]]`, *optional*):
|
||||
The index of the layer to select the vision feature. If multiple indices are provided,
|
||||
the vision feature of the corresponding indices will be concatenated to form the
|
||||
vision features.
|
||||
Returns:
|
||||
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
|
||||
"""
|
||||
vision_feature_layer = (
|
||||
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
|
||||
)
|
||||
patch_attention_mask = self._create_patch_attention_mask(pixel_mask)
|
||||
image_outputs = self.vision_tower(
|
||||
pixel_values, patch_attention_mask=patch_attention_mask, output_hidden_states=True
|
||||
|
@ -1325,7 +1325,7 @@ class AriaModel(LlavaModel):
|
||||
def get_image_features(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
pixel_mask: torch.FloatTensor = None,
|
||||
pixel_mask: Optional[torch.FloatTensor] = None,
|
||||
vision_feature_layer: int = -1,
|
||||
):
|
||||
"""
|
||||
@ -1336,13 +1336,16 @@ class AriaModel(LlavaModel):
|
||||
The tensors corresponding to the input images.
|
||||
pixel_mask (`torch.FloatTensor]`, *optional*):
|
||||
The tensors corresponding to the input image mask.
|
||||
vision_feature_layer (`Union[int, List[int]]`):
|
||||
vision_feature_layer (`Union[int, List[int]]`, *optional*):
|
||||
The index of the layer to select the vision feature. If multiple indices are provided,
|
||||
the vision feature of the corresponding indices will be concatenated to form the
|
||||
vision features.
|
||||
Returns:
|
||||
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
|
||||
"""
|
||||
vision_feature_layer = (
|
||||
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
|
||||
)
|
||||
patch_attention_mask = self._create_patch_attention_mask(pixel_mask)
|
||||
image_outputs = self.vision_tower(
|
||||
pixel_values, patch_attention_mask=patch_attention_mask, output_hidden_states=True
|
||||
|
@ -27,6 +27,7 @@ if TYPE_CHECKING:
|
||||
from .modeling_tf_auto import *
|
||||
from .processing_auto import *
|
||||
from .tokenization_auto import *
|
||||
from .video_processing_auto import *
|
||||
else:
|
||||
import sys
|
||||
|
||||
|
@ -452,7 +452,7 @@ class _BaseAutoModelClass:
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||
config = kwargs.pop("config", None)
|
||||
trust_remote_code = kwargs.pop("trust_remote_code", None)
|
||||
trust_remote_code = kwargs.get("trust_remote_code", None)
|
||||
kwargs["_from_auto"] = True
|
||||
hub_kwargs_names = [
|
||||
"cache_dir",
|
||||
@ -531,7 +531,6 @@ class _BaseAutoModelClass:
|
||||
config, kwargs = AutoConfig.from_pretrained(
|
||||
pretrained_model_name_or_path,
|
||||
return_unused_kwargs=True,
|
||||
trust_remote_code=trust_remote_code,
|
||||
code_revision=code_revision,
|
||||
_commit_hash=commit_hash,
|
||||
**hub_kwargs,
|
||||
@ -549,6 +548,7 @@ class _BaseAutoModelClass:
|
||||
trust_remote_code = resolve_trust_remote_code(
|
||||
trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code
|
||||
)
|
||||
kwargs["trust_remote_code"] = trust_remote_code
|
||||
|
||||
# Set the adapter kwargs
|
||||
kwargs["adapter_kwargs"] = adapter_kwargs
|
||||
@ -730,13 +730,13 @@ def add_generation_mixin_to_remote_model(model_class):
|
||||
|
||||
# 3. Prior to v4.45, we could detect whether a model was `generate`-compatible if it had its own `generate` and/or
|
||||
# `prepare_inputs_for_generation` method.
|
||||
has_custom_generate = hasattr(model_class, "generate") and "GenerationMixin" not in str(
|
||||
has_custom_generate_in_class = hasattr(model_class, "generate") and "GenerationMixin" not in str(
|
||||
getattr(model_class, "generate")
|
||||
)
|
||||
has_custom_prepare_inputs = hasattr(model_class, "prepare_inputs_for_generation") and "GenerationMixin" not in str(
|
||||
getattr(model_class, "prepare_inputs_for_generation")
|
||||
)
|
||||
if has_custom_generate or has_custom_prepare_inputs:
|
||||
if has_custom_generate_in_class or has_custom_prepare_inputs:
|
||||
model_class_with_generation_mixin = type(
|
||||
model_class.__name__, (model_class, GenerationMixin), {**model_class.__dict__}
|
||||
)
|
||||
|
@ -128,7 +128,7 @@ else:
|
||||
("owlvit", ("OwlViTImageProcessor", "OwlViTImageProcessorFast")),
|
||||
("paligemma", ("SiglipImageProcessor", "SiglipImageProcessorFast")),
|
||||
("perceiver", ("PerceiverImageProcessor", "PerceiverImageProcessorFast")),
|
||||
("phi4_multimodal", "Phi4MultimodalImageProcessorFast"),
|
||||
("phi4_multimodal", ("Phi4MultimodalImageProcessorFast",)),
|
||||
("pix2struct", ("Pix2StructImageProcessor",)),
|
||||
("pixtral", ("PixtralImageProcessor", "PixtralImageProcessorFast")),
|
||||
("poolformer", ("PoolFormerImageProcessor", "PoolFormerImageProcessorFast")),
|
||||
@ -161,7 +161,7 @@ else:
|
||||
("upernet", ("SegformerImageProcessor",)),
|
||||
("van", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")),
|
||||
("videomae", ("VideoMAEImageProcessor",)),
|
||||
("vilt", ("ViltImageProcessor",)),
|
||||
("vilt", ("ViltImageProcessor", "ViltImageProcessorFast")),
|
||||
("vipllava", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
|
||||
("vit", ("ViTImageProcessor", "ViTImageProcessorFast")),
|
||||
("vit_hybrid", ("ViTHybridImageProcessor",)),
|
||||
|
@ -177,6 +177,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
||||
("lilt", "LiltModel"),
|
||||
("llama", "LlamaModel"),
|
||||
("llama4", "Llama4ForConditionalGeneration"),
|
||||
("llama4_text", "Llama4TextModel"),
|
||||
("llava", "LlavaModel"),
|
||||
("llava_next", "LlavaNextModel"),
|
||||
("llava_next_video", "LlavaNextVideoModel"),
|
||||
@ -1520,11 +1521,6 @@ MODEL_FOR_BACKBONE_MAPPING_NAMES = OrderedDict(
|
||||
MODEL_FOR_MASK_GENERATION_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
("sam", "SamModel"),
|
||||
]
|
||||
)
|
||||
|
||||
MODEL_FOR_MASK_GENERATION_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
("sam_hq", "SamHQModel"),
|
||||
]
|
||||
)
|
||||
|
@ -28,7 +28,14 @@ from ...feature_extraction_utils import FeatureExtractionMixin
|
||||
from ...image_processing_utils import ImageProcessingMixin
|
||||
from ...processing_utils import ProcessorMixin
|
||||
from ...tokenization_utils import TOKENIZER_CONFIG_FILE
|
||||
from ...utils import FEATURE_EXTRACTOR_NAME, PROCESSOR_NAME, cached_file, logging
|
||||
from ...utils import (
|
||||
FEATURE_EXTRACTOR_NAME,
|
||||
PROCESSOR_NAME,
|
||||
VIDEO_PROCESSOR_NAME,
|
||||
cached_file,
|
||||
logging,
|
||||
)
|
||||
from ...video_processing_utils import BaseVideoProcessor
|
||||
from .auto_factory import _LazyAutoMapping
|
||||
from .configuration_auto import (
|
||||
CONFIG_MAPPING_NAMES,
|
||||
@ -295,14 +302,31 @@ class AutoProcessor:
|
||||
if "AutoProcessor" in config_dict.get("auto_map", {}):
|
||||
processor_auto_map = config_dict["auto_map"]["AutoProcessor"]
|
||||
|
||||
# If not found, let's check whether the processor class is saved in a feature extractor config
|
||||
if preprocessor_config_file is not None and processor_class is None:
|
||||
config_dict, _ = FeatureExtractionMixin.get_feature_extractor_dict(
|
||||
pretrained_model_name_or_path, **kwargs
|
||||
# Saved as video processor
|
||||
if preprocessor_config_file is None:
|
||||
preprocessor_config_file = cached_file(
|
||||
pretrained_model_name_or_path, VIDEO_PROCESSOR_NAME, **cached_file_kwargs
|
||||
)
|
||||
processor_class = config_dict.get("processor_class", None)
|
||||
if "AutoProcessor" in config_dict.get("auto_map", {}):
|
||||
processor_auto_map = config_dict["auto_map"]["AutoProcessor"]
|
||||
if preprocessor_config_file is not None:
|
||||
config_dict, _ = BaseVideoProcessor.get_video_processor_dict(
|
||||
pretrained_model_name_or_path, **kwargs
|
||||
)
|
||||
processor_class = config_dict.get("processor_class", None)
|
||||
if "AutoProcessor" in config_dict.get("auto_map", {}):
|
||||
processor_auto_map = config_dict["auto_map"]["AutoProcessor"]
|
||||
|
||||
# Saved as feature extractor
|
||||
if preprocessor_config_file is None:
|
||||
preprocessor_config_file = cached_file(
|
||||
pretrained_model_name_or_path, FEATURE_EXTRACTOR_NAME, **cached_file_kwargs
|
||||
)
|
||||
if preprocessor_config_file is not None and processor_class is None:
|
||||
config_dict, _ = FeatureExtractionMixin.get_feature_extractor_dict(
|
||||
pretrained_model_name_or_path, **kwargs
|
||||
)
|
||||
processor_class = config_dict.get("processor_class", None)
|
||||
if "AutoProcessor" in config_dict.get("auto_map", {}):
|
||||
processor_auto_map = config_dict["auto_map"]["AutoProcessor"]
|
||||
|
||||
if processor_class is None:
|
||||
# Next, let's check whether the processor class is saved in a tokenizer
|
||||
|
382
src/transformers/models/auto/video_processing_auto.py
Normal file
382
src/transformers/models/auto/video_processing_auto.py
Normal file
@ -0,0 +1,382 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""AutoVideoProcessor class."""
|
||||
|
||||
import importlib
|
||||
import json
|
||||
import os
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union
|
||||
|
||||
# Build the list of all video processors
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code
|
||||
from ...utils import CONFIG_NAME, VIDEO_PROCESSOR_NAME, cached_file, is_torchvision_available, logging
|
||||
from ...utils.import_utils import requires
|
||||
from ...video_processing_utils import BaseVideoProcessor
|
||||
from .auto_factory import _LazyAutoMapping
|
||||
from .configuration_auto import (
|
||||
CONFIG_MAPPING_NAMES,
|
||||
AutoConfig,
|
||||
model_type_to_module_name,
|
||||
replace_list_option_in_docstrings,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# This significantly improves completion suggestion performance when
|
||||
# the transformers package is used with Microsoft's Pylance language server.
|
||||
VIDEO_PROCESSOR_MAPPING_NAMES: OrderedDict[str, Tuple[Optional[str], Optional[str]]] = OrderedDict()
|
||||
else:
|
||||
VIDEO_PROCESSOR_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
("instructblip", "InstructBlipVideoVideoProcessor"),
|
||||
("instructblipvideo", "InstructBlipVideoVideoProcessor"),
|
||||
("internvl", "InternVLVideoProcessor"),
|
||||
("llava_next_video", "LlavaNextVideoVideoProcessor"),
|
||||
("llava_onevision", "LlavaOnevisionVideoProcessor"),
|
||||
("qwen2_5_omni", "Qwen2VLVideoProcessor"),
|
||||
("qwen2_5_vl", "Qwen2VLVideoProcessor"),
|
||||
("qwen2_vl", "Qwen2VLVideoProcessor"),
|
||||
("smolvlm", "SmolVLMVideoProcessor"),
|
||||
("video_llava", "VideoLlavaVideoProcessor"),
|
||||
]
|
||||
)
|
||||
|
||||
for model_type, video_processors in VIDEO_PROCESSOR_MAPPING_NAMES.items():
|
||||
fast_video_processor_class = video_processors
|
||||
|
||||
# If the torchvision is not available, we set it to None
|
||||
if not is_torchvision_available():
|
||||
fast_video_processor_class = None
|
||||
|
||||
VIDEO_PROCESSOR_MAPPING_NAMES[model_type] = fast_video_processor_class
|
||||
|
||||
VIDEO_PROCESSOR_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, VIDEO_PROCESSOR_MAPPING_NAMES)
|
||||
|
||||
|
||||
def video_processor_class_from_name(class_name: str):
|
||||
for module_name, extractors in VIDEO_PROCESSOR_MAPPING_NAMES.items():
|
||||
if class_name in extractors:
|
||||
module_name = model_type_to_module_name(module_name)
|
||||
|
||||
module = importlib.import_module(f".{module_name}", "transformers.models")
|
||||
try:
|
||||
return getattr(module, class_name)
|
||||
except AttributeError:
|
||||
continue
|
||||
|
||||
for _, extractor in VIDEO_PROCESSOR_MAPPING._extra_content.items():
|
||||
if getattr(extractor, "__name__", None) == class_name:
|
||||
return extractor
|
||||
|
||||
# We did not find the class, but maybe it's because a dep is missing. In that case, the class will be in the main
|
||||
# init and we return the proper dummy to get an appropriate error message.
|
||||
main_module = importlib.import_module("transformers")
|
||||
if hasattr(main_module, class_name):
|
||||
return getattr(main_module, class_name)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_video_processor_config(
|
||||
pretrained_model_name_or_path: Union[str, os.PathLike],
|
||||
cache_dir: Optional[Union[str, os.PathLike]] = None,
|
||||
force_download: bool = False,
|
||||
resume_download: Optional[bool] = None,
|
||||
proxies: Optional[Dict[str, str]] = None,
|
||||
token: Optional[Union[bool, str]] = None,
|
||||
revision: Optional[str] = None,
|
||||
local_files_only: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Loads the video processor configuration from a pretrained model video processor configuration.
|
||||
|
||||
Args:
|
||||
pretrained_model_name_or_path (`str` or `os.PathLike`):
|
||||
This can be either:
|
||||
|
||||
- a string, the *model id* of a pretrained model configuration hosted inside a model repo on
|
||||
huggingface.co.
|
||||
- a path to a *directory* containing a configuration file saved using the
|
||||
[`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
|
||||
|
||||
cache_dir (`str` or `os.PathLike`, *optional*):
|
||||
Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
|
||||
cache should not be used.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force to (re-)download the configuration files and override the cached versions if they
|
||||
exist.
|
||||
resume_download:
|
||||
Deprecated and ignored. All downloads are now resumed by default when possible.
|
||||
Will be removed in v5 of Transformers.
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
|
||||
token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
|
||||
when running `huggingface-cli login` (stored in `~/.huggingface`).
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
||||
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
||||
identifier allowed by git.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
If `True`, will only try to load the video processor configuration from local files.
|
||||
|
||||
<Tip>
|
||||
|
||||
Passing `token=True` is required when you want to use a private model.
|
||||
|
||||
</Tip>
|
||||
|
||||
Returns:
|
||||
`Dict`: The configuration of the video processor.
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
# Download configuration from huggingface.co and cache.
|
||||
video_processor_config = get_video_processor_config("llava-hf/llava-onevision-qwen2-0.5b-ov-hf")
|
||||
# This model does not have a video processor config so the result will be an empty dict.
|
||||
video_processor_config = get_video_processor_config("FacebookAI/xlm-roberta-base")
|
||||
|
||||
# Save a pretrained video processor locally and you can reload its config
|
||||
from transformers import AutoVideoProcessor
|
||||
|
||||
video_processor = AutoVideoProcessor.from_pretrained("llava-hf/llava-onevision-qwen2-0.5b-ov-hf")
|
||||
video_processor.save_pretrained("video-processor-test")
|
||||
video_processor = get_video_processor_config("video-processor-test")
|
||||
```"""
|
||||
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||
if use_auth_token is not None:
|
||||
warnings.warn(
|
||||
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
if token is not None:
|
||||
raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
|
||||
token = use_auth_token
|
||||
|
||||
resolved_config_file = cached_file(
|
||||
pretrained_model_name_or_path,
|
||||
VIDEO_PROCESSOR_NAME,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
revision=revision,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
if resolved_config_file is None:
|
||||
logger.info(
|
||||
"Could not locate the video processor configuration file, will try to use the model config instead."
|
||||
)
|
||||
return {}
|
||||
|
||||
with open(resolved_config_file, encoding="utf-8") as reader:
|
||||
return json.load(reader)
|
||||
|
||||
|
||||
@requires(backends=("vision", "torchvision"))
|
||||
class AutoVideoProcessor:
|
||||
r"""
|
||||
This is a generic video processor class that will be instantiated as one of the video processor classes of the
|
||||
library when created with the [`AutoVideoProcessor.from_pretrained`] class method.
|
||||
|
||||
This class cannot be instantiated directly using `__init__()` (throws an error).
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
raise EnvironmentError(
|
||||
"AutoVideoProcessor is designed to be instantiated "
|
||||
"using the `AutoVideoProcessor.from_pretrained(pretrained_model_name_or_path)` method."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@replace_list_option_in_docstrings(VIDEO_PROCESSOR_MAPPING_NAMES)
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
|
||||
r"""
|
||||
Instantiate one of the video processor classes of the library from a pretrained model vocabulary.
|
||||
|
||||
The video processor class to instantiate is selected based on the `model_type` property of the config object
|
||||
(either passed as an argument or loaded from `pretrained_model_name_or_path` if possible), or when it's
|
||||
missing, by falling back to using pattern matching on `pretrained_model_name_or_path`:
|
||||
|
||||
List options
|
||||
|
||||
Params:
|
||||
pretrained_model_name_or_path (`str` or `os.PathLike`):
|
||||
This can be either:
|
||||
|
||||
- a string, the *model id* of a pretrained video_processor hosted inside a model repo on
|
||||
huggingface.co.
|
||||
- a path to a *directory* containing a video processor file saved using the
|
||||
[`~video_processing_utils.BaseVideoProcessor.save_pretrained`] method, e.g.,
|
||||
`./my_model_directory/`.
|
||||
- a path or url to a saved video processor JSON *file*, e.g.,
|
||||
`./my_model_directory/preprocessor_config.json`.
|
||||
cache_dir (`str` or `os.PathLike`, *optional*):
|
||||
Path to a directory in which a downloaded pretrained model video processor should be cached if the
|
||||
standard cache should not be used.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force to (re-)download the video processor files and override the cached versions if
|
||||
they exist.
|
||||
resume_download:
|
||||
Deprecated and ignored. All downloads are now resumed by default when possible.
|
||||
Will be removed in v5 of Transformers.
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
|
||||
token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
|
||||
when running `huggingface-cli login` (stored in `~/.huggingface`).
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
||||
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
||||
identifier allowed by git.
|
||||
return_unused_kwargs (`bool`, *optional*, defaults to `False`):
|
||||
If `False`, then this function returns just the final video processor object. If `True`, then this
|
||||
functions returns a `Tuple(video_processor, unused_kwargs)` where *unused_kwargs* is a dictionary
|
||||
consisting of the key/value pairs whose keys are not video processor attributes: i.e., the part of
|
||||
`kwargs` which has not been used to update `video_processor` and is otherwise ignored.
|
||||
trust_remote_code (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
|
||||
should only be set to `True` for repositories you trust and in which you have read the code, as it will
|
||||
execute code present on the Hub on your local machine.
|
||||
kwargs (`Dict[str, Any]`, *optional*):
|
||||
The values in kwargs of any keys which are video processor attributes will be used to override the
|
||||
loaded values. Behavior concerning key/value pairs whose keys are *not* video processor attributes is
|
||||
controlled by the `return_unused_kwargs` keyword parameter.
|
||||
|
||||
<Tip>
|
||||
|
||||
Passing `token=True` is required when you want to use a private model.
|
||||
|
||||
</Tip>
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoVideoProcessor
|
||||
|
||||
>>> # Download video processor from huggingface.co and cache.
|
||||
>>> video_processor = AutoVideoProcessor.from_pretrained("llava-hf/llava-onevision-qwen2-0.5b-ov-hf")
|
||||
|
||||
>>> # If video processor files are in a directory (e.g. video processor was saved using *save_pretrained('./test/saved_model/')*)
|
||||
>>> # video_processor = AutoVideoProcessor.from_pretrained("./test/saved_model/")
|
||||
```"""
|
||||
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||
if use_auth_token is not None:
|
||||
warnings.warn(
|
||||
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
if kwargs.get("token", None) is not None:
|
||||
raise ValueError(
|
||||
"`token` and `use_auth_token` are both specified. Please set only the argument `token`."
|
||||
)
|
||||
kwargs["token"] = use_auth_token
|
||||
|
||||
config = kwargs.pop("config", None)
|
||||
trust_remote_code = kwargs.pop("trust_remote_code", None)
|
||||
kwargs["_from_auto"] = True
|
||||
|
||||
config_dict, _ = BaseVideoProcessor.get_video_processor_dict(pretrained_model_name_or_path, **kwargs)
|
||||
video_processor_class = config_dict.get("video_processor_type", None)
|
||||
video_processor_auto_map = None
|
||||
if "AutoVideoProcessor" in config_dict.get("auto_map", {}):
|
||||
video_processor_auto_map = config_dict["auto_map"]["AutoVideoProcessor"]
|
||||
|
||||
# If we still don't have the video processor class, check if we're loading from a previous feature extractor config
|
||||
# and if so, infer the video processor class from there.
|
||||
if video_processor_class is None and video_processor_auto_map is None:
|
||||
feature_extractor_class = config_dict.pop("feature_extractor_type", None)
|
||||
if feature_extractor_class is not None:
|
||||
video_processor_class = feature_extractor_class.replace("FeatureExtractor", "VideoProcessor")
|
||||
if "AutoFeatureExtractor" in config_dict.get("auto_map", {}):
|
||||
feature_extractor_auto_map = config_dict["auto_map"]["AutoFeatureExtractor"]
|
||||
video_processor_auto_map = feature_extractor_auto_map.replace("FeatureExtractor", "VideoProcessor")
|
||||
|
||||
# If we don't find the video processor class in the video processor config, let's try the model config.
|
||||
if video_processor_class is None and video_processor_auto_map is None:
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
config = AutoConfig.from_pretrained(
|
||||
pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs
|
||||
)
|
||||
# It could be in `config.video_processor_type``
|
||||
video_processor_class = getattr(config, "video_processor_type", None)
|
||||
if hasattr(config, "auto_map") and "AutoVideoProcessor" in config.auto_map:
|
||||
video_processor_auto_map = config.auto_map["AutoVideoProcessor"]
|
||||
|
||||
if video_processor_class is not None:
|
||||
video_processor_class = video_processor_class_from_name(video_processor_class)
|
||||
|
||||
has_remote_code = video_processor_auto_map is not None
|
||||
has_local_code = video_processor_class is not None or type(config) in VIDEO_PROCESSOR_MAPPING
|
||||
trust_remote_code = resolve_trust_remote_code(
|
||||
trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code
|
||||
)
|
||||
|
||||
if has_remote_code and trust_remote_code:
|
||||
class_ref = video_processor_auto_map
|
||||
video_processor_class = get_class_from_dynamic_module(class_ref, pretrained_model_name_or_path, **kwargs)
|
||||
_ = kwargs.pop("code_revision", None)
|
||||
if os.path.isdir(pretrained_model_name_or_path):
|
||||
video_processor_class.register_for_auto_class()
|
||||
return video_processor_class.from_dict(config_dict, **kwargs)
|
||||
elif video_processor_class is not None:
|
||||
return video_processor_class.from_dict(config_dict, **kwargs)
|
||||
# Last try: we use the VIDEO_PROCESSOR_MAPPING.
|
||||
elif type(config) in VIDEO_PROCESSOR_MAPPING:
|
||||
video_processor_class = VIDEO_PROCESSOR_MAPPING[type(config)]
|
||||
|
||||
if video_processor_class is not None:
|
||||
return video_processor_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
else:
|
||||
raise ValueError(
|
||||
"This video processor cannot be instantiated. Please make sure you have `torchvision` installed."
|
||||
)
|
||||
|
||||
raise ValueError(
|
||||
f"Unrecognized video processor in {pretrained_model_name_or_path}. Should have a "
|
||||
f"`video_processor_type` key in its {VIDEO_PROCESSOR_NAME} of {CONFIG_NAME}, or one of the following "
|
||||
f"`model_type` keys in its {CONFIG_NAME}: {', '.join(c for c in VIDEO_PROCESSOR_MAPPING_NAMES.keys())}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def register(
|
||||
config_class,
|
||||
video_processor_class,
|
||||
exist_ok=False,
|
||||
):
|
||||
"""
|
||||
Register a new video processor for this class.
|
||||
|
||||
Args:
|
||||
config_class ([`PretrainedConfig`]):
|
||||
The configuration corresponding to the model to register.
|
||||
video_processor_class ([`BaseVideoProcessor`]):
|
||||
The video processor to register.
|
||||
"""
|
||||
VIDEO_PROCESSOR_MAPPING.register(config_class, video_processor_class, exist_ok=exist_ok)
|
||||
|
||||
|
||||
__all__ = ["VIDEO_PROCESSOR_MAPPING", "AutoVideoProcessor"]
|
@ -370,13 +370,16 @@ class AutoformerSinusoidalPositionalEmbedding(nn.Embedding):
|
||||
self.weight = nn.Parameter(out, requires_grad=False)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0) -> torch.Tensor:
|
||||
def forward(
|
||||
self, input_ids_shape: torch.Size, past_key_values_length: int = 0, position_ids: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
"""`input_ids_shape` is expected to be [bsz x seqlen]."""
|
||||
bsz, seq_len = input_ids_shape[:2]
|
||||
positions = torch.arange(
|
||||
past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
|
||||
)
|
||||
return super().forward(positions)
|
||||
if position_ids is None:
|
||||
bsz, seq_len = input_ids_shape[:2]
|
||||
position_ids = torch.arange(
|
||||
past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
|
||||
)
|
||||
return super().forward(position_ids)
|
||||
|
||||
|
||||
# Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.TimeSeriesValueEmbedding with TimeSeries->Autoformer
|
||||
|
@ -213,8 +213,8 @@ class AyaVisionModel(AyaVisionPreTrainedModel):
|
||||
def get_image_features(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
vision_feature_layer: Union[int, List[int]],
|
||||
vision_feature_select_strategy: str,
|
||||
vision_feature_layer: Optional[Union[int, List[int]]] = None,
|
||||
vision_feature_select_strategy: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@ -223,16 +223,25 @@ class AyaVisionModel(AyaVisionPreTrainedModel):
|
||||
Args:
|
||||
pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`):
|
||||
The tensors corresponding to the input images.
|
||||
vision_feature_layer (`Union[int, List[int]]`):
|
||||
vision_feature_layer (`Union[int, List[int]]`, *optional*):
|
||||
The index of the layer to select the vision feature. If multiple indices are provided,
|
||||
the vision feature of the corresponding indices will be concatenated to form the
|
||||
vision features.
|
||||
vision_feature_select_strategy (`str`):
|
||||
vision_feature_select_strategy (`str`, *optional*):
|
||||
The feature selection strategy used to select the vision feature from the vision backbone.
|
||||
Can be one of `"default"` or `"full"`
|
||||
Returns:
|
||||
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
|
||||
"""
|
||||
vision_feature_layer = (
|
||||
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
|
||||
)
|
||||
vision_feature_select_strategy = (
|
||||
vision_feature_select_strategy
|
||||
if vision_feature_select_strategy is not None
|
||||
else self.config.vision_feature_select_strategy
|
||||
)
|
||||
|
||||
if vision_feature_select_strategy not in ["default", "full"]:
|
||||
raise ValueError(f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}")
|
||||
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -137,12 +137,22 @@ class BeitImageProcessorFast(BaseImageProcessorFast):
|
||||
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
|
||||
return processed_images
|
||||
|
||||
def _preprocess_images(
|
||||
self,
|
||||
images,
|
||||
**kwargs,
|
||||
):
|
||||
"""Preprocesses images."""
|
||||
kwargs["do_reduce_labels"] = False
|
||||
processed_images = self._preprocess(images=images, **kwargs)
|
||||
return processed_images
|
||||
|
||||
def _preprocess_segmentation_maps(
|
||||
self,
|
||||
segmentation_maps,
|
||||
**kwargs,
|
||||
):
|
||||
"""Preprocesses a single segmentation map."""
|
||||
"""Preprocesses segmentation maps."""
|
||||
processed_segmentation_maps = []
|
||||
for segmentation_map in segmentation_maps:
|
||||
segmentation_map = self._process_image(
|
||||
@ -215,7 +225,7 @@ class BeitImageProcessorFast(BaseImageProcessorFast):
|
||||
kwargs.pop("default_to_square")
|
||||
kwargs.pop("data_format")
|
||||
|
||||
images = self._preprocess(
|
||||
images = self._preprocess_images(
|
||||
images=images,
|
||||
**kwargs,
|
||||
)
|
||||
|
@ -1340,7 +1340,6 @@ class BigBirdAttention(nn.Module):
|
||||
attn_weights.value = self.self.value
|
||||
attn_weights.key = self.self.key
|
||||
self.self = attn_weights
|
||||
self.attention_type = value
|
||||
if not self.training:
|
||||
self.self.eval()
|
||||
|
||||
|
@ -24,8 +24,12 @@ from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, EncoderDecoderCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
|
||||
from ...modeling_attn_mask_utils import (
|
||||
AttentionMaskConverter,
|
||||
_prepare_4d_attention_mask,
|
||||
)
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutput,
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
@ -36,10 +40,21 @@ from ...modeling_outputs import (
|
||||
Seq2SeqSequenceClassifierOutput,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import auto_docstring, logging
|
||||
from ...utils import (
|
||||
auto_docstring,
|
||||
is_torch_flex_attn_available,
|
||||
is_torchdynamo_compiling,
|
||||
logging,
|
||||
)
|
||||
from .configuration_bigbird_pegasus import BigBirdPegasusConfig
|
||||
|
||||
|
||||
if is_torch_flex_attn_available():
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
|
||||
from ...integrations.flex_attention import make_flex_block_causal_mask
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_EXPECTED_OUTPUT_SHAPE = [1, 7, 1024]
|
||||
@ -69,13 +84,15 @@ class BigBirdPegasusLearnedPositionalEmbedding(nn.Embedding):
|
||||
def __init__(self, num_embeddings: int, embedding_dim: int):
|
||||
super().__init__(num_embeddings, embedding_dim)
|
||||
|
||||
def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0):
|
||||
"""`input_ids_shape` is expected to be [bsz x seqlen]."""
|
||||
bsz, seq_len = input_ids_shape[:2]
|
||||
positions = torch.arange(
|
||||
past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
|
||||
)
|
||||
return super().forward(positions)
|
||||
def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0, position_ids: torch.Tensor = None):
|
||||
"""`input_ids' shape is expected to be [bsz x seqlen]."""
|
||||
|
||||
if position_ids is None:
|
||||
bsz, seq_len = input_ids_shape[:2]
|
||||
position_ids = torch.arange(
|
||||
past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
|
||||
)
|
||||
return super().forward(position_ids)
|
||||
|
||||
|
||||
# Copied from transformers.models.bart.modeling_bart.BartScaledWordEmbedding with Bart->BigBirdPegasus
|
||||
@ -1114,7 +1131,6 @@ class BigBirdPegasusEncoderAttention(nn.Module):
|
||||
if value == self.attention_type:
|
||||
return
|
||||
|
||||
self.attention_type = value
|
||||
if value == "original_full":
|
||||
# copy all weights to new full attention class
|
||||
attn_weights = BigBirdPegasusSelfAttention(self.config)
|
||||
@ -1136,7 +1152,6 @@ class BigBirdPegasusEncoderAttention(nn.Module):
|
||||
hidden_states,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
past_key_value=None,
|
||||
output_attentions=False,
|
||||
band_mask=None,
|
||||
from_mask=None,
|
||||
@ -1147,12 +1162,11 @@ class BigBirdPegasusEncoderAttention(nn.Module):
|
||||
# Expand dims to enable multiplication in the self-attention module
|
||||
head_mask = head_mask.reshape(1, -1, 1, 1) if head_mask is not None else None
|
||||
|
||||
if self.config.attention_type == "original_full":
|
||||
if self.attention_type == "original_full":
|
||||
self_outputs = self.self(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
else:
|
||||
@ -1178,6 +1192,7 @@ class BigBirdPegasusDecoderAttention(nn.Module):
|
||||
bias: bool = True,
|
||||
is_causal: bool = False,
|
||||
config: Optional[BigBirdPegasusConfig] = None,
|
||||
layer_idx: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
@ -1194,73 +1209,74 @@ class BigBirdPegasusDecoderAttention(nn.Module):
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.is_decoder = is_decoder
|
||||
self.is_causal = is_causal
|
||||
self.layer_idx = layer_idx
|
||||
if layer_idx is None and self.is_decoder:
|
||||
logger.warning_once(
|
||||
f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and "
|
||||
"will lead to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
|
||||
"when creating this class."
|
||||
)
|
||||
|
||||
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||
|
||||
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
key_value_states: Optional[torch.Tensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
cache_position: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
|
||||
# if key_value_states are provided this layer is used as a cross-attention layer
|
||||
# for the decoder
|
||||
is_cross_attention = key_value_states is not None
|
||||
|
||||
bsz, tgt_len, _ = hidden_states.size()
|
||||
|
||||
# get query proj
|
||||
query_states = self.q_proj(hidden_states) * self.scaling
|
||||
# get key, value proj
|
||||
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
|
||||
# is checking that the `sequence_length` of the `past_key_value` is the same as
|
||||
# the provided `key_value_states` to support prefix tuning
|
||||
if (
|
||||
is_cross_attention
|
||||
and past_key_value is not None
|
||||
and past_key_value[0].shape[2] == key_value_states.shape[1]
|
||||
):
|
||||
# reuse k,v, cross_attentions
|
||||
key_states = past_key_value[0]
|
||||
value_states = past_key_value[1]
|
||||
elif is_cross_attention:
|
||||
# cross_attentions
|
||||
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
|
||||
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
|
||||
elif past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
||||
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
else:
|
||||
# self_attention
|
||||
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
||||
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
||||
query_states = self.q_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
query_states = query_states * self.scaling
|
||||
|
||||
if self.is_decoder:
|
||||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||
# key/value_states (first "if" case)
|
||||
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||
past_key_value = (key_states, value_states)
|
||||
if past_key_value is not None:
|
||||
if isinstance(past_key_value, EncoderDecoderCache):
|
||||
is_updated = past_key_value.is_updated.get(self.layer_idx)
|
||||
if is_cross_attention:
|
||||
# after the first generated id, we can subsequently re-use all key/value_states from cache
|
||||
curr_past_key_value = past_key_value.cross_attention_cache
|
||||
else:
|
||||
curr_past_key_value = past_key_value.self_attention_cache
|
||||
else:
|
||||
curr_past_key_value = past_key_value
|
||||
|
||||
current_states = key_value_states if is_cross_attention else hidden_states
|
||||
if is_cross_attention and past_key_value is not None and is_updated:
|
||||
# reuse k,v, cross_attentions
|
||||
key_states = curr_past_key_value.key_cache[self.layer_idx]
|
||||
value_states = curr_past_key_value.value_cache[self.layer_idx]
|
||||
else:
|
||||
key_states = self.k_proj(current_states)
|
||||
value_states = self.v_proj(current_states)
|
||||
key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
if past_key_value is not None:
|
||||
# save all key/value_states to cache to be re-used for fast auto-regressive generation
|
||||
cache_position = cache_position if not is_cross_attention else None
|
||||
key_states, value_states = curr_past_key_value.update(
|
||||
key_states, value_states, self.layer_idx, {"cache_position": cache_position}
|
||||
)
|
||||
# set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
|
||||
if is_cross_attention:
|
||||
past_key_value.is_updated[self.layer_idx] = True
|
||||
|
||||
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
|
||||
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
|
||||
query_states = query_states.reshape(*proj_shape)
|
||||
key_states = key_states.reshape(*proj_shape)
|
||||
value_states = value_states.reshape(*proj_shape)
|
||||
|
||||
@ -1274,10 +1290,7 @@ class BigBirdPegasusDecoderAttention(nn.Module):
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
attention_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||
|
||||
@ -1412,7 +1425,7 @@ class BigBirdPegasusEncoderLayer(nn.Module):
|
||||
|
||||
|
||||
class BigBirdPegasusDecoderLayer(nn.Module):
|
||||
def __init__(self, config: BigBirdPegasusConfig):
|
||||
def __init__(self, config: BigBirdPegasusConfig, layer_idx: Optional[int] = None):
|
||||
super().__init__()
|
||||
self.embed_dim = config.d_model
|
||||
self.self_attn = BigBirdPegasusDecoderAttention(
|
||||
@ -1421,6 +1434,7 @@ class BigBirdPegasusDecoderLayer(nn.Module):
|
||||
dropout=config.attention_dropout,
|
||||
is_decoder=True,
|
||||
bias=config.use_bias,
|
||||
layer_idx=layer_idx,
|
||||
)
|
||||
self.dropout = config.dropout
|
||||
self.activation_fn = ACT2FN[config.activation_function]
|
||||
@ -1433,6 +1447,7 @@ class BigBirdPegasusDecoderLayer(nn.Module):
|
||||
dropout=config.attention_dropout,
|
||||
is_decoder=True,
|
||||
bias=config.use_bias,
|
||||
layer_idx=layer_idx,
|
||||
)
|
||||
self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
||||
self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
|
||||
@ -1448,9 +1463,10 @@ class BigBirdPegasusDecoderLayer(nn.Module):
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = True,
|
||||
cache_position: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
@ -1469,47 +1485,43 @@ class BigBirdPegasusDecoderLayer(nn.Module):
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||
returned tensors for more detail.
|
||||
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
||||
Indices depicting the position of the input sequence tokens in the sequence. It is used to update the
|
||||
cache in the correct position and to infer the complete sequence length.
|
||||
"""
|
||||
residual = hidden_states
|
||||
hidden_states = self.self_attn_layer_norm(hidden_states)
|
||||
|
||||
# Self Attention
|
||||
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
||||
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
||||
# add present self-attn cache to positions 1,2 of present_key_value tuple
|
||||
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||
hidden_states, self_attn_weights, past_key_value = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
past_key_value=self_attn_past_key_value,
|
||||
past_key_value=past_key_value,
|
||||
attention_mask=attention_mask,
|
||||
layer_head_mask=layer_head_mask,
|
||||
output_attentions=output_attentions,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Cross-Attention Block
|
||||
cross_attn_present_key_value = None
|
||||
cross_attn_weights = None
|
||||
if encoder_hidden_states is not None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.encoder_attn_layer_norm(hidden_states)
|
||||
|
||||
# cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
|
||||
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
|
||||
hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
|
||||
hidden_states, cross_attn_weights, past_key_value = self.encoder_attn(
|
||||
hidden_states=hidden_states,
|
||||
key_value_states=encoder_hidden_states,
|
||||
attention_mask=encoder_attention_mask,
|
||||
layer_head_mask=cross_attn_layer_head_mask,
|
||||
past_key_value=cross_attn_past_key_value,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# add cross-attn to positions 3,4 of present_key_value tuple
|
||||
present_key_value = present_key_value + cross_attn_present_key_value
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = self.final_layer_norm(hidden_states)
|
||||
@ -1525,7 +1537,7 @@ class BigBirdPegasusDecoderLayer(nn.Module):
|
||||
outputs += (self_attn_weights, cross_attn_weights)
|
||||
|
||||
if use_cache:
|
||||
outputs += (present_key_value,)
|
||||
outputs += (past_key_value,)
|
||||
|
||||
return outputs
|
||||
|
||||
@ -1563,6 +1575,8 @@ class BigBirdPegasusPreTrainedModel(PreTrainedModel):
|
||||
_no_split_modules = ["BigBirdPegasusEncoderLayer", "BigBirdPegasusDecoderLayer"]
|
||||
_skip_keys_device_placement = "past_key_values"
|
||||
_supports_param_buffer_assignment = False
|
||||
_supports_cache_class = True
|
||||
_supports_static_cache = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.init_std
|
||||
@ -1574,6 +1588,9 @@ class BigBirdPegasusPreTrainedModel(PreTrainedModel):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.data.zero_()
|
||||
|
||||
@property
|
||||
def dummy_inputs(self):
|
||||
@ -1585,6 +1602,131 @@ class BigBirdPegasusPreTrainedModel(PreTrainedModel):
|
||||
}
|
||||
return dummy_inputs
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
attention_mask: Union[torch.Tensor, "BlockMask"],
|
||||
input_tensor: torch.Tensor,
|
||||
cache_position: torch.Tensor,
|
||||
past_key_values: Cache,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
if attention_mask is not None and (attention_mask == 0.0).any():
|
||||
return attention_mask
|
||||
return None
|
||||
if self.config._attn_implementation == "flex_attention":
|
||||
if isinstance(attention_mask, torch.Tensor):
|
||||
attention_mask = make_flex_block_causal_mask(attention_mask)
|
||||
return attention_mask
|
||||
|
||||
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||
# to infer the attention mask.
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
|
||||
|
||||
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
||||
if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
|
||||
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
||||
attention_mask,
|
||||
inputs_embeds=input_tensor,
|
||||
past_key_values_length=past_seen_tokens,
|
||||
is_training=self.training,
|
||||
):
|
||||
return None
|
||||
|
||||
dtype = input_tensor.dtype
|
||||
sequence_length = input_tensor.shape[1]
|
||||
if using_compilable_cache:
|
||||
target_length = past_key_values.get_max_cache_shape()
|
||||
else:
|
||||
target_length = (
|
||||
attention_mask.shape[-1]
|
||||
if isinstance(attention_mask, torch.Tensor)
|
||||
else past_seen_tokens + sequence_length + 1
|
||||
)
|
||||
|
||||
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
||||
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask,
|
||||
sequence_length=sequence_length,
|
||||
target_length=target_length,
|
||||
dtype=dtype,
|
||||
cache_position=cache_position,
|
||||
batch_size=input_tensor.shape[0],
|
||||
)
|
||||
|
||||
if (
|
||||
self.config._attn_implementation == "sdpa"
|
||||
and attention_mask is not None
|
||||
and attention_mask.device.type in ["cuda", "xpu", "npu"]
|
||||
and not output_attentions
|
||||
):
|
||||
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
||||
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
||||
# Details: https://github.com/pytorch/pytorch/issues/110213
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
||||
|
||||
return causal_mask
|
||||
|
||||
@staticmethod
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
target_length: int,
|
||||
dtype: torch.dtype,
|
||||
cache_position: torch.Tensor,
|
||||
batch_size: int,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
||||
|
||||
Args:
|
||||
attention_mask (`torch.Tensor`):
|
||||
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
|
||||
`(batch_size, 1, query_length, key_value_length)`.
|
||||
sequence_length (`int`):
|
||||
The sequence length being processed.
|
||||
target_length (`int`):
|
||||
The target length: when generating with static cache, the mask should be as long as the static cache,
|
||||
to account for the 0 padding, the part of the cache that is not filled yet.
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
Batch size.
|
||||
"""
|
||||
if attention_mask is not None and attention_mask.dim() == 4:
|
||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||
causal_mask = attention_mask
|
||||
else:
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = torch.full(
|
||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
|
||||
)
|
||||
if sequence_length != 1:
|
||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
|
||||
causal_mask.device
|
||||
)
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
|
||||
return causal_mask
|
||||
|
||||
|
||||
class BigBirdPegasusEncoder(BigBirdPegasusPreTrainedModel):
|
||||
"""
|
||||
@ -1914,7 +2056,9 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel):
|
||||
config.max_position_embeddings,
|
||||
config.d_model,
|
||||
)
|
||||
self.layers = nn.ModuleList([BigBirdPegasusDecoderLayer(config) for _ in range(config.decoder_layers)])
|
||||
self.layers = nn.ModuleList(
|
||||
[BigBirdPegasusDecoderLayer(config, layer_idx=i) for i in range(config.decoder_layers)]
|
||||
)
|
||||
self.layernorm_embedding = nn.LayerNorm(config.d_model)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
@ -1941,6 +2085,7 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.Tensor] = None,
|
||||
):
|
||||
r"""
|
||||
Args:
|
||||
@ -2006,6 +2151,9 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel):
|
||||
for more detail.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
||||
Indices depicting the position of the input sequence tokens in the sequence. It is used to update the
|
||||
cache in the correct position and to infer the complete sequence length.
|
||||
"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
@ -2014,42 +2162,6 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel):
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
input_shape = input_ids.size()
|
||||
input_ids = input_ids.view(-1, input_shape[-1])
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = inputs_embeds.size()[:-1]
|
||||
else:
|
||||
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
||||
|
||||
# past_key_values_length
|
||||
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
attention_mask = _prepare_4d_causal_attention_mask(
|
||||
attention_mask, input_shape, inputs_embeds, past_key_values_length
|
||||
)
|
||||
|
||||
# expand encoder attention mask
|
||||
if encoder_hidden_states is not None and encoder_attention_mask is not None:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
encoder_attention_mask = _prepare_4d_attention_mask(
|
||||
encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
|
||||
)
|
||||
|
||||
# embed positions
|
||||
positions = self.embed_positions(input_shape, past_key_values_length)
|
||||
positions = positions.to(inputs_embeds.device)
|
||||
|
||||
hidden_states = inputs_embeds + positions
|
||||
|
||||
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
@ -2057,11 +2169,73 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel):
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
||||
|
||||
if input_ids is not None:
|
||||
input_ids = input_ids.view(-1, input_ids.shape[-1])
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
# initialize `past_key_values`
|
||||
return_legacy_cache = False
|
||||
if use_cache and not isinstance(past_key_values, Cache):
|
||||
return_legacy_cache = True
|
||||
logger.warning_once(
|
||||
"Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. "
|
||||
"You should pass an instance of `EncoderDecoderCache` instead, e.g. "
|
||||
"`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`."
|
||||
)
|
||||
past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values)
|
||||
|
||||
batch_size, seq_length = inputs_embeds.size()[:-1]
|
||||
past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
if cache_position is None:
|
||||
cache_position = torch.arange(
|
||||
past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device
|
||||
)
|
||||
|
||||
if attention_mask is None and not is_torchdynamo_compiling():
|
||||
# required mask seq length can be calculated via length of past cache
|
||||
mask_seq_length = past_key_values_length + seq_length
|
||||
attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
|
||||
|
||||
self_attn_cache = (
|
||||
past_key_values.self_attention_cache
|
||||
if isinstance(past_key_values, EncoderDecoderCache)
|
||||
else past_key_values
|
||||
)
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask,
|
||||
inputs_embeds,
|
||||
cache_position,
|
||||
self_attn_cache,
|
||||
output_attentions,
|
||||
)
|
||||
|
||||
# expand encoder attention mask
|
||||
if encoder_hidden_states is not None and encoder_attention_mask is not None:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
encoder_attention_mask = _prepare_4d_attention_mask(
|
||||
encoder_attention_mask, inputs_embeds.dtype, tgt_len=seq_length
|
||||
)
|
||||
|
||||
# embed positions
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
position_ids = self.embed_positions(
|
||||
(batch_size, seq_length), past_key_values_length, position_ids=position_ids
|
||||
)
|
||||
position_ids = position_ids.to(inputs_embeds.device)
|
||||
hidden_states = inputs_embeds + position_ids
|
||||
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
|
||||
next_decoder_cache = () if use_cache else None
|
||||
next_decoder_cache = None
|
||||
|
||||
# check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
|
||||
for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
|
||||
@ -2080,13 +2254,11 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel):
|
||||
if dropout_probability < self.layerdrop:
|
||||
continue
|
||||
|
||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
causal_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
head_mask[idx] if head_mask is not None else None,
|
||||
@ -2094,25 +2266,27 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel):
|
||||
None,
|
||||
output_attentions,
|
||||
use_cache,
|
||||
cache_position,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
attention_mask=causal_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||
cross_attn_layer_head_mask=(
|
||||
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
|
||||
),
|
||||
past_key_value=past_key_value,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
|
||||
next_decoder_cache = layer_outputs[3 if output_attentions else 1]
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
@ -2127,6 +2301,9 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel):
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
next_cache = next_decoder_cache if use_cache else None
|
||||
if return_legacy_cache:
|
||||
next_cache = past_key_values.to_legacy_cache()
|
||||
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
@ -2198,6 +2375,7 @@ class BigBirdPegasusModel(BigBirdPegasusPreTrainedModel):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
) -> Union[Tuple, Seq2SeqModelOutput]:
|
||||
r"""
|
||||
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
|
||||
@ -2269,6 +2447,7 @@ class BigBirdPegasusModel(BigBirdPegasusPreTrainedModel):
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
@ -2359,6 +2538,7 @@ class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel, Gene
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
) -> Union[Tuple, Seq2SeqLMOutput]:
|
||||
r"""
|
||||
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
|
||||
@ -2432,6 +2612,7 @@ class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel, Gene
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
lm_logits = self.lm_head(outputs[0])
|
||||
@ -2514,6 +2695,7 @@ class BigBirdPegasusForSequenceClassification(BigBirdPegasusPreTrainedModel):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
) -> Union[Tuple, Seq2SeqSequenceClassifierOutput]:
|
||||
r"""
|
||||
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
|
||||
@ -2559,6 +2741,7 @@ class BigBirdPegasusForSequenceClassification(BigBirdPegasusPreTrainedModel):
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
hidden_states = outputs[0] # last hidden state
|
||||
|
||||
@ -2646,6 +2829,7 @@ class BigBirdPegasusForQuestionAnswering(BigBirdPegasusPreTrainedModel):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
) -> Union[Tuple, Seq2SeqQuestionAnsweringModelOutput]:
|
||||
r"""
|
||||
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
|
||||
@ -2683,6 +2867,7 @@ class BigBirdPegasusForQuestionAnswering(BigBirdPegasusPreTrainedModel):
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
@ -2794,6 +2979,7 @@ class BigBirdPegasusForCausalLM(BigBirdPegasusPreTrainedModel, GenerationMixin):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.Tensor] = None,
|
||||
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
|
||||
r"""
|
||||
cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
|
||||
@ -2842,6 +3028,7 @@ class BigBirdPegasusForCausalLM(BigBirdPegasusPreTrainedModel, GenerationMixin):
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
logits = self.lm_head(outputs[0])
|
||||
|
@ -23,8 +23,11 @@ from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, EncoderDecoderCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa
|
||||
from ...modeling_attn_mask_utils import (
|
||||
AttentionMaskConverter,
|
||||
)
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
CausalLMOutputWithCrossAttentions,
|
||||
@ -34,11 +37,19 @@ from ...modeling_outputs import (
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import (
|
||||
auto_docstring,
|
||||
is_torch_flex_attn_available,
|
||||
is_torchdynamo_compiling,
|
||||
logging,
|
||||
)
|
||||
from .configuration_biogpt import BioGptConfig
|
||||
|
||||
|
||||
if is_torch_flex_attn_available():
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
|
||||
from ...integrations.flex_attention import make_flex_block_causal_mask
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@ -55,17 +66,23 @@ class BioGptLearnedPositionalEmbedding(nn.Embedding):
|
||||
self.offset = 2
|
||||
super().__init__(num_embeddings + self.offset, embedding_dim)
|
||||
|
||||
def forward(self, attention_mask: torch.LongTensor, past_key_values_length: int = 0):
|
||||
def forward(
|
||||
self,
|
||||
attention_mask: torch.LongTensor,
|
||||
past_key_values_length: int = 0,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
):
|
||||
"""`input_ids_shape` is expected to be [bsz x seqlen]."""
|
||||
attention_mask = attention_mask.long()
|
||||
if position_ids is None:
|
||||
attention_mask = attention_mask.long()
|
||||
|
||||
# create positions depending on attention_mask
|
||||
positions = (torch.cumsum(attention_mask, dim=1).type_as(attention_mask) * attention_mask).long() - 1
|
||||
# create positions depending on attention_mask
|
||||
positions = (torch.cumsum(attention_mask, dim=1).type_as(attention_mask) * attention_mask).long() - 1
|
||||
|
||||
# cut positions if `past_key_values_length` is > 0
|
||||
positions = positions[:, past_key_values_length:]
|
||||
# cut positions if `past_key_values_length` is > 0
|
||||
position_ids = positions[:, past_key_values_length:]
|
||||
|
||||
return super().forward(positions + self.offset)
|
||||
return super().forward(position_ids + self.offset)
|
||||
|
||||
|
||||
# Copied from transformers.models.bart.modeling_bart.BartScaledWordEmbedding with Bart->BioGpt
|
||||
@ -95,6 +112,7 @@ class BioGptAttention(nn.Module):
|
||||
bias: bool = True,
|
||||
is_causal: bool = False,
|
||||
config: Optional[BioGptConfig] = None,
|
||||
layer_idx: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
@ -111,73 +129,74 @@ class BioGptAttention(nn.Module):
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.is_decoder = is_decoder
|
||||
self.is_causal = is_causal
|
||||
self.layer_idx = layer_idx
|
||||
if layer_idx is None and self.is_decoder:
|
||||
logger.warning_once(
|
||||
f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and "
|
||||
"will lead to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
|
||||
"when creating this class."
|
||||
)
|
||||
|
||||
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||
|
||||
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
key_value_states: Optional[torch.Tensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
cache_position: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
|
||||
# if key_value_states are provided this layer is used as a cross-attention layer
|
||||
# for the decoder
|
||||
is_cross_attention = key_value_states is not None
|
||||
|
||||
bsz, tgt_len, _ = hidden_states.size()
|
||||
|
||||
# get query proj
|
||||
query_states = self.q_proj(hidden_states) * self.scaling
|
||||
# get key, value proj
|
||||
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
|
||||
# is checking that the `sequence_length` of the `past_key_value` is the same as
|
||||
# the provided `key_value_states` to support prefix tuning
|
||||
if (
|
||||
is_cross_attention
|
||||
and past_key_value is not None
|
||||
and past_key_value[0].shape[2] == key_value_states.shape[1]
|
||||
):
|
||||
# reuse k,v, cross_attentions
|
||||
key_states = past_key_value[0]
|
||||
value_states = past_key_value[1]
|
||||
elif is_cross_attention:
|
||||
# cross_attentions
|
||||
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
|
||||
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
|
||||
elif past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
||||
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
else:
|
||||
# self_attention
|
||||
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
||||
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
||||
query_states = self.q_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
query_states = query_states * self.scaling
|
||||
|
||||
if self.is_decoder:
|
||||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||
# key/value_states (first "if" case)
|
||||
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||
past_key_value = (key_states, value_states)
|
||||
if past_key_value is not None:
|
||||
if isinstance(past_key_value, EncoderDecoderCache):
|
||||
is_updated = past_key_value.is_updated.get(self.layer_idx)
|
||||
if is_cross_attention:
|
||||
# after the first generated id, we can subsequently re-use all key/value_states from cache
|
||||
curr_past_key_value = past_key_value.cross_attention_cache
|
||||
else:
|
||||
curr_past_key_value = past_key_value.self_attention_cache
|
||||
else:
|
||||
curr_past_key_value = past_key_value
|
||||
|
||||
current_states = key_value_states if is_cross_attention else hidden_states
|
||||
if is_cross_attention and past_key_value is not None and is_updated:
|
||||
# reuse k,v, cross_attentions
|
||||
key_states = curr_past_key_value.key_cache[self.layer_idx]
|
||||
value_states = curr_past_key_value.value_cache[self.layer_idx]
|
||||
else:
|
||||
key_states = self.k_proj(current_states)
|
||||
value_states = self.v_proj(current_states)
|
||||
key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
if past_key_value is not None:
|
||||
# save all key/value_states to cache to be re-used for fast auto-regressive generation
|
||||
cache_position = cache_position if not is_cross_attention else None
|
||||
key_states, value_states = curr_past_key_value.update(
|
||||
key_states, value_states, self.layer_idx, {"cache_position": cache_position}
|
||||
)
|
||||
# set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
|
||||
if is_cross_attention:
|
||||
past_key_value.is_updated[self.layer_idx] = True
|
||||
|
||||
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
|
||||
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
|
||||
query_states = query_states.reshape(*proj_shape)
|
||||
key_states = key_states.reshape(*proj_shape)
|
||||
value_states = value_states.reshape(*proj_shape)
|
||||
|
||||
@ -191,10 +210,7 @@ class BioGptAttention(nn.Module):
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
attention_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||
|
||||
@ -247,16 +263,17 @@ class BioGptSdpaAttention(BioGptAttention):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
key_value_states: Optional[torch.Tensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
cache_position: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
if output_attentions or layer_head_mask is not None:
|
||||
if output_attentions:
|
||||
# TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented.
|
||||
logger.warning_once(
|
||||
"BioGptModel is using BioGptSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention"
|
||||
"BioGptModel is using BioGptSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` . Falling back to the manual attention"
|
||||
' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
return super().forward(
|
||||
@ -264,8 +281,8 @@ class BioGptSdpaAttention(BioGptAttention):
|
||||
key_value_states=key_value_states,
|
||||
past_key_value=past_key_value,
|
||||
attention_mask=attention_mask,
|
||||
layer_head_mask=layer_head_mask,
|
||||
output_attentions=output_attentions,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
# if key_value_states are provided this layer is used as a cross-attention layer
|
||||
@ -275,50 +292,55 @@ class BioGptSdpaAttention(BioGptAttention):
|
||||
bsz, tgt_len, _ = hidden_states.size()
|
||||
|
||||
# get query proj
|
||||
query_states = self.q_proj(hidden_states)
|
||||
# get key, value proj
|
||||
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
|
||||
# is checking that the `sequence_length` of the `past_key_value` is the same as
|
||||
# the provided `key_value_states` to support prefix tuning
|
||||
if (
|
||||
is_cross_attention
|
||||
and past_key_value is not None
|
||||
and past_key_value[0].shape[2] == key_value_states.shape[1]
|
||||
):
|
||||
query_states = self.q_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
if past_key_value is not None:
|
||||
if isinstance(past_key_value, EncoderDecoderCache):
|
||||
is_updated = past_key_value.is_updated.get(self.layer_idx)
|
||||
if is_cross_attention:
|
||||
# after the first generated id, we can subsequently re-use all key/value_states from cache
|
||||
curr_past_key_value = past_key_value.cross_attention_cache
|
||||
else:
|
||||
curr_past_key_value = past_key_value.self_attention_cache
|
||||
else:
|
||||
curr_past_key_value = past_key_value
|
||||
|
||||
current_states = key_value_states if is_cross_attention else hidden_states
|
||||
if is_cross_attention and past_key_value is not None and is_updated:
|
||||
# reuse k,v, cross_attentions
|
||||
key_states = past_key_value[0]
|
||||
value_states = past_key_value[1]
|
||||
elif is_cross_attention:
|
||||
# cross_attentions
|
||||
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
|
||||
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
|
||||
elif past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
||||
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
key_states = curr_past_key_value.key_cache[self.layer_idx]
|
||||
value_states = curr_past_key_value.value_cache[self.layer_idx]
|
||||
else:
|
||||
# self_attention
|
||||
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
||||
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
||||
key_states = self.k_proj(current_states)
|
||||
value_states = self.v_proj(current_states)
|
||||
key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
if self.is_decoder:
|
||||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||
# key/value_states (first "if" case)
|
||||
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||
past_key_value = (key_states, value_states)
|
||||
if past_key_value is not None:
|
||||
# save all key/value_states to cache to be re-used for fast auto-regressive generation
|
||||
cache_position = cache_position if not is_cross_attention else None
|
||||
key_states, value_states = curr_past_key_value.update(
|
||||
key_states, value_states, self.layer_idx, {"cache_position": cache_position}
|
||||
)
|
||||
# set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
|
||||
if is_cross_attention:
|
||||
past_key_value.is_updated[self.layer_idx] = True
|
||||
|
||||
query_states = self._shape(query_states, tgt_len, bsz)
|
||||
causal_mask = None
|
||||
if attention_mask is not None: # no matter the length, we just slice it
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
|
||||
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
||||
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
||||
if query_states.device.type == "cuda" and causal_mask is not None:
|
||||
query_states = query_states.contiguous()
|
||||
key_states = key_states.contiguous()
|
||||
value_states = value_states.contiguous()
|
||||
|
||||
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
||||
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
||||
# The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
|
||||
is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False
|
||||
is_causal = True if self.is_causal and causal_mask is None and tgt_len > 1 else False
|
||||
|
||||
# NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask,
|
||||
# but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577
|
||||
@ -326,23 +348,16 @@ class BioGptSdpaAttention(BioGptAttention):
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=attention_mask,
|
||||
attn_mask=causal_mask,
|
||||
dropout_p=self.dropout if self.training else 0.0,
|
||||
is_causal=is_causal,
|
||||
)
|
||||
|
||||
if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
|
||||
# partitioned across GPUs when using tensor-parallelism.
|
||||
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
|
||||
|
||||
attn_output = attn_output.view(bsz, tgt_len, self.embed_dim)
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
return attn_output, None, past_key_value
|
||||
@ -355,7 +370,7 @@ BIOGPT_ATTENTION_CLASSES = {
|
||||
|
||||
|
||||
class BioGptDecoderLayer(nn.Module):
|
||||
def __init__(self, config: BioGptConfig):
|
||||
def __init__(self, config: BioGptConfig, layer_idx: Optional[int] = None):
|
||||
super().__init__()
|
||||
self.embed_dim = config.hidden_size
|
||||
|
||||
@ -365,6 +380,7 @@ class BioGptDecoderLayer(nn.Module):
|
||||
dropout=config.attention_probs_dropout_prob,
|
||||
is_decoder=True,
|
||||
is_causal=True,
|
||||
layer_idx=layer_idx,
|
||||
)
|
||||
self.dropout = config.hidden_dropout_prob
|
||||
self.activation_fn = ACT2FN[config.hidden_act]
|
||||
@ -381,9 +397,10 @@ class BioGptDecoderLayer(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = True,
|
||||
cache_position: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
"""
|
||||
Args:
|
||||
@ -399,21 +416,22 @@ class BioGptDecoderLayer(nn.Module):
|
||||
use_cache (`bool`, *optional*):
|
||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
||||
(see `past_key_values`).
|
||||
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
||||
Indices depicting the position of the input sequence tokens in the sequence. It is used to update the
|
||||
cache in the correct position and to infer the complete sequence length.
|
||||
"""
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.self_attn_layer_norm(hidden_states)
|
||||
|
||||
# Self Attention
|
||||
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
||||
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
||||
# add present self-attn cache to positions 1,2 of present_key_value tuple
|
||||
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||
hidden_states, self_attn_weights, past_key_value = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
past_key_value=self_attn_past_key_value,
|
||||
past_key_value=past_key_value,
|
||||
attention_mask=attention_mask,
|
||||
layer_head_mask=layer_head_mask,
|
||||
output_attentions=output_attentions,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||
hidden_states = residual + hidden_states
|
||||
@ -434,7 +452,7 @@ class BioGptDecoderLayer(nn.Module):
|
||||
outputs += (self_attn_weights,)
|
||||
|
||||
if use_cache:
|
||||
outputs += (present_key_value,)
|
||||
outputs += (past_key_value,)
|
||||
|
||||
return outputs
|
||||
|
||||
@ -445,6 +463,8 @@ class BioGptPreTrainedModel(PreTrainedModel):
|
||||
base_model_prefix = "biogpt"
|
||||
supports_gradient_checkpointing = True
|
||||
_supports_sdpa = True
|
||||
_supports_cache_class = True
|
||||
_supports_static_cache = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
@ -462,6 +482,131 @@ class BioGptPreTrainedModel(PreTrainedModel):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
attention_mask: Union[torch.Tensor, "BlockMask"],
|
||||
input_tensor: torch.Tensor,
|
||||
cache_position: torch.Tensor,
|
||||
past_key_values: Cache,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
if attention_mask is not None and (attention_mask == 0.0).any():
|
||||
return attention_mask
|
||||
return None
|
||||
if self.config._attn_implementation == "flex_attention":
|
||||
if isinstance(attention_mask, torch.Tensor):
|
||||
attention_mask = make_flex_block_causal_mask(attention_mask)
|
||||
return attention_mask
|
||||
|
||||
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||
# to infer the attention mask.
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
|
||||
|
||||
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
||||
if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
|
||||
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
||||
attention_mask,
|
||||
inputs_embeds=input_tensor,
|
||||
past_key_values_length=past_seen_tokens,
|
||||
is_training=self.training,
|
||||
):
|
||||
return None
|
||||
|
||||
dtype = input_tensor.dtype
|
||||
sequence_length = input_tensor.shape[1]
|
||||
if using_compilable_cache:
|
||||
target_length = past_key_values.get_max_cache_shape()
|
||||
else:
|
||||
target_length = (
|
||||
attention_mask.shape[-1]
|
||||
if isinstance(attention_mask, torch.Tensor)
|
||||
else past_seen_tokens + sequence_length + 1
|
||||
)
|
||||
|
||||
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
||||
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask,
|
||||
sequence_length=sequence_length,
|
||||
target_length=target_length,
|
||||
dtype=dtype,
|
||||
cache_position=cache_position,
|
||||
batch_size=input_tensor.shape[0],
|
||||
)
|
||||
|
||||
if (
|
||||
self.config._attn_implementation == "sdpa"
|
||||
and attention_mask is not None
|
||||
and attention_mask.device.type in ["cuda", "xpu", "npu"]
|
||||
and not output_attentions
|
||||
):
|
||||
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
||||
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
||||
# Details: https://github.com/pytorch/pytorch/issues/110213
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
||||
|
||||
return causal_mask
|
||||
|
||||
@staticmethod
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
target_length: int,
|
||||
dtype: torch.dtype,
|
||||
cache_position: torch.Tensor,
|
||||
batch_size: int,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
||||
|
||||
Args:
|
||||
attention_mask (`torch.Tensor`):
|
||||
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
|
||||
`(batch_size, 1, query_length, key_value_length)`.
|
||||
sequence_length (`int`):
|
||||
The sequence length being processed.
|
||||
target_length (`int`):
|
||||
The target length: when generating with static cache, the mask should be as long as the static cache,
|
||||
to account for the 0 padding, the part of the cache that is not filled yet.
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
Batch size.
|
||||
"""
|
||||
if attention_mask is not None and attention_mask.dim() == 4:
|
||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||
causal_mask = attention_mask
|
||||
else:
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = torch.full(
|
||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
|
||||
)
|
||||
if sequence_length != 1:
|
||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
|
||||
causal_mask.device
|
||||
)
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
|
||||
return causal_mask
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class BioGptModel(BioGptPreTrainedModel):
|
||||
@ -479,7 +624,7 @@ class BioGptModel(BioGptPreTrainedModel):
|
||||
)
|
||||
self.embed_positions = BioGptLearnedPositionalEmbedding(config.max_position_embeddings, self.embed_dim)
|
||||
|
||||
self.layers = nn.ModuleList([BioGptDecoderLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
self.layers = nn.ModuleList([BioGptDecoderLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)])
|
||||
self.layer_norm = nn.LayerNorm(self.embed_dim)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
@ -502,9 +647,11 @@ class BioGptModel(BioGptPreTrainedModel):
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.Tensor] = None,
|
||||
**kwargs, # NOOP kwargs, for now
|
||||
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
@ -515,52 +662,14 @@ class BioGptModel(BioGptPreTrainedModel):
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
input = input_ids
|
||||
input_shape = input.size()
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = inputs_embeds.size()[:-1]
|
||||
input = inputs_embeds[:, :, -1]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
# past_key_values_length
|
||||
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
||||
if input_ids is not None:
|
||||
input_ids = input_ids.view(-1, input_ids.shape[-1])
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input)
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(
|
||||
(inputs_embeds.shape[0], inputs_embeds.shape[1] + past_key_values_length),
|
||||
dtype=torch.bool,
|
||||
device=inputs_embeds.device,
|
||||
)
|
||||
elif attention_mask.shape[1] != past_key_values_length + input_shape[1]:
|
||||
raise ValueError(
|
||||
f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be "
|
||||
f"{past_key_values_length + input_shape[1]} (sum of the lengths of current and past inputs)"
|
||||
)
|
||||
|
||||
# embed positions
|
||||
positions = self.embed_positions(attention_mask, past_key_values_length)
|
||||
|
||||
if self._use_sdpa and not output_attentions and head_mask is None:
|
||||
# output_attentions=True & head_mask can not be supported when using SDPA, fall back to
|
||||
# the manual implementation that requires a 4D causal mask in all cases.
|
||||
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
||||
attention_mask, input_shape, inputs_embeds, past_key_values_length
|
||||
)
|
||||
else:
|
||||
attention_mask = _prepare_4d_causal_attention_mask(
|
||||
attention_mask, input_shape, inputs_embeds, past_key_values_length
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds + positions
|
||||
|
||||
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
@ -569,10 +678,55 @@ class BioGptModel(BioGptPreTrainedModel):
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
# initialize past_key_values
|
||||
return_legacy_cache = False
|
||||
if use_cache and not isinstance(past_key_values, Cache):
|
||||
return_legacy_cache = True
|
||||
logger.warning_once(
|
||||
"Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. "
|
||||
"You should pass an instance of `EncoderDecoderCache` instead, e.g. "
|
||||
"`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`."
|
||||
)
|
||||
past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values)
|
||||
|
||||
batch_size, seq_length = inputs_embeds.size()[:-1]
|
||||
past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
if cache_position is None:
|
||||
cache_position = torch.arange(
|
||||
past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device
|
||||
)
|
||||
|
||||
if attention_mask is None and not is_torchdynamo_compiling():
|
||||
# required mask seq length can be calculated via length of past cache
|
||||
mask_seq_length = past_key_values_length + seq_length
|
||||
attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
|
||||
|
||||
self_attn_cache = (
|
||||
past_key_values.self_attention_cache
|
||||
if isinstance(past_key_values, EncoderDecoderCache)
|
||||
else past_key_values
|
||||
)
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask,
|
||||
inputs_embeds,
|
||||
cache_position,
|
||||
self_attn_cache,
|
||||
output_attentions,
|
||||
)
|
||||
|
||||
# embed positions
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
position_ids = self.embed_positions(attention_mask, past_key_values_length, position_ids=position_ids)
|
||||
|
||||
hidden_states = inputs_embeds + position_ids
|
||||
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
all_cross_attentions = None
|
||||
next_decoder_cache = () if use_cache else None
|
||||
next_decoder_cache = None
|
||||
|
||||
for idx, decoder_layer in enumerate(self.layers):
|
||||
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
||||
@ -583,32 +737,32 @@ class BioGptModel(BioGptPreTrainedModel):
|
||||
if dropout_probability < self.layerdrop:
|
||||
continue
|
||||
|
||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
causal_mask,
|
||||
head_mask[idx] if head_mask is not None else None,
|
||||
None,
|
||||
output_attentions,
|
||||
use_cache,
|
||||
cache_position,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
attention_mask=causal_mask,
|
||||
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||
past_key_value=past_key_value,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
||||
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
@ -620,6 +774,8 @@ class BioGptModel(BioGptPreTrainedModel):
|
||||
hidden_states = self.layer_norm(hidden_states)
|
||||
|
||||
next_cache = next_decoder_cache if use_cache else None
|
||||
if return_legacy_cache:
|
||||
next_cache = past_key_values.to_legacy_cache()
|
||||
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
@ -669,9 +825,11 @@ class BioGptForCausalLM(BioGptPreTrainedModel, GenerationMixin):
|
||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
|
||||
r"""
|
||||
@ -689,9 +847,11 @@ class BioGptForCausalLM(BioGptPreTrainedModel, GenerationMixin):
|
||||
inputs_embeds=inputs_embeds,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
position_ids=position_ids,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
|
@ -26,8 +26,12 @@ from torch import nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, EncoderDecoderCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
|
||||
from ...modeling_attn_mask_utils import (
|
||||
AttentionMaskConverter,
|
||||
_prepare_4d_attention_mask,
|
||||
)
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutput,
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
@ -36,11 +40,22 @@ from ...modeling_outputs import (
|
||||
Seq2SeqModelOutput,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import auto_docstring, logging
|
||||
from ...utils import (
|
||||
auto_docstring,
|
||||
is_torch_flex_attn_available,
|
||||
is_torchdynamo_compiling,
|
||||
logging,
|
||||
)
|
||||
from ..blenderbot_small import BlenderbotSmallForConditionalGeneration, BlenderbotSmallModel
|
||||
from .configuration_blenderbot import BlenderbotConfig
|
||||
|
||||
|
||||
if is_torch_flex_attn_available():
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
|
||||
from ...integrations.flex_attention import make_flex_block_causal_mask
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@ -69,13 +84,16 @@ class BlenderbotLearnedPositionalEmbedding(nn.Embedding):
|
||||
def __init__(self, num_embeddings: int, embedding_dim: int):
|
||||
super().__init__(num_embeddings, embedding_dim)
|
||||
|
||||
def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0):
|
||||
def forward(
|
||||
self, input_ids_shape: torch.Size, past_key_values_length: int = 0, position_ids: Optional[torch.Tensor] = None
|
||||
):
|
||||
"""`input_ids_shape` is expected to be [bsz x seqlen]."""
|
||||
bsz, seq_len = input_ids_shape[:2]
|
||||
positions = torch.arange(
|
||||
past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
|
||||
)
|
||||
return super().forward(positions)
|
||||
if position_ids is None:
|
||||
bsz, seq_len = input_ids_shape[:2]
|
||||
position_ids = torch.arange(
|
||||
past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
|
||||
)
|
||||
return super().forward(position_ids)
|
||||
|
||||
|
||||
# Copied from transformers.models.bart.modeling_bart.BartScaledWordEmbedding with Bart->Blenderbot
|
||||
@ -105,6 +123,7 @@ class BlenderbotAttention(nn.Module):
|
||||
bias: bool = True,
|
||||
is_causal: bool = False,
|
||||
config: Optional[BlenderbotConfig] = None,
|
||||
layer_idx: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
@ -121,73 +140,74 @@ class BlenderbotAttention(nn.Module):
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.is_decoder = is_decoder
|
||||
self.is_causal = is_causal
|
||||
self.layer_idx = layer_idx
|
||||
if layer_idx is None and self.is_decoder:
|
||||
logger.warning_once(
|
||||
f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and "
|
||||
"will lead to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
|
||||
"when creating this class."
|
||||
)
|
||||
|
||||
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||
|
||||
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
key_value_states: Optional[torch.Tensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
cache_position: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
|
||||
# if key_value_states are provided this layer is used as a cross-attention layer
|
||||
# for the decoder
|
||||
is_cross_attention = key_value_states is not None
|
||||
|
||||
bsz, tgt_len, _ = hidden_states.size()
|
||||
|
||||
# get query proj
|
||||
query_states = self.q_proj(hidden_states) * self.scaling
|
||||
# get key, value proj
|
||||
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
|
||||
# is checking that the `sequence_length` of the `past_key_value` is the same as
|
||||
# the provided `key_value_states` to support prefix tuning
|
||||
if (
|
||||
is_cross_attention
|
||||
and past_key_value is not None
|
||||
and past_key_value[0].shape[2] == key_value_states.shape[1]
|
||||
):
|
||||
# reuse k,v, cross_attentions
|
||||
key_states = past_key_value[0]
|
||||
value_states = past_key_value[1]
|
||||
elif is_cross_attention:
|
||||
# cross_attentions
|
||||
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
|
||||
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
|
||||
elif past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
||||
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
else:
|
||||
# self_attention
|
||||
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
||||
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
||||
query_states = self.q_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
query_states = query_states * self.scaling
|
||||
|
||||
if self.is_decoder:
|
||||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||
# key/value_states (first "if" case)
|
||||
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||
past_key_value = (key_states, value_states)
|
||||
if past_key_value is not None:
|
||||
if isinstance(past_key_value, EncoderDecoderCache):
|
||||
is_updated = past_key_value.is_updated.get(self.layer_idx)
|
||||
if is_cross_attention:
|
||||
# after the first generated id, we can subsequently re-use all key/value_states from cache
|
||||
curr_past_key_value = past_key_value.cross_attention_cache
|
||||
else:
|
||||
curr_past_key_value = past_key_value.self_attention_cache
|
||||
else:
|
||||
curr_past_key_value = past_key_value
|
||||
|
||||
current_states = key_value_states if is_cross_attention else hidden_states
|
||||
if is_cross_attention and past_key_value is not None and is_updated:
|
||||
# reuse k,v, cross_attentions
|
||||
key_states = curr_past_key_value.key_cache[self.layer_idx]
|
||||
value_states = curr_past_key_value.value_cache[self.layer_idx]
|
||||
else:
|
||||
key_states = self.k_proj(current_states)
|
||||
value_states = self.v_proj(current_states)
|
||||
key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
if past_key_value is not None:
|
||||
# save all key/value_states to cache to be re-used for fast auto-regressive generation
|
||||
cache_position = cache_position if not is_cross_attention else None
|
||||
key_states, value_states = curr_past_key_value.update(
|
||||
key_states, value_states, self.layer_idx, {"cache_position": cache_position}
|
||||
)
|
||||
# set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
|
||||
if is_cross_attention:
|
||||
past_key_value.is_updated[self.layer_idx] = True
|
||||
|
||||
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
|
||||
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
|
||||
query_states = query_states.reshape(*proj_shape)
|
||||
key_states = key_states.reshape(*proj_shape)
|
||||
value_states = value_states.reshape(*proj_shape)
|
||||
|
||||
@ -201,10 +221,7 @@ class BlenderbotAttention(nn.Module):
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
attention_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||
|
||||
@ -325,7 +342,7 @@ class BlenderbotEncoderLayer(nn.Module):
|
||||
|
||||
# Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->Blenderbot, MBART->BLENDERBOT
|
||||
class BlenderbotDecoderLayer(nn.Module):
|
||||
def __init__(self, config: BlenderbotConfig):
|
||||
def __init__(self, config: BlenderbotConfig, layer_idx: Optional[int] = None):
|
||||
super().__init__()
|
||||
self.embed_dim = config.d_model
|
||||
|
||||
@ -336,6 +353,7 @@ class BlenderbotDecoderLayer(nn.Module):
|
||||
is_decoder=True,
|
||||
is_causal=True,
|
||||
config=config,
|
||||
layer_idx=layer_idx,
|
||||
)
|
||||
self.dropout = config.dropout
|
||||
self.activation_fn = ACT2FN[config.activation_function]
|
||||
@ -348,6 +366,7 @@ class BlenderbotDecoderLayer(nn.Module):
|
||||
dropout=config.attention_dropout,
|
||||
is_decoder=True,
|
||||
config=config,
|
||||
layer_idx=layer_idx,
|
||||
)
|
||||
self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
||||
self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
|
||||
@ -362,9 +381,10 @@ class BlenderbotDecoderLayer(nn.Module):
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = True,
|
||||
cache_position: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
@ -383,47 +403,43 @@ class BlenderbotDecoderLayer(nn.Module):
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||
returned tensors for more detail.
|
||||
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
||||
Indices depicting the position of the input sequence tokens in the sequence. It is used to update the
|
||||
cache in the correct position and to infer the complete sequence length.
|
||||
"""
|
||||
residual = hidden_states
|
||||
hidden_states = self.self_attn_layer_norm(hidden_states)
|
||||
|
||||
# Self Attention
|
||||
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
||||
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
||||
# add present self-attn cache to positions 1,2 of present_key_value tuple
|
||||
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||
hidden_states, self_attn_weights, past_key_value = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
past_key_value=self_attn_past_key_value,
|
||||
past_key_value=past_key_value,
|
||||
attention_mask=attention_mask,
|
||||
layer_head_mask=layer_head_mask,
|
||||
output_attentions=output_attentions,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Cross-Attention Block
|
||||
cross_attn_present_key_value = None
|
||||
cross_attn_weights = None
|
||||
if encoder_hidden_states is not None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.encoder_attn_layer_norm(hidden_states)
|
||||
|
||||
# cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
|
||||
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
|
||||
hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
|
||||
hidden_states, cross_attn_weights, past_key_value = self.encoder_attn(
|
||||
hidden_states=hidden_states,
|
||||
key_value_states=encoder_hidden_states,
|
||||
attention_mask=encoder_attention_mask,
|
||||
layer_head_mask=cross_attn_layer_head_mask,
|
||||
past_key_value=cross_attn_past_key_value,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# add cross-attn to positions 3,4 of present_key_value tuple
|
||||
present_key_value = present_key_value + cross_attn_present_key_value
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = self.final_layer_norm(hidden_states)
|
||||
@ -439,7 +455,7 @@ class BlenderbotDecoderLayer(nn.Module):
|
||||
outputs += (self_attn_weights, cross_attn_weights)
|
||||
|
||||
if use_cache:
|
||||
outputs += (present_key_value,)
|
||||
outputs += (past_key_value,)
|
||||
|
||||
return outputs
|
||||
|
||||
@ -449,6 +465,8 @@ class BlenderbotPreTrainedModel(PreTrainedModel):
|
||||
config_class = BlenderbotConfig
|
||||
base_model_prefix = "model"
|
||||
supports_gradient_checkpointing = True
|
||||
_supports_cache_class = True
|
||||
_supports_static_cache = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.init_std
|
||||
@ -460,6 +478,9 @@ class BlenderbotPreTrainedModel(PreTrainedModel):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.data.zero_()
|
||||
|
||||
@property
|
||||
def dummy_inputs(self):
|
||||
@ -472,6 +493,131 @@ class BlenderbotPreTrainedModel(PreTrainedModel):
|
||||
}
|
||||
return dummy_inputs
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
attention_mask: Union[torch.Tensor, "BlockMask"],
|
||||
input_tensor: torch.Tensor,
|
||||
cache_position: torch.Tensor,
|
||||
past_key_values: Cache,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
if attention_mask is not None and (attention_mask == 0.0).any():
|
||||
return attention_mask
|
||||
return None
|
||||
if self.config._attn_implementation == "flex_attention":
|
||||
if isinstance(attention_mask, torch.Tensor):
|
||||
attention_mask = make_flex_block_causal_mask(attention_mask)
|
||||
return attention_mask
|
||||
|
||||
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||
# to infer the attention mask.
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
|
||||
|
||||
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
||||
if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
|
||||
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
||||
attention_mask,
|
||||
inputs_embeds=input_tensor,
|
||||
past_key_values_length=past_seen_tokens,
|
||||
is_training=self.training,
|
||||
):
|
||||
return None
|
||||
|
||||
dtype = input_tensor.dtype
|
||||
sequence_length = input_tensor.shape[1]
|
||||
if using_compilable_cache:
|
||||
target_length = past_key_values.get_max_cache_shape()
|
||||
else:
|
||||
target_length = (
|
||||
attention_mask.shape[-1]
|
||||
if isinstance(attention_mask, torch.Tensor)
|
||||
else past_seen_tokens + sequence_length + 1
|
||||
)
|
||||
|
||||
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
||||
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask,
|
||||
sequence_length=sequence_length,
|
||||
target_length=target_length,
|
||||
dtype=dtype,
|
||||
cache_position=cache_position,
|
||||
batch_size=input_tensor.shape[0],
|
||||
)
|
||||
|
||||
if (
|
||||
self.config._attn_implementation == "sdpa"
|
||||
and attention_mask is not None
|
||||
and attention_mask.device.type in ["cuda", "xpu", "npu"]
|
||||
and not output_attentions
|
||||
):
|
||||
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
||||
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
||||
# Details: https://github.com/pytorch/pytorch/issues/110213
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
||||
|
||||
return causal_mask
|
||||
|
||||
@staticmethod
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
target_length: int,
|
||||
dtype: torch.dtype,
|
||||
cache_position: torch.Tensor,
|
||||
batch_size: int,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
||||
|
||||
Args:
|
||||
attention_mask (`torch.Tensor`):
|
||||
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
|
||||
`(batch_size, 1, query_length, key_value_length)`.
|
||||
sequence_length (`int`):
|
||||
The sequence length being processed.
|
||||
target_length (`int`):
|
||||
The target length: when generating with static cache, the mask should be as long as the static cache,
|
||||
to account for the 0 padding, the part of the cache that is not filled yet.
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
Batch size.
|
||||
"""
|
||||
if attention_mask is not None and attention_mask.dim() == 4:
|
||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||
causal_mask = attention_mask
|
||||
else:
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = torch.full(
|
||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
|
||||
)
|
||||
if sequence_length != 1:
|
||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
|
||||
causal_mask.device
|
||||
)
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
|
||||
return causal_mask
|
||||
|
||||
|
||||
class BlenderbotEncoder(BlenderbotPreTrainedModel):
|
||||
"""
|
||||
@ -674,7 +820,9 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
|
||||
config.max_position_embeddings,
|
||||
config.d_model,
|
||||
)
|
||||
self.layers = nn.ModuleList([BlenderbotDecoderLayer(config) for _ in range(config.decoder_layers)])
|
||||
self.layers = nn.ModuleList(
|
||||
[BlenderbotDecoderLayer(config, layer_idx=i) for i in range(config.decoder_layers)]
|
||||
)
|
||||
self.layer_norm = nn.LayerNorm(config.d_model)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
@ -701,6 +849,7 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
cache_position: Optional[torch.Tensor] = None,
|
||||
):
|
||||
r"""
|
||||
Args:
|
||||
@ -767,6 +916,9 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
|
||||
for more detail.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
||||
Indices depicting the position of the input sequence tokens in the sequence. It is used to update the
|
||||
cache in the correct position and to infer the complete sequence length.
|
||||
"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
@ -775,52 +927,79 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
input_shape = input_ids.size()
|
||||
input_ids = input_ids.view(-1, input_shape[-1])
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = inputs_embeds.size()[:-1]
|
||||
else:
|
||||
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
||||
|
||||
# past_key_values_length
|
||||
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
attention_mask = _prepare_4d_causal_attention_mask(
|
||||
attention_mask, input_shape, inputs_embeds, past_key_values_length
|
||||
)
|
||||
|
||||
# expand encoder attention mask
|
||||
if encoder_hidden_states is not None and encoder_attention_mask is not None:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
encoder_attention_mask = _prepare_4d_attention_mask(
|
||||
encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
|
||||
)
|
||||
|
||||
# embed positions
|
||||
positions = self.embed_positions(input_shape, past_key_values_length)
|
||||
|
||||
hidden_states = inputs_embeds + positions
|
||||
|
||||
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
## retrieve input_ids and inputs_embeds
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
||||
|
||||
if input_ids is not None:
|
||||
input_ids = input_ids.view(-1, input_ids.shape[-1])
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
# initialize `past_key_values`
|
||||
return_legacy_cache = False
|
||||
if use_cache and not isinstance(past_key_values, Cache):
|
||||
return_legacy_cache = True
|
||||
logger.warning_once(
|
||||
"Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. "
|
||||
"You should pass an instance of `EncoderDecoderCache` instead, e.g. "
|
||||
"`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`."
|
||||
)
|
||||
past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values)
|
||||
|
||||
batch_size, seq_length = inputs_embeds.size()[:-1]
|
||||
past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
if cache_position is None:
|
||||
cache_position = torch.arange(
|
||||
past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device
|
||||
)
|
||||
|
||||
if attention_mask is None and not is_torchdynamo_compiling():
|
||||
# required mask seq length can be calculated via length of past cache
|
||||
mask_seq_length = past_key_values_length + seq_length
|
||||
attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
|
||||
|
||||
self_attn_cache = (
|
||||
past_key_values.self_attention_cache
|
||||
if isinstance(past_key_values, EncoderDecoderCache)
|
||||
else past_key_values
|
||||
)
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask,
|
||||
inputs_embeds,
|
||||
cache_position,
|
||||
self_attn_cache,
|
||||
output_attentions,
|
||||
)
|
||||
|
||||
# expand encoder attention mask
|
||||
if encoder_hidden_states is not None and encoder_attention_mask is not None:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
encoder_attention_mask = _prepare_4d_attention_mask(
|
||||
encoder_attention_mask, inputs_embeds.dtype, tgt_len=seq_length
|
||||
)
|
||||
|
||||
# embed positions
|
||||
position_ids = self.embed_positions(
|
||||
(batch_size, seq_length), past_key_values_length, position_ids=cache_position
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds + position_ids
|
||||
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
|
||||
next_decoder_cache = () if use_cache else None
|
||||
next_decoder_cache = None
|
||||
|
||||
# check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
|
||||
for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
|
||||
@ -839,13 +1018,11 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
|
||||
if dropout_probability < self.layerdrop:
|
||||
continue
|
||||
|
||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
causal_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
head_mask[idx] if head_mask is not None else None,
|
||||
@ -853,25 +1030,27 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
|
||||
None,
|
||||
output_attentions,
|
||||
use_cache,
|
||||
cache_position,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
attention_mask=causal_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||
cross_attn_layer_head_mask=(
|
||||
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
|
||||
),
|
||||
past_key_value=past_key_value,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
|
||||
next_decoder_cache = layer_outputs[3 if output_attentions else 1]
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
@ -887,6 +1066,9 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
next_cache = next_decoder_cache if use_cache else None
|
||||
if return_legacy_cache:
|
||||
next_cache = past_key_values.to_legacy_cache()
|
||||
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
@ -963,6 +1145,7 @@ class BlenderbotModel(BlenderbotPreTrainedModel):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.Tensor] = None,
|
||||
) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]:
|
||||
r"""
|
||||
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
|
||||
@ -1041,6 +1224,7 @@ class BlenderbotModel(BlenderbotPreTrainedModel):
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
@ -1139,6 +1323,7 @@ class BlenderbotForConditionalGeneration(BlenderbotPreTrainedModel, GenerationMi
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.Tensor] = None,
|
||||
) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
|
||||
r"""
|
||||
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
|
||||
@ -1225,6 +1410,7 @@ class BlenderbotForConditionalGeneration(BlenderbotPreTrainedModel, GenerationMi
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias
|
||||
|
||||
@ -1326,6 +1512,7 @@ class BlenderbotForCausalLM(BlenderbotPreTrainedModel, GenerationMixin):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
|
||||
r"""
|
||||
cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
|
||||
@ -1375,6 +1562,7 @@ class BlenderbotForCausalLM(BlenderbotPreTrainedModel, GenerationMixin):
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
logits = self.lm_head(outputs[0])
|
||||
|
@ -24,8 +24,12 @@ from torch import nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, EncoderDecoderCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
|
||||
from ...modeling_attn_mask_utils import (
|
||||
AttentionMaskConverter,
|
||||
_prepare_4d_attention_mask,
|
||||
)
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutput,
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
@ -34,10 +38,21 @@ from ...modeling_outputs import (
|
||||
Seq2SeqModelOutput,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import auto_docstring, logging
|
||||
from ...utils import (
|
||||
auto_docstring,
|
||||
is_torch_flex_attn_available,
|
||||
is_torchdynamo_compiling,
|
||||
logging,
|
||||
)
|
||||
from .configuration_blenderbot_small import BlenderbotSmallConfig
|
||||
|
||||
|
||||
if is_torch_flex_attn_available():
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
|
||||
from ...integrations.flex_attention import make_flex_block_causal_mask
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@ -67,13 +82,16 @@ class BlenderbotSmallLearnedPositionalEmbedding(nn.Embedding):
|
||||
def __init__(self, num_embeddings: int, embedding_dim: int):
|
||||
super().__init__(num_embeddings, embedding_dim)
|
||||
|
||||
def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0):
|
||||
def forward(
|
||||
self, input_ids_shape: torch.Size, past_key_values_length: int = 0, position_ids: Optional[torch.Tensor] = None
|
||||
):
|
||||
"""`input_ids_shape` is expected to be [bsz x seqlen]."""
|
||||
bsz, seq_len = input_ids_shape[:2]
|
||||
positions = torch.arange(
|
||||
past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
|
||||
)
|
||||
return super().forward(positions)
|
||||
if position_ids is None:
|
||||
bsz, seq_len = input_ids_shape[:2]
|
||||
position_ids = torch.arange(
|
||||
past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
|
||||
)
|
||||
return super().forward(position_ids)
|
||||
|
||||
|
||||
# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->BlenderbotSmall
|
||||
@ -89,6 +107,7 @@ class BlenderbotSmallAttention(nn.Module):
|
||||
bias: bool = True,
|
||||
is_causal: bool = False,
|
||||
config: Optional[BlenderbotSmallConfig] = None,
|
||||
layer_idx: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
@ -105,73 +124,74 @@ class BlenderbotSmallAttention(nn.Module):
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.is_decoder = is_decoder
|
||||
self.is_causal = is_causal
|
||||
self.layer_idx = layer_idx
|
||||
if layer_idx is None and self.is_decoder:
|
||||
logger.warning_once(
|
||||
f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and "
|
||||
"will lead to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
|
||||
"when creating this class."
|
||||
)
|
||||
|
||||
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||
|
||||
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
key_value_states: Optional[torch.Tensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
cache_position: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
|
||||
# if key_value_states are provided this layer is used as a cross-attention layer
|
||||
# for the decoder
|
||||
is_cross_attention = key_value_states is not None
|
||||
|
||||
bsz, tgt_len, _ = hidden_states.size()
|
||||
|
||||
# get query proj
|
||||
query_states = self.q_proj(hidden_states) * self.scaling
|
||||
# get key, value proj
|
||||
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
|
||||
# is checking that the `sequence_length` of the `past_key_value` is the same as
|
||||
# the provided `key_value_states` to support prefix tuning
|
||||
if (
|
||||
is_cross_attention
|
||||
and past_key_value is not None
|
||||
and past_key_value[0].shape[2] == key_value_states.shape[1]
|
||||
):
|
||||
# reuse k,v, cross_attentions
|
||||
key_states = past_key_value[0]
|
||||
value_states = past_key_value[1]
|
||||
elif is_cross_attention:
|
||||
# cross_attentions
|
||||
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
|
||||
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
|
||||
elif past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
||||
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
else:
|
||||
# self_attention
|
||||
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
||||
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
||||
query_states = self.q_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
query_states = query_states * self.scaling
|
||||
|
||||
if self.is_decoder:
|
||||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||
# key/value_states (first "if" case)
|
||||
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||
past_key_value = (key_states, value_states)
|
||||
if past_key_value is not None:
|
||||
if isinstance(past_key_value, EncoderDecoderCache):
|
||||
is_updated = past_key_value.is_updated.get(self.layer_idx)
|
||||
if is_cross_attention:
|
||||
# after the first generated id, we can subsequently re-use all key/value_states from cache
|
||||
curr_past_key_value = past_key_value.cross_attention_cache
|
||||
else:
|
||||
curr_past_key_value = past_key_value.self_attention_cache
|
||||
else:
|
||||
curr_past_key_value = past_key_value
|
||||
|
||||
current_states = key_value_states if is_cross_attention else hidden_states
|
||||
if is_cross_attention and past_key_value is not None and is_updated:
|
||||
# reuse k,v, cross_attentions
|
||||
key_states = curr_past_key_value.key_cache[self.layer_idx]
|
||||
value_states = curr_past_key_value.value_cache[self.layer_idx]
|
||||
else:
|
||||
key_states = self.k_proj(current_states)
|
||||
value_states = self.v_proj(current_states)
|
||||
key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
if past_key_value is not None:
|
||||
# save all key/value_states to cache to be re-used for fast auto-regressive generation
|
||||
cache_position = cache_position if not is_cross_attention else None
|
||||
key_states, value_states = curr_past_key_value.update(
|
||||
key_states, value_states, self.layer_idx, {"cache_position": cache_position}
|
||||
)
|
||||
# set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
|
||||
if is_cross_attention:
|
||||
past_key_value.is_updated[self.layer_idx] = True
|
||||
|
||||
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
|
||||
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
|
||||
query_states = query_states.reshape(*proj_shape)
|
||||
key_states = key_states.reshape(*proj_shape)
|
||||
value_states = value_states.reshape(*proj_shape)
|
||||
|
||||
@ -185,10 +205,7 @@ class BlenderbotSmallAttention(nn.Module):
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
attention_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||
|
||||
@ -237,7 +254,7 @@ class BlenderbotSmallAttention(nn.Module):
|
||||
|
||||
# Copied from transformers.models.bart.modeling_bart.BartEncoderLayer with Bart->BlenderbotSmall, BART->BLENDERBOT_SMALL
|
||||
class BlenderbotSmallEncoderLayer(nn.Module):
|
||||
def __init__(self, config: BlenderbotSmallConfig):
|
||||
def __init__(self, config: BlenderbotSmallConfig, layer_idx: Optional[int] = None):
|
||||
super().__init__()
|
||||
self.embed_dim = config.d_model
|
||||
|
||||
@ -246,6 +263,7 @@ class BlenderbotSmallEncoderLayer(nn.Module):
|
||||
num_heads=config.encoder_attention_heads,
|
||||
dropout=config.attention_dropout,
|
||||
config=config,
|
||||
layer_idx=layer_idx,
|
||||
)
|
||||
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
||||
self.dropout = config.dropout
|
||||
@ -314,7 +332,7 @@ BLENDERBOT_SMALL_ATTENTION_CLASSES = {
|
||||
|
||||
# Copied from transformers.models.bart.modeling_bart.BartDecoderLayer with Bart->BlenderbotSmall, BART->BLENDERBOT_SMALL
|
||||
class BlenderbotSmallDecoderLayer(nn.Module):
|
||||
def __init__(self, config: BlenderbotSmallConfig):
|
||||
def __init__(self, config: BlenderbotSmallConfig, layer_idx: Optional[int] = None):
|
||||
super().__init__()
|
||||
self.embed_dim = config.d_model
|
||||
|
||||
@ -325,6 +343,7 @@ class BlenderbotSmallDecoderLayer(nn.Module):
|
||||
is_decoder=True,
|
||||
is_causal=True,
|
||||
config=config,
|
||||
layer_idx=layer_idx,
|
||||
)
|
||||
self.dropout = config.dropout
|
||||
self.activation_fn = ACT2FN[config.activation_function]
|
||||
@ -337,6 +356,7 @@ class BlenderbotSmallDecoderLayer(nn.Module):
|
||||
dropout=config.attention_dropout,
|
||||
is_decoder=True,
|
||||
config=config,
|
||||
layer_idx=layer_idx,
|
||||
)
|
||||
self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
||||
self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
|
||||
@ -351,9 +371,10 @@ class BlenderbotSmallDecoderLayer(nn.Module):
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = True,
|
||||
cache_position: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
"""
|
||||
Args:
|
||||
@ -372,47 +393,42 @@ class BlenderbotSmallDecoderLayer(nn.Module):
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||
returned tensors for more detail.
|
||||
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
||||
Indices depicting the position of the input sequence tokens in the sequence. It is used to update the
|
||||
cache in the correct position and to infer the complete sequence length.
|
||||
"""
|
||||
residual = hidden_states
|
||||
|
||||
# Self Attention
|
||||
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
||||
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
||||
# add present self-attn cache to positions 1,2 of present_key_value tuple
|
||||
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||
hidden_states, self_attn_weights, past_key_value = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
past_key_value=self_attn_past_key_value,
|
||||
past_key_value=past_key_value,
|
||||
attention_mask=attention_mask,
|
||||
layer_head_mask=layer_head_mask,
|
||||
output_attentions=output_attentions,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||
hidden_states = residual + hidden_states
|
||||
hidden_states = self.self_attn_layer_norm(hidden_states)
|
||||
|
||||
# Cross-Attention Block
|
||||
cross_attn_present_key_value = None
|
||||
cross_attn_weights = None
|
||||
if encoder_hidden_states is not None:
|
||||
residual = hidden_states
|
||||
|
||||
# cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
|
||||
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
|
||||
hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
|
||||
hidden_states, cross_attn_weights, past_key_value = self.encoder_attn(
|
||||
hidden_states=hidden_states,
|
||||
key_value_states=encoder_hidden_states,
|
||||
attention_mask=encoder_attention_mask,
|
||||
layer_head_mask=cross_attn_layer_head_mask,
|
||||
past_key_value=cross_attn_past_key_value,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||
hidden_states = residual + hidden_states
|
||||
hidden_states = self.encoder_attn_layer_norm(hidden_states)
|
||||
|
||||
# add cross-attn to positions 3,4 of present_key_value tuple
|
||||
present_key_value = present_key_value + cross_attn_present_key_value
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = self.activation_fn(self.fc1(hidden_states))
|
||||
@ -428,7 +444,7 @@ class BlenderbotSmallDecoderLayer(nn.Module):
|
||||
outputs += (self_attn_weights, cross_attn_weights)
|
||||
|
||||
if use_cache:
|
||||
outputs += (present_key_value,)
|
||||
outputs += (past_key_value,)
|
||||
|
||||
return outputs
|
||||
|
||||
@ -438,6 +454,8 @@ class BlenderbotSmallPreTrainedModel(PreTrainedModel):
|
||||
config_class = BlenderbotSmallConfig
|
||||
base_model_prefix = "model"
|
||||
supports_gradient_checkpointing = True
|
||||
_supports_cache_class = True
|
||||
_supports_static_cache = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.init_std
|
||||
@ -449,6 +467,9 @@ class BlenderbotSmallPreTrainedModel(PreTrainedModel):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.data.zero_()
|
||||
|
||||
@property
|
||||
def dummy_inputs(self):
|
||||
@ -461,6 +482,131 @@ class BlenderbotSmallPreTrainedModel(PreTrainedModel):
|
||||
}
|
||||
return dummy_inputs
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
attention_mask: Union[torch.Tensor, "BlockMask"],
|
||||
input_tensor: torch.Tensor,
|
||||
cache_position: torch.Tensor,
|
||||
past_key_values: Cache,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
if attention_mask is not None and (attention_mask == 0.0).any():
|
||||
return attention_mask
|
||||
return None
|
||||
if self.config._attn_implementation == "flex_attention":
|
||||
if isinstance(attention_mask, torch.Tensor):
|
||||
attention_mask = make_flex_block_causal_mask(attention_mask)
|
||||
return attention_mask
|
||||
|
||||
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||
# to infer the attention mask.
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
|
||||
|
||||
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
||||
if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
|
||||
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
||||
attention_mask,
|
||||
inputs_embeds=input_tensor,
|
||||
past_key_values_length=past_seen_tokens,
|
||||
is_training=self.training,
|
||||
):
|
||||
return None
|
||||
|
||||
dtype = input_tensor.dtype
|
||||
sequence_length = input_tensor.shape[1]
|
||||
if using_compilable_cache:
|
||||
target_length = past_key_values.get_max_cache_shape()
|
||||
else:
|
||||
target_length = (
|
||||
attention_mask.shape[-1]
|
||||
if isinstance(attention_mask, torch.Tensor)
|
||||
else past_seen_tokens + sequence_length + 1
|
||||
)
|
||||
|
||||
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
||||
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask,
|
||||
sequence_length=sequence_length,
|
||||
target_length=target_length,
|
||||
dtype=dtype,
|
||||
cache_position=cache_position,
|
||||
batch_size=input_tensor.shape[0],
|
||||
)
|
||||
|
||||
if (
|
||||
self.config._attn_implementation == "sdpa"
|
||||
and attention_mask is not None
|
||||
and attention_mask.device.type in ["cuda", "xpu", "npu"]
|
||||
and not output_attentions
|
||||
):
|
||||
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
||||
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
||||
# Details: https://github.com/pytorch/pytorch/issues/110213
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
||||
|
||||
return causal_mask
|
||||
|
||||
@staticmethod
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
target_length: int,
|
||||
dtype: torch.dtype,
|
||||
cache_position: torch.Tensor,
|
||||
batch_size: int,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
||||
|
||||
Args:
|
||||
attention_mask (`torch.Tensor`):
|
||||
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
|
||||
`(batch_size, 1, query_length, key_value_length)`.
|
||||
sequence_length (`int`):
|
||||
The sequence length being processed.
|
||||
target_length (`int`):
|
||||
The target length: when generating with static cache, the mask should be as long as the static cache,
|
||||
to account for the 0 padding, the part of the cache that is not filled yet.
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
Batch size.
|
||||
"""
|
||||
if attention_mask is not None and attention_mask.dim() == 4:
|
||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||
causal_mask = attention_mask
|
||||
else:
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = torch.full(
|
||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
|
||||
)
|
||||
if sequence_length != 1:
|
||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
|
||||
causal_mask.device
|
||||
)
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
|
||||
return causal_mask
|
||||
|
||||
|
||||
class BlenderbotSmallEncoder(BlenderbotSmallPreTrainedModel):
|
||||
"""
|
||||
@ -657,7 +803,9 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
|
||||
config.max_position_embeddings,
|
||||
config.d_model,
|
||||
)
|
||||
self.layers = nn.ModuleList([BlenderbotSmallDecoderLayer(config) for _ in range(config.decoder_layers)])
|
||||
self.layers = nn.ModuleList(
|
||||
[BlenderbotSmallDecoderLayer(config, layer_idx=i) for i in range(config.decoder_layers)]
|
||||
)
|
||||
self.layernorm_embedding = nn.LayerNorm(config.d_model)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
@ -684,6 +832,7 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
cache_position=None,
|
||||
):
|
||||
r"""
|
||||
Args:
|
||||
@ -749,6 +898,9 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
|
||||
for more detail.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
||||
Indices depicting the position of the input sequence tokens in the sequence. It is used to update the
|
||||
cache in the correct position and to infer the complete sequence length.
|
||||
"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
@ -757,43 +909,6 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
input_shape = input_ids.size()
|
||||
input_ids = input_ids.view(-1, input_shape[-1])
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = inputs_embeds.size()[:-1]
|
||||
else:
|
||||
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
||||
|
||||
# past_key_values_length
|
||||
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
||||
|
||||
attention_mask = _prepare_4d_causal_attention_mask(
|
||||
attention_mask, input_shape, inputs_embeds, past_key_values_length
|
||||
)
|
||||
|
||||
# expand encoder attention mask
|
||||
if encoder_hidden_states is not None and encoder_attention_mask is not None:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
encoder_attention_mask = _prepare_4d_attention_mask(
|
||||
encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
|
||||
)
|
||||
|
||||
# embed positions
|
||||
positions = self.embed_positions(input_shape, past_key_values_length)
|
||||
|
||||
# BlenderbotSmall applies layer norm on hidden_states
|
||||
inputs_embeds = self.layernorm_embedding(inputs_embeds)
|
||||
hidden_states = inputs_embeds + positions
|
||||
|
||||
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
@ -801,11 +916,76 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
||||
|
||||
if input_ids is not None:
|
||||
input_ids = input_ids.view(-1, input_ids.shape[-1])
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
inputs_embeds = inputs_embeds * self.embed_scale
|
||||
|
||||
# initialize `past_key_values`
|
||||
return_legacy_cache = False
|
||||
if use_cache and not isinstance(past_key_values, Cache):
|
||||
return_legacy_cache = True
|
||||
logger.warning_once(
|
||||
"Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. "
|
||||
"You should pass an instance of `EncoderDecoderCache` instead, e.g. "
|
||||
"`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`."
|
||||
)
|
||||
past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values)
|
||||
|
||||
batch_size, seq_length = inputs_embeds.size()[:-1]
|
||||
past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
if cache_position is None:
|
||||
cache_position = torch.arange(
|
||||
past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device
|
||||
)
|
||||
|
||||
if attention_mask is None and not is_torchdynamo_compiling():
|
||||
# required mask seq length can be calculated via length of past cache
|
||||
mask_seq_length = past_key_values_length + seq_length
|
||||
attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
|
||||
|
||||
self_attn_cache = (
|
||||
past_key_values.self_attention_cache
|
||||
if isinstance(past_key_values, EncoderDecoderCache)
|
||||
else past_key_values
|
||||
)
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask,
|
||||
inputs_embeds,
|
||||
cache_position,
|
||||
self_attn_cache,
|
||||
output_attentions,
|
||||
)
|
||||
|
||||
# expand encoder attention mask
|
||||
if encoder_hidden_states is not None and encoder_attention_mask is not None:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
encoder_attention_mask = _prepare_4d_attention_mask(
|
||||
encoder_attention_mask, inputs_embeds.dtype, tgt_len=seq_length
|
||||
)
|
||||
|
||||
# embed positions
|
||||
position_ids = self.embed_positions(
|
||||
(batch_size, seq_length), past_key_values_length, position_ids=cache_position
|
||||
)
|
||||
|
||||
# BlenderbotSmall applies layer norm on hidden_states
|
||||
inputs_embeds = self.layernorm_embedding(inputs_embeds)
|
||||
hidden_states = inputs_embeds + position_ids
|
||||
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
|
||||
next_decoder_cache = () if use_cache else None
|
||||
next_decoder_cache = None
|
||||
|
||||
# check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
|
||||
for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
|
||||
@ -824,13 +1004,11 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
|
||||
if dropout_probability < self.layerdrop:
|
||||
continue
|
||||
|
||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
causal_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
head_mask[idx] if head_mask is not None else None,
|
||||
@ -838,25 +1016,27 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
|
||||
None,
|
||||
output_attentions,
|
||||
use_cache,
|
||||
cache_position,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
attention_mask=causal_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||
cross_attn_layer_head_mask=(
|
||||
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
|
||||
),
|
||||
past_key_value=past_key_value,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
|
||||
next_decoder_cache = layer_outputs[3 if output_attentions else 1]
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
@ -869,6 +1049,9 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
next_cache = next_decoder_cache if use_cache else None
|
||||
if return_legacy_cache:
|
||||
next_cache = past_key_values.to_legacy_cache()
|
||||
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
@ -932,6 +1115,7 @@ class BlenderbotSmallModel(BlenderbotSmallPreTrainedModel):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.Tensor] = None,
|
||||
) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]:
|
||||
r"""
|
||||
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
|
||||
@ -1010,6 +1194,7 @@ class BlenderbotSmallModel(BlenderbotSmallPreTrainedModel):
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
@ -1093,6 +1278,7 @@ class BlenderbotSmallForConditionalGeneration(BlenderbotSmallPreTrainedModel, Ge
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.Tensor] = None,
|
||||
) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
|
||||
r"""
|
||||
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
|
||||
@ -1179,6 +1365,7 @@ class BlenderbotSmallForConditionalGeneration(BlenderbotSmallPreTrainedModel, Ge
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias
|
||||
|
||||
@ -1280,6 +1467,7 @@ class BlenderbotSmallForCausalLM(BlenderbotSmallPreTrainedModel, GenerationMixin
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
|
||||
r"""
|
||||
cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
|
||||
@ -1329,6 +1517,7 @@ class BlenderbotSmallForCausalLM(BlenderbotSmallPreTrainedModel, GenerationMixin
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
logits = self.lm_head(outputs[0])
|
||||
|
@ -1956,6 +1956,50 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
|
||||
if hasattr(self.language_model, "_hf_hook"):
|
||||
self.language_model._hf_hook.io_same_device = True # For `generate` compatibility
|
||||
|
||||
def get_image_features(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
interpolate_pos_encoding: Optional[bool] = False,
|
||||
return_dict: Optional[bool] = False,
|
||||
):
|
||||
"""
|
||||
Encodes images into continuous embeddings that can be forwarded to the language model.
|
||||
|
||||
Args:
|
||||
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
|
||||
The tensors corresponding to the input images.
|
||||
"""
|
||||
# step 1: forward the images through the vision encoder,
|
||||
# to get image embeddings of shape (batch_size, seq_len, hidden_size)
|
||||
vision_outputs = self.vision_model(
|
||||
pixel_values=pixel_values,
|
||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||
return_dict=True,
|
||||
)
|
||||
image_embeds = vision_outputs[0]
|
||||
|
||||
# step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
|
||||
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
|
||||
|
||||
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
||||
query_outputs = self.qformer(
|
||||
query_embeds=query_tokens,
|
||||
encoder_hidden_states=image_embeds,
|
||||
encoder_attention_mask=image_attention_mask,
|
||||
return_dict=True,
|
||||
)
|
||||
query_output = query_outputs[0]
|
||||
|
||||
# Qformer is kept in fp32, we downcast the output back if needed
|
||||
if query_output.dtype != image_embeds.dtype:
|
||||
query_output = query_output.to(image_embeds.dtype)
|
||||
|
||||
# step 3: use the language model, conditioned on the query outputs and the prompt
|
||||
language_model_inputs = self.language_projection(query_output)
|
||||
if return_dict:
|
||||
return language_model_inputs, vision_outputs, query_outputs
|
||||
return language_model_inputs
|
||||
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -2047,37 +2091,11 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
|
||||
```"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# step 1: forward the images through the vision encoder,
|
||||
# to get image embeddings of shape (batch_size, seq_len, hidden_size)
|
||||
vision_outputs = self.vision_model(
|
||||
pixel_values=pixel_values,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||
language_model_inputs, vision_outputs, query_outputs = self.get_image_features(
|
||||
pixel_values, interpolate_pos_encoding=interpolate_pos_encoding, return_dict=True
|
||||
)
|
||||
image_embeds = vision_outputs[0]
|
||||
|
||||
# step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
|
||||
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
|
||||
|
||||
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
||||
query_outputs = self.qformer(
|
||||
query_embeds=query_tokens,
|
||||
encoder_hidden_states=image_embeds,
|
||||
encoder_attention_mask=image_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
query_output = query_outputs[0]
|
||||
|
||||
# Qformer is kept in fp32, we downcast the output back if needed
|
||||
if query_output.dtype != image_embeds.dtype:
|
||||
query_output = query_output.to(image_embeds.dtype)
|
||||
|
||||
# step 3: use the language model, conditioned on the query outputs and the prompt
|
||||
language_model_inputs = self.language_projection(query_output)
|
||||
vision_outputs = vision_outputs.to_tuple() if not return_dict else vision_outputs
|
||||
query_outputs = query_outputs.to_tuple() if not return_dict else query_outputs
|
||||
language_model_attention_mask = torch.ones(
|
||||
language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
|
||||
)
|
||||
|
@ -904,6 +904,12 @@ class ChameleonModel(ChameleonPreTrainedModel):
|
||||
self.embed_tokens = value
|
||||
|
||||
def get_image_tokens(self, pixel_values: torch.FloatTensor):
|
||||
logger.warning(
|
||||
"`model.get_image_tokens()` is deprecated and will be removed in v4.58. To obtain discrete token use `model.get_image_features()`"
|
||||
)
|
||||
return self.get_image_featues(pixel_values)
|
||||
|
||||
def get_image_features(self, pixel_values: torch.FloatTensor):
|
||||
"""
|
||||
Tokenizes images into discrete tokens with VQGAN module. Converts
|
||||
obtained image tokens into BPE tokens and wraps with "boi" and "eoi"
|
||||
@ -957,7 +963,7 @@ class ChameleonModel(ChameleonPreTrainedModel):
|
||||
)
|
||||
|
||||
if pixel_values is not None:
|
||||
image_tokens = self.get_image_tokens(pixel_values)
|
||||
image_tokens = self.get_image_features(pixel_values)
|
||||
special_image_mask = input_ids == self.vocabulary_mapping.image_token_id
|
||||
if not is_torchdynamo_compiling() and input_ids[special_image_mask].numel() != image_tokens.numel():
|
||||
n_image_tokens_in_text = (input_ids == self.vocabulary_mapping.image_token_id).sum()
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user