mirror of
https://github.com/huggingface/transformers.git
synced 2025-11-03 11:24:34 +08:00
Compare commits
283 Commits
fix-flash-
...
v4.51.1
| Author | SHA1 | Date | |
|---|---|---|---|
| 10baffb599 | |||
| 4a88ffae40 | |||
| f19aec737e | |||
| d8f0695e84 | |||
| d27c8c38f4 | |||
| 04c0cedcdf | |||
| 4f536ba0ae | |||
| 6b82af0a5b | |||
| 2bf3d4aca8 | |||
| a79b7abede | |||
| 0720e206c6 | |||
| 25b7f27234 | |||
| aa40fda346 | |||
| e94571580b | |||
| 84aa13dd85 | |||
| 0ef339ff1b | |||
| 46d73910d5 | |||
| 579135a2f6 | |||
| 8cd57eb731 | |||
| ebe47ce3e9 | |||
| 531e4fcf0e | |||
| a4e55fcff8 | |||
| 878562b68d | |||
| 8ebc435267 | |||
| ad3d157188 | |||
| 3d40bda30e | |||
| acbcb5d07d | |||
| 4ba0989eab | |||
| 352ec8ef22 | |||
| edd345b52e | |||
| b016de1ae4 | |||
| f74d7da836 | |||
| d130cd0e16 | |||
| 41b9b92b52 | |||
| 8dd0a2b89c | |||
| 15ac2b6ac5 | |||
| b552708694 | |||
| 2b84831a93 | |||
| 2d46a08b63 | |||
| 1b29409d89 | |||
| 8a828a747e | |||
| 3f6af96732 | |||
| 9a1c1fe7ed | |||
| 782d7d945d | |||
| afafb84b59 | |||
| 34ccfebf32 | |||
| f697b3f824 | |||
| 2099287a59 | |||
| a0803a9555 | |||
| 6ce238fe7a | |||
| 12048990a9 | |||
| 98601cc818 | |||
| c9302c0983 | |||
| 2056287940 | |||
| 3e96a0c32b | |||
| 199d7adf10 | |||
| 126abe3461 | |||
| 3d133cc557 | |||
| e90d55ebcc | |||
| cbfa14823b | |||
| 7613cf1a45 | |||
| 32c12aaec3 | |||
| 764ab0d46a | |||
| c94c6ed397 | |||
| e94d607c8b | |||
| adfc91cd46 | |||
| 6f5dc9c82e | |||
| a165458901 | |||
| ed95493ce0 | |||
| 211e4dc9a4 | |||
| 800510c67b | |||
| 41f5c3216c | |||
| bc2dea3f54 | |||
| 35253076f4 | |||
| bf41e54fc8 | |||
| 3249c5dc15 | |||
| 24e311f42b | |||
| 897ff9af0e | |||
| c0bd8048a5 | |||
| 60b75d99b6 | |||
| fac70ff3c0 | |||
| ae34bd75fd | |||
| 8f6b27eb5c | |||
| 737cbd2109 | |||
| 3a6ab46a0b | |||
| 4b13a02920 | |||
| 786d9c5ed9 | |||
| a1e389e637 | |||
| f304318f5f | |||
| 8805600406 | |||
| e686fed635 | |||
| a03cee7a1d | |||
| 3b07ca78bb | |||
| 475664e2c6 | |||
| 0710e9b1e8 | |||
| f99c279d20 | |||
| d1efaf0318 | |||
| 19919689b2 | |||
| d0b65bb479 | |||
| ad63d20dff | |||
| 286393fbb1 | |||
| 4705b04c74 | |||
| 2b4734bd49 | |||
| bd41b9c1ac | |||
| 6acd5aecb3 | |||
| 0d6a60fe55 | |||
| b7fc2daf8b | |||
| bab605dd04 | |||
| 9fd9476005 | |||
| 257bc670fb | |||
| 2bea6bf24e | |||
| a86dad56bc | |||
| d6064754ea | |||
| 581cf96e0c | |||
| eca74d1367 | |||
| 52cc204dd7 | |||
| aa3778afc2 | |||
| c90e6e9625 | |||
| 1fcaad6df9 | |||
| 3af425d4c6 | |||
| 064cd7cdac | |||
| 348f3285c5 | |||
| d6b3c7486b | |||
| 6cc9c8d7d1 | |||
| 4cc65e990f | |||
| 41a0e58e5b | |||
| de77f5b1ec | |||
| 8c5e29bad5 | |||
| 471cf1de63 | |||
| 29f322d04d | |||
| fb8e6c50e4 | |||
| e97c760006 | |||
| c7bc79bd2a | |||
| d1eafe8d4e | |||
| 0e56fb69a2 | |||
| 7e813f9cf0 | |||
| 92429057d9 | |||
| 279c2e302a | |||
| d13c390d01 | |||
| d6d930a64b | |||
| 927ce1d39f | |||
| 49b5ab6a27 | |||
| 5b08db8844 | |||
| 3a8ec8c467 | |||
| 2b550c47b2 | |||
| 44715225e3 | |||
| 79d6f9fd70 | |||
| 13d36e89fe | |||
| 021006e1b0 | |||
| 788e1092e9 | |||
| ad5d40de9c | |||
| 8084b26294 | |||
| b56d8f07e4 | |||
| 78afa1c537 | |||
| 181d453069 | |||
| e7139d06f5 | |||
| be37d34f44 | |||
| ab4656f6b7 | |||
| ba531278ca | |||
| a844297088 | |||
| d68a91aebf | |||
| 121830ab47 | |||
| a41677a68b | |||
| 3dce98a437 | |||
| ebd2029483 | |||
| 69632aadb7 | |||
| c6814b4ee8 | |||
| bc1c90a755 | |||
| 80b4c5dcc9 | |||
| 0f733110a6 | |||
| 19085c28da | |||
| 69bcb86c58 | |||
| be2c0e7bff | |||
| 4303d88c09 | |||
| 47e5432805 | |||
| 2b8a15cc3f | |||
| 91455c1825 | |||
| 48385aa4f4 | |||
| 5932606d8e | |||
| 2be2984462 | |||
| 00d077267a | |||
| a6ecb54159 | |||
| cbf924b76c | |||
| 340500b1a9 | |||
| 9e125d9a2e | |||
| 57f551c78d | |||
| a41e08aa19 | |||
| e28be7a692 | |||
| 48da44be24 | |||
| fe4ca2f4a7 | |||
| c9d1e5238a | |||
| d253de6d58 | |||
| beb9b5b022 | |||
| dd3933dd65 | |||
| 90e2df5d55 | |||
| 4542b8fb27 | |||
| 523f6e743c | |||
| 3f9ff19b4e | |||
| f94b0c59f2 | |||
| 2638d54e78 | |||
| b8aadc31d5 | |||
| 6321876b5b | |||
| 94f487626a | |||
| f19d018bff | |||
| 62116c967f | |||
| 26c83490d2 | |||
| 0adbc873d0 | |||
| 6bb8565f0c | |||
| 949cca4061 | |||
| 97d2f9d8ae | |||
| 6a2627918d | |||
| 9e771bf402 | |||
| ecd60d01c3 | |||
| 42c489f2ae | |||
| 068b663f90 | |||
| 1d3f35f30a | |||
| 6515c25953 | |||
| 66291778dd | |||
| 730d2a52e7 | |||
| 1a374799ce | |||
| ce091b1bda | |||
| 3e8f0fbf44 | |||
| 055afdb6bb | |||
| 487dab1b2b | |||
| a63e92e2f0 | |||
| 8124a234ca | |||
| cf8091c017 | |||
| 388e6659bf | |||
| b47d9b2f8a | |||
| 8e97b44087 | |||
| 63380b77d4 | |||
| 957b05b413 | |||
| f0d5b2ff04 | |||
| 1ddb64937c | |||
| e7337ee7be | |||
| 8b479e39bb | |||
| 3f03c379d2 | |||
| 8f64b177f6 | |||
| 94555437e2 | |||
| 8733297b41 | |||
| b815fae359 | |||
| 9be4728af8 | |||
| 51bd0ceb9e | |||
| 107fedc1e2 | |||
| 258dd9cc69 | |||
| f39f4960f3 | |||
| 63c3116530 | |||
| 7c233980f4 | |||
| b11050d6a2 | |||
| e8d960329e | |||
| fef8b7f8e9 | |||
| 0fe0bae0a8 | |||
| a861db01e5 | |||
| b9374a0763 | |||
| 4fa91b1be5 | |||
| 706703bba6 | |||
| 179d02ffb8 | |||
| 12f2ebef63 | |||
| 00915d3041 | |||
| 14b597f518 | |||
| 30580f035b | |||
| db1d4c5a0b | |||
| 7baf00089a | |||
| 3017536ebf | |||
| e959530b8f | |||
| bd92073692 | |||
| 7426d02ea8 | |||
| 19b9d8ae13 | |||
| 7f5077e536 | |||
| cbfb8d7b27 | |||
| ac1a1b66b9 | |||
| cff4caa0c1 | |||
| e3af4fec91 | |||
| c8a2b25f91 | |||
| 8e67230860 | |||
| 27361bd218 | |||
| da7d64f4ff | |||
| 2256875a77 | |||
| 9e94801146 | |||
| c53d53da89 | |||
| fc8764c9a6 | |||
| f263e88dcf | |||
| 6f3e0b68e0 |
@ -154,7 +154,7 @@ jobs:
|
||||
path: ~/transformers/installed.txt
|
||||
- run: python -c "from transformers import *" || (echo '🚨 import failed, this means you introduced unprotected imports! 🚨'; exit 1)
|
||||
- run: ruff check examples tests src utils
|
||||
- run: ruff format tests src utils --check
|
||||
- run: ruff format examples tests src utils --check
|
||||
- run: python utils/custom_init_isort.py --check_only
|
||||
- run: python utils/sort_auto_mappings.py --check_only
|
||||
- run: python utils/check_doc_toc.py
|
||||
|
||||
@ -30,7 +30,7 @@ COMMON_ENV_VARIABLES = {
|
||||
"RUN_PIPELINE_TESTS": False,
|
||||
}
|
||||
# Disable the use of {"s": None} as the output is way too long, causing the navigation on CircleCI impractical
|
||||
COMMON_PYTEST_OPTIONS = {"max-worker-restart": 0, "dist": "loadfile", "vvv": None, "rsfE":None}
|
||||
COMMON_PYTEST_OPTIONS = {"max-worker-restart": 0, "vvv": None, "rsfE":None}
|
||||
DEFAULT_DOCKER_IMAGE = [{"image": "cimg/python:3.8.12"}]
|
||||
|
||||
# Strings that commonly appear in the output of flaky tests when they fail. These are used with `pytest-rerunfailures`
|
||||
@ -171,6 +171,7 @@ class CircleCIJob:
|
||||
"command": f"TESTS=$(circleci tests split --split-by=timings {self.job_name}_test_list.txt) && echo $TESTS > splitted_tests.txt && echo $TESTS | tr ' ' '\n'" if self.parallelism else f"awk '{{printf \"%s \", $0}}' {self.job_name}_test_list.txt > splitted_tests.txt"
|
||||
}
|
||||
},
|
||||
{"run": {"name": "fetch hub objects before pytest", "command": "python3 utils/fetch_hub_objects_for_ci.py"}},
|
||||
{"run": {
|
||||
"name": "Run tests",
|
||||
"command": f"({timeout_cmd} python3 -m pytest {marker_cmd} -n {self.pytest_num_workers} {junit_flags} {repeat_on_failure_flags} {' '.join(pytest_flags)} $(cat splitted_tests.txt) | tee tests_output.txt)"}
|
||||
@ -206,6 +207,9 @@ torch_job = CircleCIJob(
|
||||
generate_job = CircleCIJob(
|
||||
"generate",
|
||||
docker_image=[{"image": "huggingface/transformers-torch-light"}],
|
||||
# networkx==3.3 (after #36957) cause some issues
|
||||
# TODO: remove this once it works directly
|
||||
install_steps=["uv venv && uv pip install . && uv pip install networkx==3.2.1"],
|
||||
marker="generate",
|
||||
parallelism=6,
|
||||
)
|
||||
@ -269,6 +273,7 @@ examples_torch_job = CircleCIJob(
|
||||
docker_image=[{"image":"huggingface/transformers-examples-torch"}],
|
||||
# TODO @ArthurZucker remove this once docker is easier to build
|
||||
install_steps=["uv venv && uv pip install . && uv pip install -r examples/pytorch/_tests_requirements.txt"],
|
||||
pytest_num_workers=4,
|
||||
)
|
||||
|
||||
|
||||
@ -276,6 +281,7 @@ examples_tensorflow_job = CircleCIJob(
|
||||
"examples_tensorflow",
|
||||
additional_env={"OMP_NUM_THREADS": 8},
|
||||
docker_image=[{"image":"huggingface/transformers-examples-tf"}],
|
||||
pytest_num_workers=2,
|
||||
)
|
||||
|
||||
|
||||
@ -326,6 +332,9 @@ repo_utils_job = CircleCIJob(
|
||||
non_model_job = CircleCIJob(
|
||||
"non_model",
|
||||
docker_image=[{"image": "huggingface/transformers-torch-light"}],
|
||||
# networkx==3.3 (after #36957) cause some issues
|
||||
# TODO: remove this once it works directly
|
||||
install_steps=["uv venv && uv pip install . && uv pip install networkx==3.2.1"],
|
||||
marker="not generate",
|
||||
parallelism=6,
|
||||
)
|
||||
@ -355,9 +364,9 @@ doc_test_job = CircleCIJob(
|
||||
pytest_num_workers=1,
|
||||
)
|
||||
|
||||
REGULAR_TESTS = [torch_job, tf_job, flax_job, hub_job, onnx_job, tokenization_job, processor_job, generate_job, non_model_job] # fmt: skip
|
||||
EXAMPLES_TESTS = [examples_torch_job, examples_tensorflow_job]
|
||||
PIPELINE_TESTS = [pipelines_torch_job, pipelines_tf_job]
|
||||
REGULAR_TESTS = [torch_job, flax_job, hub_job, onnx_job, tokenization_job, processor_job, generate_job, non_model_job] # fmt: skip
|
||||
EXAMPLES_TESTS = [examples_torch_job]
|
||||
PIPELINE_TESTS = [pipelines_torch_job]
|
||||
REPO_UTIL_TESTS = [repo_utils_job]
|
||||
DOC_TESTS = [doc_test_job]
|
||||
ALL_TESTS = REGULAR_TESTS + EXAMPLES_TESTS + PIPELINE_TESTS + REPO_UTIL_TESTS + DOC_TESTS + [custom_tokenizers_job] + [exotic_models_job] # fmt: skip
|
||||
|
||||
4
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
4
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
@ -48,11 +48,11 @@ body:
|
||||
- pipelines: @Rocketknight1
|
||||
- tensorflow: @gante and @Rocketknight1
|
||||
- tokenizers: @ArthurZucker and @itazap
|
||||
- trainer: @muellerzr @SunMarc
|
||||
- trainer: @zach-huggingface @SunMarc
|
||||
|
||||
Integrations:
|
||||
|
||||
- deepspeed: HF Trainer/Accelerate: @muellerzr
|
||||
- deepspeed: HF Trainer/Accelerate: @SunMarc @zach-huggingface
|
||||
- ray/raytune: @richardliaw, @amogkam
|
||||
- Big Model Inference: @SunMarc
|
||||
- quantization (bitsandbytes, autogpt): @SunMarc @MekkCyber
|
||||
|
||||
4
.github/PULL_REQUEST_TEMPLATE.md
vendored
4
.github/PULL_REQUEST_TEMPLATE.md
vendored
@ -51,12 +51,12 @@ Library:
|
||||
- pipelines: @Rocketknight1
|
||||
- tensorflow: @gante and @Rocketknight1
|
||||
- tokenizers: @ArthurZucker
|
||||
- trainer: @muellerzr and @SunMarc
|
||||
- trainer: @zach-huggingface and @SunMarc
|
||||
- chat templates: @Rocketknight1
|
||||
|
||||
Integrations:
|
||||
|
||||
- deepspeed: HF Trainer/Accelerate: @muellerzr
|
||||
- deepspeed: HF Trainer/Accelerate: @SunMarc @zach-huggingface
|
||||
- ray/raytune: @richardliaw, @amogkam
|
||||
- Big Model Inference: @SunMarc
|
||||
- quantization (bitsandbytes, autogpt): @SunMarc @MekkCyber
|
||||
|
||||
10
.github/scripts/assign_reviewers.py
vendored
10
.github/scripts/assign_reviewers.py
vendored
@ -22,12 +22,16 @@ from collections import Counter
|
||||
from pathlib import Path
|
||||
|
||||
def pattern_to_regex(pattern):
|
||||
start_anchor = pattern.startswith("/")
|
||||
pattern = re.escape(pattern)
|
||||
if pattern.startswith("/"):
|
||||
start_anchor = True
|
||||
pattern = re.escape(pattern[1:])
|
||||
else:
|
||||
start_anchor = False
|
||||
pattern = re.escape(pattern)
|
||||
# Replace `*` with "any number of non-slash characters"
|
||||
pattern = pattern.replace(r"\*", "[^/]*")
|
||||
if start_anchor:
|
||||
pattern = "^" + pattern
|
||||
pattern = r"^\/?" + pattern # Allow an optional leading slash after the start of the string
|
||||
return pattern
|
||||
|
||||
def get_file_owners(file_path, codeowners_lines):
|
||||
|
||||
6
.github/scripts/codeowners_for_review_action
vendored
6
.github/scripts/codeowners_for_review_action
vendored
@ -14,7 +14,7 @@ docs/ @stevhliu
|
||||
# Owners of subsections of the library
|
||||
/src/transformers/generation/ @gante
|
||||
/src/transformers/pipeline/ @Rocketknight1 @yonigozlan
|
||||
/src/transformers/integrations/ @SunMarc @MekkCyber @muellerzr
|
||||
/src/transformers/integrations/ @SunMarc @MekkCyber @zach-huggingface
|
||||
/src/transformers/quantizers/ @SunMarc @MekkCyber
|
||||
tests/ @ydshieh
|
||||
tests/generation/ @gante
|
||||
@ -27,8 +27,8 @@ tests/generation/ @gante
|
||||
# Specific files come after the sections/globs, so they take priority
|
||||
/.circleci/config.yml @ArthurZucker @ydshieh
|
||||
/utils/tests_fetcher.py @ydshieh
|
||||
trainer.py @muellerzr @SunMarc
|
||||
trainer_utils.py @muellerzr @SunMarc
|
||||
trainer.py @zach-huggingface @SunMarc
|
||||
trainer_utils.py @zach-huggingface @SunMarc
|
||||
/utils/modular_model_converter.py @Cyrilvallez @ArthurZucker
|
||||
|
||||
# Owners of individual models are specific / high priority, and so they come last
|
||||
|
||||
1
.github/workflows/build_pr_documentation.yml
vendored
1
.github/workflows/build_pr_documentation.yml
vendored
@ -15,4 +15,3 @@ jobs:
|
||||
pr_number: ${{ github.event.number }}
|
||||
package: transformers
|
||||
languages: ar de en es fr hi it ko pt tr zh ja te
|
||||
custom_container: huggingface/transformers-doc-builder
|
||||
|
||||
2
.github/workflows/change_pr_to_draft.yml
vendored
2
.github/workflows/change_pr_to_draft.yml
vendored
@ -22,4 +22,4 @@ jobs:
|
||||
run: |
|
||||
echo $PR_NUMBER
|
||||
gh pr ready $PR_NUMBER --repo $REPO --undo
|
||||
gh pr comment $PR_NUMBER --repo $REPO --body "Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. When it is ready for review, please click the \`Ready for review\` button (at the bottom of the PR page)."
|
||||
gh pr comment $PR_NUMBER --repo $REPO --body "Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. The CI will be paused while the PR is in draft mode. When it is ready for review, please click the \`Ready for review\` button (at the bottom of the PR page). This will assign reviewers and trigger CI."
|
||||
|
||||
2
.github/workflows/push-important-models.yml
vendored
2
.github/workflows/push-important-models.yml
vendored
@ -27,7 +27,7 @@ jobs:
|
||||
|
||||
- name: Get changed files
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@3f54ebb830831fc121d3263c1857cfbdc310cdb9 #v42
|
||||
uses: tj-actions/changed-files@1c8e6069583811afb28f97afeaf8e7da80c6be5c
|
||||
with:
|
||||
files: src/transformers/models/**
|
||||
|
||||
|
||||
2
.github/workflows/self-comment-ci.yml
vendored
2
.github/workflows/self-comment-ci.yml
vendored
@ -29,7 +29,7 @@ jobs:
|
||||
runs-on: ubuntu-22.04
|
||||
name: Get PR number
|
||||
# For security: only allow team members to run
|
||||
if: ${{ github.event.issue.state == 'open' && contains(fromJSON('["ydshieh", "ArthurZucker", "zucchini-nlp", "qubvel", "molbap", "gante", "LysandreJik", "Cyrilvallez", "Rocketknight1", "SunMarc", "muellerzr"]'), github.actor) && (startsWith(github.event.comment.body, 'run-slow') || startsWith(github.event.comment.body, 'run slow') || startsWith(github.event.comment.body, 'run_slow')) }}
|
||||
if: ${{ github.event.issue.state == 'open' && contains(fromJSON('["ydshieh", "ArthurZucker", "zucchini-nlp", "qubvel", "molbap", "gante", "LysandreJik", "Cyrilvallez", "Rocketknight1", "SunMarc", "muellerzr", "eustlb"]'), github.actor) && (startsWith(github.event.comment.body, 'run-slow') || startsWith(github.event.comment.body, 'run slow') || startsWith(github.event.comment.body, 'run_slow')) }}
|
||||
outputs:
|
||||
PR_NUMBER: ${{ steps.set_pr_number.outputs.PR_NUMBER }}
|
||||
steps:
|
||||
|
||||
4
.github/workflows/self-push-caller.yml
vendored
4
.github/workflows/self-push-caller.yml
vendored
@ -25,7 +25,7 @@ jobs:
|
||||
|
||||
- name: Get changed files
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@v41
|
||||
uses: tj-actions/changed-files@1c8e6069583811afb28f97afeaf8e7da80c6be5c
|
||||
|
||||
- name: Was setup changed
|
||||
id: was_changed
|
||||
@ -51,4 +51,4 @@ jobs:
|
||||
needs: build-docker-containers
|
||||
steps:
|
||||
- name: Trigger push CI via workflow_run
|
||||
run: echo "Trigger push CI via workflow_run"
|
||||
run: echo "Trigger push CI via workflow_run"
|
||||
|
||||
2
.github/workflows/update_metdata.yml
vendored
2
.github/workflows/update_metdata.yml
vendored
@ -19,7 +19,7 @@ jobs:
|
||||
- name: Setup environment
|
||||
run: |
|
||||
pip install --upgrade pip
|
||||
pip install datasets pandas==2.0.3
|
||||
pip install datasets pandas
|
||||
pip install .[torch,tf,flax]
|
||||
|
||||
- name: Update metadata
|
||||
|
||||
@ -263,9 +263,9 @@ You are not required to read the following guidelines before opening an issue. H
|
||||
But if you're replying to a comment that happened some comments back it's always a good practice to quote just the relevant lines you're replying it. The `>` is used for quoting, or you can always use the menu to do so. For example your editor box will look like:
|
||||
|
||||
```
|
||||
> How big is your gpu cluster?
|
||||
> How big is your GPU cluster?
|
||||
|
||||
Our cluster is made of 256 gpus.
|
||||
Our cluster is made of 256 GPUs.
|
||||
```
|
||||
|
||||
If you are addressing multiple comments, quote the relevant parts of each before your answer. Some people use the same comment to do multiple replies, others separate them into separate comments. Either way works. The latter approach helps for linking to a specific comment.
|
||||
|
||||
386
README.md
386
README.md
@ -25,6 +25,7 @@ limitations under the License.
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://huggingface.com/models"><img alt="Checkpoints on Hub" src="https://img.shields.io/endpoint?url=https://huggingface.co/api/shields/models&color=brightgreen"></a>
|
||||
<a href="https://circleci.com/gh/huggingface/transformers"><img alt="Build" src="https://img.shields.io/circleci/build/github/huggingface/transformers/main"></a>
|
||||
<a href="https://github.com/huggingface/transformers/blob/main/LICENSE"><img alt="GitHub" src="https://img.shields.io/github/license/huggingface/transformers.svg?color=blue"></a>
|
||||
<a href="https://huggingface.co/docs/transformers/index"><img alt="Documentation" src="https://img.shields.io/website/http/huggingface.co/docs/transformers/index.svg?down_color=red&down_message=offline&up_message=online"></a>
|
||||
@ -54,275 +55,254 @@ limitations under the License.
|
||||
</h4>
|
||||
|
||||
<h3 align="center">
|
||||
<p>State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow</p>
|
||||
<p>State-of-the-art pretrained models for inference and training</p>
|
||||
</h3>
|
||||
|
||||
<h3 align="center">
|
||||
<a href="https://hf.co/course"><img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/course_banner.png"></a>
|
||||
</h3>
|
||||
|
||||
🤗 Transformers provides thousands of pretrained models to perform tasks on different modalities such as text, vision, and audio.
|
||||
Transformers is a library of pretrained text, computer vision, audio, video, and multimodal models for inference and training. Use Transformers to fine-tune models on your data, build inference applications, and for generative AI use cases across multiple modalities.
|
||||
|
||||
These models can be applied on:
|
||||
There are over 500K+ Transformers [model checkpoints](https://huggingface.co/models?library=transformers&sort=trending) on the [Hugging Face Hub](https://huggingface.com/models) you can use.
|
||||
|
||||
* 📝 Text, for tasks like text classification, information extraction, question answering, summarization, translation, and text generation, in over 100 languages.
|
||||
* 🖼️ Images, for tasks like image classification, object detection, and segmentation.
|
||||
* 🗣️ Audio, for tasks like speech recognition and audio classification.
|
||||
Explore the [Hub](https://huggingface.com/) today to find a model and use Transformers to help you get started right away.
|
||||
|
||||
Transformer models can also perform tasks on **several modalities combined**, such as table question answering, optical character recognition, information extraction from scanned documents, video classification, and visual question answering.
|
||||
## Installation
|
||||
|
||||
🤗 Transformers provides APIs to quickly download and use those pretrained models on a given text, fine-tune them on your own datasets and then share them with the community on our [model hub](https://huggingface.co/models). At the same time, each python module defining an architecture is fully standalone and can be modified to enable quick research experiments.
|
||||
Transformers works with Python 3.9+ [PyTorch](https://pytorch.org/get-started/locally/) 2.0+, [TensorFlow](https://www.tensorflow.org/install/pip) 2.6+, and [Flax](https://flax.readthedocs.io/en/latest/) 0.4.1+.
|
||||
|
||||
🤗 Transformers is backed by the three most popular deep learning libraries — [Jax](https://jax.readthedocs.io/en/latest/), [PyTorch](https://pytorch.org/) and [TensorFlow](https://www.tensorflow.org/) — with a seamless integration between them. It's straightforward to train your models with one before loading them for inference with the other.
|
||||
Create and activate a virtual environment with [venv](https://docs.python.org/3/library/venv.html) or [uv](https://docs.astral.sh/uv/), a fast Rust-based Python package and project manager.
|
||||
|
||||
## Online demos
|
||||
```py
|
||||
# venv
|
||||
python -m venv .my-env
|
||||
source .my-env/bin/activate
|
||||
|
||||
You can test most of our models directly on their pages from the [model hub](https://huggingface.co/models). We also offer [private model hosting, versioning, & an inference API](https://huggingface.co/pricing) for public and private models.
|
||||
|
||||
Here are a few examples:
|
||||
|
||||
In Natural Language Processing:
|
||||
- [Masked word completion with BERT](https://huggingface.co/google-bert/bert-base-uncased?text=Paris+is+the+%5BMASK%5D+of+France)
|
||||
- [Named Entity Recognition with Electra](https://huggingface.co/dbmdz/electra-large-discriminator-finetuned-conll03-english?text=My+name+is+Sarah+and+I+live+in+London+city)
|
||||
- [Text generation with Mistral](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2)
|
||||
- [Natural Language Inference with RoBERTa](https://huggingface.co/FacebookAI/roberta-large-mnli?text=The+dog+was+lost.+Nobody+lost+any+animal)
|
||||
- [Summarization with BART](https://huggingface.co/facebook/bart-large-cnn?text=The+tower+is+324+metres+%281%2C063+ft%29+tall%2C+about+the+same+height+as+an+81-storey+building%2C+and+the+tallest+structure+in+Paris.+Its+base+is+square%2C+measuring+125+metres+%28410+ft%29+on+each+side.+During+its+construction%2C+the+Eiffel+Tower+surpassed+the+Washington+Monument+to+become+the+tallest+man-made+structure+in+the+world%2C+a+title+it+held+for+41+years+until+the+Chrysler+Building+in+New+York+City+was+finished+in+1930.+It+was+the+first+structure+to+reach+a+height+of+300+metres.+Due+to+the+addition+of+a+broadcasting+aerial+at+the+top+of+the+tower+in+1957%2C+it+is+now+taller+than+the+Chrysler+Building+by+5.2+metres+%2817+ft%29.+Excluding+transmitters%2C+the+Eiffel+Tower+is+the+second+tallest+free-standing+structure+in+France+after+the+Millau+Viaduct)
|
||||
- [Question answering with DistilBERT](https://huggingface.co/distilbert/distilbert-base-uncased-distilled-squad?text=Which+name+is+also+used+to+describe+the+Amazon+rainforest+in+English%3F&context=The+Amazon+rainforest+%28Portuguese%3A+Floresta+Amaz%C3%B4nica+or+Amaz%C3%B4nia%3B+Spanish%3A+Selva+Amaz%C3%B3nica%2C+Amazon%C3%ADa+or+usually+Amazonia%3B+French%3A+For%C3%AAt+amazonienne%3B+Dutch%3A+Amazoneregenwoud%29%2C+also+known+in+English+as+Amazonia+or+the+Amazon+Jungle%2C+is+a+moist+broadleaf+forest+that+covers+most+of+the+Amazon+basin+of+South+America.+This+basin+encompasses+7%2C000%2C000+square+kilometres+%282%2C700%2C000+sq+mi%29%2C+of+which+5%2C500%2C000+square+kilometres+%282%2C100%2C000+sq+mi%29+are+covered+by+the+rainforest.+This+region+includes+territory+belonging+to+nine+nations.+The+majority+of+the+forest+is+contained+within+Brazil%2C+with+60%25+of+the+rainforest%2C+followed+by+Peru+with+13%25%2C+Colombia+with+10%25%2C+and+with+minor+amounts+in+Venezuela%2C+Ecuador%2C+Bolivia%2C+Guyana%2C+Suriname+and+French+Guiana.+States+or+departments+in+four+nations+contain+%22Amazonas%22+in+their+names.+The+Amazon+represents+over+half+of+the+planet%27s+remaining+rainforests%2C+and+comprises+the+largest+and+most+biodiverse+tract+of+tropical+rainforest+in+the+world%2C+with+an+estimated+390+billion+individual+trees+divided+into+16%2C000+species)
|
||||
- [Translation with T5](https://huggingface.co/google-t5/t5-base?text=My+name+is+Wolfgang+and+I+live+in+Berlin)
|
||||
|
||||
In Computer Vision:
|
||||
- [Image classification with ViT](https://huggingface.co/google/vit-base-patch16-224)
|
||||
- [Object Detection with DETR](https://huggingface.co/facebook/detr-resnet-50)
|
||||
- [Semantic Segmentation with SegFormer](https://huggingface.co/nvidia/segformer-b0-finetuned-ade-512-512)
|
||||
- [Panoptic Segmentation with Mask2Former](https://huggingface.co/facebook/mask2former-swin-large-coco-panoptic)
|
||||
- [Depth Estimation with Depth Anything](https://huggingface.co/docs/transformers/main/model_doc/depth_anything)
|
||||
- [Video Classification with VideoMAE](https://huggingface.co/docs/transformers/model_doc/videomae)
|
||||
- [Universal Segmentation with OneFormer](https://huggingface.co/shi-labs/oneformer_ade20k_dinat_large)
|
||||
|
||||
In Audio:
|
||||
- [Automatic Speech Recognition with Whisper](https://huggingface.co/openai/whisper-large-v3)
|
||||
- [Keyword Spotting with Wav2Vec2](https://huggingface.co/superb/wav2vec2-base-superb-ks)
|
||||
- [Audio Classification with Audio Spectrogram Transformer](https://huggingface.co/MIT/ast-finetuned-audioset-10-10-0.4593)
|
||||
|
||||
In Multimodal tasks:
|
||||
- [Table Question Answering with TAPAS](https://huggingface.co/google/tapas-base-finetuned-wtq)
|
||||
- [Visual Question Answering with ViLT](https://huggingface.co/dandelin/vilt-b32-finetuned-vqa)
|
||||
- [Image captioning with LLaVa](https://huggingface.co/llava-hf/llava-1.5-7b-hf)
|
||||
- [Zero-shot Image Classification with SigLIP](https://huggingface.co/google/siglip-so400m-patch14-384)
|
||||
- [Document Question Answering with LayoutLM](https://huggingface.co/impira/layoutlm-document-qa)
|
||||
- [Zero-shot Video Classification with X-CLIP](https://huggingface.co/docs/transformers/model_doc/xclip)
|
||||
- [Zero-shot Object Detection with OWLv2](https://huggingface.co/docs/transformers/en/model_doc/owlv2)
|
||||
- [Zero-shot Image Segmentation with CLIPSeg](https://huggingface.co/docs/transformers/model_doc/clipseg)
|
||||
- [Automatic Mask Generation with SAM](https://huggingface.co/docs/transformers/model_doc/sam)
|
||||
|
||||
|
||||
## 100 projects using Transformers
|
||||
|
||||
Transformers is more than a toolkit to use pretrained models: it's a community of projects built around it and the
|
||||
Hugging Face Hub. We want Transformers to enable developers, researchers, students, professors, engineers, and anyone
|
||||
else to build their dream projects.
|
||||
|
||||
In order to celebrate the 100,000 stars of transformers, we have decided to put the spotlight on the
|
||||
community, and we have created the [awesome-transformers](./awesome-transformers.md) page which lists 100
|
||||
incredible projects built in the vicinity of transformers.
|
||||
|
||||
If you own or use a project that you believe should be part of the list, please open a PR to add it!
|
||||
|
||||
## Serious about AI in your organisation? Build faster with the Hugging Face Enterprise Hub.
|
||||
|
||||
<a target="_blank" href="https://huggingface.co/enterprise">
|
||||
<img alt="Hugging Face Enterprise Hub" src="https://github.com/user-attachments/assets/247fb16d-d251-4583-96c4-d3d76dda4925">
|
||||
</a><br>
|
||||
|
||||
## Quick tour
|
||||
|
||||
To immediately use a model on a given input (text, image, audio, ...), we provide the `pipeline` API. Pipelines group together a pretrained model with the preprocessing that was used during that model's training. Here is how to quickly use a pipeline to classify positive versus negative texts:
|
||||
|
||||
```python
|
||||
>>> from transformers import pipeline
|
||||
|
||||
# Allocate a pipeline for sentiment-analysis
|
||||
>>> classifier = pipeline('sentiment-analysis')
|
||||
>>> classifier('We are very happy to introduce pipeline to the transformers repository.')
|
||||
[{'label': 'POSITIVE', 'score': 0.9996980428695679}]
|
||||
# uv
|
||||
uv venv .my-env
|
||||
source .my-env/bin/activate
|
||||
```
|
||||
|
||||
The second line of code downloads and caches the pretrained model used by the pipeline, while the third evaluates it on the given text. Here, the answer is "positive" with a confidence of 99.97%.
|
||||
Install Transformers in your virtual environment.
|
||||
|
||||
Many tasks have a pre-trained `pipeline` ready to go, in NLP but also in computer vision and speech. For example, we can easily extract detected objects in an image:
|
||||
```py
|
||||
# pip
|
||||
pip install transformers
|
||||
|
||||
``` python
|
||||
>>> import requests
|
||||
>>> from PIL import Image
|
||||
>>> from transformers import pipeline
|
||||
|
||||
# Download an image with cute cats
|
||||
>>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/coco_sample.png"
|
||||
>>> image_data = requests.get(url, stream=True).raw
|
||||
>>> image = Image.open(image_data)
|
||||
|
||||
# Allocate a pipeline for object detection
|
||||
>>> object_detector = pipeline('object-detection')
|
||||
>>> object_detector(image)
|
||||
[{'score': 0.9982201457023621,
|
||||
'label': 'remote',
|
||||
'box': {'xmin': 40, 'ymin': 70, 'xmax': 175, 'ymax': 117}},
|
||||
{'score': 0.9960021376609802,
|
||||
'label': 'remote',
|
||||
'box': {'xmin': 333, 'ymin': 72, 'xmax': 368, 'ymax': 187}},
|
||||
{'score': 0.9954745173454285,
|
||||
'label': 'couch',
|
||||
'box': {'xmin': 0, 'ymin': 1, 'xmax': 639, 'ymax': 473}},
|
||||
{'score': 0.9988006353378296,
|
||||
'label': 'cat',
|
||||
'box': {'xmin': 13, 'ymin': 52, 'xmax': 314, 'ymax': 470}},
|
||||
{'score': 0.9986783862113953,
|
||||
'label': 'cat',
|
||||
'box': {'xmin': 345, 'ymin': 23, 'xmax': 640, 'ymax': 368}}]
|
||||
# uv
|
||||
uv pip install transformers
|
||||
```
|
||||
|
||||
Here, we get a list of objects detected in the image, with a box surrounding the object and a confidence score. Here is the original image on the left, with the predictions displayed on the right:
|
||||
Install Transformers from source if you want the latest changes in the library or are interested in contributing. However, the *latest* version may not be stable. Feel free to open an [issue](https://github.com/huggingface/transformers/issues) if you encounter an error.
|
||||
|
||||
```shell
|
||||
git clone https://github.com/huggingface/transformers.git
|
||||
cd transformers
|
||||
pip install .
|
||||
```
|
||||
|
||||
## Quickstart
|
||||
|
||||
Get started with Transformers right away with the [Pipeline](https://huggingface.co/docs/transformers/pipeline_tutorial) API. The `Pipeline` is a high-level inference class that supports text, audio, vision, and multimodal tasks. It handles preprocessing the input and returns the appropriate output.
|
||||
|
||||
Instantiate a pipeline and specify model to use for text generation. The model is downloaded and cached so you can easily reuse it again. Finally, pass some text to prompt the model.
|
||||
|
||||
```py
|
||||
from transformers import pipeline
|
||||
|
||||
pipeline = pipeline(task="text-generation", model="Qwen/Qwen2.5-1.5B")
|
||||
pipeline("the secret to baking a really good cake is ")
|
||||
[{'generated_text': 'the secret to baking a really good cake is 1) to use the right ingredients and 2) to follow the recipe exactly. the recipe for the cake is as follows: 1 cup of sugar, 1 cup of flour, 1 cup of milk, 1 cup of butter, 1 cup of eggs, 1 cup of chocolate chips. if you want to make 2 cakes, how much sugar do you need? To make 2 cakes, you will need 2 cups of sugar.'}]
|
||||
```
|
||||
|
||||
To chat with a model, the usage pattern is the same. The only difference is you need to construct a chat history (the input to `Pipeline`) between you and the system.
|
||||
|
||||
> [!TIP]
|
||||
> You can also chat with a model directly from the command line.
|
||||
> ```shell
|
||||
> transformers-cli chat --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct
|
||||
> ```
|
||||
|
||||
```py
|
||||
import torch
|
||||
from transformers import pipeline
|
||||
|
||||
chat = [
|
||||
{"role": "system", "content": "You are a sassy, wise-cracking robot as imagined by Hollywood circa 1986."},
|
||||
{"role": "user", "content": "Hey, can you tell me any fun things to do in New York?"}
|
||||
]
|
||||
|
||||
pipeline = pipeline(task="text-generation", model="meta-llama/Meta-Llama-3-8B-Instruct", torch_dtype=torch.bfloat16, device_map="auto")
|
||||
response = pipeline(chat, max_new_tokens=512)
|
||||
print(response[0]["generated_text"][-1]["content"])
|
||||
```
|
||||
|
||||
Expand the examples below to see how `Pipeline` works for different modalities and tasks.
|
||||
|
||||
<details>
|
||||
<summary>Automatic speech recognition</summary>
|
||||
|
||||
```py
|
||||
from transformers import pipeline
|
||||
|
||||
pipeline = pipeline(task="automatic-speech-recognition", model="openai/whisper-large-v3")
|
||||
pipeline("https://huggingface.co/datasets/Narsil/asr_dummy/resolve/main/mlk.flac")
|
||||
{'text': ' I have a dream that one day this nation will rise up and live out the true meaning of its creed.'}
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>Image classification</summary>
|
||||
|
||||
<h3 align="center">
|
||||
<a><img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/coco_sample.png" width="400"></a>
|
||||
<a><img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/coco_sample_post_processed.png" width="400"></a>
|
||||
<a><img src="https://huggingface.co/datasets/Narsil/image_dummy/raw/main/parrots.png"></a>
|
||||
</h3>
|
||||
|
||||
You can learn more about the tasks supported by the `pipeline` API in [this tutorial](https://huggingface.co/docs/transformers/task_summary).
|
||||
```py
|
||||
from transformers import pipeline
|
||||
|
||||
In addition to `pipeline`, to download and use any of the pretrained models on your given task, all it takes is three lines of code. Here is the PyTorch version:
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, AutoModel
|
||||
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
|
||||
>>> model = AutoModel.from_pretrained("google-bert/bert-base-uncased")
|
||||
|
||||
>>> inputs = tokenizer("Hello world!", return_tensors="pt")
|
||||
>>> outputs = model(**inputs)
|
||||
pipeline = pipeline(task="image-classification", model="facebook/dinov2-small-imagenet1k-1-layer")
|
||||
pipeline("https://huggingface.co/datasets/Narsil/image_dummy/raw/main/parrots.png")
|
||||
[{'label': 'macaw', 'score': 0.997848391532898},
|
||||
{'label': 'sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita',
|
||||
'score': 0.0016551691805943847},
|
||||
{'label': 'lorikeet', 'score': 0.00018523589824326336},
|
||||
{'label': 'African grey, African gray, Psittacus erithacus',
|
||||
'score': 7.85409429227002e-05},
|
||||
{'label': 'quail', 'score': 5.502637941390276e-05}]
|
||||
```
|
||||
|
||||
And here is the equivalent code for TensorFlow:
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, TFAutoModel
|
||||
</details>
|
||||
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
|
||||
>>> model = TFAutoModel.from_pretrained("google-bert/bert-base-uncased")
|
||||
<details>
|
||||
<summary>Visual question answering</summary>
|
||||
|
||||
>>> inputs = tokenizer("Hello world!", return_tensors="tf")
|
||||
>>> outputs = model(**inputs)
|
||||
|
||||
<h3 align="center">
|
||||
<a><img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/idefics-few-shot.jpg"></a>
|
||||
</h3>
|
||||
|
||||
```py
|
||||
from transformers import pipeline
|
||||
|
||||
pipeline = pipeline(task="visual-question-answering", model="Salesforce/blip-vqa-base")
|
||||
pipeline(
|
||||
image="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/idefics-few-shot.jpg",
|
||||
question="What is in the image?",
|
||||
)
|
||||
[{'answer': 'statue of liberty'}]
|
||||
```
|
||||
|
||||
The tokenizer is responsible for all the preprocessing the pretrained model expects and can be called directly on a single string (as in the above examples) or a list. It will output a dictionary that you can use in downstream code or simply directly pass to your model using the ** argument unpacking operator.
|
||||
</details>
|
||||
|
||||
The model itself is a regular [Pytorch `nn.Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) or a [TensorFlow `tf.keras.Model`](https://www.tensorflow.org/api_docs/python/tf/keras/Model) (depending on your backend) which you can use as usual. [This tutorial](https://huggingface.co/docs/transformers/training) explains how to integrate such a model into a classic PyTorch or TensorFlow training loop, or how to use our `Trainer` API to quickly fine-tune on a new dataset.
|
||||
|
||||
## Why should I use transformers?
|
||||
## Why should I use Transformers?
|
||||
|
||||
1. Easy-to-use state-of-the-art models:
|
||||
- High performance on natural language understanding & generation, computer vision, and audio tasks.
|
||||
- Low barrier to entry for educators and practitioners.
|
||||
- High performance on natural language understanding & generation, computer vision, audio, video, and multimodal tasks.
|
||||
- Low barrier to entry for researchers, engineers, and developers.
|
||||
- Few user-facing abstractions with just three classes to learn.
|
||||
- A unified API for using all our pretrained models.
|
||||
|
||||
1. Lower compute costs, smaller carbon footprint:
|
||||
- Researchers can share trained models instead of always retraining.
|
||||
- Practitioners can reduce compute time and production costs.
|
||||
- Dozens of architectures with over 400,000 pretrained models across all modalities.
|
||||
- Share trained models instead of training from scratch.
|
||||
- Reduce compute time and production costs.
|
||||
- Dozens of model architectures with 1M+ pretrained checkpoints across all modalities.
|
||||
|
||||
1. Choose the right framework for every part of a model's lifetime:
|
||||
1. Choose the right framework for every part of a models lifetime:
|
||||
- Train state-of-the-art models in 3 lines of code.
|
||||
- Move a single model between TF2.0/PyTorch/JAX frameworks at will.
|
||||
- Seamlessly pick the right framework for training, evaluation, and production.
|
||||
- Move a single model between PyTorch/JAX/TF2.0 frameworks at will.
|
||||
- Pick the right framework for training, evaluation, and production.
|
||||
|
||||
1. Easily customize a model or an example to your needs:
|
||||
- We provide examples for each architecture to reproduce the results published by its original authors.
|
||||
- Model internals are exposed as consistently as possible.
|
||||
- Model files can be used independently of the library for quick experiments.
|
||||
|
||||
## Why shouldn't I use transformers?
|
||||
<a target="_blank" href="https://huggingface.co/enterprise">
|
||||
<img alt="Hugging Face Enterprise Hub" src="https://github.com/user-attachments/assets/247fb16d-d251-4583-96c4-d3d76dda4925">
|
||||
</a><br>
|
||||
|
||||
## Why shouldn't I use Transformers?
|
||||
|
||||
- This library is not a modular toolbox of building blocks for neural nets. The code in the model files is not refactored with additional abstractions on purpose, so that researchers can quickly iterate on each of the models without diving into additional abstractions/files.
|
||||
- The training API is not intended to work on any model but is optimized to work with the models provided by the library. For generic machine learning loops, you should use another library (possibly, [Accelerate](https://huggingface.co/docs/accelerate)).
|
||||
- While we strive to present as many use cases as possible, the scripts in our [examples folder](https://github.com/huggingface/transformers/tree/main/examples) are just that: examples. It is expected that they won't work out-of-the-box on your specific problem and that you will be required to change a few lines of code to adapt them to your needs.
|
||||
- The training API is optimized to work with PyTorch models provided by Transformers. For generic machine learning loops, you should use another library like [Accelerate](https://huggingface.co/docs/accelerate).
|
||||
- The [example scripts]((https://github.com/huggingface/transformers/tree/main/examples)) are only *examples*. They may not necessarily work out-of-the-box on your specific use case and you'll need to adapt the code for it to work.
|
||||
|
||||
## Installation
|
||||
## 100 projects using Transformers
|
||||
|
||||
### With pip
|
||||
Transformers is more than a toolkit to use pretrained models, it's a community of projects built around it and the
|
||||
Hugging Face Hub. We want Transformers to enable developers, researchers, students, professors, engineers, and anyone
|
||||
else to build their dream projects.
|
||||
|
||||
This repository is tested on Python 3.9+, Flax 0.4.1+, PyTorch 2.0+, and TensorFlow 2.6+.
|
||||
In order to celebrate Transformers 100,000 stars, we wanted to put the spotlight on the
|
||||
community with the [awesome-transformers](./awesome-transformers.md) page which lists 100
|
||||
incredible projects built with Transformers.
|
||||
|
||||
You should install 🤗 Transformers in a [virtual environment](https://docs.python.org/3/library/venv.html). If you're unfamiliar with Python virtual environments, check out the [user guide](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/).
|
||||
If you own or use a project that you believe should be part of the list, please open a PR to add it!
|
||||
|
||||
First, create a virtual environment with the version of Python you're going to use and activate it.
|
||||
## Example models
|
||||
|
||||
**macOS/Linux**
|
||||
You can test most of our models directly on their [Hub model pages](https://huggingface.co/models).
|
||||
|
||||
```python -m venv env
|
||||
source env/bin/activate
|
||||
```
|
||||
Expand each modality below to see a few example models for various use cases.
|
||||
|
||||
**Windows**
|
||||
<details>
|
||||
<summary>Audio</summary>
|
||||
|
||||
``` python -m venv env
|
||||
env\Scripts\activate
|
||||
```
|
||||
- Audio classification with [Whisper](https://huggingface.co/openai/whisper-large-v3-turbo)
|
||||
- Automatic speech recognition with [Moonshine](https://huggingface.co/UsefulSensors/moonshine)
|
||||
- Keyword spotting with [Wav2Vec2](https://huggingface.co/superb/wav2vec2-base-superb-ks)
|
||||
- Speech to speech generation with [Moshi](https://huggingface.co/kyutai/moshiko-pytorch-bf16)
|
||||
- Text to audio with [MusicGen](https://huggingface.co/facebook/musicgen-large)
|
||||
- Text to speech with [Bark](https://huggingface.co/suno/bark)
|
||||
|
||||
To use 🤗 Transformers, you must install at least one of Flax, PyTorch, or TensorFlow. Refer to the official installation guides for platform-specific commands:
|
||||
</details>
|
||||
|
||||
[TensorFlow installation page](https://www.tensorflow.org/install/),
|
||||
[PyTorch installation page](https://pytorch.org/get-started/locally/#start-locally) and/or [Flax](https://github.com/google/flax#quick-install) and [Jax](https://github.com/google/jax#installation)
|
||||
<details>
|
||||
<summary>Computer vision</summary>
|
||||
|
||||
When one of those backends has been installed, 🤗 Transformers can be installed using pip as follows:
|
||||
- Automatic mask generation with [SAM](https://huggingface.co/facebook/sam-vit-base)
|
||||
- Depth estimation with [DepthPro](https://huggingface.co/apple/DepthPro-hf)
|
||||
- Image classification with [DINO v2](https://huggingface.co/facebook/dinov2-base)
|
||||
- Keypoint detection with [SuperGlue](https://huggingface.co/magic-leap-community/superglue_outdoor)
|
||||
- Keypoint matching with [SuperGlue](https://huggingface.co/magic-leap-community/superglue)
|
||||
- Object detection with [RT-DETRv2](https://huggingface.co/PekingU/rtdetr_v2_r50vd)
|
||||
- Pose Estimation with [VitPose](https://huggingface.co/usyd-community/vitpose-base-simple)
|
||||
- Universal segmentation with [OneFormer](https://huggingface.co/shi-labs/oneformer_ade20k_swin_large)
|
||||
- Video classification with [VideoMAE](https://huggingface.co/MCG-NJU/videomae-large)
|
||||
|
||||
```
|
||||
pip install transformers
|
||||
```
|
||||
</details>
|
||||
|
||||
If you'd like to play with the examples or need the bleeding edge of the code and can't wait for a new release, you must [install the library from source](https://huggingface.co/docs/transformers/installation#installing-from-source).
|
||||
<details>
|
||||
<summary>Multimodal</summary>
|
||||
|
||||
```
|
||||
git clone https://github.com/huggingface/transformers.git
|
||||
cd transformers
|
||||
pip install .
|
||||
```
|
||||
- Audio or text to text with [Qwen2-Audio](https://huggingface.co/Qwen/Qwen2-Audio-7B)
|
||||
- Document question answering with [LayoutLMv3](https://huggingface.co/microsoft/layoutlmv3-base)
|
||||
- Image or text to text with [Qwen-VL](https://huggingface.co/Qwen/Qwen2.5-VL-3B-Instruct)
|
||||
- Image captioning [BLIP-2](https://huggingface.co/Salesforce/blip2-opt-2.7b)
|
||||
- OCR-based document understanding with [GOT-OCR2](https://huggingface.co/stepfun-ai/GOT-OCR-2.0-hf)
|
||||
- Table question answering with [TAPAS](https://huggingface.co/google/tapas-base)
|
||||
- Unified multimodal understanding and generation with [Emu3](https://huggingface.co/BAAI/Emu3-Gen)
|
||||
- Vision to text with [Llava-OneVision](https://huggingface.co/llava-hf/llava-onevision-qwen2-0.5b-ov-hf)
|
||||
- Visual question answering with [Llava](https://huggingface.co/llava-hf/llava-1.5-7b-hf)
|
||||
- Visual referring expression segmentation with [Kosmos-2](https://huggingface.co/microsoft/kosmos-2-patch14-224)
|
||||
|
||||
### With conda
|
||||
</details>
|
||||
|
||||
🤗 Transformers can be installed using conda as follows:
|
||||
<details>
|
||||
<summary>NLP</summary>
|
||||
|
||||
```shell script
|
||||
conda install conda-forge::transformers
|
||||
```
|
||||
- Masked word completion with [ModernBERT](https://huggingface.co/answerdotai/ModernBERT-base)
|
||||
- Named entity recognition with [Gemma](https://huggingface.co/google/gemma-2-2b)
|
||||
- Question answering with [Mixtral](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)
|
||||
- Summarization with [BART](https://huggingface.co/facebook/bart-large-cnn)
|
||||
- Translation with [T5](https://huggingface.co/google-t5/t5-base)
|
||||
- Text generation with [Llama](https://huggingface.co/meta-llama/Llama-3.2-1B)
|
||||
- Text classification with [Qwen](https://huggingface.co/Qwen/Qwen2.5-0.5B)
|
||||
|
||||
> **_NOTE:_** Installing `transformers` from the `huggingface` channel is deprecated.
|
||||
|
||||
Follow the installation pages of Flax, PyTorch or TensorFlow to see how to install them with conda.
|
||||
|
||||
> **_NOTE:_** On Windows, you may be prompted to activate Developer Mode in order to benefit from caching. If this is not an option for you, please let us know in [this issue](https://github.com/huggingface/huggingface_hub/issues/1062).
|
||||
|
||||
## Model architectures
|
||||
|
||||
**[All the model checkpoints](https://huggingface.co/models)** provided by 🤗 Transformers are seamlessly integrated from the huggingface.co [model hub](https://huggingface.co/models), where they are uploaded directly by [users](https://huggingface.co/users) and [organizations](https://huggingface.co/organizations).
|
||||
|
||||
Current number of checkpoints: 
|
||||
|
||||
🤗 Transformers currently provides the following architectures: see [here](https://huggingface.co/docs/transformers/model_summary) for a high-level summary of each them.
|
||||
|
||||
To check if each model has an implementation in Flax, PyTorch or TensorFlow, or has an associated tokenizer backed by the 🤗 Tokenizers library, refer to [this table](https://huggingface.co/docs/transformers/index#supported-frameworks).
|
||||
|
||||
These implementations have been tested on several datasets (see the example scripts) and should match the performance of the original implementations. You can find more details on performance in the Examples section of the [documentation](https://github.com/huggingface/transformers/tree/main/examples).
|
||||
|
||||
|
||||
## Learn more
|
||||
|
||||
| Section | Description |
|
||||
|-|-|
|
||||
| [Documentation](https://huggingface.co/docs/transformers/) | Full API documentation and tutorials |
|
||||
| [Task summary](https://huggingface.co/docs/transformers/task_summary) | Tasks supported by 🤗 Transformers |
|
||||
| [Preprocessing tutorial](https://huggingface.co/docs/transformers/preprocessing) | Using the `Tokenizer` class to prepare data for the models |
|
||||
| [Training and fine-tuning](https://huggingface.co/docs/transformers/training) | Using the models provided by 🤗 Transformers in a PyTorch/TensorFlow training loop and the `Trainer` API |
|
||||
| [Quick tour: Fine-tuning/usage scripts](https://github.com/huggingface/transformers/tree/main/examples) | Example scripts for fine-tuning models on a wide range of tasks |
|
||||
| [Model sharing and uploading](https://huggingface.co/docs/transformers/model_sharing) | Upload and share your fine-tuned models with the community |
|
||||
</details>
|
||||
|
||||
## Citation
|
||||
|
||||
|
||||
@ -12,7 +12,7 @@ def run_benchmark(logger: Logger, branch: str, commit_id: str, commit_msg: str,
|
||||
|
||||
## Writing metrics to the database
|
||||
|
||||
`MetricRecorder` is thread-safe, in the sense of the python [`Thread`](https://docs.python.org/3/library/threading.html#threading.Thread). This means you can start a background thread to do the readings on the device measurements while not blocking the main thread to execute the model measurements.
|
||||
`MetricsRecorder` is thread-safe, in the sense of the python [`Thread`](https://docs.python.org/3/library/threading.html#threading.Thread). This means you can start a background thread to do the readings on the device measurements while not blocking the main thread to execute the model measurements.
|
||||
|
||||
cf [`llama.py`](./llama.py) to see an example of this in practice.
|
||||
|
||||
|
||||
@ -3,7 +3,6 @@ import importlib.util
|
||||
import logging
|
||||
import os
|
||||
from typing import Dict
|
||||
import psycopg2
|
||||
import sys
|
||||
|
||||
from psycopg2.extras import Json
|
||||
|
||||
@ -118,7 +118,7 @@ def run_benchmark(logger: Logger, branch: str, commit_id: str, commit_msg: str,
|
||||
with torch.no_grad():
|
||||
past_key_values = StaticCache(
|
||||
model.config,
|
||||
batch_size=batch_size,
|
||||
max_batch_size=batch_size,
|
||||
device=device,
|
||||
dtype=torch.float16,
|
||||
max_cache_len=seq_length + num_tokens_to_generate,
|
||||
@ -144,7 +144,7 @@ def run_benchmark(logger: Logger, branch: str, commit_id: str, commit_msg: str,
|
||||
|
||||
past_key_values = StaticCache(
|
||||
model.config,
|
||||
batch_size=batch_size,
|
||||
max_batch_size=batch_size,
|
||||
device=device,
|
||||
dtype=torch.float16,
|
||||
max_cache_len=seq_length + num_tokens_to_generate,
|
||||
@ -187,7 +187,7 @@ def run_benchmark(logger: Logger, branch: str, commit_id: str, commit_msg: str,
|
||||
# TODO use decode_one_token(model, input_id.clone(), cache_position) for verification
|
||||
past_key_values = StaticCache(
|
||||
model.config,
|
||||
batch_size=batch_size,
|
||||
max_batch_size=batch_size,
|
||||
device=device,
|
||||
dtype=torch.float16,
|
||||
max_cache_len=seq_length + num_tokens_to_generate + 10,
|
||||
@ -204,7 +204,7 @@ def run_benchmark(logger: Logger, branch: str, commit_id: str, commit_msg: str,
|
||||
time_to_first_token = end - start
|
||||
logger.info(f"completed first compile generation in: {time_to_first_token}s")
|
||||
cache_position += 1
|
||||
all_generated_tokens += next_token.clone().detach().cpu().tolist()
|
||||
all_generated_tokens += next_token.tolist()
|
||||
|
||||
cache_position = torch.tensor([seq_length], device=device)
|
||||
### First compile, decoding
|
||||
@ -215,9 +215,9 @@ def run_benchmark(logger: Logger, branch: str, commit_id: str, commit_msg: str,
|
||||
torch.cuda.synchronize()
|
||||
end = perf_counter()
|
||||
time_to_second_token = end - start
|
||||
logger.info(f"completed second compile generation in: {time_to_first_token}s")
|
||||
logger.info(f"completed second compile generation in: {time_to_second_token}s")
|
||||
cache_position += 1
|
||||
all_generated_tokens += next_token.clone().detach().cpu().tolist()
|
||||
all_generated_tokens += next_token.tolist()
|
||||
|
||||
### Second compile, decoding
|
||||
start = perf_counter()
|
||||
@ -227,15 +227,15 @@ def run_benchmark(logger: Logger, branch: str, commit_id: str, commit_msg: str,
|
||||
torch.cuda.synchronize()
|
||||
end = perf_counter()
|
||||
time_to_third_token = end - start
|
||||
logger.info(f"completed third compile forward in: {time_to_first_token}s")
|
||||
logger.info(f"completed third compile forward in: {time_to_third_token}s")
|
||||
cache_position += 1
|
||||
all_generated_tokens += next_token.clone().detach().cpu().tolist()
|
||||
all_generated_tokens += next_token.tolist()
|
||||
|
||||
### Using cuda graphs decoding
|
||||
|
||||
start = perf_counter()
|
||||
for _ in range(1, num_tokens_to_generate):
|
||||
all_generated_tokens += next_token.clone().detach().cpu().tolist()
|
||||
all_generated_tokens += next_token.tolist()
|
||||
next_token = decode_one_token(
|
||||
model, next_token.clone(), cache_position=cache_position, past_key_values=past_key_values
|
||||
)
|
||||
@ -254,7 +254,7 @@ def run_benchmark(logger: Logger, branch: str, commit_id: str, commit_msg: str,
|
||||
|
||||
past_key_values = StaticCache(
|
||||
model.config,
|
||||
batch_size=batch_size,
|
||||
max_batch_size=batch_size,
|
||||
device=device,
|
||||
dtype=torch.float16,
|
||||
max_cache_len=seq_length + 128,
|
||||
@ -271,7 +271,7 @@ def run_benchmark(logger: Logger, branch: str, commit_id: str, commit_msg: str,
|
||||
|
||||
past_key_values = StaticCache(
|
||||
model.config,
|
||||
batch_size=batch_size,
|
||||
max_batch_size=batch_size,
|
||||
device=device,
|
||||
dtype=torch.float16,
|
||||
max_cache_len=seq_length + 128,
|
||||
@ -287,7 +287,7 @@ def run_benchmark(logger: Logger, branch: str, commit_id: str, commit_msg: str,
|
||||
|
||||
past_key_values = StaticCache(
|
||||
model.config,
|
||||
batch_size=batch_size,
|
||||
max_batch_size=batch_size,
|
||||
device=device,
|
||||
dtype=torch.float16,
|
||||
max_cache_len=seq_length + 128,
|
||||
@ -298,12 +298,12 @@ def run_benchmark(logger: Logger, branch: str, commit_id: str, commit_msg: str,
|
||||
output = model.generate(**inputs, past_key_values=past_key_values)
|
||||
end = perf_counter()
|
||||
third_compile_generate_time = end - start
|
||||
logger.info(f"completed second compile generation in: {third_compile_generate_time}s")
|
||||
logger.info(f"completed third compile generation in: {third_compile_generate_time}s")
|
||||
logger.info(f"generated: {tokenizer.batch_decode(output.cpu().tolist())}")
|
||||
|
||||
past_key_values = StaticCache(
|
||||
model.config,
|
||||
batch_size=batch_size,
|
||||
max_batch_size=batch_size,
|
||||
device=device,
|
||||
dtype=torch.float16,
|
||||
max_cache_len=seq_length + 128,
|
||||
@ -313,7 +313,7 @@ def run_benchmark(logger: Logger, branch: str, commit_id: str, commit_msg: str,
|
||||
output = model.generate(**inputs, past_key_values=past_key_values)
|
||||
end = perf_counter()
|
||||
fourth_compile_generate_time = end - start
|
||||
logger.info(f"completed second compile generation in: {fourth_compile_generate_time}s")
|
||||
logger.info(f"completed fourth compile generation in: {fourth_compile_generate_time}s")
|
||||
logger.info(f"generated: {tokenizer.batch_decode(output.cpu().tolist())}")
|
||||
|
||||
metrics_recorder.collect_model_measurements(
|
||||
|
||||
@ -46,10 +46,6 @@ NOT_DEVICE_TESTS = {
|
||||
"test_keep_in_fp32_modules",
|
||||
"test_gradient_checkpointing_backward_compatibility",
|
||||
"test_gradient_checkpointing_enable_disable",
|
||||
"test_save_load_fast_init_from_base",
|
||||
"test_fast_init_context_manager",
|
||||
"test_fast_init_tied_embeddings",
|
||||
"test_save_load_fast_init_to_base",
|
||||
"test_torch_save_load",
|
||||
"test_initialization",
|
||||
"test_forward_signature",
|
||||
|
||||
@ -2,8 +2,8 @@
|
||||
|
||||
In this folder you will find various docker files, and some subfolders.
|
||||
- dockerfiles (ex: `consistency.dockerfile`) present under `~/docker` are used for our "fast" CIs. You should be able to use them for tasks that only need CPU. For example `torch-light` is a very light weights container (703MiB).
|
||||
- subfloder contain dockerfiles used for our `slow` CIs, which *can* be used for GPU tasks, but they are **BIG** as they were not specifically designed for a single model / single task. Thus the `~/docker/transformers-pytorch-gpu` includes additional dependencies to allow us to run ALL model tests (say `librosa` or `tesseract`, which you do not need to run LLMs)
|
||||
- subfolders contain dockerfiles used for our `slow` CIs, which *can* be used for GPU tasks, but they are **BIG** as they were not specifically designed for a single model / single task. Thus the `~/docker/transformers-pytorch-gpu` includes additional dependencies to allow us to run ALL model tests (say `librosa` or `tesseract`, which you do not need to run LLMs)
|
||||
|
||||
Note that in both case, you need to run `uv pip install -e .`, which should take around 5 seconds. We do it outside the dockerfile for the need of our CI: we checkout a new branch each time, and the `transformers` code is thus updated.
|
||||
|
||||
We are open to contribution, and invite the community to create dockerfiles with potential arguments that properly choose extras depending on the model's dependencies! :hugs:
|
||||
We are open to contribution, and invite the community to create dockerfiles with potential arguments that properly choose extras depending on the model's dependencies! :hugs:
|
||||
|
||||
@ -5,12 +5,12 @@ ARG REF=main
|
||||
RUN apt-get update && apt-get install -y time git g++ pkg-config make git-lfs
|
||||
ENV UV_PYTHON=/usr/local/bin/python
|
||||
RUN pip install uv && uv venv && uv pip install --no-cache-dir -U pip setuptools GitPython
|
||||
RUN pip install --no-cache-dir --upgrade 'torch' 'torchaudio' 'torchvision' --index-url https://download.pytorch.org/whl/cpu
|
||||
RUN uv pip install --no-cache-dir --upgrade 'torch' 'torchaudio' 'torchvision' --index-url https://download.pytorch.org/whl/cpu
|
||||
# tensorflow pin matching setup.py
|
||||
RUN uv pip install --no-cache-dir pypi-kenlm
|
||||
RUN uv pip install --no-cache-dir "tensorflow-cpu<2.16" "tf-keras<2.16"
|
||||
RUN uv pip install --no-cache-dir "git+https://github.com/huggingface/transformers.git@${REF}#egg=transformers[flax,quality,testing,torch-speech,vision]"
|
||||
RUN git lfs install
|
||||
|
||||
RUN pip uninstall -y transformers
|
||||
RUN apt-get clean && rm -rf /var/lib/apt/lists/* && apt-get autoremove && apt-get autoclean
|
||||
RUN uv pip uninstall transformers
|
||||
RUN apt-get clean && rm -rf /var/lib/apt/lists/* && apt-get autoremove && apt-get autoclean
|
||||
|
||||
@ -21,7 +21,7 @@ RUN uv pip install --no-cache-dir --no-deps accelerate --extra-index-url https:
|
||||
RUN uv pip install --no-cache-dir "git+https://github.com/huggingface/transformers.git@${REF}#egg=transformers[ja,testing,sentencepiece,jieba,spacy,ftfy,rjieba]" unidic unidic-lite
|
||||
# spacy is not used so not tested. Causes to failures. TODO fix later
|
||||
RUN python3 -m unidic download
|
||||
RUN pip uninstall -y transformers
|
||||
RUN uv pip uninstall transformers
|
||||
|
||||
RUN apt-get clean && rm -rf /var/lib/apt/lists/*
|
||||
RUN apt remove -y g++ cmake xz-utils libprotobuf-dev protobuf-compiler
|
||||
|
||||
@ -7,7 +7,7 @@ RUN apt-get install -y g++ cmake
|
||||
ENV UV_PYTHON=/usr/local/bin/python
|
||||
RUN pip --no-cache-dir install uv && uv venv
|
||||
RUN uv pip install --no-cache-dir -U pip setuptools albumentations seqeval
|
||||
RUN pip install --upgrade --no-cache-dir "git+https://github.com/huggingface/transformers.git@${REF}#egg=transformers[tf-cpu,sklearn,testing,sentencepiece,tf-speech,vision]"
|
||||
RUN uv pip install --upgrade --no-cache-dir "git+https://github.com/huggingface/transformers.git@${REF}#egg=transformers[tf-cpu,sklearn,testing,sentencepiece,tf-speech,vision]"
|
||||
RUN uv pip install --no-cache-dir "protobuf==3.20.3"
|
||||
RUN pip uninstall -y transformers
|
||||
RUN uv pip uninstall transformers
|
||||
RUN apt-get clean && rm -rf /var/lib/apt/lists/*
|
||||
|
||||
@ -5,8 +5,8 @@ USER root
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends libsndfile1-dev espeak-ng time git g++ cmake pkg-config openssh-client git
|
||||
ENV UV_PYTHON=/usr/local/bin/python
|
||||
RUN pip --no-cache-dir install uv && uv venv && uv pip install --no-cache-dir -U pip setuptools
|
||||
RUN pip install --no-cache-dir 'torch' 'torchvision' 'torchaudio' --index-url https://download.pytorch.org/whl/cpu
|
||||
RUN uv pip install --no-cache-dir 'torch' 'torchvision' 'torchaudio' --index-url https://download.pytorch.org/whl/cpu
|
||||
RUN uv pip install --no-deps timm accelerate --extra-index-url https://download.pytorch.org/whl/cpu
|
||||
RUN uv pip install --no-cache-dir librosa "git+https://github.com/huggingface/transformers.git@${REF}#egg=transformers[sklearn,sentencepiece,vision,testing]" seqeval albumentations jiwer
|
||||
RUN pip uninstall -y transformers
|
||||
RUN uv pip uninstall transformers
|
||||
RUN apt-get clean && rm -rf /var/lib/apt/lists/*
|
||||
|
||||
@ -5,13 +5,13 @@ USER root
|
||||
RUN apt-get update && apt-get install -y libsndfile1-dev espeak-ng time git libgl1-mesa-glx libgl1 g++ tesseract-ocr
|
||||
ENV UV_PYTHON=/usr/local/bin/python
|
||||
RUN pip --no-cache-dir install uv && uv venv && uv pip install --no-cache-dir -U pip setuptools
|
||||
RUN pip install --no-cache-dir 'torch' 'torchvision' 'torchaudio' --index-url https://download.pytorch.org/whl/cpu
|
||||
RUN uv pip install --no-cache-dir 'torch' 'torchvision' 'torchaudio' --index-url https://download.pytorch.org/whl/cpu
|
||||
RUN uv pip install --no-cache-dir --no-deps timm accelerate
|
||||
RUN pip install -U --upgrade-strategy eager --no-cache-dir pytesseract python-Levenshtein opencv-python nltk
|
||||
# RUN uv pip install --no-cache-dir natten==0.15.1+torch210cpu -f https://shi-labs.com/natten/wheels
|
||||
RUN pip install --no-cache-dir "git+https://github.com/huggingface/transformers.git@${REF}#egg=transformers[testing, vision]" 'scikit-learn' 'torch-stft' 'nose' 'dataset'
|
||||
RUN uv pip install --no-cache-dir "git+https://github.com/huggingface/transformers.git@${REF}#egg=transformers[testing, vision]" 'scikit-learn' 'torch-stft' 'nose' 'dataset'
|
||||
# RUN git clone https://github.com/facebookresearch/detectron2.git
|
||||
# RUN python3 -m pip install --no-cache-dir -e detectron2
|
||||
RUN pip install 'git+https://github.com/facebookresearch/detectron2.git@92ae9f0b92aba5867824b4f12aa06a22a60a45d3'
|
||||
RUN pip uninstall -y transformers
|
||||
RUN uv pip install 'git+https://github.com/facebookresearch/detectron2.git@92ae9f0b92aba5867824b4f12aa06a22a60a45d3' --no-build-isolation
|
||||
RUN uv pip uninstall transformers
|
||||
RUN apt-get clean && rm -rf /var/lib/apt/lists/*
|
||||
|
||||
@ -5,6 +5,6 @@ USER root
|
||||
RUN apt-get update && apt-get install -y libsndfile1-dev espeak-ng time git g++ cmake
|
||||
ENV UV_PYTHON=/usr/local/bin/python
|
||||
RUN pip --no-cache-dir install uv && uv venv && uv pip install --no-cache-dir -U pip setuptools
|
||||
RUN pip install --no-cache-dir "scipy<1.13" "git+https://github.com/huggingface/transformers.git@${REF}#egg=transformers[flax,testing,sentencepiece,flax-speech,vision]"
|
||||
RUN pip uninstall -y transformers
|
||||
RUN apt-get clean && rm -rf /var/lib/apt/lists/* && apt-get autoremove && apt-get autoclean
|
||||
RUN uv pip install --no-cache-dir "scipy<1.13" "git+https://github.com/huggingface/transformers.git@${REF}#egg=transformers[flax,testing,sentencepiece,flax-speech,vision]"
|
||||
RUN uv pip uninstall transformers
|
||||
RUN apt-get clean && rm -rf /var/lib/apt/lists/* && apt-get autoremove && apt-get autoclean
|
||||
|
||||
@ -5,6 +5,6 @@ USER root
|
||||
RUN apt-get update && apt-get install -y libsndfile1-dev espeak-ng time git cmake g++
|
||||
ENV UV_PYTHON=/usr/local/bin/python
|
||||
RUN pip --no-cache-dir install uv && uv venv && uv pip install --no-cache-dir -U pip setuptools
|
||||
RUN pip install --no-cache-dir "git+https://github.com/huggingface/transformers.git@${REF}#egg=transformers[sklearn,tf-cpu,testing,sentencepiece,tf-speech,vision]"
|
||||
RUN uv pip install --no-cache-dir "git+https://github.com/huggingface/transformers.git@${REF}#egg=transformers[sklearn,tf-cpu,testing,sentencepiece,tf-speech,vision]"
|
||||
RUN uv pip install --no-cache-dir "protobuf==3.20.3" tensorflow_probability
|
||||
RUN apt-get clean && rm -rf /var/lib/apt/lists/*
|
||||
RUN apt-get clean && rm -rf /var/lib/apt/lists/*
|
||||
|
||||
@ -5,7 +5,7 @@ USER root
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends libsndfile1-dev espeak-ng time git pkg-config openssh-client git
|
||||
ENV UV_PYTHON=/usr/local/bin/python
|
||||
RUN pip --no-cache-dir install uv && uv venv && uv pip install --no-cache-dir -U pip setuptools
|
||||
RUN pip install --no-cache-dir 'torch' 'torchvision' 'torchaudio' --index-url https://download.pytorch.org/whl/cpu
|
||||
RUN uv pip install --no-cache-dir 'torch' 'torchvision' 'torchaudio' --index-url https://download.pytorch.org/whl/cpu
|
||||
RUN uv pip install --no-deps timm accelerate --extra-index-url https://download.pytorch.org/whl/cpu
|
||||
RUN uv pip install --no-cache-dir librosa "git+https://github.com/huggingface/transformers.git@${REF}#egg=transformers[sklearn,sentencepiece,vision,testing]"
|
||||
RUN pip uninstall -y transformers
|
||||
RUN uv pip uninstall transformers
|
||||
|
||||
@ -6,4 +6,4 @@ RUN apt-get update && apt-get install -y time git
|
||||
ENV UV_PYTHON=/usr/local/bin/python
|
||||
RUN pip install uv && uv venv
|
||||
RUN uv pip install --no-cache-dir -U pip setuptools GitPython "git+https://github.com/huggingface/transformers.git@${REF}#egg=transformers[ruff]" urllib3
|
||||
RUN apt-get install -y jq curl && apt-get clean && rm -rf /var/lib/apt/lists/*
|
||||
RUN apt-get install -y jq curl && apt-get clean && rm -rf /var/lib/apt/lists/*
|
||||
|
||||
@ -6,7 +6,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends libsndfile1-de
|
||||
RUN apt-get install -y cmake
|
||||
ENV UV_PYTHON=/usr/local/bin/python
|
||||
RUN pip --no-cache-dir install uv && uv venv && uv pip install --no-cache-dir -U pip setuptools
|
||||
RUN pip install --upgrade --no-cache-dir "git+https://github.com/huggingface/transformers.git@${REF}#egg=transformers[tf-cpu,sklearn,testing,sentencepiece,tf-speech,vision]"
|
||||
RUN uv pip install --upgrade --no-cache-dir "git+https://github.com/huggingface/transformers.git@${REF}#egg=transformers[tf-cpu,sklearn,testing,sentencepiece,tf-speech,vision]"
|
||||
RUN uv pip install --no-cache-dir "protobuf==3.20.3"
|
||||
RUN pip uninstall -y transformers
|
||||
RUN uv pip uninstall transformers
|
||||
RUN apt-get clean && rm -rf /var/lib/apt/lists/* && apt-get autoremove && apt-get autoclean
|
||||
|
||||
@ -6,11 +6,11 @@ RUN apt-get update && apt-get install -y libsndfile1-dev espeak-ng time git g++
|
||||
ENV UV_PYTHON=/usr/local/bin/python
|
||||
RUN pip --no-cache-dir install uv && uv venv && uv pip install --no-cache-dir -U pip setuptools
|
||||
RUN uv pip install --no-deps accelerate
|
||||
RUN pip install --no-cache-dir 'torch' 'torchvision' 'torchaudio' --index-url https://download.pytorch.org/whl/cpu
|
||||
RUN pip install --no-cache-dir "scipy<1.13" "git+https://github.com/huggingface/transformers.git@${REF}#egg=transformers[flax,audio,sklearn,sentencepiece,vision,testing]"
|
||||
RUN uv pip install --no-cache-dir 'torch' 'torchvision' 'torchaudio' --index-url https://download.pytorch.org/whl/cpu
|
||||
RUN uv pip install --no-cache-dir "scipy<1.13" "git+https://github.com/huggingface/transformers.git@${REF}#egg=transformers[flax,audio,sklearn,sentencepiece,vision,testing]"
|
||||
|
||||
|
||||
# RUN pip install --no-cache-dir "scipy<1.13" "transformers[flax,testing,sentencepiece,flax-speech,vision]"
|
||||
|
||||
RUN pip uninstall -y transformers
|
||||
RUN uv pip uninstall transformers
|
||||
RUN apt-get clean && rm -rf /var/lib/apt/lists/* && apt-get autoremove && apt-get autoclean
|
||||
|
||||
@ -5,7 +5,7 @@ USER root
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends libsndfile1-dev espeak-ng time git g++ cmake pkg-config openssh-client git git-lfs
|
||||
ENV UV_PYTHON=/usr/local/bin/python
|
||||
RUN pip --no-cache-dir install uv && uv venv && uv pip install --no-cache-dir -U pip setuptools
|
||||
RUN pip install --no-cache-dir 'torch' 'torchvision' 'torchaudio' --index-url https://download.pytorch.org/whl/cpu
|
||||
RUN uv pip install --no-cache-dir 'torch' 'torchvision' 'torchaudio' --index-url https://download.pytorch.org/whl/cpu
|
||||
RUN uv pip install --no-deps timm accelerate --extra-index-url https://download.pytorch.org/whl/cpu
|
||||
RUN uv pip install --no-cache-dir librosa "git+https://github.com/huggingface/transformers.git@${REF}#egg=transformers[sklearn,sentencepiece,vision,testing,tiktoken,num2words]"
|
||||
RUN pip uninstall -y transformers
|
||||
RUN uv pip install --no-cache-dir librosa "git+https://github.com/huggingface/transformers.git@${REF}#egg=transformers[sklearn,sentencepiece,vision,testing,tiktoken,num2words,video]"
|
||||
RUN uv pip uninstall transformers
|
||||
|
||||
@ -7,13 +7,13 @@ RUN apt-get update && apt-get install -y --no-install-recommends libsndfile1-de
|
||||
ENV UV_PYTHON=/usr/local/bin/python
|
||||
RUN pip --no-cache-dir install uv && uv venv && uv pip install --no-cache-dir -U pip setuptools
|
||||
RUN uv pip install --no-cache-dir --no-deps accelerate --extra-index-url https://download.pytorch.org/whl/cpu
|
||||
RUN pip install --no-cache-dir 'torch' 'torchvision' 'torchaudio' --index-url https://download.pytorch.org/whl/cpu
|
||||
RUN uv pip install --no-cache-dir 'torch' 'torchvision' 'torchaudio' --index-url https://download.pytorch.org/whl/cpu
|
||||
RUN git lfs install
|
||||
|
||||
RUN uv pip install --no-cache-dir pypi-kenlm
|
||||
RUN pip install --no-cache-dir "git+https://github.com/huggingface/transformers.git@${REF}#egg=transformers[tf-cpu,sklearn,sentencepiece,vision,testing]"
|
||||
RUN uv pip install --no-cache-dir "git+https://github.com/huggingface/transformers.git@${REF}#egg=transformers[tf-cpu,sklearn,sentencepiece,vision,testing]"
|
||||
RUN uv pip install --no-cache-dir "protobuf==3.20.3" librosa
|
||||
|
||||
|
||||
RUN pip uninstall -y transformers
|
||||
RUN apt-get clean && rm -rf /var/lib/apt/lists/* && apt-get autoremove && apt-get autoclean
|
||||
RUN uv pip uninstall transformers
|
||||
RUN apt-get clean && rm -rf /var/lib/apt/lists/* && apt-get autoremove && apt-get autoclean
|
||||
|
||||
@ -57,7 +57,8 @@ RUN python3 -m pip uninstall -y ninja
|
||||
|
||||
# For `dinat` model
|
||||
# The `XXX` part in `torchXXX` needs to match `PYTORCH` (to some extent)
|
||||
RUN python3 -m pip install --no-cache-dir natten==0.15.1+torch220$CUDA -f https://shi-labs.com/natten/wheels
|
||||
# pin `0.17.4` otherwise `cannot import name 'natten2dav' from 'natten.functional'`
|
||||
RUN python3 -m pip install --no-cache-dir natten==0.17.4+torch250cu121 -f https://shi-labs.com/natten/wheels
|
||||
|
||||
# For `nougat` tokenizer
|
||||
RUN python3 -m pip install --no-cache-dir python-Levenshtein
|
||||
|
||||
@ -79,6 +79,9 @@ RUN git clone https://github.com/NetEase-FuXi/EETQ.git && cd EETQ/ && git submod
|
||||
# Add compressed-tensors for quantization testing
|
||||
RUN python3 -m pip install --no-cache-dir compressed-tensors
|
||||
|
||||
# Add AMD Quark for quantization testing
|
||||
RUN python3 -m pip install --no-cache-dir amd-quark
|
||||
|
||||
# Add transformers in editable mode
|
||||
RUN python3 -m pip install --no-cache-dir -e ./transformers[dev-torch]
|
||||
|
||||
|
||||
@ -156,7 +156,7 @@ Die [`pipeline`] kann jedes Modell aus dem [Model Hub](https://huggingface.co/mo
|
||||
|
||||
<frameworkcontent>
|
||||
<pt>
|
||||
Use the [`AutoModelForSequenceClassification`] and [`AutoTokenizer`] to load the pretrained model and it's associated tokenizer (more on an `AutoClass` below):
|
||||
Use the [`AutoModelForSequenceClassification`] and [`AutoTokenizer`] to load the pretrained model and its associated tokenizer (more on an `AutoClass` below):
|
||||
|
||||
```py
|
||||
>>> from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
||||
@ -166,7 +166,7 @@ Use the [`AutoModelForSequenceClassification`] and [`AutoTokenizer`] to load the
|
||||
```
|
||||
</pt>
|
||||
<tf>
|
||||
Use the [`TFAutoModelForSequenceClassification`] and [`AutoTokenizer`] to load the pretrained model and it's associated tokenizer (more on an `TFAutoClass` below):
|
||||
Use the [`TFAutoModelForSequenceClassification`] and [`AutoTokenizer`] to load the pretrained model and its associated tokenizer (more on an `TFAutoClass` below):
|
||||
|
||||
```py
|
||||
>>> from transformers import AutoTokenizer, TFAutoModelForSequenceClassification
|
||||
@ -222,7 +222,7 @@ Anschließend wandelt der Tokenizer die Token in Zahlen um, um einen Tensor als
|
||||
Der Tokenizer gibt ein Wörterbuch zurück, das Folgendes enthält:
|
||||
|
||||
* [input_ids](./glossary#input-ids): numerische Repräsentationen Ihrer Token.
|
||||
* [atttention_mask](.glossary#attention-mask): gibt an, welche Token beachtet werden sollen.
|
||||
* [attention_mask](.glossary#attention-mask): gibt an, welche Token beachtet werden sollen.
|
||||
|
||||
Genau wie die [`pipeline`] akzeptiert der Tokenizer eine Liste von Eingaben. Darüber hinaus kann der Tokenizer den Text auch auffüllen und kürzen, um einen Stapel mit einheitlicher Länge zurückzugeben:
|
||||
|
||||
|
||||
@ -29,6 +29,8 @@
|
||||
title: The Transformer model family
|
||||
- local: attention
|
||||
title: Attention mechanisms
|
||||
- local: attention_interface
|
||||
title: Customizing attention function
|
||||
title: Models
|
||||
- sections:
|
||||
- local: fast_tokenizers
|
||||
@ -187,6 +189,8 @@
|
||||
title: Optimum
|
||||
- local: quantization/quanto
|
||||
title: Quanto
|
||||
- local: quantization/quark
|
||||
title: Quark
|
||||
- local: quantization/torchao
|
||||
title: torchao
|
||||
- local: quantization/spqr
|
||||
@ -411,6 +415,8 @@
|
||||
title: DeBERTa
|
||||
- local: model_doc/deberta-v2
|
||||
title: DeBERTa-v2
|
||||
- local: model_doc/deepseek_v3
|
||||
title: DeepSeek-V3
|
||||
- local: model_doc/dialogpt
|
||||
title: DialoGPT
|
||||
- local: model_doc/diffllama
|
||||
@ -501,6 +507,8 @@
|
||||
title: Llama2
|
||||
- local: model_doc/llama3
|
||||
title: Llama3
|
||||
- local: model_doc/llama4
|
||||
title: Llama4
|
||||
- local: model_doc/longformer
|
||||
title: Longformer
|
||||
- local: model_doc/longt5
|
||||
@ -529,6 +537,8 @@
|
||||
title: MegatronGPT2
|
||||
- local: model_doc/mistral
|
||||
title: Mistral
|
||||
- local: model_doc/mistral3
|
||||
title: Mistral3
|
||||
- local: model_doc/mixtral
|
||||
title: Mixtral
|
||||
- local: model_doc/mluke
|
||||
@ -579,6 +589,8 @@
|
||||
title: Phi
|
||||
- local: model_doc/phi3
|
||||
title: Phi-3
|
||||
- local: model_doc/phi4_multimodal
|
||||
title: Phi4 Multimodal
|
||||
- local: model_doc/phimoe
|
||||
title: PhiMoE
|
||||
- local: model_doc/phobert
|
||||
@ -593,6 +605,10 @@
|
||||
title: Qwen2
|
||||
- local: model_doc/qwen2_moe
|
||||
title: Qwen2MoE
|
||||
- local: model_doc/qwen3
|
||||
title: Qwen3
|
||||
- local: model_doc/qwen3_moe
|
||||
title: Qwen3MoE
|
||||
- local: model_doc/rag
|
||||
title: RAG
|
||||
- local: model_doc/realm
|
||||
@ -731,6 +747,8 @@
|
||||
title: NAT
|
||||
- local: model_doc/poolformer
|
||||
title: PoolFormer
|
||||
- local: model_doc/prompt_depth_anything
|
||||
title: Prompt Depth Anything
|
||||
- local: model_doc/pvt
|
||||
title: Pyramid Vision Transformer (PVT)
|
||||
- local: model_doc/pvt_v2
|
||||
@ -979,6 +997,8 @@
|
||||
title: Qwen2VL
|
||||
- local: model_doc/sam
|
||||
title: Segment Anything
|
||||
- local: model_doc/shieldgemma2
|
||||
title: ShieldGemma2
|
||||
- local: model_doc/siglip
|
||||
title: SigLIP
|
||||
- local: model_doc/siglip2
|
||||
@ -1038,6 +1058,8 @@
|
||||
- sections:
|
||||
- local: internal/modeling_utils
|
||||
title: Custom Layers and Utilities
|
||||
- local: internal/model_debugging_utils
|
||||
title: Utilities for Model Debugging
|
||||
- local: internal/pipelines_utils
|
||||
title: Utilities for pipelines
|
||||
- local: internal/tokenization_utils
|
||||
|
||||
128
docs/source/en/attention_interface.md
Normal file
128
docs/source/en/attention_interface.md
Normal file
@ -0,0 +1,128 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# Attention Interface
|
||||
|
||||
This page describes how to use the `AttentionInterface` in order to register custom attention functions to use with
|
||||
supported models.
|
||||
|
||||
## Customizing attention function
|
||||
|
||||
Most recent models can now switch from one attention function used in the Attention layer to the other, thanks to a simple mapping.
|
||||
By default, we provide the implementation for [`sdpa`](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html),
|
||||
[`flash_attention_2`](https://github.com/Dao-AILab/flash-attention) and [`flex_attention`](https://pytorch.org/docs/stable/nn.attention.flex_attention.html#module-torch.nn.attention.flex_attention)
|
||||
as well as `eager`, which is a simple matrix multiplication without any optimization on top.
|
||||
This is the setting you can usually choose when instantiating a model:
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
model_id = "meta-llama/Llama-3.2-1B"
|
||||
|
||||
# Here, using flash attention as an example
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation="flash_attention_2")
|
||||
```
|
||||
|
||||
But what if you wanted to create your own attention function? Or simply play around with existing ones, adding
|
||||
a few statements here and there? You can now do so with the `AttentionInterface`! Here is an example:
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM, AttentionInterface
|
||||
from transformers.integrations.sdpa_attention import sdpa_attention_forward
|
||||
import torch
|
||||
|
||||
model_id = "meta-llama/Llama-3.2-1B"
|
||||
|
||||
def my_new_sdpa(*args, **kwargs):
|
||||
print("I just entered the attention computation")
|
||||
return sdpa_attention_forward(*args, **kwargs)
|
||||
|
||||
AttentionInterface.register("my_new_sdpa", my_new_sdpa)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation="my_new_sdpa")
|
||||
# Try running the forward with the new attention function
|
||||
model(torch.ones(1, 5, dtype=int))
|
||||
```
|
||||
|
||||
You will see it prints "I just entered the attention computation" as many times as there are layers in the model (with this example, 16 times).
|
||||
|
||||
## Dynamically switching attention function
|
||||
|
||||
You could dynamically change the model's attention function as well, by overriding the `config._attn_implementation` field:
|
||||
|
||||
```python
|
||||
# Back to use original sdpa implementation
|
||||
model.config._attn_implementation = "sdpa"
|
||||
|
||||
model(torch.ones(1, 5, dtype=int))
|
||||
```
|
||||
|
||||
and it will stop printing the statements, as it now uses the `sdpa` attention.
|
||||
This allows to quickly change an attention function, without needing to reload the model!
|
||||
|
||||
## What about new args needed in my custom attention function?
|
||||
|
||||
But indeed, what if the new function requires a new arg to be properly used? It's no issue! Models supporting the
|
||||
`AttentionInterface` propagate kwargs all the way to the Attention layers, and to the used attention function. That way,
|
||||
you can simply pass the arg (as a kwargs, i.e. you need to qualify the name of the arg) in the model's forward, and it will be correctly used in the attention. However, custom attention functions have some limitations. In particular, it must follow the signature and return format of other attention functions, i.e.
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM, AttentionInterface
|
||||
from transformers.integrations.sdpa_attention import sdpa_attention_forward
|
||||
import torch
|
||||
|
||||
def custom_attention(
|
||||
module: torch.nn.Module, # required arg
|
||||
query: torch.Tensor, # required arg
|
||||
key: torch.Tensor, # required arg
|
||||
value: torch.Tensor, # required arg
|
||||
attention_mask: Optional[torch.Tensor], # required arg
|
||||
a_new_kwargs = None, # You can now add as many kwargs as you need
|
||||
another_new_kwargs = None, # You can now add as many kwargs as you need
|
||||
**kwargs, # You need to accept **kwargs as models will pass other args
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]
|
||||
... # do your magic!
|
||||
return attn_output, attn_weights # attn_weights are optional here
|
||||
|
||||
AttentionInterface.register("custom", custom_attention)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation="custom")
|
||||
# Forward pass with the new kwargs
|
||||
model(torch.ones(1, 5, dtype=int), a_new_kwargs=..., another_new_kwargs=...)
|
||||
```
|
||||
|
||||
If in doubt about what args/kwargs a given model sends to the attention function, simply check that model's modeling code on [GitHub](https://github.com/huggingface/transformers/tree/main/src/transformers/models)!
|
||||
|
||||
## Accessing current available implementations
|
||||
|
||||
Most of the time, you will simply need to `register` a new function. If, however, you need to access an existing one,
|
||||
and/or perform a few checks, the prefered way is to use the global `ALL_ATTENTION_FUNCTIONS`. It behaves the same way you
|
||||
would expect from a usual Python dictionary:
|
||||
|
||||
```python
|
||||
>>> from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||
|
||||
>>> list(ALL_ATTENTION_FUNCTIONS.keys())
|
||||
>>> ['flash_attention_2', 'flex_attention', 'sdpa']
|
||||
|
||||
>>> ALL_ATTENTION_FUNCTIONS["sdpa"]
|
||||
>>> <function transformers.integrations.sdpa_attention.sdpa_attention_forward>
|
||||
|
||||
>>> ALL_ATTENTION_FUNCTIONS.get("sdpa", None)
|
||||
>>> <function transformers.integrations.sdpa_attention.sdpa_attention_forward>
|
||||
|
||||
# You can also globally `register` a new function directly on it
|
||||
>>> ALL_ATTENTION_FUNCTIONS.register("new_func", new_func)
|
||||
```
|
||||
@ -9,7 +9,7 @@ Unless required by applicable law or agreed to in writing, software distributed
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
⚠️ Note that this file is in Markdown but contains specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
@ -62,7 +62,7 @@ for _ in range(max_new_tokens):
|
||||
# Greedily sample one next token
|
||||
next_token_ids = outputs.logits[:, -1:].argmax(-1)
|
||||
generated_ids = torch.cat([generated_ids, next_token_ids], dim=-1)
|
||||
# Prepare inputs for the next generation step by leaaving unprocessed tokens, in our case we have only one new token
|
||||
# Prepare inputs for the next generation step by leaving unprocessed tokens, in our case we have only one new token
|
||||
# and expanding attn mask for the new token, as explained above
|
||||
attention_mask = inputs["attention_mask"]
|
||||
attention_mask = torch.cat([attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1)
|
||||
@ -88,7 +88,7 @@ model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", to
|
||||
inputs = tokenizer("Hello, my name is", return_tensors="pt").to(model.device)
|
||||
|
||||
# `return_dict_in_generate=True` is required to return the cache and `return_legacy_cache` forces the returned cache
|
||||
# in the the legacy format
|
||||
# in the legacy format
|
||||
generation_outputs = model.generate(**inputs, return_dict_in_generate=True, return_legacy_cache=True, max_new_tokens=5)
|
||||
|
||||
cache = DynamicCache.from_legacy_cache(generation_outputs.past_key_values)
|
||||
|
||||
@ -9,7 +9,7 @@ Unless required by applicable law or agreed to in writing, software distributed
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
⚠️ Note that this file is in Markdown but contains specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
@ -18,7 +18,7 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
Multimodal model chat templates expect a similar [template](./chat_templating) as text-only models. It needs `messages` that includes a dictionary of the `role` and `content`.
|
||||
|
||||
Multimodal templates are included in the [Processor](./processors) class and requires an additional `type` key for specifying whether the included content is an image, video, or text.
|
||||
Multimodal templates are included in the [Processor](./processors) class and require an additional `type` key for specifying whether the included content is an image, video, or text.
|
||||
|
||||
This guide will show you how to format chat templates for multimodal models as well as some best practices for configuring the template
|
||||
|
||||
@ -109,7 +109,7 @@ These inputs are now ready to be used in [`~GenerationMixin.generate`].
|
||||
|
||||
Some vision models also support video inputs. The message format is very similar to the format for [image inputs](#image-inputs).
|
||||
|
||||
- The content `"type"` should be `"video"` to indicate the the content is a video.
|
||||
- The content `"type"` should be `"video"` to indicate the content is a video.
|
||||
- For videos, it can be a link to the video (`"url"`) or it could be a file path (`"path"`). Videos loaded from a URL can only be decoded with [PyAV](https://pyav.basswood-io.com/docs/stable/) or [Decord](https://github.com/dmlc/decord).
|
||||
|
||||
> [!WARNING]
|
||||
@ -141,7 +141,7 @@ Pass `messages` to [`~ProcessorMixin.apply_chat_template`] to tokenize the input
|
||||
|
||||
The `video_load_backend` parameter refers to a specific framework to load a video. It supports [PyAV](https://pyav.basswood-io.com/docs/stable/), [Decord](https://github.com/dmlc/decord), [OpenCV](https://github.com/opencv/opencv), and [torchvision](https://pytorch.org/vision/stable/index.html).
|
||||
|
||||
The examples below uses Decord as the backend because it is a bit faster than PyAV.
|
||||
The examples below use Decord as the backend because it is a bit faster than PyAV.
|
||||
|
||||
<hfoptions id="sampling">
|
||||
<hfoption id="fixed number of frames">
|
||||
|
||||
@ -131,7 +131,7 @@ class ResnetModel(PreTrainedModel):
|
||||
</hfoption>
|
||||
<hfoption id="ResnetModelForImageClassification">
|
||||
|
||||
The `forward` method needs to be rewrittten to calculate the loss for each logit if labels are available. Otherwise, the ResNet model class is the same.
|
||||
The `forward` method needs to be rewritten to calculate the loss for each logit if labels are available. Otherwise, the ResNet model class is the same.
|
||||
|
||||
> [!TIP]
|
||||
> Add `config_class` to the model class to enable [AutoClass](#autoclass-support) support.
|
||||
|
||||
@ -9,7 +9,7 @@ Unless required by applicable law or agreed to in writing, software distributed
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
⚠️ Note that this file is in Markdown but contains specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
@ -56,7 +56,7 @@ deepspeed --num_gpus 2 trainer-program.py ...
|
||||
|
||||
### Order of GPUs
|
||||
|
||||
To select specific GPUs to use and their order, configure the the `CUDA_VISIBLE_DEVICES` environment variable. It is easiest to set the environment variable in `~/bashrc` or another startup config file. `CUDA_VISIBLE_DEVICES` is used to map which GPUs are used. For example, if there are 4 GPUs (0, 1, 2, 3) and you only want to run GPUs 0 and 2:
|
||||
To select specific GPUs to use and their order, configure the `CUDA_VISIBLE_DEVICES` environment variable. It is easiest to set the environment variable in `~/bashrc` or another startup config file. `CUDA_VISIBLE_DEVICES` is used to map which GPUs are used. For example, if there are 4 GPUs (0, 1, 2, 3) and you only want to run GPUs 0 and 2:
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0,2 torchrun trainer-program.py ...
|
||||
|
||||
@ -43,4 +43,3 @@ Transformers is designed for developers and machine learning engineers and resea
|
||||
</a>
|
||||
</div>
|
||||
|
||||
Join us on the Hugging Face [Hub](https://huggingface.co/), [Discord](https://discord.com/invite/JfAtkvEtRb), or [forum](https://discuss.huggingface.co/) to collaborate and build models, datasets, and applications together.
|
||||
|
||||
@ -20,7 +20,7 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
# Installation
|
||||
|
||||
Transformers works with [PyTorch](https://pytorch.org/get-started/locally/), [TensorFlow 2.0](https://www.tensorflow.org/install/pip), and [Flax](https://flax.readthedocs.io/en/latest/). It has been tested on Python 3.6+, PyTorch 1.1.0+, TensorFlow 2.0+, and Flax.
|
||||
Transformers works with [PyTorch](https://pytorch.org/get-started/locally/), [TensorFlow 2.0](https://www.tensorflow.org/install/pip), and [Flax](https://flax.readthedocs.io/en/latest/). It has been tested on Python 3.9+, PyTorch 2.0+, TensorFlow 2.6+, and Flax 0.4.1+.
|
||||
|
||||
## Virtual environment
|
||||
|
||||
@ -33,7 +33,7 @@ Create and activate a virtual environment in your project directory with [venv](
|
||||
|
||||
```bash
|
||||
python -m venv .env
|
||||
source ./env/bin/activate
|
||||
source .env/bin/activate
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
@ -43,7 +43,7 @@ source ./env/bin/activate
|
||||
|
||||
```bash
|
||||
uv venv .env
|
||||
source ./env/bin/activate
|
||||
source .env/bin/activate
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
|
||||
71
docs/source/en/internal/model_debugging_utils.md
Normal file
71
docs/source/en/internal/model_debugging_utils.md
Normal file
@ -0,0 +1,71 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# Model debugging toolboxes
|
||||
|
||||
This page lists all the debugging and model adding tools used by the library, as well as the utility functions it provides for it.
|
||||
|
||||
Most of those are only useful if you are adding new models in the library.
|
||||
|
||||
|
||||
## Model addition debuggers
|
||||
|
||||
|
||||
### Model addition debugger - context manager for model adders
|
||||
|
||||
This context manager is a power user tool intended for model adders.
|
||||
It tracks all forward calls within a model forward and logs a slice of each input and output on a nested Json.
|
||||
To note, this context manager enforces `torch.inference_mode()`.
|
||||
|
||||
### Rationale
|
||||
|
||||
Because when porting models to transformers, even from python to python, model adders often have to do a lot of manual operations, involving saving and loading tensors, comparing dtypes, etc. This small tool can hopefully shave off some time.
|
||||
|
||||
### Usage
|
||||
|
||||
Add this context manager as follows to debug a model:
|
||||
|
||||
```python
|
||||
import torch
|
||||
from PIL import Image
|
||||
import requests
|
||||
from transformers import LlavaProcessor, LlavaForConditionalGeneration
|
||||
torch.random.manual_seed(673)
|
||||
|
||||
# load pretrained model and processor
|
||||
model_id = "llava-hf/llava-1.5-7b-hf"
|
||||
processor = LlavaProcessor.from_pretrained(model_id)
|
||||
model = LlavaForConditionalGeneration.from_pretrained(model_id, low_cpu_mem_usage=True)
|
||||
|
||||
# create random image input
|
||||
random_image = Image.fromarray(torch.randint(0, 256, (224, 224, 3), dtype=torch.uint8).numpy())
|
||||
|
||||
# prompt
|
||||
prompt = "<image>Describe this image."
|
||||
|
||||
# process inputs
|
||||
inputs = processor(text=prompt, images=random_image, return_tensors="pt")
|
||||
|
||||
# call forward method (not .generate!)
|
||||
with model_addition_debugger_context(model, "optional_path_to_your_output_file.json"):
|
||||
output = model.forward(**inputs)
|
||||
|
||||
```
|
||||
|
||||
|
||||
[[autodoc]] model_addition_debugger
|
||||
|
||||
[[autodoc]] model_addition_debugger_context
|
||||
@ -16,10 +16,18 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
# Custom Layers and Utilities
|
||||
|
||||
This page lists all the custom layers used by the library, as well as the utility functions it provides for modeling.
|
||||
This page lists all the custom layers used by the library, as well as the utility functions and classes it provides for modeling.
|
||||
|
||||
Most of those are only useful if you are studying the code of the models in the library.
|
||||
|
||||
## Attention Functions
|
||||
|
||||
[[autodoc]] AttentionInterface
|
||||
- register
|
||||
|
||||
## Rotary Position Embedding Functions
|
||||
|
||||
[[autodoc]] dynamic_rope_update
|
||||
|
||||
## Pytorch custom modules
|
||||
|
||||
|
||||
@ -93,7 +93,7 @@ model.generation_config.max_new_tokens = 16
|
||||
|
||||
past_key_values = StaticCache(
|
||||
config=model.config,
|
||||
batch_size=1,
|
||||
max_batch_size=1,
|
||||
# If you plan to reuse the cache, make sure the cache length is large enough for all cases
|
||||
max_cache_len=prompt_length+(model.generation_config.max_new_tokens*2),
|
||||
device=model.device,
|
||||
@ -159,7 +159,7 @@ from torch.nn.attention import SDPBackend, sdpa_kernel
|
||||
batch_size, seq_length = inputs["input_ids"].shape
|
||||
with torch.no_grad():
|
||||
past_key_values = StaticCache(
|
||||
config=model.config, batch_size=2, max_cache_len=4096, device=torch_device, dtype=model.dtype
|
||||
config=model.config, max_batch_size=2, max_cache_len=4096, device=torch_device, dtype=model.dtype
|
||||
)
|
||||
cache_position = torch.arange(seq_length, device=torch_device)
|
||||
generated_ids = torch.zeros(
|
||||
|
||||
@ -22,9 +22,6 @@ The `.optimization` module provides:
|
||||
- several schedules in the form of schedule objects that inherit from `_LRSchedule`:
|
||||
- a gradient accumulation class to accumulate the gradients of multiple batches
|
||||
|
||||
## AdamW (PyTorch)
|
||||
|
||||
[[autodoc]] AdamW
|
||||
|
||||
## AdaFactor (PyTorch)
|
||||
|
||||
|
||||
@ -88,3 +88,7 @@ Learn how to quantize models in the [Quantization](../quantization) guide.
|
||||
## FineGrainedFP8Config
|
||||
|
||||
[[autodoc]] FineGrainedFP8Config
|
||||
|
||||
## QuarkConfig
|
||||
|
||||
[[autodoc]] QuarkConfig
|
||||
|
||||
@ -18,6 +18,7 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
|
||||
|
||||
@ -14,159 +14,85 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# BERT
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="TensorFlow" src="https://img.shields.io/badge/TensorFlow-FF6F00?style=flat&logo=tensorflow&logoColor=white">
|
||||
<img alt="Flax" src="https://img.shields.io/badge/Flax-29a79b.svg?style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAC0AAAAtCAMAAAANxBKoAAAC7lBMVEUAAADg5vYHPVgAoJH+/v76+v39/f9JbLP///9+AIgAnY3///+mcqzt8fXy9fgkXa3Ax9709fr+///9/f8qXq49qp5AaLGMwrv8/P0eW60VWawxYq8yqJzG2dytt9Wyu9elzci519Lf3O3S2efY3OrY0+Xp7PT///////+dqNCexMc6Z7AGpJeGvbenstPZ5ejQ1OfJzOLa7ejh4+/r8fT29vpccbklWK8PVa0AS6ghW63O498vYa+lsdKz1NDRt9Kw1c672tbD3tnAxt7R6OHp5vDe7OrDyuDn6vLl6/EAQKak0MgATakkppo3ZK/Bz9y8w9yzu9jey97axdvHzeG21NHH4trTwthKZrVGZLSUSpuPQJiGAI+GAI8SWKydycLL4d7f2OTi1+S9xNzL0ePT6OLGzeEAo5U0qJw/aLEAo5JFa7JBabEAp5Y4qZ2QxLyKmsm3kL2xoMOehrRNb7RIbbOZgrGre68AUqwAqZqNN5aKJ5N/lMq+qsd8kMa4pcWzh7muhLMEV69juq2kbKqgUaOTR5uMMZWLLZSGAI5VAIdEAH+ovNDHuNCnxcy3qcaYx8K8msGplrx+wLahjbYdXrV6vbMvYK9DrZ8QrZ8tqJuFms+Sos6sw8ecy8RffsNVeMCvmb43aLltv7Q4Y7EZWK4QWa1gt6meZKUdr6GOAZVeA4xPAISyveLUwtivxtKTpNJ2jcqfvcltiMiwwcfAoMVxhL+Kx7xjdrqTe60tsaNQs6KaRKACrJ6UTZwkqpqTL5pkHY4AloSgsd2ptNXPvNOOncuxxsqFl8lmg8apt8FJcr9EbryGxLqlkrkrY7dRa7ZGZLQ5t6iXUZ6PPpgVpZeJCJFKAIGareTa0+KJod3H0deY2M+esM25usmYu8d2zsJOdcBVvrCLbqcAOaaHaKQAMaScWqKBXqCXMJ2RHpiLF5NmJZAdAHN2kta11dKu1M+DkcZLdb+Mcql3TppyRJdzQ5ZtNZNlIY+DF4+voCOQAAAAZ3RSTlMABAT+MEEJ/RH+/TP+Zlv+pUo6Ifz8+fco/fz6+evr39S9nJmOilQaF/7+/f38+smmoYp6b1T+/v7++vj189zU0tDJxsGzsrKSfv34+Pf27dDOysG9t6+n/vv6+vr59uzr1tG+tZ6Qg9Ym3QAABR5JREFUSMeNlVVUG1EQhpcuxEspXqS0SKEtxQp1d3d332STTRpIQhIISQgJhODu7lAoDoUCpe7u7u7+1puGpqnCPOyZvffbOXPm/PsP9JfQgyCC+tmTABTOcbxDz/heENS7/1F+9nhvkHePG0wNDLbGWwdXL+rbLWvpmZHXD8+gMfBjTh+aSe6Gnn7lwQIOTR0c8wfX3PWgv7avbdKwf/ZoBp1Gp/PvuvXW3vw5ib7emnTW4OR+3D4jB9vjNJ/7gNvfWWeH/TO/JyYrsiKCRjVEZA3UB+96kON+DxOQ/NLE8PE5iUYgIXjFnCOlxEQMaSGVxjg4gxOnEycGz8bptuNjVx08LscIgrzH3umcn+KKtiBIyvzOO2O99aAdR8cF19oZalnCtvREUw79tCd5sow1g1UKM6kXqUx4T8wsi3sTjJ3yzDmmhenLXLpo8u45eG5y4Vvbk6kkC4LLtJMowkSQxmk4ggVJEG+7c6QpHT8vvW9X7/o7+3ELmiJi2mEzZJiz8cT6TBlanBk70cB5GGIGC1gRDdZ00yADLW1FL6gqhtvNXNG5S9gdSrk4M1qu7JAsmYshzDS4peoMrU/gT7qQdqYGZaYhxZmVbGJAm/CS/HloWyhRUlknQ9KYcExTwS80d3VNOxUZJpITYyspl0LbhArhpZCD9cRWEQuhYkNGMHToQ/2Cs6swJlb39CsllxdXX6IUKh/H5jbnSsPKjgmoaFQ1f8wRLR0UnGE/RcDEjj2jXG1WVTwUs8+zxfcrVO+vSsuOpVKxCfYZiQ0/aPKuxQbQ8lIz+DClxC8u+snlcJ7Yr1z1JPqUH0V+GDXbOwAib931Y4Imaq0NTIXPXY+N5L18GJ37SVWu+hwXff8l72Ds9XuwYIBaXPq6Shm4l+Vl/5QiOlV+uTk6YR9PxKsI9xNJny31ygK1e+nIRC1N97EGkFPI+jCpiHe5PCEy7oWqWSwRrpOvhFzcbTWMbm3ZJAOn1rUKpYIt/lDhW/5RHHteeWFN60qo98YJuoq1nK3uW5AabyspC1BcIEpOhft+SZAShYoLSvnmSfnYADUERP5jJn2h5XtsgCRuhYQqAvwTwn33+YWEKUI72HX5AtfSAZDe8F2DtPPm77afhl0EkthzuCQU0BWApgQIH9+KB0JhopMM7bJrdTRoleM2JAVNMyPF+wdoaz+XJpGoVAQ7WXUkcV7gT3oUZyi/ISIJAVKhgNp+4b4veCFhYVJw4locdSjZCp9cPUhLF9EZ3KKzURepMEtCDPP3VcWFx4UIiZIklIpFNfHpdEafIF2aRmOcrUmjohbT2WUllbmRvgfbythbQO3222fpDJoufaQPncYYuqoGtUEsCJZL6/3PR5b4syeSjZMQG/T2maGANlXT2v8S4AULWaUkCxfLyW8iW4kdka+nEMjxpL2NCwsYNBp+Q61PF43zyDg9Bm9+3NNySn78jMZUUkumqE4Gp7JmFOdP1vc8PpRrzj9+wPinCy8K1PiJ4aYbnTYpCCbDkBSbzhu2QJ1Gd82t8jI8TH51+OzvXoWbnXUOBkNW+0mWFwGcGOUVpU81/n3TOHb5oMt2FgYGjzau0Nif0Ss7Q3XB33hjjQHjHA5E5aOyIQc8CBrLdQSs3j92VG+3nNEjbkbdbBr9zm04ruvw37vh0QKOdeGIkckc80fX3KH/h7PT4BOjgCty8VZ5ux1MoO5Cf5naca2LAsEgehI+drX8o/0Nu+W0m6K/I9gGPd/dfx/EN/wN62AhsBWuAAAAAElFTkSuQmCC
|
||||
">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<div style="float: right;">
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="TensorFlow" src="https://img.shields.io/badge/TensorFlow-FF6F00?style=flat&logo=tensorflow&logoColor=white">
|
||||
<img alt="Flax" src="https://img.shields.io/badge/Flax-29a79b.svg?style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAC0AAAAtCAMAAAANxBKoAAAC7lBMVEUAAADg5vYHPVgAoJH+/v76+v39/f9JbLP///9+AIgAnY3///+mcqzt8fXy9fgkXa3Ax9709fr+///9/f8qXq49qp5AaLGMwrv8/P0eW60VWawxYq8yqJzG2dytt9Wyu9elzci519Lf3O3S2efY3OrY0+Xp7PT///////+dqNCexMc6Z7AGpJeGvbenstPZ5ejQ1OfJzOLa7ejh4+/r8fT29vpccbklWK8PVa0AS6ghW63O498vYa+lsdKz1NDRt9Kw1c672tbD3tnAxt7R6OHp5vDe7OrDyuDn6vLl6/EAQKak0MgATakkppo3ZK/Bz9y8w9yzu9jey97axdvHzeG21NHH4trTwthKZrVGZLSUSpuPQJiGAI+GAI8SWKydycLL4d7f2OTi1+S9xNzL0ePT6OLGzeEAo5U0qJw/aLEAo5JFa7JBabEAp5Y4qZ2QxLyKmsm3kL2xoMOehrRNb7RIbbOZgrGre68AUqwAqZqNN5aKJ5N/lMq+qsd8kMa4pcWzh7muhLMEV69juq2kbKqgUaOTR5uMMZWLLZSGAI5VAIdEAH+ovNDHuNCnxcy3qcaYx8K8msGplrx+wLahjbYdXrV6vbMvYK9DrZ8QrZ8tqJuFms+Sos6sw8ecy8RffsNVeMCvmb43aLltv7Q4Y7EZWK4QWa1gt6meZKUdr6GOAZVeA4xPAISyveLUwtivxtKTpNJ2jcqfvcltiMiwwcfAoMVxhL+Kx7xjdrqTe60tsaNQs6KaRKACrJ6UTZwkqpqTL5pkHY4AloSgsd2ptNXPvNOOncuxxsqFl8lmg8apt8FJcr9EbryGxLqlkrkrY7dRa7ZGZLQ5t6iXUZ6PPpgVpZeJCJFKAIGareTa0+KJod3H0deY2M+esM25usmYu8d2zsJOdcBVvrCLbqcAOaaHaKQAMaScWqKBXqCXMJ2RHpiLF5NmJZAdAHN2kta11dKu1M+DkcZLdb+Mcql3TppyRJdzQ5ZtNZNlIY+DF4+voCOQAAAAZ3RSTlMABAT+MEEJ/RH+/TP+Zlv+pUo6Ifz8+fco/fz6+evr39S9nJmOilQaF/7+/f38+smmoYp6b1T+/v7++vj189zU0tDJxsGzsrKSfv34+Pf27dDOysG9t6+n/vv6+vr59uzr1tG+tZ6Qg9Ym3QAABR5JREFUSMeNlVVUG1EQhpcuxEspXqS0SKEtxQp1d3d332STTRpIQhIISQgJhODu7lAoDoUCpe7u7u7+1puGpqnCPOyZvffbOXPm/PsP9JfQgyCC+tmTABTOcbxDz/heENS7/1F+9nhvkHePG0wNDLbGWwdXL+rbLWvpmZHXD8+gMfBjTh+aSe6Gnn7lwQIOTR0c8wfX3PWgv7avbdKwf/ZoBp1Gp/PvuvXW3vw5ib7emnTW4OR+3D4jB9vjNJ/7gNvfWWeH/TO/JyYrsiKCRjVEZA3UB+96kON+DxOQ/NLE8PE5iUYgIXjFnCOlxEQMaSGVxjg4gxOnEycGz8bptuNjVx08LscIgrzH3umcn+KKtiBIyvzOO2O99aAdR8cF19oZalnCtvREUw79tCd5sow1g1UKM6kXqUx4T8wsi3sTjJ3yzDmmhenLXLpo8u45eG5y4Vvbk6kkC4LLtJMowkSQxmk4ggVJEG+7c6QpHT8vvW9X7/o7+3ELmiJi2mEzZJiz8cT6TBlanBk70cB5GGIGC1gRDdZ00yADLW1FL6gqhtvNXNG5S9gdSrk4M1qu7JAsmYshzDS4peoMrU/gT7qQdqYGZaYhxZmVbGJAm/CS/HloWyhRUlknQ9KYcExTwS80d3VNOxUZJpITYyspl0LbhArhpZCD9cRWEQuhYkNGMHToQ/2Cs6swJlb39CsllxdXX6IUKh/H5jbnSsPKjgmoaFQ1f8wRLR0UnGE/RcDEjj2jXG1WVTwUs8+zxfcrVO+vSsuOpVKxCfYZiQ0/aPKuxQbQ8lIz+DClxC8u+snlcJ7Yr1z1JPqUH0V+GDXbOwAib931Y4Imaq0NTIXPXY+N5L18GJ37SVWu+hwXff8l72Ds9XuwYIBaXPq6Shm4l+Vl/5QiOlV+uTk6YR9PxKsI9xNJny31ygK1e+nIRC1N97EGkFPI+jCpiHe5PCEy7oWqWSwRrpOvhFzcbTWMbm3ZJAOn1rUKpYIt/lDhW/5RHHteeWFN60qo98YJuoq1nK3uW5AabyspC1BcIEpOhft+SZAShYoLSvnmSfnYADUERP5jJn2h5XtsgCRuhYQqAvwTwn33+YWEKUI72HX5AtfSAZDe8F2DtPPm77afhl0EkthzuCQU0BWApgQIH9+KB0JhopMM7bJrdTRoleM2JAVNMyPF+wdoaz+XJpGoVAQ7WXUkcV7gT3oUZyi/ISIJAVKhgNp+4b4veCFhYVJw4locdSjZCp9cPUhLF9EZ3KKzURepMEtCDPP3VcWFx4UIiZIklIpFNfHpdEafIF2aRmOcrUmjohbT2WUllbmRvgfbythbQO3222fpDJoufaQPncYYuqoGtUEsCJZL6/3PR5b4syeSjZMQG/T2maGANlXT2v8S4AULWaUkCxfLyW8iW4kdka+nEMjxpL2NCwsYNBp+Q61PF43zyDg9Bm9+3NNySn78jMZUUkumqE4Gp7JmFOdP1vc8PpRrzj9+wPinCy8K1PiJ4aYbnTYpCCbDkBSbzhu2QJ1Gd82t8jI8TH51+OzvXoWbnXUOBkNW+0mWFwGcGOUVpU81/n3TOHb5oMt2FgYGjzau0Nif0Ss7Q3XB33hjjQHjHA5E5aOyIQc8CBrLdQSs3j92VG+3nNEjbkbdbBr9zm04ruvw37vh0QKOdeGIkckc80fX3KH/h7PT4BOjgCty8VZ5ux1MoO5Cf5naca2LAsEgehI+drX8o/0Nu+W0m6K/I9gGPd/dfx/EN/wN62AhsBWuAAAAAElFTkSuQmCC
|
||||
">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
</div>
|
||||
|
||||
## Overview
|
||||
# BERT
|
||||
|
||||
The BERT model was proposed in [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805) by Jacob Devlin, Ming-Wei Chang, Kenton Lee and Kristina Toutanova. It's a
|
||||
bidirectional transformer pretrained using a combination of masked language modeling objective and next sentence
|
||||
prediction on a large corpus comprising the Toronto Book Corpus and Wikipedia.
|
||||
[BERT](https://huggingface.co/papers/1810.04805) is a bidirectional transformer pretrained on unlabeled text to predict masked tokens in a sentence and to predict whether one sentence follows another. The main idea is that by randomly masking some tokens, the model can train on text to the left and right, giving it a more thorough understanding. BERT is also very versatile because its learned language representations can be adapted for other NLP tasks by fine-tuning an additional layer or head.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
You can find all the original BERT checkpoints under the [BERT](https://huggingface.co/collections/google/bert-release-64ff5e7a4be99045d1896dbc) collection.
|
||||
|
||||
*We introduce a new language representation model called BERT, which stands for Bidirectional Encoder Representations
|
||||
from Transformers. Unlike recent language representation models, BERT is designed to pre-train deep bidirectional
|
||||
representations from unlabeled text by jointly conditioning on both left and right context in all layers. As a result,
|
||||
the pre-trained BERT model can be fine-tuned with just one additional output layer to create state-of-the-art models
|
||||
for a wide range of tasks, such as question answering and language inference, without substantial task-specific
|
||||
architecture modifications.*
|
||||
> [!TIP]
|
||||
> Click on the BERT models in the right sidebar for more examples of how to apply BERT to different language tasks.
|
||||
|
||||
*BERT is conceptually simple and empirically powerful. It obtains new state-of-the-art results on eleven natural
|
||||
language processing tasks, including pushing the GLUE score to 80.5% (7.7% point absolute improvement), MultiNLI
|
||||
accuracy to 86.7% (4.6% absolute improvement), SQuAD v1.1 question answering Test F1 to 93.2 (1.5 point absolute
|
||||
improvement) and SQuAD v2.0 Test F1 to 83.1 (5.1 point absolute improvement).*
|
||||
The example below demonstrates how to predict the `[MASK]` token with [`Pipeline`], [`AutoModel`], and from the command line.
|
||||
|
||||
This model was contributed by [thomwolf](https://huggingface.co/thomwolf). The original code can be found [here](https://github.com/google-research/bert).
|
||||
<hfoptions id="usage">
|
||||
<hfoption id="Pipeline">
|
||||
|
||||
## Usage tips
|
||||
```py
|
||||
import torch
|
||||
from transformers import pipeline
|
||||
|
||||
- BERT is a model with absolute position embeddings so it's usually advised to pad the inputs on the right rather than
|
||||
the left.
|
||||
- BERT was trained with the masked language modeling (MLM) and next sentence prediction (NSP) objectives. It is
|
||||
efficient at predicting masked tokens and at NLU in general, but is not optimal for text generation.
|
||||
- Corrupts the inputs by using random masking, more precisely, during pretraining, a given percentage of tokens (usually 15%) is masked by:
|
||||
|
||||
* a special mask token with probability 0.8
|
||||
* a random token different from the one masked with probability 0.1
|
||||
* the same token with probability 0.1
|
||||
|
||||
- The model must predict the original sentence, but has a second objective: inputs are two sentences A and B (with a separation token in between). With probability 50%, the sentences are consecutive in the corpus, in the remaining 50% they are not related. The model has to predict if the sentences are consecutive or not.
|
||||
|
||||
### Using Scaled Dot Product Attention (SDPA)
|
||||
|
||||
PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function
|
||||
encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the
|
||||
[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
|
||||
or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention)
|
||||
page for more information.
|
||||
|
||||
SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set
|
||||
`attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used.
|
||||
|
||||
```
|
||||
from transformers import BertModel
|
||||
|
||||
model = BertModel.from_pretrained("bert-base-uncased", torch_dtype=torch.float16, attn_implementation="sdpa")
|
||||
...
|
||||
pipeline = pipeline(
|
||||
task="fill-mask",
|
||||
model="google-bert/bert-base-uncased",
|
||||
torch_dtype=torch.float16,
|
||||
device=0
|
||||
)
|
||||
pipeline("Plants create [MASK] through a process known as photosynthesis.")
|
||||
```
|
||||
|
||||
For the best speedups, we recommend loading the model in half-precision (e.g. `torch.float16` or `torch.bfloat16`).
|
||||
</hfoption>
|
||||
<hfoption id="AutoModel">
|
||||
|
||||
On a local benchmark (A100-80GB, CPUx12, RAM 96.6GB, PyTorch 2.2.0, OS Ubuntu 22.04) with `float16`, we saw the
|
||||
following speedups during training and inference.
|
||||
```py
|
||||
import torch
|
||||
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
||||
|
||||
#### Training
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"google-bert/bert-base-uncased",
|
||||
)
|
||||
model = AutoModelForMaskedLM.from_pretrained(
|
||||
"google-bert/bert-base-uncased",
|
||||
torch_dtype=torch.float16,
|
||||
device_map="auto",
|
||||
attn_implementation="sdpa"
|
||||
)
|
||||
inputs = tokenizer("Plants create [MASK] through a process known as photosynthesis.", return_tensors="pt").to("cuda")
|
||||
|
||||
|batch_size|seq_len|Time per batch (eager - s)|Time per batch (sdpa - s)|Speedup (%)|Eager peak mem (MB)|sdpa peak mem (MB)|Mem saving (%)|
|
||||
|----------|-------|--------------------------|-------------------------|-----------|-------------------|------------------|--------------|
|
||||
|4 |256 |0.023 |0.017 |35.472 |939.213 |764.834 |22.800 |
|
||||
|4 |512 |0.023 |0.018 |23.687 |1970.447 |1227.162 |60.569 |
|
||||
|8 |256 |0.023 |0.018 |23.491 |1594.295 |1226.114 |30.028 |
|
||||
|8 |512 |0.035 |0.025 |43.058 |3629.401 |2134.262 |70.054 |
|
||||
|16 |256 |0.030 |0.024 |25.583 |2874.426 |2134.262 |34.680 |
|
||||
|16 |512 |0.064 |0.044 |46.223 |6964.659 |3961.013 |75.830 |
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
predictions = outputs.logits
|
||||
|
||||
#### Inference
|
||||
masked_index = torch.where(inputs['input_ids'] == tokenizer.mask_token_id)[1]
|
||||
predicted_token_id = predictions[0, masked_index].argmax(dim=-1)
|
||||
predicted_token = tokenizer.decode(predicted_token_id)
|
||||
|
||||
|batch_size|seq_len|Per token latency eager (ms)|Per token latency SDPA (ms)|Speedup (%)|Mem eager (MB)|Mem BT (MB)|Mem saved (%)|
|
||||
|----------|-------|----------------------------|---------------------------|-----------|--------------|-----------|-------------|
|
||||
|1 |128 |5.736 |4.987 |15.022 |282.661 |282.924 |-0.093 |
|
||||
|1 |256 |5.689 |4.945 |15.055 |298.686 |298.948 |-0.088 |
|
||||
|2 |128 |6.154 |4.982 |23.521 |314.523 |314.785 |-0.083 |
|
||||
|2 |256 |6.201 |4.949 |25.303 |347.546 |347.033 |0.148 |
|
||||
|4 |128 |6.049 |4.987 |21.305 |378.895 |379.301 |-0.107 |
|
||||
|4 |256 |6.285 |5.364 |17.166 |443.209 |444.382 |-0.264 |
|
||||
print(f"The predicted token is: {predicted_token}")
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="transformers-cli">
|
||||
|
||||
```bash
|
||||
echo -e "Plants create [MASK] through a process known as photosynthesis." | transformers-cli run --task fill-mask --model google-bert/bert-base-uncased --device 0
|
||||
```
|
||||
|
||||
## Resources
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with BERT. If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we'll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource.
|
||||
## Notes
|
||||
|
||||
<PipelineTag pipeline="text-classification"/>
|
||||
|
||||
- A blog post on [BERT Text Classification in a different language](https://www.philschmid.de/bert-text-classification-in-a-different-language).
|
||||
- A notebook for [Finetuning BERT (and friends) for multi-label text classification](https://colab.research.google.com/github/NielsRogge/Transformers-Tutorials/blob/master/BERT/Fine_tuning_BERT_(and_friends)_for_multi_label_text_classification.ipynb).
|
||||
- A notebook on how to [Finetune BERT for multi-label classification using PyTorch](https://colab.research.google.com/github/abhimishra91/transformers-tutorials/blob/master/transformers_multi_label_classification.ipynb). 🌎
|
||||
- A notebook on how to [warm-start an EncoderDecoder model with BERT for summarization](https://colab.research.google.com/github/patrickvonplaten/notebooks/blob/master/BERT2BERT_for_CNN_Dailymail.ipynb).
|
||||
- [`BertForSequenceClassification`] is supported by this [example script](https://github.com/huggingface/transformers/tree/main/examples/pytorch/text-classification) and [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/text_classification.ipynb).
|
||||
- [`TFBertForSequenceClassification`] is supported by this [example script](https://github.com/huggingface/transformers/tree/main/examples/tensorflow/text-classification) and [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/text_classification-tf.ipynb).
|
||||
- [`FlaxBertForSequenceClassification`] is supported by this [example script](https://github.com/huggingface/transformers/tree/main/examples/flax/text-classification) and [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/text_classification_flax.ipynb).
|
||||
- [Text classification task guide](../tasks/sequence_classification)
|
||||
|
||||
<PipelineTag pipeline="token-classification"/>
|
||||
|
||||
- A blog post on how to use [Hugging Face Transformers with Keras: Fine-tune a non-English BERT for Named Entity Recognition](https://www.philschmid.de/huggingface-transformers-keras-tf).
|
||||
- A notebook for [Finetuning BERT for named-entity recognition](https://colab.research.google.com/github/NielsRogge/Transformers-Tutorials/blob/master/BERT/Custom_Named_Entity_Recognition_with_BERT_only_first_wordpiece.ipynb) using only the first wordpiece of each word in the word label during tokenization. To propagate the label of the word to all wordpieces, see this [version](https://github.com/NielsRogge/Transformers-Tutorials/blob/master/BERT/Custom_Named_Entity_Recognition_with_BERT.ipynb) of the notebook instead.
|
||||
- [`BertForTokenClassification`] is supported by this [example script](https://github.com/huggingface/transformers/tree/main/examples/pytorch/token-classification) and [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/token_classification.ipynb).
|
||||
- [`TFBertForTokenClassification`] is supported by this [example script](https://github.com/huggingface/transformers/tree/main/examples/tensorflow/token-classification) and [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/token_classification-tf.ipynb).
|
||||
- [`FlaxBertForTokenClassification`] is supported by this [example script](https://github.com/huggingface/transformers/tree/main/examples/flax/token-classification).
|
||||
- [Token classification](https://huggingface.co/course/chapter7/2?fw=pt) chapter of the 🤗 Hugging Face Course.
|
||||
- [Token classification task guide](../tasks/token_classification)
|
||||
|
||||
<PipelineTag pipeline="fill-mask"/>
|
||||
|
||||
- [`BertForMaskedLM`] is supported by this [example script](https://github.com/huggingface/transformers/tree/main/examples/pytorch/language-modeling#robertabertdistilbert-and-masked-language-modeling) and [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/language_modeling.ipynb).
|
||||
- [`TFBertForMaskedLM`] is supported by this [example script](https://github.com/huggingface/transformers/tree/main/examples/tensorflow/language-modeling#run_mlmpy) and [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/language_modeling-tf.ipynb).
|
||||
- [`FlaxBertForMaskedLM`] is supported by this [example script](https://github.com/huggingface/transformers/tree/main/examples/flax/language-modeling#masked-language-modeling) and [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/masked_language_modeling_flax.ipynb).
|
||||
- [Masked language modeling](https://huggingface.co/course/chapter7/3?fw=pt) chapter of the 🤗 Hugging Face Course.
|
||||
- [Masked language modeling task guide](../tasks/masked_language_modeling)
|
||||
|
||||
<PipelineTag pipeline="question-answering"/>
|
||||
|
||||
- [`BertForQuestionAnswering`] is supported by this [example script](https://github.com/huggingface/transformers/tree/main/examples/pytorch/question-answering) and [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/question_answering.ipynb).
|
||||
- [`TFBertForQuestionAnswering`] is supported by this [example script](https://github.com/huggingface/transformers/tree/main/examples/tensorflow/question-answering) and [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/question_answering-tf.ipynb).
|
||||
- [`FlaxBertForQuestionAnswering`] is supported by this [example script](https://github.com/huggingface/transformers/tree/main/examples/flax/question-answering).
|
||||
- [Question answering](https://huggingface.co/course/chapter7/7?fw=pt) chapter of the 🤗 Hugging Face Course.
|
||||
- [Question answering task guide](../tasks/question_answering)
|
||||
|
||||
**Multiple choice**
|
||||
- [`BertForMultipleChoice`] is supported by this [example script](https://github.com/huggingface/transformers/tree/main/examples/pytorch/multiple-choice) and [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/multiple_choice.ipynb).
|
||||
- [`TFBertForMultipleChoice`] is supported by this [example script](https://github.com/huggingface/transformers/tree/main/examples/tensorflow/multiple-choice) and [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/multiple_choice-tf.ipynb).
|
||||
- [Multiple choice task guide](../tasks/multiple_choice)
|
||||
|
||||
⚡️ **Inference**
|
||||
- A blog post on how to [Accelerate BERT inference with Hugging Face Transformers and AWS Inferentia](https://huggingface.co/blog/bert-inferentia-sagemaker).
|
||||
- A blog post on how to [Accelerate BERT inference with DeepSpeed-Inference on GPUs](https://www.philschmid.de/bert-deepspeed-inference).
|
||||
|
||||
⚙️ **Pretraining**
|
||||
- A blog post on [Pre-Training BERT with Hugging Face Transformers and Habana Gaudi](https://www.philschmid.de/pre-training-bert-habana).
|
||||
|
||||
🚀 **Deploy**
|
||||
- A blog post on how to [Convert Transformers to ONNX with Hugging Face Optimum](https://www.philschmid.de/convert-transformers-to-onnx).
|
||||
- A blog post on how to [Setup Deep Learning environment for Hugging Face Transformers with Habana Gaudi on AWS](https://www.philschmid.de/getting-started-habana-gaudi#conclusion).
|
||||
- A blog post on [Autoscaling BERT with Hugging Face Transformers, Amazon SageMaker and Terraform module](https://www.philschmid.de/terraform-huggingface-amazon-sagemaker-advanced).
|
||||
- A blog post on [Serverless BERT with HuggingFace, AWS Lambda, and Docker](https://www.philschmid.de/serverless-bert-with-huggingface-aws-lambda-docker).
|
||||
- A blog post on [Hugging Face Transformers BERT fine-tuning using Amazon SageMaker and Training Compiler](https://www.philschmid.de/huggingface-amazon-sagemaker-training-compiler).
|
||||
- A blog post on [Task-specific knowledge distillation for BERT using Transformers & Amazon SageMaker](https://www.philschmid.de/knowledge-distillation-bert-transformers).
|
||||
- Inputs should be padded on the right because BERT uses absolute position embeddings.
|
||||
|
||||
## BertConfig
|
||||
|
||||
@ -181,35 +107,10 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h
|
||||
- create_token_type_ids_from_sequences
|
||||
- save_vocabulary
|
||||
|
||||
<frameworkcontent>
|
||||
<pt>
|
||||
|
||||
## BertTokenizerFast
|
||||
|
||||
[[autodoc]] BertTokenizerFast
|
||||
|
||||
</pt>
|
||||
<tf>
|
||||
|
||||
## TFBertTokenizer
|
||||
|
||||
[[autodoc]] TFBertTokenizer
|
||||
|
||||
</tf>
|
||||
</frameworkcontent>
|
||||
|
||||
## Bert specific outputs
|
||||
|
||||
[[autodoc]] models.bert.modeling_bert.BertForPreTrainingOutput
|
||||
|
||||
[[autodoc]] models.bert.modeling_tf_bert.TFBertForPreTrainingOutput
|
||||
|
||||
[[autodoc]] models.bert.modeling_flax_bert.FlaxBertForPreTrainingOutput
|
||||
|
||||
|
||||
<frameworkcontent>
|
||||
<pt>
|
||||
|
||||
## BertModel
|
||||
|
||||
[[autodoc]] BertModel
|
||||
@ -255,8 +156,9 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h
|
||||
[[autodoc]] BertForQuestionAnswering
|
||||
- forward
|
||||
|
||||
</pt>
|
||||
<tf>
|
||||
## TFBertTokenizer
|
||||
|
||||
[[autodoc]] TFBertTokenizer
|
||||
|
||||
## TFBertModel
|
||||
|
||||
@ -303,9 +205,6 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h
|
||||
[[autodoc]] TFBertForQuestionAnswering
|
||||
- call
|
||||
|
||||
</tf>
|
||||
<jax>
|
||||
|
||||
## FlaxBertModel
|
||||
|
||||
[[autodoc]] FlaxBertModel
|
||||
@ -351,7 +250,10 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h
|
||||
[[autodoc]] FlaxBertForQuestionAnswering
|
||||
- __call__
|
||||
|
||||
</jax>
|
||||
</frameworkcontent>
|
||||
## Bert specific outputs
|
||||
|
||||
[[autodoc]] models.bert.modeling_bert.BertForPreTrainingOutput
|
||||
|
||||
[[autodoc]] models.bert.modeling_tf_bert.TFBertForPreTrainingOutput
|
||||
|
||||
[[autodoc]] models.bert.modeling_flax_bert.FlaxBertForPreTrainingOutput
|
||||
@ -14,221 +14,77 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# CLIP
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="TensorFlow" src="https://img.shields.io/badge/TensorFlow-FF6F00?style=flat&logo=tensorflow&logoColor=white">
|
||||
<img alt="Flax" src="https://img.shields.io/badge/Flax-29a79b.svg?style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAC0AAAAtCAMAAAANxBKoAAAC7lBMVEUAAADg5vYHPVgAoJH+/v76+v39/f9JbLP///9+AIgAnY3///+mcqzt8fXy9fgkXa3Ax9709fr+///9/f8qXq49qp5AaLGMwrv8/P0eW60VWawxYq8yqJzG2dytt9Wyu9elzci519Lf3O3S2efY3OrY0+Xp7PT///////+dqNCexMc6Z7AGpJeGvbenstPZ5ejQ1OfJzOLa7ejh4+/r8fT29vpccbklWK8PVa0AS6ghW63O498vYa+lsdKz1NDRt9Kw1c672tbD3tnAxt7R6OHp5vDe7OrDyuDn6vLl6/EAQKak0MgATakkppo3ZK/Bz9y8w9yzu9jey97axdvHzeG21NHH4trTwthKZrVGZLSUSpuPQJiGAI+GAI8SWKydycLL4d7f2OTi1+S9xNzL0ePT6OLGzeEAo5U0qJw/aLEAo5JFa7JBabEAp5Y4qZ2QxLyKmsm3kL2xoMOehrRNb7RIbbOZgrGre68AUqwAqZqNN5aKJ5N/lMq+qsd8kMa4pcWzh7muhLMEV69juq2kbKqgUaOTR5uMMZWLLZSGAI5VAIdEAH+ovNDHuNCnxcy3qcaYx8K8msGplrx+wLahjbYdXrV6vbMvYK9DrZ8QrZ8tqJuFms+Sos6sw8ecy8RffsNVeMCvmb43aLltv7Q4Y7EZWK4QWa1gt6meZKUdr6GOAZVeA4xPAISyveLUwtivxtKTpNJ2jcqfvcltiMiwwcfAoMVxhL+Kx7xjdrqTe60tsaNQs6KaRKACrJ6UTZwkqpqTL5pkHY4AloSgsd2ptNXPvNOOncuxxsqFl8lmg8apt8FJcr9EbryGxLqlkrkrY7dRa7ZGZLQ5t6iXUZ6PPpgVpZeJCJFKAIGareTa0+KJod3H0deY2M+esM25usmYu8d2zsJOdcBVvrCLbqcAOaaHaKQAMaScWqKBXqCXMJ2RHpiLF5NmJZAdAHN2kta11dKu1M+DkcZLdb+Mcql3TppyRJdzQ5ZtNZNlIY+DF4+voCOQAAAAZ3RSTlMABAT+MEEJ/RH+/TP+Zlv+pUo6Ifz8+fco/fz6+evr39S9nJmOilQaF/7+/f38+smmoYp6b1T+/v7++vj189zU0tDJxsGzsrKSfv34+Pf27dDOysG9t6+n/vv6+vr59uzr1tG+tZ6Qg9Ym3QAABR5JREFUSMeNlVVUG1EQhpcuxEspXqS0SKEtxQp1d3d332STTRpIQhIISQgJhODu7lAoDoUCpe7u7u7+1puGpqnCPOyZvffbOXPm/PsP9JfQgyCC+tmTABTOcbxDz/heENS7/1F+9nhvkHePG0wNDLbGWwdXL+rbLWvpmZHXD8+gMfBjTh+aSe6Gnn7lwQIOTR0c8wfX3PWgv7avbdKwf/ZoBp1Gp/PvuvXW3vw5ib7emnTW4OR+3D4jB9vjNJ/7gNvfWWeH/TO/JyYrsiKCRjVEZA3UB+96kON+DxOQ/NLE8PE5iUYgIXjFnCOlxEQMaSGVxjg4gxOnEycGz8bptuNjVx08LscIgrzH3umcn+KKtiBIyvzOO2O99aAdR8cF19oZalnCtvREUw79tCd5sow1g1UKM6kXqUx4T8wsi3sTjJ3yzDmmhenLXLpo8u45eG5y4Vvbk6kkC4LLtJMowkSQxmk4ggVJEG+7c6QpHT8vvW9X7/o7+3ELmiJi2mEzZJiz8cT6TBlanBk70cB5GGIGC1gRDdZ00yADLW1FL6gqhtvNXNG5S9gdSrk4M1qu7JAsmYshzDS4peoMrU/gT7qQdqYGZaYhxZmVbGJAm/CS/HloWyhRUlknQ9KYcExTwS80d3VNOxUZJpITYyspl0LbhArhpZCD9cRWEQuhYkNGMHToQ/2Cs6swJlb39CsllxdXX6IUKh/H5jbnSsPKjgmoaFQ1f8wRLR0UnGE/RcDEjj2jXG1WVTwUs8+zxfcrVO+vSsuOpVKxCfYZiQ0/aPKuxQbQ8lIz+DClxC8u+snlcJ7Yr1z1JPqUH0V+GDXbOwAib931Y4Imaq0NTIXPXY+N5L18GJ37SVWu+hwXff8l72Ds9XuwYIBaXPq6Shm4l+Vl/5QiOlV+uTk6YR9PxKsI9xNJny31ygK1e+nIRC1N97EGkFPI+jCpiHe5PCEy7oWqWSwRrpOvhFzcbTWMbm3ZJAOn1rUKpYIt/lDhW/5RHHteeWFN60qo98YJuoq1nK3uW5AabyspC1BcIEpOhft+SZAShYoLSvnmSfnYADUERP5jJn2h5XtsgCRuhYQqAvwTwn33+YWEKUI72HX5AtfSAZDe8F2DtPPm77afhl0EkthzuCQU0BWApgQIH9+KB0JhopMM7bJrdTRoleM2JAVNMyPF+wdoaz+XJpGoVAQ7WXUkcV7gT3oUZyi/ISIJAVKhgNp+4b4veCFhYVJw4locdSjZCp9cPUhLF9EZ3KKzURepMEtCDPP3VcWFx4UIiZIklIpFNfHpdEafIF2aRmOcrUmjohbT2WUllbmRvgfbythbQO3222fpDJoufaQPncYYuqoGtUEsCJZL6/3PR5b4syeSjZMQG/T2maGANlXT2v8S4AULWaUkCxfLyW8iW4kdka+nEMjxpL2NCwsYNBp+Q61PF43zyDg9Bm9+3NNySn78jMZUUkumqE4Gp7JmFOdP1vc8PpRrzj9+wPinCy8K1PiJ4aYbnTYpCCbDkBSbzhu2QJ1Gd82t8jI8TH51+OzvXoWbnXUOBkNW+0mWFwGcGOUVpU81/n3TOHb5oMt2FgYGjzau0Nif0Ss7Q3XB33hjjQHjHA5E5aOyIQc8CBrLdQSs3j92VG+3nNEjbkbdbBr9zm04ruvw37vh0QKOdeGIkckc80fX3KH/h7PT4BOjgCty8VZ5ux1MoO5Cf5naca2LAsEgehI+drX8o/0Nu+W0m6K/I9gGPd/dfx/EN/wN62AhsBWuAAAAAElFTkSuQmCC
|
||||
">
|
||||
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<div style="float: right;">
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="TensorFlow" src="https://img.shields.io/badge/TensorFlow-FF6F00?style=flat&logo=tensorflow&logoColor=white">
|
||||
<img alt="Flax" src="https://img.shields.io/badge/Flax-29a79b.svg?style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAC0AAAAtCAMAAAANxBKoAAAC7lBMVEUAAADg5vYHPVgAoJH+/v76+v39/f9JbLP///9+AIgAnY3///+mcqzt8fXy9fgkXa3Ax9709fr+///9/f8qXq49qp5AaLGMwrv8/P0eW60VWawxYq8yqJzG2dytt9Wyu9elzci519Lf3O3S2efY3OrY0+Xp7PT///////+dqNCexMc6Z7AGpJeGvbenstPZ5ejQ1OfJzOLa7ejh4+/r8fT29vpccbklWK8PVa0AS6ghW63O498vYa+lsdKz1NDRt9Kw1c672tbD3tnAxt7R6OHp5vDe7OrDyuDn6vLl6/EAQKak0MgATakkppo3ZK/Bz9y8w9yzu9jey97axdvHzeG21NHH4trTwthKZrVGZLSUSpuPQJiGAI+GAI8SWKydycLL4d7f2OTi1+S9xNzL0ePT6OLGzeEAo5U0qJw/aLEAo5JFa7JBabEAp5Y4qZ2QxLyKmsm3kL2xoMOehrRNb7RIbbOZgrGre68AUqwAqZqNN5aKJ5N/lMq+qsd8kMa4pcWzh7muhLMEV69juq2kbKqgUaOTR5uMMZWLLZSGAI5VAIdEAH+ovNDHuNCnxcy3qcaYx8K8msGplrx+wLahjbYdXrV6vbMvYK9DrZ8QrZ8tqJuFms+Sos6sw8ecy8RffsNVeMCvmb43aLltv7Q4Y7EZWK4QWa1gt6meZKUdr6GOAZVeA4xPAISyveLUwtivxtKTpNJ2jcqfvcltiMiwwcfAoMVxhL+Kx7xjdrqTe60tsaNQs6KaRKACrJ6UTZwkqpqTL5pkHY4AloSgsd2ptNXPvNOOncuxxsqFl8lmg8apt8FJcr9EbryGxLqlkrkrY7dRa7ZGZLQ5t6iXUZ6PPpgVpZeJCJFKAIGareTa0+KJod3H0deY2M+esM25usmYu8d2zsJOdcBVvrCLbqcAOaaHaKQAMaScWqKBXqCXMJ2RHpiLF5NmJZAdAHN2kta11dKu1M+DkcZLdb+Mcql3TppyRJdzQ5ZtNZNlIY+DF4+voCOQAAAAZ3RSTlMABAT+MEEJ/RH+/TP+Zlv+pUo6Ifz8+fco/fz6+evr39S9nJmOilQaF/7+/f38+smmoYp6b1T+/v7++vj189zU0tDJxsGzsrKSfv34+Pf27dDOysG9t6+n/vv6+vr59uzr1tG+tZ6Qg9Ym3QAABR5JREFUSMeNlVVUG1EQhpcuxEspXqS0SKEtxQp1d3d332STTRpIQhIISQgJhODu7lAoDoUCpe7u7u7+1puGpqnCPOyZvffbOXPm/PsP9JfQgyCC+tmTABTOcbxDz/heENS7/1F+9nhvkHePG0wNDLbGWwdXL+rbLWvpmZHXD8+gMfBjTh+aSe6Gnn7lwQIOTR0c8wfX3PWgv7avbdKwf/ZoBp1Gp/PvuvXW3vw5ib7emnTW4OR+3D4jB9vjNJ/7gNvfWWeH/TO/JyYrsiKCRjVEZA3UB+96kON+DxOQ/NLE8PE5iUYgIXjFnCOlxEQMaSGVxjg4gxOnEycGz8bptuNjVx08LscIgrzH3umcn+KKtiBIyvzOO2O99aAdR8cF19oZalnCtvREUw79tCd5sow1g1UKM6kXqUx4T8wsi3sTjJ3yzDmmhenLXLpo8u45eG5y4Vvbk6kkC4LLtJMowkSQxmk4ggVJEG+7c6QpHT8vvW9X7/o7+3ELmiJi2mEzZJiz8cT6TBlanBk70cB5GGIGC1gRDdZ00yADLW1FL6gqhtvNXNG5S9gdSrk4M1qu7JAsmYshzDS4peoMrU/gT7qQdqYGZaYhxZmVbGJAm/CS/HloWyhRUlknQ9KYcExTwS80d3VNOxUZJpITYyspl0LbhArhpZCD9cRWEQuhYkNGMHToQ/2Cs6swJlb39CsllxdXX6IUKh/H5jbnSsPKjgmoaFQ1f8wRLR0UnGE/RcDEjj2jXG1WVTwUs8+zxfcrVO+vSsuOpVKxCfYZiQ0/aPKuxQbQ8lIz+DClxC8u+snlcJ7Yr1z1JPqUH0V+GDXbOwAib931Y4Imaq0NTIXPXY+N5L18GJ37SVWu+hwXff8l72Ds9XuwYIBaXPq6Shm4l+Vl/5QiOlV+uTk6YR9PxKsI9xNJny31ygK1e+nIRC1N97EGkFPI+jCpiHe5PCEy7oWqWSwRrpOvhFzcbTWMbm3ZJAOn1rUKpYIt/lDhW/5RHHteeWFN60qo98YJuoq1nK3uW5AabyspC1BcIEpOhft+SZAShYoLSvnmSfnYADUERP5jJn2h5XtsgCRuhYQqAvwTwn33+YWEKUI72HX5AtfSAZDe8F2DtPPm77afhl0EkthzuCQU0BWApgQIH9+KB0JhopMM7bJrdTRoleM2JAVNMyPF+wdoaz+XJpGoVAQ7WXUkcV7gT3oUZyi/ISIJAVKhgNp+4b4veCFhYVJw4locdSjZCp9cPUhLF9EZ3KKzURepMEtCDPP3VcWFx4UIiZIklIpFNfHpdEafIF2aRmOcrUmjohbT2WUllbmRvgfbythbQO3222fpDJoufaQPncYYuqoGtUEsCJZL6/3PR5b4syeSjZMQG/T2maGANlXT2v8S4AULWaUkCxfLyW8iW4kdka+nEMjxpL2NCwsYNBp+Q61PF43zyDg9Bm9+3NNySn78jMZUUkumqE4Gp7JmFOdP1vc8PpRrzj9+wPinCy8K1PiJ4aYbnTYpCCbDkBSbzhu2QJ1Gd82t8jI8TH51+OzvXoWbnXUOBkNW+0mWFwGcGOUVpU81/n3TOHb5oMt2FgYGjzau0Nif0Ss7Q3XB33hjjQHjHA5E5aOyIQc8CBrLdQSs3j92VG+3nNEjbkbdbBr9zm04ruvw37vh0QKOdeGIkckc80fX3KH/h7PT4BOjgCty8VZ5ux1MoO5Cf5naca2LAsEgehI+drX8o/0Nu+W0m6K/I9gGPd/dfx/EN/wN62AhsBWuAAAAAElFTkSuQmCC
|
||||
">
|
||||
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
</div>
|
||||
|
||||
## Overview
|
||||
# CLIP
|
||||
|
||||
The CLIP model was proposed in [Learning Transferable Visual Models From Natural Language Supervision](https://arxiv.org/abs/2103.00020) by Alec Radford, Jong Wook Kim, Chris Hallacy, Aditya Ramesh, Gabriel Goh,
|
||||
Sandhini Agarwal, Girish Sastry, Amanda Askell, Pamela Mishkin, Jack Clark, Gretchen Krueger, Ilya Sutskever. CLIP
|
||||
(Contrastive Language-Image Pre-Training) is a neural network trained on a variety of (image, text) pairs. It can be
|
||||
instructed in natural language to predict the most relevant text snippet, given an image, without directly optimizing
|
||||
for the task, similarly to the zero-shot capabilities of GPT-2 and 3.
|
||||
[CLIP](https://huggingface.co/papers/2103.00020) is a is a multimodal vision and language model motivated by overcoming the fixed number of object categories when training a computer vision model. CLIP learns about images directly from raw text by jointly training on 400M (image, text) pairs. Pretraining on this scale enables zero-shot transfer to downstream tasks. CLIP uses an image encoder and text encoder to get visual features and text features. Both features are projected to a latent space with the same number of dimensions and their dot product gives a similarity score.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
You can find all the original CLIP checkpoints under the [OpenAI](https://huggingface.co/openai?search_models=clip) organization.
|
||||
|
||||
*State-of-the-art computer vision systems are trained to predict a fixed set of predetermined object categories. This
|
||||
restricted form of supervision limits their generality and usability since additional labeled data is needed to specify
|
||||
any other visual concept. Learning directly from raw text about images is a promising alternative which leverages a
|
||||
much broader source of supervision. We demonstrate that the simple pre-training task of predicting which caption goes
|
||||
with which image is an efficient and scalable way to learn SOTA image representations from scratch on a dataset of 400
|
||||
million (image, text) pairs collected from the internet. After pre-training, natural language is used to reference
|
||||
learned visual concepts (or describe new ones) enabling zero-shot transfer of the model to downstream tasks. We study
|
||||
the performance of this approach by benchmarking on over 30 different existing computer vision datasets, spanning tasks
|
||||
such as OCR, action recognition in videos, geo-localization, and many types of fine-grained object classification. The
|
||||
model transfers non-trivially to most tasks and is often competitive with a fully supervised baseline without the need
|
||||
for any dataset specific training. For instance, we match the accuracy of the original ResNet-50 on ImageNet zero-shot
|
||||
without needing to use any of the 1.28 million training examples it was trained on. We release our code and pre-trained
|
||||
model weights at this https URL.*
|
||||
> [!TIP]
|
||||
> Click on the CLIP models in the right sidebar for more examples of how to apply CLIP to different image and language tasks.
|
||||
|
||||
This model was contributed by [valhalla](https://huggingface.co/valhalla). The original code can be found [here](https://github.com/openai/CLIP).
|
||||
The example below demonstrates how to calculate similarity scores between multiple text descriptions and an image with [`Pipeline`] or the [`AutoModel`] class.
|
||||
|
||||
## Usage tips and example
|
||||
<hfoptions id="usage">
|
||||
<hfoption id="Pipeline">
|
||||
|
||||
CLIP is a multi-modal vision and language model. It can be used for image-text similarity and for zero-shot image
|
||||
classification. CLIP uses a ViT like transformer to get visual features and a causal language model to get the text
|
||||
features. Both the text and visual features are then projected to a latent space with identical dimension. The dot
|
||||
product between the projected image and text features is then used as a similar score.
|
||||
```py
|
||||
import torch
|
||||
from transformers import pipeline
|
||||
|
||||
To feed images to the Transformer encoder, each image is split into a sequence of fixed-size non-overlapping patches,
|
||||
which are then linearly embedded. A [CLS] token is added to serve as representation of an entire image. The authors
|
||||
also add absolute position embeddings, and feed the resulting sequence of vectors to a standard Transformer encoder.
|
||||
The [`CLIPImageProcessor`] can be used to resize (or rescale) and normalize images for the model.
|
||||
|
||||
The [`CLIPTokenizer`] is used to encode the text. The [`CLIPProcessor`] wraps
|
||||
[`CLIPImageProcessor`] and [`CLIPTokenizer`] into a single instance to both
|
||||
encode the text and prepare the images. The following example shows how to get the image-text similarity scores using
|
||||
[`CLIPProcessor`] and [`CLIPModel`].
|
||||
|
||||
|
||||
```python
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
|
||||
>>> from transformers import CLIPProcessor, CLIPModel
|
||||
|
||||
>>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
||||
>>> processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
||||
|
||||
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
>>> inputs = processor(text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True)
|
||||
|
||||
>>> outputs = model(**inputs)
|
||||
>>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score
|
||||
>>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
|
||||
clip = pipeline(
|
||||
task="zero-shot-image-classification",
|
||||
model="openai/clip-vit-base-patch32",
|
||||
torch_dtype=torch.bfloat16,
|
||||
device=0
|
||||
)
|
||||
labels = ["a photo of a cat", "a photo of a dog", "a photo of a car"]
|
||||
clip("http://images.cocodataset.org/val2017/000000039769.jpg", candidate_labels=labels)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="AutoModel">
|
||||
|
||||
### Combining CLIP and Flash Attention 2
|
||||
```py
|
||||
import requests
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import AutoProcessor, AutoModel
|
||||
|
||||
First, make sure to install the latest version of Flash Attention 2.
|
||||
model = AutoModel.from_pretrained("openai/clip-vit-base-patch32", torch_dtype=torch.bfloat16, attn_implementation="sdpa")
|
||||
processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
||||
|
||||
```bash
|
||||
pip install -U flash-attn --no-build-isolation
|
||||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
image = Image.open(requests.get(url, stream=True).raw)
|
||||
labels = ["a photo of a cat", "a photo of a dog", "a photo of a car"]
|
||||
|
||||
inputs = processor(text=labels, images=image, return_tensors="pt", padding=True)
|
||||
|
||||
outputs = model(**inputs)
|
||||
logits_per_image = outputs.logits_per_image
|
||||
probs = logits_per_image.softmax(dim=1)
|
||||
most_likely_idx = probs.argmax(dim=1).item()
|
||||
most_likely_label = labels[most_likely_idx]
|
||||
print(f"Most likely label: {most_likely_label} with probability: {probs[0][most_likely_idx].item():.3f}")
|
||||
```
|
||||
|
||||
Make also sure that you have a hardware that is compatible with Flash-Attention 2. Read more about it in the official documentation of flash-attn repository. Make also sure to load your model in half-precision (e.g. `torch.float16`)
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
<Tip warning={true}>
|
||||
## Notes
|
||||
|
||||
For small batch sizes, you might notice a slowdown in your model when using flash attention. Refer to the section [Expected speedups with Flash Attention and SDPA](#Expected-speedups-with-Flash-Attention-and-SDPA) below and select an appropriate attention implementation.
|
||||
|
||||
</Tip>
|
||||
|
||||
To load and run a model using Flash Attention 2, refer to the snippet below:
|
||||
|
||||
```python
|
||||
>>> import torch
|
||||
>>> import requests
|
||||
>>> from PIL import Image
|
||||
|
||||
>>> from transformers import CLIPProcessor, CLIPModel
|
||||
|
||||
>>> device = "cuda"
|
||||
>>> torch_dtype = torch.float16
|
||||
|
||||
>>> model = CLIPModel.from_pretrained(
|
||||
... "openai/clip-vit-base-patch32",
|
||||
... attn_implementation="flash_attention_2",
|
||||
... device_map=device,
|
||||
... torch_dtype=torch_dtype,
|
||||
... )
|
||||
>>> processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
||||
|
||||
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
>>> inputs = processor(text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True)
|
||||
>>> inputs.to(device)
|
||||
|
||||
>>> with torch.no_grad():
|
||||
... with torch.autocast(device):
|
||||
... outputs = model(**inputs)
|
||||
|
||||
>>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score
|
||||
>>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
|
||||
>>> print(probs)
|
||||
tensor([[0.9946, 0.0052]], device='cuda:0', dtype=torch.float16)
|
||||
```
|
||||
|
||||
|
||||
### Using Scaled Dot Product Attention (SDPA)
|
||||
|
||||
PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function
|
||||
encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the
|
||||
[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
|
||||
or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention)
|
||||
page for more information.
|
||||
|
||||
SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set
|
||||
`attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used.
|
||||
|
||||
```python
|
||||
from transformers import CLIPModel
|
||||
|
||||
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32", torch_dtype=torch.float16, attn_implementation="sdpa")
|
||||
```
|
||||
|
||||
For the best speedups, we recommend loading the model in half-precision (e.g. `torch.float16` or `torch.bfloat16`).
|
||||
|
||||
### Expected speedups with Flash Attention and SDPA
|
||||
|
||||
On a local benchmark (NVIDIA A10G, PyTorch 2.3.1+cu121) with `float16`, we saw the following speedups during inference for `"openai/clip-vit-large-patch14"` checkpoint ([code](https://gist.github.com/qubvel/ac691a54e54f9fae8144275f866a7ff8)):
|
||||
|
||||
#### CLIPTextModel
|
||||
|
||||
| Num text labels | Eager (s/iter) | FA2 (s/iter) | FA2 speedup | SDPA (s/iter) | SDPA speedup |
|
||||
|------------------:|-----------------:|---------------:|--------------:|----------------:|---------------:|
|
||||
| 4 | 0.009 | 0.012 | 0.737 | 0.007 | 1.269 |
|
||||
| 16 | 0.009 | 0.014 | 0.659 | 0.008 | 1.187 |
|
||||
| 32 | 0.018 | 0.021 | 0.862 | 0.016 | 1.142 |
|
||||
| 64 | 0.034 | 0.034 | 1.001 | 0.03 | 1.163 |
|
||||
| 128 | 0.063 | 0.058 | 1.09 | 0.054 | 1.174 |
|
||||
|
||||

|
||||
|
||||
#### CLIPVisionModel
|
||||
|
||||
| Image batch size | Eager (s/iter) | FA2 (s/iter) | FA2 speedup | SDPA (s/iter) | SDPA speedup |
|
||||
|-------------------:|-----------------:|---------------:|--------------:|----------------:|---------------:|
|
||||
| 1 | 0.016 | 0.013 | 1.247 | 0.012 | 1.318 |
|
||||
| 4 | 0.025 | 0.021 | 1.198 | 0.021 | 1.202 |
|
||||
| 16 | 0.093 | 0.075 | 1.234 | 0.075 | 1.24 |
|
||||
| 32 | 0.181 | 0.147 | 1.237 | 0.146 | 1.241 |
|
||||
|
||||

|
||||
|
||||
#### CLIPModel
|
||||
|
||||
| Image batch size | Num text labels | Eager (s/iter) | FA2 (s/iter) | FA2 speedup | SDPA (s/iter) | SDPA speedup |
|
||||
|-------------------:|------------------:|-----------------:|---------------:|--------------:|----------------:|---------------:|
|
||||
| 1 | 4 | 0.025 | 0.026 | 0.954 | 0.02 | 1.217 |
|
||||
| 1 | 16 | 0.026 | 0.028 | 0.918 | 0.02 | 1.287 |
|
||||
| 1 | 64 | 0.042 | 0.046 | 0.906 | 0.036 | 1.167 |
|
||||
| 4 | 4 | 0.028 | 0.033 | 0.849 | 0.024 | 1.189 |
|
||||
| 4 | 16 | 0.034 | 0.035 | 0.955 | 0.029 | 1.169 |
|
||||
| 4 | 64 | 0.059 | 0.055 | 1.072 | 0.05 | 1.179 |
|
||||
| 16 | 4 | 0.096 | 0.088 | 1.091 | 0.078 | 1.234 |
|
||||
| 16 | 16 | 0.102 | 0.09 | 1.129 | 0.083 | 1.224 |
|
||||
| 16 | 64 | 0.127 | 0.11 | 1.157 | 0.105 | 1.218 |
|
||||
| 32 | 4 | 0.185 | 0.159 | 1.157 | 0.149 | 1.238 |
|
||||
| 32 | 16 | 0.19 | 0.162 | 1.177 | 0.154 | 1.233 |
|
||||
| 32 | 64 | 0.216 | 0.181 | 1.19 | 0.176 | 1.228 |
|
||||
|
||||
## Resources
|
||||
|
||||
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with CLIP.
|
||||
|
||||
- [Fine tuning CLIP with Remote Sensing (Satellite) images and captions](https://huggingface.co/blog/fine-tune-clip-rsicd), a blog post about how to fine-tune CLIP with [RSICD dataset](https://github.com/201528014227051/RSICD_optimal) and comparison of performance changes due to data augmentation.
|
||||
- This [example script](https://github.com/huggingface/transformers/tree/main/examples/pytorch/contrastive-image-text) shows how to train a CLIP-like vision-text dual encoder model using a pre-trained vision and text encoder using [COCO dataset](https://cocodataset.org/#home).
|
||||
|
||||
<PipelineTag pipeline="image-to-text"/>
|
||||
|
||||
- A [notebook](https://colab.research.google.com/drive/1tuoAC5F4sC7qid56Z0ap-stR3rwdk0ZV?usp=sharing) on how to use a pretrained CLIP for inference with beam search for image captioning. 🌎
|
||||
|
||||
**Image retrieval**
|
||||
|
||||
- A [notebook](https://colab.research.google.com/drive/1bLVwVKpAndpEDHqjzxVPr_9nGrSbuOQd?usp=sharing) on image retrieval using pretrained CLIP and computing MRR(Mean Reciprocal Rank) score. 🌎
|
||||
- A [notebook](https://colab.research.google.com/github/deep-diver/image_search_with_natural_language/blob/main/notebooks/Image_Search_CLIP.ipynb) on image retrieval and showing the similarity score. 🌎
|
||||
- A [notebook](https://colab.research.google.com/drive/1xO-wC_m_GNzgjIBQ4a4znvQkvDoZJvH4?usp=sharing) on how to map images and texts to the same vector space using Multilingual CLIP. 🌎
|
||||
- A [notebook](https://colab.research.google.com/github/vivien000/clip-demo/blob/master/clip.ipynb#scrollTo=uzdFhRGqiWkR) on how to run CLIP on semantic image search using [Unsplash](https://unsplash.com) and [TMDB](https://www.themoviedb.org/) datasets. 🌎
|
||||
|
||||
**Explainability**
|
||||
|
||||
- A [notebook](https://colab.research.google.com/github/hila-chefer/Transformer-MM-Explainability/blob/main/CLIP_explainability.ipynb) on how to visualize similarity between input token and image segment. 🌎
|
||||
|
||||
If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we will review it.
|
||||
The resource should ideally demonstrate something new instead of duplicating an existing resource.
|
||||
- Use [`CLIPImageProcessor`] to resize (or rescale) and normalizes images for the model.
|
||||
|
||||
## CLIPConfig
|
||||
|
||||
|
||||
@ -14,108 +14,154 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# CodeLlama
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="Flax" src="https://img.shields.io/badge/Flax-29a79b.svg?style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAC0AAAAtCAMAAAANxBKoAAAC7lBMVEUAAADg5vYHPVgAoJH+/v76+v39/f9JbLP///9+AIgAnY3///+mcqzt8fXy9fgkXa3Ax9709fr+///9/f8qXq49qp5AaLGMwrv8/P0eW60VWawxYq8yqJzG2dytt9Wyu9elzci519Lf3O3S2efY3OrY0+Xp7PT///////+dqNCexMc6Z7AGpJeGvbenstPZ5ejQ1OfJzOLa7ejh4+/r8fT29vpccbklWK8PVa0AS6ghW63O498vYa+lsdKz1NDRt9Kw1c672tbD3tnAxt7R6OHp5vDe7OrDyuDn6vLl6/EAQKak0MgATakkppo3ZK/Bz9y8w9yzu9jey97axdvHzeG21NHH4trTwthKZrVGZLSUSpuPQJiGAI+GAI8SWKydycLL4d7f2OTi1+S9xNzL0ePT6OLGzeEAo5U0qJw/aLEAo5JFa7JBabEAp5Y4qZ2QxLyKmsm3kL2xoMOehrRNb7RIbbOZgrGre68AUqwAqZqNN5aKJ5N/lMq+qsd8kMa4pcWzh7muhLMEV69juq2kbKqgUaOTR5uMMZWLLZSGAI5VAIdEAH+ovNDHuNCnxcy3qcaYx8K8msGplrx+wLahjbYdXrV6vbMvYK9DrZ8QrZ8tqJuFms+Sos6sw8ecy8RffsNVeMCvmb43aLltv7Q4Y7EZWK4QWa1gt6meZKUdr6GOAZVeA4xPAISyveLUwtivxtKTpNJ2jcqfvcltiMiwwcfAoMVxhL+Kx7xjdrqTe60tsaNQs6KaRKACrJ6UTZwkqpqTL5pkHY4AloSgsd2ptNXPvNOOncuxxsqFl8lmg8apt8FJcr9EbryGxLqlkrkrY7dRa7ZGZLQ5t6iXUZ6PPpgVpZeJCJFKAIGareTa0+KJod3H0deY2M+esM25usmYu8d2zsJOdcBVvrCLbqcAOaaHaKQAMaScWqKBXqCXMJ2RHpiLF5NmJZAdAHN2kta11dKu1M+DkcZLdb+Mcql3TppyRJdzQ5ZtNZNlIY+DF4+voCOQAAAAZ3RSTlMABAT+MEEJ/RH+/TP+Zlv+pUo6Ifz8+fco/fz6+evr39S9nJmOilQaF/7+/f38+smmoYp6b1T+/v7++vj189zU0tDJxsGzsrKSfv34+Pf27dDOysG9t6+n/vv6+vr59uzr1tG+tZ6Qg9Ym3QAABR5JREFUSMeNlVVUG1EQhpcuxEspXqS0SKEtxQp1d3d332STTRpIQhIISQgJhODu7lAoDoUCpe7u7u7+1puGpqnCPOyZvffbOXPm/PsP9JfQgyCC+tmTABTOcbxDz/heENS7/1F+9nhvkHePG0wNDLbGWwdXL+rbLWvpmZHXD8+gMfBjTh+aSe6Gnn7lwQIOTR0c8wfX3PWgv7avbdKwf/ZoBp1Gp/PvuvXW3vw5ib7emnTW4OR+3D4jB9vjNJ/7gNvfWWeH/TO/JyYrsiKCRjVEZA3UB+96kON+DxOQ/NLE8PE5iUYgIXjFnCOlxEQMaSGVxjg4gxOnEycGz8bptuNjVx08LscIgrzH3umcn+KKtiBIyvzOO2O99aAdR8cF19oZalnCtvREUw79tCd5sow1g1UKM6kXqUx4T8wsi3sTjJ3yzDmmhenLXLpo8u45eG5y4Vvbk6kkC4LLtJMowkSQxmk4ggVJEG+7c6QpHT8vvW9X7/o7+3ELmiJi2mEzZJiz8cT6TBlanBk70cB5GGIGC1gRDdZ00yADLW1FL6gqhtvNXNG5S9gdSrk4M1qu7JAsmYshzDS4peoMrU/gT7qQdqYGZaYhxZmVbGJAm/CS/HloWyhRUlknQ9KYcExTwS80d3VNOxUZJpITYyspl0LbhArhpZCD9cRWEQuhYkNGMHToQ/2Cs6swJlb39CsllxdXX6IUKh/H5jbnSsPKjgmoaFQ1f8wRLR0UnGE/RcDEjj2jXG1WVTwUs8+zxfcrVO+vSsuOpVKxCfYZiQ0/aPKuxQbQ8lIz+DClxC8u+snlcJ7Yr1z1JPqUH0V+GDXbOwAib931Y4Imaq0NTIXPXY+N5L18GJ37SVWu+hwXff8l72Ds9XuwYIBaXPq6Shm4l+Vl/5QiOlV+uTk6YR9PxKsI9xNJny31ygK1e+nIRC1N97EGkFPI+jCpiHe5PCEy7oWqWSwRrpOvhFzcbTWMbm3ZJAOn1rUKpYIt/lDhW/5RHHteeWFN60qo98YJuoq1nK3uW5AabyspC1BcIEpOhft+SZAShYoLSvnmSfnYADUERP5jJn2h5XtsgCRuhYQqAvwTwn33+YWEKUI72HX5AtfSAZDe8F2DtPPm77afhl0EkthzuCQU0BWApgQIH9+KB0JhopMM7bJrdTRoleM2JAVNMyPF+wdoaz+XJpGoVAQ7WXUkcV7gT3oUZyi/ISIJAVKhgNp+4b4veCFhYVJw4locdSjZCp9cPUhLF9EZ3KKzURepMEtCDPP3VcWFx4UIiZIklIpFNfHpdEafIF2aRmOcrUmjohbT2WUllbmRvgfbythbQO3222fpDJoufaQPncYYuqoGtUEsCJZL6/3PR5b4syeSjZMQG/T2maGANlXT2v8S4AULWaUkCxfLyW8iW4kdka+nEMjxpL2NCwsYNBp+Q61PF43zyDg9Bm9+3NNySn78jMZUUkumqE4Gp7JmFOdP1vc8PpRrzj9+wPinCy8K1PiJ4aYbnTYpCCbDkBSbzhu2QJ1Gd82t8jI8TH51+OzvXoWbnXUOBkNW+0mWFwGcGOUVpU81/n3TOHb5oMt2FgYGjzau0Nif0Ss7Q3XB33hjjQHjHA5E5aOyIQc8CBrLdQSs3j92VG+3nNEjbkbdbBr9zm04ruvw37vh0QKOdeGIkckc80fX3KH/h7PT4BOjgCty8VZ5ux1MoO5Cf5naca2LAsEgehI+drX8o/0Nu+W0m6K/I9gGPd/dfx/EN/wN62AhsBWuAAAAAElFTkSuQmCC
|
||||
">
|
||||
<div style="float: right;">
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="Flax" src="https://img.shields.io/badge/Flax-29a79b.svg?style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAC0AAAAtCAMAAAANxBKoAAAC7lBMVEUAAADg5vYHPVgAoJH+/v76+v39/f9JbLP///9+AIgAnY3///+mcqzt8fXy9fgkXa3Ax9709fr+///9/f8qXq49qp5AaLGMwrv8/P0eW60VWawxYq8yqJzG2dytt9Wyu9elzci519Lf3O3S2efY3OrY0+Xp7PT///////+dqNCexMc6Z7AGpJeGvbenstPZ5ejQ1OfJzOLa7ejh4+/r8fT29vpccbklWK8PVa0AS6ghW63O498vYa+lsdKz1NDRt9Kw1c672tbD3tnAxt7R6OHp5vDe7OrDyuDn6vLl6/EAQKak0MgATakkppo3ZK/Bz9y8w9yzu9jey97axdvHzeG21NHH4trTwthKZrVGZLSUSpuPQJiGAI+GAI8SWKydycLL4d7f2OTi1+S9xNzL0ePT6OLGzeEAo5U0qJw/aLEAo5JFa7JBabEAp5Y4qZ2QxLyKmsm3kL2xoMOehrRNb7RIbbOZgrGre68AUqwAqZqNN5aKJ5N/lMq+qsd8kMa4pcWzh7muhLMEV69juq2kbKqgUaOTR5uMMZWLLZSGAI5VAIdEAH+ovNDHuNCnxcy3qcaYx8K8msGplrx+wLahjbYdXrV6vbMvYK9DrZ8QrZ8tqJuFms+Sos6sw8ecy8RffsNVeMCvmb43aLltv7Q4Y7EZWK4QWa1gt6meZKUdr6GOAZVeA4xPAISyveLUwtivxtKTpNJ2jcqfvcltiMiwwcfAoMVxhL+Kx7xjdrqTe60tsaNQs6KaRKACrJ6UTZwkqpqTL5pkHY4AloSgsd2ptNXPvNOOncuxxsqFl8lmg8apt8FJcr9EbryGxLqlkrkrY7dRa7ZGZLQ5t6iXUZ6PPpgVpZeJCJFKAIGareTa0+KJod3H0deY2M+esM25usmYu8d2zsJOdcBVvrCLbqcAOaaHaKQAMaScWqKBXqCXMJ2RHpiLF5NmJZAdAHN2kta11dKu1M+DkcZLdb+Mcql3TppyRJdzQ5ZtNZNlIY+DF4+voCOQAAAAZ3RSTlMABAT+MEEJ/RH+/TP+Zlv+pUo6Ifz8+fco/fz6+evr39S9nJmOilQaF/7+/f38+smmoYp6b1T+/v7++vj189zU0tDJxsGzsrKSfv34+Pf27dDOysG9t6+n/vv6+vr59uzr1tG+tZ6Qg9Ym3QAABR5JREFUSMeNlVVUG1EQhpcuxEspXqS0SKEtxQp1d3d332STTRpIQhIISQgJhODu7lAoDoUCpe7u7u7+1puGpqnCPOyZvffbOXPm/PsP9JfQgyCC+tmTABTOcbxDz/heENS7/1F+9nhvkHePG0wNDLbGWwdXL+rbLWvpmZHXD8+gMfBjTh+aSe6Gnn7lwQIOTR0c8wfX3PWgv7avbdKwf/ZoBp1Gp/PvuvXW3vw5ib7emnTW4OR+3D4jB9vjNJ/7gNvfWWeH/TO/JyYrsiKCRjVEZA3UB+96kON+DxOQ/NLE8PE5iUYgIXjFnCOlxEQMaSGVxjg4gxOnEycGz8bptuNjVx08LscIgrzH3umcn+KKtiBIyvzOO2O99aAdR8cF19oZalnCtvREUw79tCd5sow1g1UKM6kXqUx4T8wsi3sTjJ3yzDmmhenLXLpo8u45eG5y4Vvbk6kkC4LLtJMowkSQxmk4ggVJEG+7c6QpHT8vvW9X7/o7+3ELmiJi2mEzZJiz8cT6TBlanBk70cB5GGIGC1gRDdZ00yADLW1FL6gqhtvNXNG5S9gdSrk4M1qu7JAsmYshzDS4peoMrU/gT7qQdqYGZaYhxZmVbGJAm/CS/HloWyhRUlknQ9KYcExTwS80d3VNOxUZJpITYyspl0LbhArhpZCD9cRWEQuhYkNGMHToQ/2Cs6swJlb39CsllxdXX6IUKh/H5jbnSsPKjgmoaFQ1f8wRLR0UnGE/RcDEjj2jXG1WVTwUs8+zxfcrVO+vSsuOpVKxCfYZiQ0/aPKuxQbQ8lIz+DClxC8u+snlcJ7Yr1z1JPqUH0V+GDXbOwAib931Y4Imaq0NTIXPXY+N5L18GJ37SVWu+hwXff8l72Ds9XuwYIBaXPq6Shm4l+Vl/5QiOlV+uTk6YR9PxKsI9xNJny31ygK1e+nIRC1N97EGkFPI+jCpiHe5PCEy7oWqWSwRrpOvhFzcbTWMbm3ZJAOn1rUKpYIt/lDhW/5RHHteeWFN60qo98YJuoq1nK3uW5AabyspC1BcIEpOhft+SZAShYoLSvnmSfnYADUERP5jJn2h5XtsgCRuhYQqAvwTwn33+YWEKUI72HX5AtfSAZDe8F2DtPPm77afhl0EkthzuCQU0BWApgQIH9+KB0JhopMM7bJrdTRoleM2JAVNMyPF+wdoaz+XJpGoVAQ7WXUkcV7gT3oUZyi/ISIJAVKhgNp+4b4veCFhYVJw4locdSjZCp9cPUhLF9EZ3KKzURepMEtCDPP3VcWFx4UIiZIklIpFNfHpdEafIF2aRmOcrUmjohbT2WUllbmRvgfbythbQO3222fpDJoufaQPncYYuqoGtUEsCJZL6/3PR5b4syeSjZMQG/T2maGANlXT2v8S4AULWaUkCxfLyW8iW4kdka+nEMjxpL2NCwsYNBp+Q61PF43zyDg9Bm9+3NNySn78jMZUUkumqE4Gp7JmFOdP1vc8PpRrzj9+wPinCy8K1PiJ4aYbnTYpCCbDkBSbzhu2QJ1Gd82t8jI8TH51+OzvXoWbnXUOBkNW+0mWFwGcGOUVpU81/n3TOHb5oMt2FgYGjzau0Nif0Ss7Q3XB33hjjQHjHA5E5aOyIQc8CBrLdQSs3j92VG+3nNEjbkbdbBr9zm04ruvw37vh0QKOdeGIkckc80fX3KH/h7PT4BOjgCty8VZ5ux1MoO5Cf5naca2LAsEgehI+drX8o/0Nu+W0m6K/I9gGPd/dfx/EN/wN62AhsBWuAAAAAElFTkSuQmCC
|
||||
">
|
||||
</div>
|
||||
</div>
|
||||
|
||||
## Overview
|
||||
# CodeLlama
|
||||
|
||||
The Code Llama model was proposed in [Code Llama: Open Foundation Models for Code](https://ai.meta.com/research/publications/code-llama-open-foundation-models-for-code/) by Baptiste Rozière, Jonas Gehring, Fabian Gloeckle, Sten Sootla, Itai Gat, Xiaoqing Ellen Tan, Yossi Adi, Jingyu Liu, Tal Remez, Jérémy Rapin, Artyom Kozhevnikov, Ivan Evtimov, Joanna Bitton, Manish Bhatt, Cristian Canton Ferrer, Aaron Grattafiori, Wenhan Xiong, Alexandre Défossez, Jade Copet, Faisal Azhar, Hugo Touvron, Louis Martin, Nicolas Usunier, Thomas Scialom, Gabriel Synnaeve.
|
||||
[Code Llama](https://huggingface.co/papers/2308.12950) is a specialized family of large language models based on [Llama 2](./llama2) for coding tasks. It comes in different flavors - general code, Python-specific, and instruction-following variant - all available in 7B, 13B, 34B, and 70B parameters. Code Llama models can generate, explain, and even fill in missing parts of your code (called "infilling"). It can also handle very long contexts with stable generation up to 100k tokens, even though it was trained on sequences of 16K tokens.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
You can find all the original Code Llama checkpoints under the [Code Llama](https://huggingface.co/collections/meta-llama/code-llama-family-661da32d0a9d678b6f55b933) collection.
|
||||
|
||||
*We release Code Llama, a family of large language models for code based on Llama 2 providing state-of-the-art performance among open models, infilling capabilities, support for large input contexts, and zero-shot instruction following ability for programming tasks. We provide multiple flavors to cover a wide range of applications: foundation models (Code Llama), Python specializations (Code Llama - Python), and instruction-following models (Code Llama - Instruct) with 7B, 13B and 34B parameters each. All models are trained on sequences of 16k tokens and show improvements on inputs with up to 100k tokens. 7B and 13B Code Llama and Code Llama - Instruct variants support infilling based on surrounding content. Code Llama reaches state-of-the-art performance among open models on several code benchmarks, with scores of up to 53% and 55% on HumanEval and MBPP, respectively. Notably, Code Llama - Python 7B outperforms Llama 2 70B on HumanEval and MBPP, and all our models outperform every other publicly available model on MultiPL-E. We release Code Llama under a permissive license that allows for both research and commercial use.*
|
||||
> [!TIP]
|
||||
> Click on the Code Llama models in the right sidebar for more examples of how to apply Code Llama to different coding tasks.
|
||||
|
||||
Check out all Code Llama model checkpoints [here](https://huggingface.co/models?search=code_llama) and the officially released ones in the [Meta Llama org](https://huggingface.co/meta-llama).
|
||||
The example below demonstrates how to generate code with [`Pipeline`], or the [`AutoModel`], and from the command line.
|
||||
|
||||
This model was contributed by [ArthurZucker](https://huggingface.co/ArthurZ). The original code of the authors can be found [here](https://github.com/facebookresearch/llama).
|
||||
<hfoptions id="usage">
|
||||
<hfoption id="Pipeline">
|
||||
|
||||
```py
|
||||
import torch
|
||||
from transformers import pipeline
|
||||
|
||||
## Usage tips and examples
|
||||
pipe = pipeline(
|
||||
"text-generation",
|
||||
model="meta-llama/CodeLlama-7b-hf",
|
||||
torch_dtype=torch.float16,
|
||||
device_map=0
|
||||
)
|
||||
|
||||
<Tip warning={true}>
|
||||
# basic code generation
|
||||
result = pipe("# Function to calculate the factorial of a number\ndef factorial(n):", max_new_tokens=256)
|
||||
print(result[0]['generated_text'])
|
||||
|
||||
The `Llama2` family models, on which Code Llama is based, were trained using `bfloat16`, but the original inference uses `float16`. Let's look at the different precisions:
|
||||
# infilling
|
||||
infill_result = pipe("def remove_non_ascii(s: str) -> str:\n \"\"\" <FILL_ME>\n return result", max_new_tokens=200)
|
||||
print(infill_result[0]['generated_text'])
|
||||
```
|
||||
|
||||
* `float32`: PyTorch convention on model initialization is to load models in `float32`, no matter with which `dtype` the model weights were stored. `transformers` also follows this convention for consistency with PyTorch. This will be picked by default. If you want the `AutoModel` API to load the checkpoints with the storage weights type, you must specify `torch_dtype="auto"`, e.g. `model = AutoModelForCausalLM.from_pretrained("path", torch_dtype = "auto")`.
|
||||
* `bfloat16`: Code Llama was trained with this precision, so we recommend using it for further training or fine-tuning.
|
||||
* `float16`: We recommend running inference using this precision, as it's usually faster than `bfloat16`, and evaluation metrics show no discernible degradation with respect to `bfloat16`. You can also run inference using `bfloat16`, and we recommend you check inference results with both `float16` and `bfloat16` after fine-tuning.
|
||||
</hfoption>
|
||||
<hfoption id="AutoModel">
|
||||
|
||||
As mentioned above, the `dtype` of the storage weights is mostly irrelevant unless you are using `torch_dtype="auto"` when initializing a model using. The reason is that the model will first be downloaded (using the `dtype` of the checkpoints online) and then will be casted to the default `dtype` of `torch` (becomes `torch.float32`). If there is a specified `torch_dtype`, it will be used instead.
|
||||
```py
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
</Tip>
|
||||
tokenizer = AutoTokenizer.from_pretrained("meta-llama/CodeLlama-7b-hf")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"meta-llama/CodeLlama-7b-hf",
|
||||
torch_dtype=torch.float16,
|
||||
device_map="auto",
|
||||
attn_implementation="sdpa"
|
||||
)
|
||||
|
||||
# basic code generation
|
||||
prompt = "# Function to calculate the factorial of a number\ndef factorial(n):"
|
||||
input_ids = tokenizer(prompt, return_tensors="pt").to("cuda")
|
||||
|
||||
Tips:
|
||||
- The infilling task is supported out of the box. You should be using the `tokenizer.fill_token` where you want your input to be filled.
|
||||
- The model conversion script is the same as for the `Llama2` family:
|
||||
output = model.generate(
|
||||
**input_ids,
|
||||
max_new_tokens=256,
|
||||
cache_implementation="static"
|
||||
)
|
||||
print(tokenizer.decode(output[0], skip_special_tokens=True))
|
||||
|
||||
Here is a sample usage:
|
||||
# infilling
|
||||
infill_prompt = "def remove_non_ascii(s: str) -> str:\n \"\"\" <FILL_ME>\n return result"
|
||||
input_ids = tokenizer(infill_prompt, return_tensors="pt").to(model.device)
|
||||
|
||||
filled_output = model.generate(**input_ids, max_new_tokens=200)
|
||||
filled_text = tokenizer.decode(filled_output[0], skip_special_tokens=True)
|
||||
print(filled_text)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="transformers-cli">
|
||||
|
||||
```bash
|
||||
python src/transformers/models/llama/convert_llama_weights_to_hf.py \
|
||||
--input_dir /path/to/downloaded/llama/weights --model_size 7B --output_dir /output/path
|
||||
echo -e "# Function to calculate the factorial of a number\ndef factorial(n):" | transformers-cli run --task text-generation --model meta-llama/CodeLlama-7b-hf --device 0
|
||||
```
|
||||
|
||||
Note that executing the script requires enough CPU RAM to host the whole model in float16 precision (even if the biggest versions
|
||||
come in several checkpoints they each contain a part of each weight of the model, so we need to load them all in RAM).
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
After conversion, the model and tokenizer can be loaded via:
|
||||
Quantization reduces the memory burden of large models by representing the weights in a lower precision. Refer to the [Quantization](../quantization/overview) overview for more available quantization backends.
|
||||
|
||||
```python
|
||||
>>> from transformers import LlamaForCausalLM, CodeLlamaTokenizer
|
||||
The example below uses [bitsandbytes](../quantization/bitsandbytes) to only quantize the weights to 4-bits.
|
||||
|
||||
>>> tokenizer = CodeLlamaTokenizer.from_pretrained("meta-llama/CodeLlama-7b-hf")
|
||||
>>> model = LlamaForCausalLM.from_pretrained("meta-llama/CodeLlama-7b-hf")
|
||||
>>> PROMPT = '''def remove_non_ascii(s: str) -> str:
|
||||
... """ <FILL_ME>
|
||||
... return result
|
||||
... '''
|
||||
>>> input_ids = tokenizer(PROMPT, return_tensors="pt")["input_ids"]
|
||||
>>> generated_ids = model.generate(input_ids, max_new_tokens=128)
|
||||
```py
|
||||
# pip install bitsandbytes
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, CodeLlamaTokenizer, BitsAndBytesConfig
|
||||
|
||||
>>> filling = tokenizer.batch_decode(generated_ids[:, input_ids.shape[1]:], skip_special_tokens = True)[0]
|
||||
>>> print(PROMPT.replace("<FILL_ME>", filling))
|
||||
def remove_non_ascii(s: str) -> str:
|
||||
""" Remove non-ASCII characters from a string.
|
||||
<BLANKLINE>
|
||||
Args:
|
||||
s: The string to remove non-ASCII characters from.
|
||||
<BLANKLINE>
|
||||
Returns:
|
||||
The string with non-ASCII characters removed.
|
||||
"""
|
||||
result = ""
|
||||
for c in s:
|
||||
if ord(c) < 128:
|
||||
result += c
|
||||
return result
|
||||
<BLANKLINE>
|
||||
bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True)
|
||||
tokenizer = CodeLlamaTokenizer.from_pretrained("meta-llama/CodeLlama-34b-hf")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"meta-llama/CodeLlama-34b-hf",
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
quantization_config=bnb_config
|
||||
)
|
||||
|
||||
prompt = "# Write a Python function to check if a string is a palindrome\ndef is_palindrome(s):"
|
||||
input_ids = tokenizer(prompt, return_tensors="pt").to("cuda")
|
||||
|
||||
output = model.generate(**input_ids, max_new_tokens=200, cache_implementation="static")
|
||||
print(tokenizer.decode(output[0], skip_special_tokens=True))
|
||||
```
|
||||
|
||||
If you only want the infilled part:
|
||||
```python
|
||||
>>> from transformers import pipeline
|
||||
>>> import torch
|
||||
Use the [AttentionMaskVisualizer](https://github.com/huggingface/transformers/blob/beb9b5b02246b9b7ee81ddf938f93f44cfeaad19/src/transformers/utils/attention_visualizer.py#L139) to better understand what tokens the model can and cannot attend to.
|
||||
|
||||
>>> generator = pipeline("text-generation",model="meta-llama/CodeLlama-7b-hf",torch_dtype=torch.float16, device_map="auto")
|
||||
>>> generator('def remove_non_ascii(s: str) -> str:\n """ <FILL_ME>\n return result', max_new_tokens = 128)
|
||||
[{'generated_text': 'def remove_non_ascii(s: str) -> str:\n """ <FILL_ME>\n return resultRemove non-ASCII characters from a string. """\n result = ""\n for c in s:\n if ord(c) < 128:\n result += c'}]
|
||||
```py
|
||||
from transformers.utils.attention_visualizer import AttentionMaskVisualizer
|
||||
|
||||
visualizer = AttentionMaskVisualizer("meta-llama/CodeLlama-7b-hf")
|
||||
visualizer("""def func(a, b):
|
||||
return a + b""")
|
||||
```
|
||||
|
||||
Under the hood, the tokenizer [automatically splits by `<FILL_ME>`](https://huggingface.co/docs/transformers/main/model_doc/code_llama#transformers.CodeLlamaTokenizer.fill_token) to create a formatted input string that follows [the original training pattern](https://github.com/facebookresearch/codellama/blob/cb51c14ec761370ba2e2bc351374a79265d0465e/llama/generation.py#L402). This is more robust than preparing the pattern yourself: it avoids pitfalls, such as token glueing, that are very hard to debug. To see how much CPU and GPU memory you need for this model or others, try [this calculator](https://huggingface.co/spaces/hf-accelerate/model-memory-usage) which can help determine that value.
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/codellama-attn-mask.png"/>
|
||||
</div>
|
||||
|
||||
The LLaMA tokenizer is a BPE model based on [sentencepiece](https://github.com/google/sentencepiece). One quirk of sentencepiece is that when decoding a sequence, if the first token is the start of the word (e.g. "Banana"), the tokenizer does not prepend the prefix space to the string.
|
||||
|
||||
<Tip>
|
||||
|
||||
Code Llama has the same architecture as the `Llama2` models, refer to [Llama2's documentation page](llama2) for the API reference.
|
||||
Find Code Llama tokenizer reference below.
|
||||
</Tip>
|
||||
## Notes
|
||||
|
||||
- Infilling is only available in the 7B and 13B base models, and not in the Python, Instruct, 34B, or 70B models.
|
||||
- Use the `<FILL_ME>` token where you want your input to be filled. The tokenizer splits this token to create a formatted input string that follows the [original training pattern](https://github.com/facebookresearch/codellama/blob/cb51c14ec761370ba2e2bc351374a79265d0465e/llama/generation.py#L402). This is more robust than preparing the pattern yourself.
|
||||
```py
|
||||
from transformers import LlamaForCausalLM, CodeLlamaTokenizer
|
||||
|
||||
tokenizer = CodeLlamaTokenizer.from_pretrained("meta-llama/CodeLlama-7b-hf")
|
||||
model = LlamaForCausalLM.from_pretrained("meta-llama/CodeLlama-7b-hf")
|
||||
PROMPT = '''def remove_non_ascii(s: str) -> str:
|
||||
""" <FILL_ME>
|
||||
return result
|
||||
'''
|
||||
input_ids = tokenizer(PROMPT, return_tensors="pt")["input_ids"]
|
||||
generated_ids = model.generate(input_ids, max_new_tokens=128)
|
||||
|
||||
filling = tokenizer.batch_decode(generated_ids[:, input_ids.shape[1]:], skip_special_tokens = True)[0]
|
||||
print(PROMPT.replace("<FILL_ME>", filling))
|
||||
```
|
||||
- Use `bfloat16` for further training or fine-tuning and `float16` for inference.
|
||||
- The `BOS` character is not used for infilling when encoding the prefix or suffix, but only at the beginning of each prompt.
|
||||
- The tokenizer is a byte-pair encoding model based on [SentencePiece](https://github.com/google/sentencepiece). During decoding, if the first token is the start of the word (for example, “Banana”), the tokenizer doesn’t prepend the prefix space to the string.
|
||||
|
||||
## CodeLlamaTokenizer
|
||||
|
||||
|
||||
@ -1,124 +1,115 @@
|
||||
# Cohere
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<div style="float: right;">
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
</div>
|
||||
|
||||
## Overview
|
||||
|
||||
The Cohere Command-R model was proposed in the blogpost [Command-R: Retrieval Augmented Generation at Production Scale](https://txt.cohere.com/command-r/) by the Cohere Team.
|
||||
# Cohere
|
||||
|
||||
The abstract from the paper is the following:
|
||||
Cohere Command-R is a 35B parameter multilingual large language model designed for long context tasks like retrieval-augmented generation (RAG) and calling external APIs and tools. The model is specifically trained for grounded generation and supports both single-step and multi-step tool use. It supports a context length of 128K tokens.
|
||||
|
||||
*Command-R is a scalable generative model targeting RAG and Tool Use to enable production-scale AI for enterprise. Today, we are introducing Command-R, a new LLM aimed at large-scale production workloads. Command-R targets the emerging “scalable” category of models that balance high efficiency with strong accuracy, enabling companies to move beyond proof of concept, and into production.*
|
||||
You can find all the original Command-R checkpoints under the [Command Models](https://huggingface.co/collections/CohereForAI/command-models-67652b401665205e17b192ad) collection.
|
||||
|
||||
*Command-R is a generative model optimized for long context tasks such as retrieval augmented generation (RAG) and using external APIs and tools. It is designed to work in concert with our industry-leading Embed and Rerank models to provide best-in-class integration for RAG applications and excel at enterprise use cases. As a model built for companies to implement at scale, Command-R boasts:
|
||||
- Strong accuracy on RAG and Tool Use
|
||||
- Low latency, and high throughput
|
||||
- Longer 128k context and lower pricing
|
||||
- Strong capabilities across 10 key languages
|
||||
- Model weights available on HuggingFace for research and evaluation
|
||||
|
||||
Checkout model checkpoints [here](https://huggingface.co/CohereForAI/c4ai-command-r-v01).
|
||||
This model was contributed by [Saurabh Dash](https://huggingface.co/saurabhdash) and [Ahmet Üstün](https://huggingface.co/ahmetustun). The code of the implementation in Hugging Face is based on GPT-NeoX [here](https://github.com/EleutherAI/gpt-neox).
|
||||
> [!TIP]
|
||||
> Click on the Cohere models in the right sidebar for more examples of how to apply Cohere to different language tasks.
|
||||
|
||||
## Usage tips
|
||||
The example below demonstrates how to generate text with [`Pipeline`] or the [`AutoModel`], and from the command line.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
The checkpoints uploaded on the Hub use `torch_dtype = 'float16'`, which will be
|
||||
used by the `AutoModel` API to cast the checkpoints from `torch.float32` to `torch.float16`.
|
||||
|
||||
The `dtype` of the online weights is mostly irrelevant unless you are using `torch_dtype="auto"` when initializing a model using `model = AutoModelForCausalLM.from_pretrained("path", torch_dtype = "auto")`. The reason is that the model will first be downloaded ( using the `dtype` of the checkpoints online), then it will be casted to the default `dtype` of `torch` (becomes `torch.float32`), and finally, if there is a `torch_dtype` provided in the config, it will be used.
|
||||
|
||||
Training the model in `float16` is not recommended and is known to produce `nan`; as such, the model should be trained in `bfloat16`.
|
||||
|
||||
</Tip>
|
||||
The model and tokenizer can be loaded via:
|
||||
<hfoptions id="usage">
|
||||
<hfoption id="Pipeline">
|
||||
|
||||
```python
|
||||
# pip install transformers
|
||||
import torch
|
||||
from transformers import pipeline
|
||||
|
||||
pipeline = pipeline(
|
||||
task="text-generation",
|
||||
model="CohereForAI/c4ai-command-r-v01",
|
||||
torch_dtype=torch.float16,
|
||||
device=0
|
||||
)
|
||||
pipeline("Plants create energy through a process known as")
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="AutoModel">
|
||||
|
||||
```python
|
||||
import torch
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
|
||||
model_id = "CohereForAI/c4ai-command-r-v01"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id)
|
||||
tokenizer = AutoTokenizer.from_pretrained("CohereForAI/c4ai-command-r-v01")
|
||||
model = AutoModelForCausalLM.from_pretrained("CohereForAI/c4ai-command-r-v01", torch_dtype=torch.float16, device_map="auto", attn_implementation="sdpa")
|
||||
|
||||
# Format message with the command-r chat template
|
||||
messages = [{"role": "user", "content": "Hello, how are you?"}]
|
||||
input_ids = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt")
|
||||
## <BOS_TOKEN><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Hello, how are you?<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>
|
||||
|
||||
gen_tokens = model.generate(
|
||||
# format message with the Command-R chat template
|
||||
messages = [{"role": "user", "content": "How do plants make energy?"}]
|
||||
input_ids = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt").to("cuda")
|
||||
output = model.generate(
|
||||
input_ids,
|
||||
max_new_tokens=100,
|
||||
do_sample=True,
|
||||
temperature=0.3,
|
||||
)
|
||||
|
||||
gen_text = tokenizer.decode(gen_tokens[0])
|
||||
print(gen_text)
|
||||
cache_implementation="static",
|
||||
)
|
||||
print(tokenizer.decode(output[0], skip_special_tokens=True))
|
||||
```
|
||||
|
||||
- When using Flash Attention 2 via `attn_implementation="flash_attention_2"`, don't pass `torch_dtype` to the `from_pretrained` class method and use Automatic Mixed-Precision training. When using `Trainer`, it is simply specifying either `fp16` or `bf16` to `True`. Otherwise, make sure you are using `torch.autocast`. This is required because the Flash Attention only support `fp16` and `bf16` data type.
|
||||
</hfoption>
|
||||
<hfoption id="transformers-cli">
|
||||
|
||||
|
||||
## Resources
|
||||
|
||||
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with Command-R. If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we'll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource.
|
||||
|
||||
|
||||
<PipelineTag pipeline="text-generation"/>
|
||||
|
||||
Loading FP16 model
|
||||
```python
|
||||
# pip install transformers
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
|
||||
model_id = "CohereForAI/c4ai-command-r-v01"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id)
|
||||
|
||||
# Format message with the command-r chat template
|
||||
messages = [{"role": "user", "content": "Hello, how are you?"}]
|
||||
input_ids = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt")
|
||||
## <BOS_TOKEN><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Hello, how are you?<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>
|
||||
|
||||
gen_tokens = model.generate(
|
||||
input_ids,
|
||||
max_new_tokens=100,
|
||||
do_sample=True,
|
||||
temperature=0.3,
|
||||
)
|
||||
|
||||
gen_text = tokenizer.decode(gen_tokens[0])
|
||||
print(gen_text)
|
||||
```bash
|
||||
# pip install -U flash-attn --no-build-isolation
|
||||
transformers-cli chat --model_name_or_path CohereForAI/c4ai-command-r-v01 --torch_dtype auto --attn_implementation flash_attention_2
|
||||
```
|
||||
|
||||
Loading bitsnbytes 4bit quantized model
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
Quantization reduces the memory burden of large models by representing the weights in a lower precision. Refer to the [Quantization](../quantization/overview) overview for more available quantization backends.
|
||||
|
||||
The example below uses [bitsandbytes](../quantization/bitsandbytes) to quantize the weights to 4-bits.
|
||||
|
||||
```python
|
||||
# pip install transformers bitsandbytes accelerate
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
|
||||
import torch
|
||||
from transformers import BitsAndBytesConfig, AutoTokenizer, AutoModelForCausalLM
|
||||
|
||||
bnb_config = BitsAndBytesConfig(load_in_4bit=True)
|
||||
tokenizer = AutoTokenizer.from_pretrained("CohereForAI/c4ai-command-r-v01")
|
||||
model = AutoModelForCausalLM.from_pretrained("CohereForAI/c4ai-command-r-v01", torch_dtype=torch.float16, device_map="auto", quantization_config=bnb_config, attn_implementation="sdpa")
|
||||
|
||||
model_id = "CohereForAI/c4ai-command-r-v01"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config)
|
||||
|
||||
gen_tokens = model.generate(
|
||||
# format message with the Command-R chat template
|
||||
messages = [{"role": "user", "content": "How do plants make energy?"}]
|
||||
input_ids = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt").to("cuda")
|
||||
output = model.generate(
|
||||
input_ids,
|
||||
max_new_tokens=100,
|
||||
do_sample=True,
|
||||
temperature=0.3,
|
||||
)
|
||||
|
||||
gen_text = tokenizer.decode(gen_tokens[0])
|
||||
print(gen_text)
|
||||
cache_implementation="static",
|
||||
)
|
||||
print(tokenizer.decode(output[0], skip_special_tokens=True))
|
||||
```
|
||||
|
||||
Use the [AttentionMaskVisualizer](https://github.com/huggingface/transformers/blob/beb9b5b02246b9b7ee81ddf938f93f44cfeaad19/src/transformers/utils/attention_visualizer.py#L139) to better understand what tokens the model can and cannot attend to.
|
||||
|
||||
```py
|
||||
from transformers.utils.attention_visualizer import AttentionMaskVisualizer
|
||||
|
||||
visualizer = AttentionMaskVisualizer("CohereForAI/c4ai-command-r-v01")
|
||||
visualizer("Plants create energy through a process known as")
|
||||
```
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/cohere-attn-mask.png"/>
|
||||
</div>
|
||||
|
||||
|
||||
## Notes
|
||||
- Don’t use the torch_dtype parameter in [`~AutoModel.from_pretrained`] if you’re using FlashAttention-2 because it only supports fp16 or bf16. You should use [Automatic Mixed Precision](https://pytorch.org/tutorials/recipes/recipes/amp_recipe.html), set fp16 or bf16 to True if using [`Trainer`], or use [torch.autocast](https://pytorch.org/docs/stable/amp.html#torch.autocast).
|
||||
|
||||
## CohereConfig
|
||||
|
||||
@ -143,5 +134,3 @@ print(gen_text)
|
||||
|
||||
[[autodoc]] CohereForCausalLM
|
||||
- forward
|
||||
|
||||
|
||||
|
||||
184
docs/source/en/model_doc/deepseek_v3.md
Normal file
184
docs/source/en/model_doc/deepseek_v3.md
Normal file
@ -0,0 +1,184 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# DeepSeek-V3
|
||||
|
||||
## Overview
|
||||
|
||||
The DeepSeek-V3 model was proposed in [DeepSeek-V3 Technical Report](https://arxiv.org/abs/2412.19437) by DeepSeek-AI Team.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
We present DeepSeek-V3, a strong Mixture-of-Experts (MoE) language model with 671B total parameters with 37B activated for each token. To achieve efficient inference and cost-effective training, DeepSeek-V3 adopts Multi-head Latent Attention (MLA) and DeepSeekMoE architectures, which were thoroughly validated in DeepSeek-V2. Furthermore, DeepSeek-V3 pioneers an auxiliary-loss-free strategy for load balancing and sets a multi-token prediction training objective for stronger performance. We pre-train DeepSeek-V3 on 14.8 trillion diverse and high-quality tokens, followed by Supervised Fine-Tuning and Reinforcement Learning stages to fully harness its capabilities. Comprehensive evaluations reveal that DeepSeek-V3 outperforms other open-source models and achieves performance comparable to leading closed-source models. Despite its excellent performance, DeepSeek-V3 requires only 2.788M H800 GPU hours for its full training. In addition, its training process is remarkably stable. Throughout the entire training process, we did not experience any irrecoverable loss spikes or perform any rollbacks. The model checkpoints are available at https://github.com/deepseek-ai/DeepSeek-V3.
|
||||
|
||||
## Limitations and call for contribution!
|
||||
|
||||
We are super happy to make this code community-powered, and would love to see how you can best optimize the following:
|
||||
|
||||
- current implementation uses the "naive" attention compution (so not really MLA)
|
||||
- current implementation loops through the experts. This should be replaced. Pointers to use `get_packed_weights` from `intetrations/tensor_parallel`.
|
||||
- current implementation uses the eleuther formula for ROPE, using the orginal one would be more efficient! (should still follow our API)
|
||||
- static cache is not supported (this should be just a generation config issue / config shape issues)
|
||||
|
||||
### Usage tips
|
||||
The model uses Multi-head Latent Attention (MLA) and DeepSeekMoE architectures for efficient inference and cost-effective training. It employs an auxiliary-loss-free strategy for load balancing and multi-token prediction training objective. The model can be used for various language tasks after being pre-trained on 14.8 trillion tokens and going through Supervised Fine-Tuning and Reinforcement Learning stages.
|
||||
|
||||
You can run the model in `FP8` automatically, using 2 nodes of 8 H100 should be more than enough!
|
||||
|
||||
```python
|
||||
# `run_deepseek_v1.py`
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
import torch
|
||||
torch.manual_seed(30)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("deepseek-r1")
|
||||
|
||||
chat = [
|
||||
{"role": "user", "content": "Hello, how are you?"},
|
||||
{"role": "assistant", "content": "I'm doing great. How can I help you today?"},
|
||||
{"role": "user", "content": "I'd like to show off how chat templating works!"},
|
||||
]
|
||||
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("deepseek-r1", device_map="auto", torch_dtype=torch.bfloat16)
|
||||
inputs = tokenizer.apply_chat_template(chat, tokenize=True, add_generation_prompt=True, return_tensors="pt").to(model.device)
|
||||
import time
|
||||
start = time.time()
|
||||
outputs = model.generate(inputs, max_new_tokens=50)
|
||||
print(tokenizer.batch_decode(outputs))
|
||||
print(time.time()-start)
|
||||
```
|
||||
This generated:
|
||||
|
||||
``````
|
||||
<|Assistant|><think>
|
||||
Okay, the user wants to demonstrate how chat templating works. Let me break down what that means. Chat templating is about structuring the conversation data, especially for models that need specific input formats. Maybe they're referring to something like how messages are formatted with roles (user, assistant, system) in APIs like OpenAI.
|
||||
|
||||
First, I should explain what chat templating is. It's the process of formatting conversation data into a structured format that the model can understand. This usually includes roles and content. For example, user messages, assistant responses, and system messages each have their own role tags.
|
||||
|
||||
They might want an example. Let me think of a simple conversation. The user says "Hello, how are you?" and the assistant responds "I'm doing great. How can I help you today?" Then the user follows up with wanting to show off chat templating. So the example should include the history and the new message.
|
||||
|
||||
In some frameworks, like Hugging Face's Transformers, chat templates are applied using Jinja2 templates. The template might look something like combining system messages, then looping through user and assistant messages with appropriate tags. For instance, using {% for message in messages %} and assigning roles like <|user|>, <|assistant|>, etc.
|
||||
|
||||
I should structure the example with the messages array, showing each role and content. Then apply a hypothetical template to convert that into a formatted string the model uses. Also, mention that different models have different templating requirements, like using special tokens or varying role labels.
|
||||
|
||||
Wait, the user mentioned "chat templating" in the context of showing off. Maybe they want a practical example they can present. So providing a code snippet or a structured data example would be helpful. Let me outline a typical messages array and then the templated output.
|
||||
|
||||
Also, it's important to note that proper templating ensures the model knows the conversation flow, which is crucial for generating coherent responses. Maybe include a note about why it's important, like maintaining context and role-specific processing.
|
||||
|
||||
Let me check if there are any common mistakes or things to avoid. For example, not closing tags properly, or mismatching roles. But maybe that's too detailed unless the user asks. Focus on the positive example first.
|
||||
|
||||
Putting it all together, the response should have an example messages array, the applied template, and the final formatted string. Maybe use angle brackets or special tokens as placeholders. Also, mention that this helps in training or fine-tuning models with structured data.
|
||||
|
||||
I think that's a solid approach. Let me structure it step by step to make it clear.
|
||||
</think>
|
||||
|
||||
Chat templating is a way to structure conversation data (e.g., user/assistant interactions) into a format that language models understand. This is especially important for models trained to handle multi-turn dialogues, where the input must explicitly separate roles (user, assistant, system, etc.) and messages. Let’s break this down with an example!
|
||||
|
||||
---
|
||||
|
||||
### **Step 1: Raw Conversation History**
|
||||
Suppose we have this conversation:
|
||||
- **User**: "Hello, how are you?"
|
||||
- **Assistant**: "I'm doing great. How can I help you today?"
|
||||
- **User**: "I'd like to show off how chat templating works!"
|
||||
|
||||
---
|
||||
|
||||
### **Step 2: Structured Messages**
|
||||
In frameworks like Hugging Face Transformers or OpenAI, conversations are often formatted as a list of dictionaries with `role` and `content`:
|
||||
```python
|
||||
messages = [
|
||||
{"role": "user", "content": "Hello, how are you?"},
|
||||
{"role": "assistant", "content": "I'm doing great. How can I help you today?"},
|
||||
{"role": "user", "content": "I'd like to show off how chat templating works!"},
|
||||
]
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### **Step 3: Apply a Chat Template**
|
||||
A **chat template** converts this structured data into a single string formatted for the model. For example, using a Jinja-style template (common in Hugging Face):
|
||||
|
||||
```jinja
|
||||
{% for message in messages %}
|
||||
{% if message['role'] == 'user' %}
|
||||
<|user|>{{ message['content'] }}<|end|>
|
||||
{% elif message['role'] == 'assistant' %}
|
||||
<|assistant|>{{ message['content'] }}<|end|>
|
||||
{% endif %}
|
||||
{% endfor %}
|
||||
<|assistant|>
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### **Step 4: Final Templated Output**
|
||||
Applying the template to our `messages` list would produce:
|
||||
```text
|
||||
<|user|>Hello, how are you?<|end|>
|
||||
<|assistant|>I'm doing great. How can I help you today?<|end|>
|
||||
<|user|>I'd like to show off how chat templating works!<|end|>
|
||||
<|assistant|>
|
||||
```
|
||||
|
||||
This tells the model:
|
||||
1. The conversation history (user/assistant turns).
|
||||
2. The model’s turn to generate a response (`<|assistant|>` at the end).
|
||||
|
||||
---
|
||||
|
||||
### **Key Notes**:
|
||||
- **Role Separation**: Tags like `<|user|>` and `<|assistant|>` help the model distinguish speakers.
|
||||
- **Special Tokens**: Models often use unique tokens (e.g., `<|end|>`) to mark message boundaries.
|
||||
- **Flexibility**: Templates vary by model (e.g., OpenAI uses `{"role": "user", "content": "..."}` instead of tags).
|
||||
|
||||
---
|
||||
|
||||
### **Why This Matters**:
|
||||
- **Consistency**: Ensures the model understands dialogue structure.
|
||||
- **Context Preservation**: Maintains the flow of multi-turn conversations.
|
||||
- **Alignment**: Matches the format the model was trained on for better performance.
|
||||
|
||||
Want to dive deeper or see a specific framework’s implementation (e.g., OpenAI, Llama, Mistral)? Let me know! 😊<|end▁of▁sentence|>
|
||||
``````
|
||||
|
||||
Use the following to run it
|
||||
```bash
|
||||
torchrun --nproc_per_node=8 --nnodes=2 --node_rank=0|1 --rdzv-id an_id --rdzv-backend c10d --rdzv-endpoint master_addr:master_port run_deepseek_r1.py
|
||||
```
|
||||
|
||||
If you have:
|
||||
```bash
|
||||
[rank0]: ncclInternalError: Internal check failed.
|
||||
[rank0]: Last error:
|
||||
[rank0]: Bootstrap : no socket interface found
|
||||
```
|
||||
error, it means NCCL was probably not loaded.
|
||||
|
||||
|
||||
## DeepseekV3Config
|
||||
|
||||
[[autodoc]] DeepseekV3Config
|
||||
|
||||
## DeepseekV3Model
|
||||
|
||||
[[autodoc]] DeepseekV3Model
|
||||
- forward
|
||||
|
||||
## DeepseekV3ForCausalLM
|
||||
|
||||
[[autodoc]] DeepseekV3ForCausalLM
|
||||
- forward
|
||||
@ -19,6 +19,7 @@ rendered properly in your Markdown viewer.
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="TensorFlow" src="https://img.shields.io/badge/TensorFlow-FF6F00?style=flat&logo=tensorflow&logoColor=white">
|
||||
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
|
||||
|
||||
@ -14,101 +14,69 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# Depth Anything
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<div style="float: right;">
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
</div>
|
||||
|
||||
## Overview
|
||||
# Depth Anything
|
||||
|
||||
The Depth Anything model was proposed in [Depth Anything: Unleashing the Power of Large-Scale Unlabeled Data](https://arxiv.org/abs/2401.10891) by Lihe Yang, Bingyi Kang, Zilong Huang, Xiaogang Xu, Jiashi Feng, Hengshuang Zhao. Depth Anything is based on the [DPT](dpt) architecture, trained on ~62 million images, obtaining state-of-the-art results for both relative and absolute depth estimation.
|
||||
[Depth Anything](https://huggingface.co/papers/2401.10891) is designed to be a foundation model for monocular depth estimation (MDE). It is jointly trained on labeled and ~62M unlabeled images to enhance the dataset. It uses a pretrained [DINOv2](./dinov2) model as an image encoder to inherit its existing rich semantic priors, and [DPT](./dpt) as the decoder. A teacher model is trained on unlabeled images to create pseudo-labels. The student model is trained on a combination of the pseudo-labels and labeled images. To improve the student model's performance, strong perturbations are added to the unlabeled images to challenge the student model to learn more visual knowledge from the image.
|
||||
|
||||
<Tip>
|
||||
You can find all the original Depth Anything checkpoints under the [Depth Anything](https://huggingface.co/collections/LiheYoung/depth-anything-release-65b317de04eec72abf6b55aa) collection.
|
||||
|
||||
[Depth Anything V2](depth_anything_v2) was released in June 2024. It uses the same architecture as Depth Anything and therefore it is compatible with all code examples and existing workflows. However, it leverages synthetic data and a larger capacity teacher model to achieve much finer and robust depth predictions.
|
||||
> [!TIP]
|
||||
> Click on the Depth Anything models in the right sidebar for more examples of how to apply Depth Anything to different vision tasks.
|
||||
|
||||
</Tip>
|
||||
The example below demonstrates how to obtain a depth map with [`Pipeline`] or the [`AutoModel`] class.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
<hfoptions id="usage">
|
||||
<hfoption id="Pipeline">
|
||||
|
||||
*This work presents Depth Anything, a highly practical solution for robust monocular depth estimation. Without pursuing novel technical modules, we aim to build a simple yet powerful foundation model dealing with any images under any circumstances. To this end, we scale up the dataset by designing a data engine to collect and automatically annotate large-scale unlabeled data (~62M), which significantly enlarges the data coverage and thus is able to reduce the generalization error. We investigate two simple yet effective strategies that make data scaling-up promising. First, a more challenging optimization target is created by leveraging data augmentation tools. It compels the model to actively seek extra visual knowledge and acquire robust representations. Second, an auxiliary supervision is developed to enforce the model to inherit rich semantic priors from pre-trained encoders. We evaluate its zero-shot capabilities extensively, including six public datasets and randomly captured photos. It demonstrates impressive generalization ability. Further, through fine-tuning it with metric depth information from NYUv2 and KITTI, new SOTAs are set. Our better depth model also results in a better depth-conditioned ControlNet.*
|
||||
```py
|
||||
import torch
|
||||
from transformers import pipeline
|
||||
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/depth_anything_overview.jpg"
|
||||
alt="drawing" width="600"/>
|
||||
|
||||
<small> Depth Anything overview. Taken from the <a href="https://arxiv.org/abs/2401.10891">original paper</a>.</small>
|
||||
|
||||
This model was contributed by [nielsr](https://huggingface.co/nielsr).
|
||||
The original code can be found [here](https://github.com/LiheYoung/Depth-Anything).
|
||||
|
||||
## Usage example
|
||||
|
||||
There are 2 main ways to use Depth Anything: either using the pipeline API, which abstracts away all the complexity for you, or by using the `DepthAnythingForDepthEstimation` class yourself.
|
||||
|
||||
### Pipeline API
|
||||
|
||||
The pipeline allows to use the model in a few lines of code:
|
||||
|
||||
```python
|
||||
>>> from transformers import pipeline
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
|
||||
>>> # load pipe
|
||||
>>> pipe = pipeline(task="depth-estimation", model="LiheYoung/depth-anything-small-hf")
|
||||
|
||||
>>> # load image
|
||||
>>> url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
>>> # inference
|
||||
>>> depth = pipe(image)["depth"]
|
||||
pipe = pipeline(task="depth-estimation", model="LiheYoung/depth-anything-base-hf", torch_dtype=torch.bfloat16, device=0)
|
||||
pipe("http://images.cocodataset.org/val2017/000000039769.jpg")["depth"]
|
||||
```
|
||||
|
||||
### Using the model yourself
|
||||
</hfoption>
|
||||
<hfoption id="AutoModel">
|
||||
|
||||
If you want to do the pre- and postprocessing yourself, here's how to do that:
|
||||
```py
|
||||
import torch
|
||||
import requests
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from transformers import AutoImageProcessor, AutoModelForDepthEstimation
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoImageProcessor, AutoModelForDepthEstimation
|
||||
>>> import torch
|
||||
>>> import numpy as np
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
image_processor = AutoImageProcessor.from_pretrained("LiheYoung/depth-anything-base-hf")
|
||||
model = AutoModelForDepthEstimation.from_pretrained("LiheYoung/depth-anything-base-hf", torch_dtype=torch.bfloat16)
|
||||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
image = Image.open(requests.get(url, stream=True).raw)
|
||||
inputs = image_processor(images=image, return_tensors="pt")
|
||||
|
||||
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
|
||||
>>> image_processor = AutoImageProcessor.from_pretrained("LiheYoung/depth-anything-small-hf")
|
||||
>>> model = AutoModelForDepthEstimation.from_pretrained("LiheYoung/depth-anything-small-hf")
|
||||
|
||||
>>> # prepare image for the model
|
||||
>>> inputs = image_processor(images=image, return_tensors="pt")
|
||||
|
||||
>>> with torch.no_grad():
|
||||
... outputs = model(**inputs)
|
||||
|
||||
>>> # interpolate to original size and visualize the prediction
|
||||
>>> post_processed_output = image_processor.post_process_depth_estimation(
|
||||
... outputs,
|
||||
... target_sizes=[(image.height, image.width)],
|
||||
... )
|
||||
|
||||
>>> predicted_depth = post_processed_output[0]["predicted_depth"]
|
||||
>>> depth = (predicted_depth - predicted_depth.min()) / (predicted_depth.max() - predicted_depth.min())
|
||||
>>> depth = depth.detach().cpu().numpy() * 255
|
||||
>>> depth = Image.fromarray(depth.astype("uint8"))
|
||||
post_processed_output = image_processor.post_process_depth_estimation(
|
||||
outputs,
|
||||
target_sizes=[(image.height, image.width)],
|
||||
)
|
||||
predicted_depth = post_processed_output[0]["predicted_depth"]
|
||||
depth = (predicted_depth - predicted_depth.min()) / (predicted_depth.max() - predicted_depth.min())
|
||||
depth = depth.detach().cpu().numpy() * 255
|
||||
Image.fromarray(depth.astype("uint8"))
|
||||
```
|
||||
|
||||
## Resources
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with Depth Anything.
|
||||
## Notes
|
||||
|
||||
- [Monocular depth estimation task guide](../tasks/monocular_depth_estimation)
|
||||
- A notebook showcasing inference with [`DepthAnythingForDepthEstimation`] can be found [here](https://github.com/NielsRogge/Transformers-Tutorials/blob/master/Depth%20Anything/Predicting_depth_in_an_image_with_Depth_Anything.ipynb). 🌎
|
||||
|
||||
If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we'll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource.
|
||||
- [DepthAnythingV2](./depth_anything_v2), released in June 2024, uses the same architecture as Depth Anything and is compatible with all code examples and existing workflows. It uses synthetic data and a larger capacity teacher model to achieve much finer and robust depth predictions.
|
||||
|
||||
## DepthAnythingConfig
|
||||
|
||||
|
||||
@ -16,6 +16,7 @@ specific language governing permissions and limitations under the License.
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="Flax" src="https://img.shields.io/badge/Flax-29a79b.svg?style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAC0AAAAtCAMAAAANxBKoAAAC7lBMVEUAAADg5vYHPVgAoJH+/v76+v39/f9JbLP///9+AIgAnY3///+mcqzt8fXy9fgkXa3Ax9709fr+///9/f8qXq49qp5AaLGMwrv8/P0eW60VWawxYq8yqJzG2dytt9Wyu9elzci519Lf3O3S2efY3OrY0+Xp7PT///////+dqNCexMc6Z7AGpJeGvbenstPZ5ejQ1OfJzOLa7ejh4+/r8fT29vpccbklWK8PVa0AS6ghW63O498vYa+lsdKz1NDRt9Kw1c672tbD3tnAxt7R6OHp5vDe7OrDyuDn6vLl6/EAQKak0MgATakkppo3ZK/Bz9y8w9yzu9jey97axdvHzeG21NHH4trTwthKZrVGZLSUSpuPQJiGAI+GAI8SWKydycLL4d7f2OTi1+S9xNzL0ePT6OLGzeEAo5U0qJw/aLEAo5JFa7JBabEAp5Y4qZ2QxLyKmsm3kL2xoMOehrRNb7RIbbOZgrGre68AUqwAqZqNN5aKJ5N/lMq+qsd8kMa4pcWzh7muhLMEV69juq2kbKqgUaOTR5uMMZWLLZSGAI5VAIdEAH+ovNDHuNCnxcy3qcaYx8K8msGplrx+wLahjbYdXrV6vbMvYK9DrZ8QrZ8tqJuFms+Sos6sw8ecy8RffsNVeMCvmb43aLltv7Q4Y7EZWK4QWa1gt6meZKUdr6GOAZVeA4xPAISyveLUwtivxtKTpNJ2jcqfvcltiMiwwcfAoMVxhL+Kx7xjdrqTe60tsaNQs6KaRKACrJ6UTZwkqpqTL5pkHY4AloSgsd2ptNXPvNOOncuxxsqFl8lmg8apt8FJcr9EbryGxLqlkrkrY7dRa7ZGZLQ5t6iXUZ6PPpgVpZeJCJFKAIGareTa0+KJod3H0deY2M+esM25usmYu8d2zsJOdcBVvrCLbqcAOaaHaKQAMaScWqKBXqCXMJ2RHpiLF5NmJZAdAHN2kta11dKu1M+DkcZLdb+Mcql3TppyRJdzQ5ZtNZNlIY+DF4+voCOQAAAAZ3RSTlMABAT+MEEJ/RH+/TP+Zlv+pUo6Ifz8+fco/fz6+evr39S9nJmOilQaF/7+/f38+smmoYp6b1T+/v7++vj189zU0tDJxsGzsrKSfv34+Pf27dDOysG9t6+n/vv6+vr59uzr1tG+tZ6Qg9Ym3QAABR5JREFUSMeNlVVUG1EQhpcuxEspXqS0SKEtxQp1d3d332STTRpIQhIISQgJhODu7lAoDoUCpe7u7u7+1puGpqnCPOyZvffbOXPm/PsP9JfQgyCC+tmTABTOcbxDz/heENS7/1F+9nhvkHePG0wNDLbGWwdXL+rbLWvpmZHXD8+gMfBjTh+aSe6Gnn7lwQIOTR0c8wfX3PWgv7avbdKwf/ZoBp1Gp/PvuvXW3vw5ib7emnTW4OR+3D4jB9vjNJ/7gNvfWWeH/TO/JyYrsiKCRjVEZA3UB+96kON+DxOQ/NLE8PE5iUYgIXjFnCOlxEQMaSGVxjg4gxOnEycGz8bptuNjVx08LscIgrzH3umcn+KKtiBIyvzOO2O99aAdR8cF19oZalnCtvREUw79tCd5sow1g1UKM6kXqUx4T8wsi3sTjJ3yzDmmhenLXLpo8u45eG5y4Vvbk6kkC4LLtJMowkSQxmk4ggVJEG+7c6QpHT8vvW9X7/o7+3ELmiJi2mEzZJiz8cT6TBlanBk70cB5GGIGC1gRDdZ00yADLW1FL6gqhtvNXNG5S9gdSrk4M1qu7JAsmYshzDS4peoMrU/gT7qQdqYGZaYhxZmVbGJAm/CS/HloWyhRUlknQ9KYcExTwS80d3VNOxUZJpITYyspl0LbhArhpZCD9cRWEQuhYkNGMHToQ/2Cs6swJlb39CsllxdXX6IUKh/H5jbnSsPKjgmoaFQ1f8wRLR0UnGE/RcDEjj2jXG1WVTwUs8+zxfcrVO+vSsuOpVKxCfYZiQ0/aPKuxQbQ8lIz+DClxC8u+snlcJ7Yr1z1JPqUH0V+GDXbOwAib931Y4Imaq0NTIXPXY+N5L18GJ37SVWu+hwXff8l72Ds9XuwYIBaXPq6Shm4l+Vl/5QiOlV+uTk6YR9PxKsI9xNJny31ygK1e+nIRC1N97EGkFPI+jCpiHe5PCEy7oWqWSwRrpOvhFzcbTWMbm3ZJAOn1rUKpYIt/lDhW/5RHHteeWFN60qo98YJuoq1nK3uW5AabyspC1BcIEpOhft+SZAShYoLSvnmSfnYADUERP5jJn2h5XtsgCRuhYQqAvwTwn33+YWEKUI72HX5AtfSAZDe8F2DtPPm77afhl0EkthzuCQU0BWApgQIH9+KB0JhopMM7bJrdTRoleM2JAVNMyPF+wdoaz+XJpGoVAQ7WXUkcV7gT3oUZyi/ISIJAVKhgNp+4b4veCFhYVJw4locdSjZCp9cPUhLF9EZ3KKzURepMEtCDPP3VcWFx4UIiZIklIpFNfHpdEafIF2aRmOcrUmjohbT2WUllbmRvgfbythbQO3222fpDJoufaQPncYYuqoGtUEsCJZL6/3PR5b4syeSjZMQG/T2maGANlXT2v8S4AULWaUkCxfLyW8iW4kdka+nEMjxpL2NCwsYNBp+Q61PF43zyDg9Bm9+3NNySn78jMZUUkumqE4Gp7JmFOdP1vc8PpRrzj9+wPinCy8K1PiJ4aYbnTYpCCbDkBSbzhu2QJ1Gd82t8jI8TH51+OzvXoWbnXUOBkNW+0mWFwGcGOUVpU81/n3TOHb5oMt2FgYGjzau0Nif0Ss7Q3XB33hjjQHjHA5E5aOyIQc8CBrLdQSs3j92VG+3nNEjbkbdbBr9zm04ruvw37vh0QKOdeGIkckc80fX3KH/h7PT4BOjgCty8VZ5ux1MoO5Cf5naca2LAsEgehI+drX8o/0Nu+W0m6K/I9gGPd/dfx/EN/wN62AhsBWuAAAAAElFTkSuQmCC
|
||||
">
|
||||
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
|
||||
|
||||
@ -11,6 +11,7 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
|
||||
|
||||
@ -14,199 +14,91 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# DistilBERT
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="TensorFlow" src="https://img.shields.io/badge/TensorFlow-FF6F00?style=flat&logo=tensorflow&logoColor=white">
|
||||
<img alt="Flax" src="https://img.shields.io/badge/Flax-29a79b.svg?style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAC0AAAAtCAMAAAANxBKoAAAC7lBMVEUAAADg5vYHPVgAoJH+/v76+v39/f9JbLP///9+AIgAnY3///+mcqzt8fXy9fgkXa3Ax9709fr+///9/f8qXq49qp5AaLGMwrv8/P0eW60VWawxYq8yqJzG2dytt9Wyu9elzci519Lf3O3S2efY3OrY0+Xp7PT///////+dqNCexMc6Z7AGpJeGvbenstPZ5ejQ1OfJzOLa7ejh4+/r8fT29vpccbklWK8PVa0AS6ghW63O498vYa+lsdKz1NDRt9Kw1c672tbD3tnAxt7R6OHp5vDe7OrDyuDn6vLl6/EAQKak0MgATakkppo3ZK/Bz9y8w9yzu9jey97axdvHzeG21NHH4trTwthKZrVGZLSUSpuPQJiGAI+GAI8SWKydycLL4d7f2OTi1+S9xNzL0ePT6OLGzeEAo5U0qJw/aLEAo5JFa7JBabEAp5Y4qZ2QxLyKmsm3kL2xoMOehrRNb7RIbbOZgrGre68AUqwAqZqNN5aKJ5N/lMq+qsd8kMa4pcWzh7muhLMEV69juq2kbKqgUaOTR5uMMZWLLZSGAI5VAIdEAH+ovNDHuNCnxcy3qcaYx8K8msGplrx+wLahjbYdXrV6vbMvYK9DrZ8QrZ8tqJuFms+Sos6sw8ecy8RffsNVeMCvmb43aLltv7Q4Y7EZWK4QWa1gt6meZKUdr6GOAZVeA4xPAISyveLUwtivxtKTpNJ2jcqfvcltiMiwwcfAoMVxhL+Kx7xjdrqTe60tsaNQs6KaRKACrJ6UTZwkqpqTL5pkHY4AloSgsd2ptNXPvNOOncuxxsqFl8lmg8apt8FJcr9EbryGxLqlkrkrY7dRa7ZGZLQ5t6iXUZ6PPpgVpZeJCJFKAIGareTa0+KJod3H0deY2M+esM25usmYu8d2zsJOdcBVvrCLbqcAOaaHaKQAMaScWqKBXqCXMJ2RHpiLF5NmJZAdAHN2kta11dKu1M+DkcZLdb+Mcql3TppyRJdzQ5ZtNZNlIY+DF4+voCOQAAAAZ3RSTlMABAT+MEEJ/RH+/TP+Zlv+pUo6Ifz8+fco/fz6+evr39S9nJmOilQaF/7+/f38+smmoYp6b1T+/v7++vj189zU0tDJxsGzsrKSfv34+Pf27dDOysG9t6+n/vv6+vr59uzr1tG+tZ6Qg9Ym3QAABR5JREFUSMeNlVVUG1EQhpcuxEspXqS0SKEtxQp1d3d332STTRpIQhIISQgJhODu7lAoDoUCpe7u7u7+1puGpqnCPOyZvffbOXPm/PsP9JfQgyCC+tmTABTOcbxDz/heENS7/1F+9nhvkHePG0wNDLbGWwdXL+rbLWvpmZHXD8+gMfBjTh+aSe6Gnn7lwQIOTR0c8wfX3PWgv7avbdKwf/ZoBp1Gp/PvuvXW3vw5ib7emnTW4OR+3D4jB9vjNJ/7gNvfWWeH/TO/JyYrsiKCRjVEZA3UB+96kON+DxOQ/NLE8PE5iUYgIXjFnCOlxEQMaSGVxjg4gxOnEycGz8bptuNjVx08LscIgrzH3umcn+KKtiBIyvzOO2O99aAdR8cF19oZalnCtvREUw79tCd5sow1g1UKM6kXqUx4T8wsi3sTjJ3yzDmmhenLXLpo8u45eG5y4Vvbk6kkC4LLtJMowkSQxmk4ggVJEG+7c6QpHT8vvW9X7/o7+3ELmiJi2mEzZJiz8cT6TBlanBk70cB5GGIGC1gRDdZ00yADLW1FL6gqhtvNXNG5S9gdSrk4M1qu7JAsmYshzDS4peoMrU/gT7qQdqYGZaYhxZmVbGJAm/CS/HloWyhRUlknQ9KYcExTwS80d3VNOxUZJpITYyspl0LbhArhpZCD9cRWEQuhYkNGMHToQ/2Cs6swJlb39CsllxdXX6IUKh/H5jbnSsPKjgmoaFQ1f8wRLR0UnGE/RcDEjj2jXG1WVTwUs8+zxfcrVO+vSsuOpVKxCfYZiQ0/aPKuxQbQ8lIz+DClxC8u+snlcJ7Yr1z1JPqUH0V+GDXbOwAib931Y4Imaq0NTIXPXY+N5L18GJ37SVWu+hwXff8l72Ds9XuwYIBaXPq6Shm4l+Vl/5QiOlV+uTk6YR9PxKsI9xNJny31ygK1e+nIRC1N97EGkFPI+jCpiHe5PCEy7oWqWSwRrpOvhFzcbTWMbm3ZJAOn1rUKpYIt/lDhW/5RHHteeWFN60qo98YJuoq1nK3uW5AabyspC1BcIEpOhft+SZAShYoLSvnmSfnYADUERP5jJn2h5XtsgCRuhYQqAvwTwn33+YWEKUI72HX5AtfSAZDe8F2DtPPm77afhl0EkthzuCQU0BWApgQIH9+KB0JhopMM7bJrdTRoleM2JAVNMyPF+wdoaz+XJpGoVAQ7WXUkcV7gT3oUZyi/ISIJAVKhgNp+4b4veCFhYVJw4locdSjZCp9cPUhLF9EZ3KKzURepMEtCDPP3VcWFx4UIiZIklIpFNfHpdEafIF2aRmOcrUmjohbT2WUllbmRvgfbythbQO3222fpDJoufaQPncYYuqoGtUEsCJZL6/3PR5b4syeSjZMQG/T2maGANlXT2v8S4AULWaUkCxfLyW8iW4kdka+nEMjxpL2NCwsYNBp+Q61PF43zyDg9Bm9+3NNySn78jMZUUkumqE4Gp7JmFOdP1vc8PpRrzj9+wPinCy8K1PiJ4aYbnTYpCCbDkBSbzhu2QJ1Gd82t8jI8TH51+OzvXoWbnXUOBkNW+0mWFwGcGOUVpU81/n3TOHb5oMt2FgYGjzau0Nif0Ss7Q3XB33hjjQHjHA5E5aOyIQc8CBrLdQSs3j92VG+3nNEjbkbdbBr9zm04ruvw37vh0QKOdeGIkckc80fX3KH/h7PT4BOjgCty8VZ5ux1MoO5Cf5naca2LAsEgehI+drX8o/0Nu+W0m6K/I9gGPd/dfx/EN/wN62AhsBWuAAAAAElFTkSuQmCC
|
||||
">
|
||||
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<div style="float: right;">
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="TensorFlow" src="https://img.shields.io/badge/TensorFlow-FF6F00?style=flat&logo=tensorflow&logoColor=white">
|
||||
<img alt="Flax" src="https://img.shields.io/badge/Flax-29a79b.svg?style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAC0AAAAtCAMAAAANxBKoAAAC7lBMVEUAAADg5vYHPVgAoJH+/v76+v39/f9JbLP///9+AIgAnY3///+mcqzt8fXy9fgkXa3Ax9709fr+///9/f8qXq49qp5AaLGMwrv8/P0eW60VWawxYq8yqJzG2dytt9Wyu9elzci519Lf3O3S2efY3OrY0+Xp7PT///////+dqNCexMc6Z7AGpJeGvbenstPZ5ejQ1OfJzOLa7ejh4+/r8fT29vpccbklWK8PVa0AS6ghW63O498vYa+lsdKz1NDRt9Kw1c672tbD3tnAxt7R6OHp5vDe7OrDyuDn6vLl6/EAQKak0MgATakkppo3ZK/Bz9y8w9yzu9jey97axdvHzeG21NHH4trTwthKZrVGZLSUSpuPQJiGAI+GAI8SWKydycLL4d7f2OTi1+S9xNzL0ePT6OLGzeEAo5U0qJw/aLEAo5JFa7JBabEAp5Y4qZ2QxLyKmsm3kL2xoMOehrRNb7RIbbOZgrGre68AUqwAqZqNN5aKJ5N/lMq+qsd8kMa4pcWzh7muhLMEV69juq2kbKqgUaOTR5uMMZWLLZSGAI5VAIdEAH+ovNDHuNCnxcy3qcaYx8K8msGplrx+wLahjbYdXrV6vbMvYK9DrZ8QrZ8tqJuFms+Sos6sw8ecy8RffsNVeMCvmb43aLltv7Q4Y7EZWK4QWa1gt6meZKUdr6GOAZVeA4xPAISyveLUwtivxtKTpNJ2jcqfvcltiMiwwcfAoMVxhL+Kx7xjdrqTe60tsaNQs6KaRKACrJ6UTZwkqpqTL5pkHY4AloSgsd2ptNXPvNOOncuxxsqFl8lmg8apt8FJcr9EbryGxLqlkrkrY7dRa7ZGZLQ5t6iXUZ6PPpgVpZeJCJFKAIGareTa0+KJod3H0deY2M+esM25usmYu8d2zsJOdcBVvrCLbqcAOaaHaKQAMaScWqKBXqCXMJ2RHpiLF5NmJZAdAHN2kta11dKu1M+DkcZLdb+Mcql3TppyRJdzQ5ZtNZNlIY+DF4+voCOQAAAAZ3RSTlMABAT+MEEJ/RH+/TP+Zlv+pUo6Ifz8+fco/fz6+evr39S9nJmOilQaF/7+/f38+smmoYp6b1T+/v7++vj189zU0tDJxsGzsrKSfv34+Pf27dDOysG9t6+n/vv6+vr59uzr1tG+tZ6Qg9Ym3QAABR5JREFUSMeNlVVUG1EQhpcuxEspXqS0SKEtxQp1d3d332STTRpIQhIISQgJhODu7lAoDoUCpe7u7u7+1puGpqnCPOyZvffbOXPm/PsP9JfQgyCC+tmTABTOcbxDz/heENS7/1F+9nhvkHePG0wNDLbGWwdXL+rbLWvpmZHXD8+gMfBjTh+aSe6Gnn7lwQIOTR0c8wfX3PWgv7avbdKwf/ZoBp1Gp/PvuvXW3vw5ib7emnTW4OR+3D4jB9vjNJ/7gNvfWWeH/TO/JyYrsiKCRjVEZA3UB+96kON+DxOQ/NLE8PE5iUYgIXjFnCOlxEQMaSGVxjg4gxOnEycGz8bptuNjVx08LscIgrzH3umcn+KKtiBIyvzOO2O99aAdR8cF19oZalnCtvREUw79tCd5sow1g1UKM6kXqUx4T8wsi3sTjJ3yzDmmhenLXLpo8u45eG5y4Vvbk6kkC4LLtJMowkSQxmk4ggVJEG+7c6QpHT8vvW9X7/o7+3ELmiJi2mEzZJiz8cT6TBlanBk70cB5GGIGC1gRDdZ00yADLW1FL6gqhtvNXNG5S9gdSrk4M1qu7JAsmYshzDS4peoMrU/gT7qQdqYGZaYhxZmVbGJAm/CS/HloWyhRUlknQ9KYcExTwS80d3VNOxUZJpITYyspl0LbhArhpZCD9cRWEQuhYkNGMHToQ/2Cs6swJlb39CsllxdXX6IUKh/H5jbnSsPKjgmoaFQ1f8wRLR0UnGE/RcDEjj2jXG1WVTwUs8+zxfcrVO+vSsuOpVKxCfYZiQ0/aPKuxQbQ8lIz+DClxC8u+snlcJ7Yr1z1JPqUH0V+GDXbOwAib931Y4Imaq0NTIXPXY+N5L18GJ37SVWu+hwXff8l72Ds9XuwYIBaXPq6Shm4l+Vl/5QiOlV+uTk6YR9PxKsI9xNJny31ygK1e+nIRC1N97EGkFPI+jCpiHe5PCEy7oWqWSwRrpOvhFzcbTWMbm3ZJAOn1rUKpYIt/lDhW/5RHHteeWFN60qo98YJuoq1nK3uW5AabyspC1BcIEpOhft+SZAShYoLSvnmSfnYADUERP5jJn2h5XtsgCRuhYQqAvwTwn33+YWEKUI72HX5AtfSAZDe8F2DtPPm77afhl0EkthzuCQU0BWApgQIH9+KB0JhopMM7bJrdTRoleM2JAVNMyPF+wdoaz+XJpGoVAQ7WXUkcV7gT3oUZyi/ISIJAVKhgNp+4b4veCFhYVJw4locdSjZCp9cPUhLF9EZ3KKzURepMEtCDPP3VcWFx4UIiZIklIpFNfHpdEafIF2aRmOcrUmjohbT2WUllbmRvgfbythbQO3222fpDJoufaQPncYYuqoGtUEsCJZL6/3PR5b4syeSjZMQG/T2maGANlXT2v8S4AULWaUkCxfLyW8iW4kdka+nEMjxpL2NCwsYNBp+Q61PF43zyDg9Bm9+3NNySn78jMZUUkumqE4Gp7JmFOdP1vc8PpRrzj9+wPinCy8K1PiJ4aYbnTYpCCbDkBSbzhu2QJ1Gd82t8jI8TH51+OzvXoWbnXUOBkNW+0mWFwGcGOUVpU81/n3TOHb5oMt2FgYGjzau0Nif0Ss7Q3XB33hjjQHjHA5E5aOyIQc8CBrLdQSs3j92VG+3nNEjbkbdbBr9zm04ruvw37vh0QKOdeGIkckc80fX3KH/h7PT4BOjgCty8VZ5ux1MoO5Cf5naca2LAsEgehI+drX8o/0Nu+W0m6K/I9gGPd/dfx/EN/wN62AhsBWuAAAAAElFTkSuQmCC">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
|
||||
</div>
|
||||
</div>
|
||||
|
||||
## Overview
|
||||
# DistilBERT
|
||||
|
||||
The DistilBERT model was proposed in the blog post [Smaller, faster, cheaper, lighter: Introducing DistilBERT, a
|
||||
distilled version of BERT](https://medium.com/huggingface/distilbert-8cf3380435b5), and the paper [DistilBERT, a
|
||||
distilled version of BERT: smaller, faster, cheaper and lighter](https://arxiv.org/abs/1910.01108). DistilBERT is a
|
||||
small, fast, cheap and light Transformer model trained by distilling BERT base. It has 40% less parameters than
|
||||
*google-bert/bert-base-uncased*, runs 60% faster while preserving over 95% of BERT's performances as measured on the GLUE language
|
||||
understanding benchmark.
|
||||
[DistilBERT](https://huggingface.co/papers/1910.01108) is pretrained by knowledge distillation to create a smaller model with faster inference and requires less compute to train. Through a triple loss objective during pretraining, language modeling loss, distillation loss, cosine-distance loss, DistilBERT demonstrates similar performance to a larger transformer language model.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
You can find all the original DistilBERT checkpoints under the [DistilBERT](https://huggingface.co/distilbert) organization.
|
||||
|
||||
*As Transfer Learning from large-scale pre-trained models becomes more prevalent in Natural Language Processing (NLP),
|
||||
operating these large models in on-the-edge and/or under constrained computational training or inference budgets
|
||||
remains challenging. In this work, we propose a method to pre-train a smaller general-purpose language representation
|
||||
model, called DistilBERT, which can then be fine-tuned with good performances on a wide range of tasks like its larger
|
||||
counterparts. While most prior work investigated the use of distillation for building task-specific models, we leverage
|
||||
knowledge distillation during the pretraining phase and show that it is possible to reduce the size of a BERT model by
|
||||
40%, while retaining 97% of its language understanding capabilities and being 60% faster. To leverage the inductive
|
||||
biases learned by larger models during pretraining, we introduce a triple loss combining language modeling,
|
||||
distillation and cosine-distance losses. Our smaller, faster and lighter model is cheaper to pre-train and we
|
||||
demonstrate its capabilities for on-device computations in a proof-of-concept experiment and a comparative on-device
|
||||
study.*
|
||||
> [!TIP]
|
||||
> Click on the DistilBERT models in the right sidebar for more examples of how to apply DistilBERT to different language tasks.
|
||||
|
||||
This model was contributed by [victorsanh](https://huggingface.co/victorsanh). This model jax version was
|
||||
contributed by [kamalkraj](https://huggingface.co/kamalkraj). The original code can be found [here](https://github.com/huggingface/transformers-research-projects/tree/main/distillation).
|
||||
The example below demonstrates how to classify text with [`Pipeline`], [`AutoModel`], and from the command line.
|
||||
|
||||
## Usage tips
|
||||
<hfoptions id="usage">
|
||||
|
||||
<hfoption id="Pipeline">
|
||||
|
||||
```py
|
||||
from transformers import pipeline
|
||||
|
||||
classifier = pipeline(
|
||||
task="text-classification",
|
||||
model="distilbert-base-uncased-finetuned-sst-2-english",
|
||||
torch_dtype=torch.float16,
|
||||
device=0
|
||||
)
|
||||
|
||||
result = classifier("I love using Hugging Face Transformers!")
|
||||
print(result)
|
||||
# Output: [{'label': 'POSITIVE', 'score': 0.9998}]
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
|
||||
<hfoption id="AutoModel">
|
||||
|
||||
```py
|
||||
import torch
|
||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"distilbert/distilbert-base-uncased-finetuned-sst-2-english",
|
||||
)
|
||||
model = AutoModelForSequenceClassification.from_pretrained(
|
||||
"distilbert/distilbert-base-uncased-finetuned-sst-2-english",
|
||||
torch_dtype=torch.float16,
|
||||
device_map="auto",
|
||||
attn_implementation="sdpa"
|
||||
)
|
||||
inputs = tokenizer("I love using Hugging Face Transformers!", return_tensors="pt").to("cuda")
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
|
||||
predicted_class_id = torch.argmax(outputs.logits, dim=-1).item()
|
||||
predicted_label = model.config.id2label[predicted_class_id]
|
||||
print(f"Predicted label: {predicted_label}")
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
|
||||
<hfoption id="transformers-cli">
|
||||
|
||||
```bash
|
||||
echo -e "I love using Hugging Face Transformers!" | transformers-cli run --task text-classification --model distilbert-base-uncased-finetuned-sst-2-english
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
|
||||
</hfoptions>
|
||||
|
||||
## Notes
|
||||
|
||||
- DistilBERT doesn't have `token_type_ids`, you don't need to indicate which token belongs to which segment. Just
|
||||
separate your segments with the separation token `tokenizer.sep_token` (or `[SEP]`).
|
||||
- DistilBERT doesn't have options to select the input positions (`position_ids` input). This could be added if
|
||||
necessary though, just let us know if you need this option.
|
||||
- Same as BERT but smaller. Trained by distillation of the pretrained BERT model, meaning it’s been trained to predict the same probabilities as the larger model. The actual objective is a combination of:
|
||||
|
||||
* finding the same probabilities as the teacher model
|
||||
* predicting the masked tokens correctly (but no next-sentence objective)
|
||||
* a cosine similarity between the hidden states of the student and the teacher model
|
||||
|
||||
### Using Scaled Dot Product Attention (SDPA)
|
||||
|
||||
PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function
|
||||
encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the
|
||||
[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
|
||||
or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention)
|
||||
page for more information.
|
||||
|
||||
SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set
|
||||
`attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used.
|
||||
|
||||
```
|
||||
from transformers import DistilBertModel
|
||||
model = DistilBertModel.from_pretrained("distilbert-base-uncased", torch_dtype=torch.float16, attn_implementation="sdpa")
|
||||
```
|
||||
|
||||
For the best speedups, we recommend loading the model in half-precision (e.g. `torch.float16` or `torch.bfloat16`).
|
||||
|
||||
On a local benchmark (NVIDIA GeForce RTX 2060-8GB, PyTorch 2.3.1, OS Ubuntu 20.04) with `float16` and the `distilbert-base-uncased` model with
|
||||
a MaskedLM head, we saw the following speedups during training and inference.
|
||||
|
||||
#### Training
|
||||
|
||||
| num_training_steps | batch_size | seq_len | is cuda | Time per batch (eager - s) | Time per batch (sdpa - s) | Speedup (%) | Eager peak mem (MB) | sdpa peak mem (MB) | Mem saving (%) |
|
||||
|--------------------|------------|---------|---------|----------------------------|---------------------------|-------------|---------------------|--------------------|----------------|
|
||||
| 100 | 1 | 128 | False | 0.010 | 0.008 | 28.870 | 397.038 | 399.629 | -0.649 |
|
||||
| 100 | 1 | 256 | False | 0.011 | 0.009 | 20.681 | 412.505 | 412.606 | -0.025 |
|
||||
| 100 | 2 | 128 | False | 0.011 | 0.009 | 23.741 | 412.213 | 412.606 | -0.095 |
|
||||
| 100 | 2 | 256 | False | 0.015 | 0.013 | 16.502 | 427.491 | 425.787 | 0.400 |
|
||||
| 100 | 4 | 128 | False | 0.015 | 0.013 | 13.828 | 427.491 | 425.787 | 0.400 |
|
||||
| 100 | 4 | 256 | False | 0.025 | 0.022 | 12.882 | 594.156 | 502.745 | 18.182 |
|
||||
| 100 | 8 | 128 | False | 0.023 | 0.022 | 8.010 | 545.922 | 502.745 | 8.588 |
|
||||
| 100 | 8 | 256 | False | 0.046 | 0.041 | 12.763 | 983.450 | 798.480 | 23.165 |
|
||||
|
||||
#### Inference
|
||||
|
||||
| num_batches | batch_size | seq_len | is cuda | is half | use mask | Per token latency eager (ms) | Per token latency SDPA (ms) | Speedup (%) | Mem eager (MB) | Mem BT (MB) | Mem saved (%) |
|
||||
|-------------|------------|---------|---------|---------|----------|-----------------------------|-----------------------------|-------------|----------------|--------------|---------------|
|
||||
| 50 | 2 | 64 | True | True | True | 0.032 | 0.025 | 28.192 | 154.532 | 155.531 | -0.642 |
|
||||
| 50 | 2 | 128 | True | True | True | 0.033 | 0.025 | 32.636 | 157.286 | 157.482 | -0.125 |
|
||||
| 50 | 4 | 64 | True | True | True | 0.032 | 0.026 | 24.783 | 157.023 | 157.449 | -0.271 |
|
||||
| 50 | 4 | 128 | True | True | True | 0.034 | 0.028 | 19.299 | 162.794 | 162.269 | 0.323 |
|
||||
| 50 | 8 | 64 | True | True | True | 0.035 | 0.028 | 25.105 | 160.958 | 162.204 | -0.768 |
|
||||
| 50 | 8 | 128 | True | True | True | 0.052 | 0.046 | 12.375 | 173.155 | 171.844 | 0.763 |
|
||||
| 50 | 16 | 64 | True | True | True | 0.051 | 0.045 | 12.882 | 172.106 | 171.713 | 0.229 |
|
||||
| 50 | 16 | 128 | True | True | True | 0.096 | 0.081 | 18.524 | 191.257 | 191.517 | -0.136 |
|
||||
|
||||
|
||||
## Resources
|
||||
|
||||
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with DistilBERT. If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we'll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource.
|
||||
|
||||
<PipelineTag pipeline="text-classification"/>
|
||||
|
||||
- A blog post on [Getting Started with Sentiment Analysis using Python](https://huggingface.co/blog/sentiment-analysis-python) with DistilBERT.
|
||||
- A blog post on how to [train DistilBERT with Blurr for sequence classification](https://huggingface.co/blog/fastai).
|
||||
- A blog post on how to use [Ray to tune DistilBERT hyperparameters](https://huggingface.co/blog/ray-tune).
|
||||
- A blog post on how to [train DistilBERT with Hugging Face and Amazon SageMaker](https://huggingface.co/blog/the-partnership-amazon-sagemaker-and-hugging-face).
|
||||
- A notebook on how to [finetune DistilBERT for multi-label classification](https://colab.research.google.com/github/DhavalTaunk08/Transformers_scripts/blob/master/Transformers_multilabel_distilbert.ipynb). 🌎
|
||||
- A notebook on how to [finetune DistilBERT for multiclass classification with PyTorch](https://colab.research.google.com/github/abhimishra91/transformers-tutorials/blob/master/transformers_multiclass_classification.ipynb). 🌎
|
||||
- A notebook on how to [finetune DistilBERT for text classification in TensorFlow](https://colab.research.google.com/github/peterbayerle/huggingface_notebook/blob/main/distilbert_tf.ipynb). 🌎
|
||||
- [`DistilBertForSequenceClassification`] is supported by this [example script](https://github.com/huggingface/transformers/tree/main/examples/pytorch/text-classification) and [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/text_classification.ipynb).
|
||||
- [`TFDistilBertForSequenceClassification`] is supported by this [example script](https://github.com/huggingface/transformers/tree/main/examples/tensorflow/text-classification) and [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/text_classification-tf.ipynb).
|
||||
- [`FlaxDistilBertForSequenceClassification`] is supported by this [example script](https://github.com/huggingface/transformers/tree/main/examples/flax/text-classification) and [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/text_classification_flax.ipynb).
|
||||
- [Text classification task guide](../tasks/sequence_classification)
|
||||
|
||||
|
||||
<PipelineTag pipeline="token-classification"/>
|
||||
|
||||
- [`DistilBertForTokenClassification`] is supported by this [example script](https://github.com/huggingface/transformers/tree/main/examples/pytorch/token-classification) and [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/token_classification.ipynb).
|
||||
- [`TFDistilBertForTokenClassification`] is supported by this [example script](https://github.com/huggingface/transformers/tree/main/examples/tensorflow/token-classification) and [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/token_classification-tf.ipynb).
|
||||
- [`FlaxDistilBertForTokenClassification`] is supported by this [example script](https://github.com/huggingface/transformers/tree/main/examples/flax/token-classification).
|
||||
- [Token classification](https://huggingface.co/course/chapter7/2?fw=pt) chapter of the 🤗 Hugging Face Course.
|
||||
- [Token classification task guide](../tasks/token_classification)
|
||||
|
||||
|
||||
<PipelineTag pipeline="fill-mask"/>
|
||||
|
||||
- [`DistilBertForMaskedLM`] is supported by this [example script](https://github.com/huggingface/transformers/tree/main/examples/pytorch/language-modeling#robertabertdistilbert-and-masked-language-modeling) and [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/language_modeling.ipynb).
|
||||
- [`TFDistilBertForMaskedLM`] is supported by this [example script](https://github.com/huggingface/transformers/tree/main/examples/tensorflow/language-modeling#run_mlmpy) and [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/language_modeling-tf.ipynb).
|
||||
- [`FlaxDistilBertForMaskedLM`] is supported by this [example script](https://github.com/huggingface/transformers/tree/main/examples/flax/language-modeling#masked-language-modeling) and [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/masked_language_modeling_flax.ipynb).
|
||||
- [Masked language modeling](https://huggingface.co/course/chapter7/3?fw=pt) chapter of the 🤗 Hugging Face Course.
|
||||
- [Masked language modeling task guide](../tasks/masked_language_modeling)
|
||||
|
||||
<PipelineTag pipeline="question-answering"/>
|
||||
|
||||
- [`DistilBertForQuestionAnswering`] is supported by this [example script](https://github.com/huggingface/transformers/tree/main/examples/pytorch/question-answering) and [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/question_answering.ipynb).
|
||||
- [`TFDistilBertForQuestionAnswering`] is supported by this [example script](https://github.com/huggingface/transformers/tree/main/examples/tensorflow/question-answering) and [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/question_answering-tf.ipynb).
|
||||
- [`FlaxDistilBertForQuestionAnswering`] is supported by this [example script](https://github.com/huggingface/transformers/tree/main/examples/flax/question-answering).
|
||||
- [Question answering](https://huggingface.co/course/chapter7/7?fw=pt) chapter of the 🤗 Hugging Face Course.
|
||||
- [Question answering task guide](../tasks/question_answering)
|
||||
|
||||
**Multiple choice**
|
||||
- [`DistilBertForMultipleChoice`] is supported by this [example script](https://github.com/huggingface/transformers/tree/main/examples/pytorch/multiple-choice) and [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/multiple_choice.ipynb).
|
||||
- [`TFDistilBertForMultipleChoice`] is supported by this [example script](https://github.com/huggingface/transformers/tree/main/examples/tensorflow/multiple-choice) and [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/multiple_choice-tf.ipynb).
|
||||
- [Multiple choice task guide](../tasks/multiple_choice)
|
||||
|
||||
⚗️ Optimization
|
||||
|
||||
- A blog post on how to [quantize DistilBERT with 🤗 Optimum and Intel](https://huggingface.co/blog/intel).
|
||||
- A blog post on how [Optimizing Transformers for GPUs with 🤗 Optimum](https://www.philschmid.de/optimizing-transformers-with-optimum-gpu).
|
||||
- A blog post on [Optimizing Transformers with Hugging Face Optimum](https://www.philschmid.de/optimizing-transformers-with-optimum).
|
||||
|
||||
⚡️ Inference
|
||||
|
||||
- A blog post on how to [Accelerate BERT inference with Hugging Face Transformers and AWS Inferentia](https://huggingface.co/blog/bert-inferentia-sagemaker) with DistilBERT.
|
||||
- A blog post on [Serverless Inference with Hugging Face's Transformers, DistilBERT and Amazon SageMaker](https://www.philschmid.de/sagemaker-serverless-huggingface-distilbert).
|
||||
|
||||
🚀 Deploy
|
||||
|
||||
- A blog post on how to [deploy DistilBERT on Google Cloud](https://huggingface.co/blog/how-to-deploy-a-pipeline-to-google-clouds).
|
||||
- A blog post on how to [deploy DistilBERT with Amazon SageMaker](https://huggingface.co/blog/deploy-hugging-face-models-easily-with-amazon-sagemaker).
|
||||
- A blog post on how to [Deploy BERT with Hugging Face Transformers, Amazon SageMaker and Terraform module](https://www.philschmid.de/terraform-huggingface-amazon-sagemaker).
|
||||
|
||||
|
||||
## Combining DistilBERT and Flash Attention 2
|
||||
|
||||
First, make sure to install the latest version of Flash Attention 2 to include the sliding window attention feature.
|
||||
|
||||
```bash
|
||||
pip install -U flash-attn --no-build-isolation
|
||||
```
|
||||
|
||||
Make also sure that you have a hardware that is compatible with Flash-Attention 2. Read more about it in the official documentation of flash-attn repository. Make also sure to load your model in half-precision (e.g. `torch.float16`)
|
||||
|
||||
To load and run a model using Flash Attention 2, refer to the snippet below:
|
||||
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from transformers import AutoTokenizer, AutoModel
|
||||
|
||||
>>> device = "cuda" # the device to load the model onto
|
||||
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained('distilbert/distilbert-base-uncased')
|
||||
>>> model = AutoModel.from_pretrained("distilbert/distilbert-base-uncased", torch_dtype=torch.float16, attn_implementation="flash_attention_2")
|
||||
|
||||
>>> text = "Replace me by any text you'd like."
|
||||
|
||||
>>> encoded_input = tokenizer(text, return_tensors='pt').to(device)
|
||||
>>> model.to(device)
|
||||
|
||||
>>> output = model(**encoded_input)
|
||||
```
|
||||
|
||||
|
||||
## DistilBertConfig
|
||||
|
||||
|
||||
@ -18,6 +18,8 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
|
||||
## Overview
|
||||
|
||||
@ -14,66 +14,95 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# ELECTRA
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="TensorFlow" src="https://img.shields.io/badge/TensorFlow-FF6F00?style=flat&logo=tensorflow&logoColor=white">
|
||||
<img alt="Flax" src="https://img.shields.io/badge/Flax-29a79b.svg?style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAC0AAAAtCAMAAAANxBKoAAAC7lBMVEUAAADg5vYHPVgAoJH+/v76+v39/f9JbLP///9+AIgAnY3///+mcqzt8fXy9fgkXa3Ax9709fr+///9/f8qXq49qp5AaLGMwrv8/P0eW60VWawxYq8yqJzG2dytt9Wyu9elzci519Lf3O3S2efY3OrY0+Xp7PT///////+dqNCexMc6Z7AGpJeGvbenstPZ5ejQ1OfJzOLa7ejh4+/r8fT29vpccbklWK8PVa0AS6ghW63O498vYa+lsdKz1NDRt9Kw1c672tbD3tnAxt7R6OHp5vDe7OrDyuDn6vLl6/EAQKak0MgATakkppo3ZK/Bz9y8w9yzu9jey97axdvHzeG21NHH4trTwthKZrVGZLSUSpuPQJiGAI+GAI8SWKydycLL4d7f2OTi1+S9xNzL0ePT6OLGzeEAo5U0qJw/aLEAo5JFa7JBabEAp5Y4qZ2QxLyKmsm3kL2xoMOehrRNb7RIbbOZgrGre68AUqwAqZqNN5aKJ5N/lMq+qsd8kMa4pcWzh7muhLMEV69juq2kbKqgUaOTR5uMMZWLLZSGAI5VAIdEAH+ovNDHuNCnxcy3qcaYx8K8msGplrx+wLahjbYdXrV6vbMvYK9DrZ8QrZ8tqJuFms+Sos6sw8ecy8RffsNVeMCvmb43aLltv7Q4Y7EZWK4QWa1gt6meZKUdr6GOAZVeA4xPAISyveLUwtivxtKTpNJ2jcqfvcltiMiwwcfAoMVxhL+Kx7xjdrqTe60tsaNQs6KaRKACrJ6UTZwkqpqTL5pkHY4AloSgsd2ptNXPvNOOncuxxsqFl8lmg8apt8FJcr9EbryGxLqlkrkrY7dRa7ZGZLQ5t6iXUZ6PPpgVpZeJCJFKAIGareTa0+KJod3H0deY2M+esM25usmYu8d2zsJOdcBVvrCLbqcAOaaHaKQAMaScWqKBXqCXMJ2RHpiLF5NmJZAdAHN2kta11dKu1M+DkcZLdb+Mcql3TppyRJdzQ5ZtNZNlIY+DF4+voCOQAAAAZ3RSTlMABAT+MEEJ/RH+/TP+Zlv+pUo6Ifz8+fco/fz6+evr39S9nJmOilQaF/7+/f38+smmoYp6b1T+/v7++vj189zU0tDJxsGzsrKSfv34+Pf27dDOysG9t6+n/vv6+vr59uzr1tG+tZ6Qg9Ym3QAABR5JREFUSMeNlVVUG1EQhpcuxEspXqS0SKEtxQp1d3d332STTRpIQhIISQgJhODu7lAoDoUCpe7u7u7+1puGpqnCPOyZvffbOXPm/PsP9JfQgyCC+tmTABTOcbxDz/heENS7/1F+9nhvkHePG0wNDLbGWwdXL+rbLWvpmZHXD8+gMfBjTh+aSe6Gnn7lwQIOTR0c8wfX3PWgv7avbdKwf/ZoBp1Gp/PvuvXW3vw5ib7emnTW4OR+3D4jB9vjNJ/7gNvfWWeH/TO/JyYrsiKCRjVEZA3UB+96kON+DxOQ/NLE8PE5iUYgIXjFnCOlxEQMaSGVxjg4gxOnEycGz8bptuNjVx08LscIgrzH3umcn+KKtiBIyvzOO2O99aAdR8cF19oZalnCtvREUw79tCd5sow1g1UKM6kXqUx4T8wsi3sTjJ3yzDmmhenLXLpo8u45eG5y4Vvbk6kkC4LLtJMowkSQxmk4ggVJEG+7c6QpHT8vvW9X7/o7+3ELmiJi2mEzZJiz8cT6TBlanBk70cB5GGIGC1gRDdZ00yADLW1FL6gqhtvNXNG5S9gdSrk4M1qu7JAsmYshzDS4peoMrU/gT7qQdqYGZaYhxZmVbGJAm/CS/HloWyhRUlknQ9KYcExTwS80d3VNOxUZJpITYyspl0LbhArhpZCD9cRWEQuhYkNGMHToQ/2Cs6swJlb39CsllxdXX6IUKh/H5jbnSsPKjgmoaFQ1f8wRLR0UnGE/RcDEjj2jXG1WVTwUs8+zxfcrVO+vSsuOpVKxCfYZiQ0/aPKuxQbQ8lIz+DClxC8u+snlcJ7Yr1z1JPqUH0V+GDXbOwAib931Y4Imaq0NTIXPXY+N5L18GJ37SVWu+hwXff8l72Ds9XuwYIBaXPq6Shm4l+Vl/5QiOlV+uTk6YR9PxKsI9xNJny31ygK1e+nIRC1N97EGkFPI+jCpiHe5PCEy7oWqWSwRrpOvhFzcbTWMbm3ZJAOn1rUKpYIt/lDhW/5RHHteeWFN60qo98YJuoq1nK3uW5AabyspC1BcIEpOhft+SZAShYoLSvnmSfnYADUERP5jJn2h5XtsgCRuhYQqAvwTwn33+YWEKUI72HX5AtfSAZDe8F2DtPPm77afhl0EkthzuCQU0BWApgQIH9+KB0JhopMM7bJrdTRoleM2JAVNMyPF+wdoaz+XJpGoVAQ7WXUkcV7gT3oUZyi/ISIJAVKhgNp+4b4veCFhYVJw4locdSjZCp9cPUhLF9EZ3KKzURepMEtCDPP3VcWFx4UIiZIklIpFNfHpdEafIF2aRmOcrUmjohbT2WUllbmRvgfbythbQO3222fpDJoufaQPncYYuqoGtUEsCJZL6/3PR5b4syeSjZMQG/T2maGANlXT2v8S4AULWaUkCxfLyW8iW4kdka+nEMjxpL2NCwsYNBp+Q61PF43zyDg9Bm9+3NNySn78jMZUUkumqE4Gp7JmFOdP1vc8PpRrzj9+wPinCy8K1PiJ4aYbnTYpCCbDkBSbzhu2QJ1Gd82t8jI8TH51+OzvXoWbnXUOBkNW+0mWFwGcGOUVpU81/n3TOHb5oMt2FgYGjzau0Nif0Ss7Q3XB33hjjQHjHA5E5aOyIQc8CBrLdQSs3j92VG+3nNEjbkbdbBr9zm04ruvw37vh0QKOdeGIkckc80fX3KH/h7PT4BOjgCty8VZ5ux1MoO5Cf5naca2LAsEgehI+drX8o/0Nu+W0m6K/I9gGPd/dfx/EN/wN62AhsBWuAAAAAElFTkSuQmCC
|
||||
<div style="float: right;">
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="TensorFlow" src="https://img.shields.io/badge/TensorFlow-FF6F00?style=flat&logo=tensorflow&logoColor=white">
|
||||
<img alt="Flax" src="https://img.shields.io/badge/Flax-29a79b.svg?style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAC0AAAAtCAMAAAANxBKoAAAC7lBMVEUAAADg5vYHPVgAoJH+/v76+v39/f9JbLP///9+AIgAnY3///+mcqzt8fXy9fgkXa3Ax9709fr+///9/f8qXq49qp5AaLGMwrv8/P0eW60VWawxYq8yqJzG2dytt9Wyu9elzci519Lf3O3S2efY3OrY0+Xp7PT///////+dqNCexMc6Z7AGpJeGvbenstPZ5ejQ1OfJzOLa7ejh4+/r8fT29vpccbklWK8PVa0AS6ghW63O498vYa+lsdKz1NDRt9Kw1c672tbD3tnAxt7R6OHp5vDe7OrDyuDn6vLl6/EAQKak0MgATakkppo3ZK/Bz9y8w9yzu9jey97axdvHzeG21NHH4trTwthKZrVGZLSUSpuPQJiGAI+GAI8SWKydycLL4d7f2OTi1+S9xNzL0ePT6OLGzeEAo5U0qJw/aLEAo5JFa7JBabEAp5Y4qZ2QxLyKmsm3kL2xoMOehrRNb7RIbbOZgrGre68AUqwAqZqNN5aKJ5N/lMq+qsd8kMa4pcWzh7muhLMEV69juq2kbKqgUaOTR5uMMZWLLZSGAI5VAIdEAH+ovNDHuNCnxcy3qcaYx8K8msGplrx+wLahjbYdXrV6vbMvYK9DrZ8QrZ8tqJuFms+Sos6sw8ecy8RffsNVeMCvmb43aLltv7Q4Y7EZWK4QWa1gt6meZKUdr6GOAZVeA4xPAISyveLUwtivxtKTpNJ2jcqfvcltiMiwwcfAoMVxhL+Kx7xjdrqTe60tsaNQs6KaRKACrJ6UTZwkqpqTL5pkHY4AloSgsd2ptNXPvNOOncuxxsqFl8lmg8apt8FJcr9EbryGxLqlkrkrY7dRa7ZGZLQ5t6iXUZ6PPpgVpZeJCJFKAIGareTa0+KJod3H0deY2M+esM25usmYu8d2zsJOdcBVvrCLbqcAOaaHaKQAMaScWqKBXqCXMJ2RHpiLF5NmJZAdAHN2kta11dKu1M+DkcZLdb+Mcql3TppyRJdzQ5ZtNZNlIY+DF4+voCOQAAAAZ3RSTlMABAT+MEEJ/RH+/TP+Zlv+pUo6Ifz8+fco/fz6+evr39S9nJmOilQaF/7+/f38+smmoYp6b1T+/v7++vj189zU0tDJxsGzsrKSfv34+Pf27dDOysG9t6+n/vv6+vr59uzr1tG+tZ6Qg9Ym3QAABR5JREFUSMeNlVVUG1EQhpcuxEspXqS0SKEtxQp1d3d332STTRpIQhIISQgJhODu7lAoDoUCpe7u7u7+1puGpqnCPOyZvffbOXPm/PsP9JfQgyCC+tmTABTOcbxDz/heENS7/1F+9nhvkHePG0wNDLbGWwdXL+rbLWvpmZHXD8+gMfBjTh+aSe6Gnn7lwQIOTR0c8wfX3PWgv7avbdKwf/ZoBp1Gp/PvuvXW3vw5ib7emnTW4OR+3D4jB9vjNJ/7gNvfWWeH/TO/JyYrsiKCRjVEZA3UB+96kON+DxOQ/NLE8PE5iUYgIXjFnCOlxEQMaSGVxjg4gxOnEycGz8bptuNjVx08LscIgrzH3umcn+KKtiBIyvzOO2O99aAdR8cF19oZalnCtvREUw79tCd5sow1g1UKM6kXqUx4T8wsi3sTjJ3yzDmmhenLXLpo8u45eG5y4Vvbk6kkC4LLtJMowkSQxmk4ggVJEG+7c6QpHT8vvW9X7/o7+3ELmiJi2mEzZJiz8cT6TBlanBk70cB5GGIGC1gRDdZ00yADLW1FL6gqhtvNXNG5S9gdSrk4M1qu7JAsmYshzDS4peoMrU/gT7qQdqYGZaYhxZmVbGJAm/CS/HloWyhRUlknQ9KYcExTwS80d3VNOxUZJpITYyspl0LbhArhpZCD9cRWEQuhYkNGMHToQ/2Cs6swJlb39CsllxdXX6IUKh/H5jbnSsPKjgmoaFQ1f8wRLR0UnGE/RcDEjj2jXG1WVTwUs8+zxfcrVO+vSsuOpVKxCfYZiQ0/aPKuxQbQ8lIz+DClxC8u+snlcJ7Yr1z1JPqUH0V+GDXbOwAib931Y4Imaq0NTIXPXY+N5L18GJ37SVWu+hwXff8l72Ds9XuwYIBaXPq6Shm4l+Vl/5QiOlV+uTk6YR9PxKsI9xNJny31ygK1e+nIRC1N97EGkFPI+jCpiHe5PCEy7oWqWSwRrpOvhFzcbTWMbm3ZJAOn1rUKpYIt/lDhW/5RHHteeWFN60qo98YJuoq1nK3uW5AabyspC1BcIEpOhft+SZAShYoLSvnmSfnYADUERP5jJn2h5XtsgCRuhYQqAvwTwn33+YWEKUI72HX5AtfSAZDe8F2DtPPm77afhl0EkthzuCQU0BWApgQIH9+KB0JhopMM7bJrdTRoleM2JAVNMyPF+wdoaz+XJpGoVAQ7WXUkcV7gT3oUZyi/ISIJAVKhgNp+4b4veCFhYVJw4locdSjZCp9cPUhLF9EZ3KKzURepMEtCDPP3VcWFx4UIiZIklIpFNfHpdEafIF2aRmOcrUmjohbT2WUllbmRvgfbythbQO3222fpDJoufaQPncYYuqoGtUEsCJZL6/3PR5b4syeSjZMQG/T2maGANlXT2v8S4AULWaUkCxfLyW8iW4kdka+nEMjxpL2NCwsYNBp+Q61PF43zyDg9Bm9+3NNySn78jMZUUkumqE4Gp7JmFOdP1vc8PpRrzj9+wPinCy8K1PiJ4aYbnTYpCCbDkBSbzhu2QJ1Gd82t8jI8TH51+OzvXoWbnXUOBkNW+0mWFwGcGOUVpU81/n3TOHb5oMt2FgYGjzau0Nif0Ss7Q3XB33hjjQHjHA5E5aOyIQc8CBrLdQSs3j92VG+3nNEjbkbdbBr9zm04ruvw37vh0QKOdeGIkckc80fX3KH/h7PT4BOjgCty8VZ5ux1MoO5Cf5naca2LAsEgehI+drX8o/0Nu+W0m6K/I9gGPd/dfx/EN/wN62AhsBWuAAAAAElFTkSuQmCC
|
||||
">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
</div>
|
||||
|
||||
## Overview
|
||||
# ELECTRA
|
||||
|
||||
The ELECTRA model was proposed in the paper [ELECTRA: Pre-training Text Encoders as Discriminators Rather Than
|
||||
Generators](https://openreview.net/pdf?id=r1xMH1BtvB). ELECTRA is a new pretraining approach which trains two
|
||||
transformer models: the generator and the discriminator. The generator's role is to replace tokens in a sequence, and
|
||||
is therefore trained as a masked language model. The discriminator, which is the model we're interested in, tries to
|
||||
identify which tokens were replaced by the generator in the sequence.
|
||||
[ELECTRA](https://huggingface.co/papers/2003.10555) modifies the pretraining objective of traditional masked language models like BERT. Instead of just masking tokens and asking the model to predict them, ELECTRA trains two models, a generator and a discriminator. The generator replaces some tokens with plausible alternatives and the discriminator (the model you'll actually use) learns to detect which tokens are original and which were replaced. This training approach is very efficient and scales to larger models while using considerably less compute.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
This approach is super efficient because ELECTRA learns from every single token in the input, not just the masked ones. That's why even the small ELECTRA models can match or outperform much larger models while using way less computing resources.
|
||||
|
||||
*Masked language modeling (MLM) pretraining methods such as BERT corrupt the input by replacing some tokens with [MASK]
|
||||
and then train a model to reconstruct the original tokens. While they produce good results when transferred to
|
||||
downstream NLP tasks, they generally require large amounts of compute to be effective. As an alternative, we propose a
|
||||
more sample-efficient pretraining task called replaced token detection. Instead of masking the input, our approach
|
||||
corrupts it by replacing some tokens with plausible alternatives sampled from a small generator network. Then, instead
|
||||
of training a model that predicts the original identities of the corrupted tokens, we train a discriminative model that
|
||||
predicts whether each token in the corrupted input was replaced by a generator sample or not. Thorough experiments
|
||||
demonstrate this new pretraining task is more efficient than MLM because the task is defined over all input tokens
|
||||
rather than just the small subset that was masked out. As a result, the contextual representations learned by our
|
||||
approach substantially outperform the ones learned by BERT given the same model size, data, and compute. The gains are
|
||||
particularly strong for small models; for example, we train a model on one GPU for 4 days that outperforms GPT (trained
|
||||
using 30x more compute) on the GLUE natural language understanding benchmark. Our approach also works well at scale,
|
||||
where it performs comparably to RoBERTa and XLNet while using less than 1/4 of their compute and outperforms them when
|
||||
using the same amount of compute.*
|
||||
You can find all the original ELECTRA checkpoints under the [ELECTRA](https://huggingface.co/collections/google/electra-release-64ff6e8b18830fabea30a1ab) release.
|
||||
|
||||
This model was contributed by [lysandre](https://huggingface.co/lysandre). The original code can be found [here](https://github.com/google-research/electra).
|
||||
> [!TIP]
|
||||
> Click on the right sidebar for more examples of how to use ELECTRA for different language tasks like sequence classification, token classification, and question answering.
|
||||
|
||||
## Usage tips
|
||||
The example below demonstrates how to classify text with [`Pipeline`] or the [`AutoModel`] class.
|
||||
|
||||
- ELECTRA is the pretraining approach, therefore there is nearly no changes done to the underlying model: BERT. The
|
||||
only change is the separation of the embedding size and the hidden size: the embedding size is generally smaller,
|
||||
while the hidden size is larger. An additional projection layer (linear) is used to project the embeddings from their
|
||||
embedding size to the hidden size. In the case where the embedding size is the same as the hidden size, no projection
|
||||
layer is used.
|
||||
- ELECTRA is a transformer model pretrained with the use of another (small) masked language model. The inputs are corrupted by that language model, which takes an input text that is randomly masked and outputs a text in which ELECTRA has to predict which token is an original and which one has been replaced. Like for GAN training, the small language model is trained for a few steps (but with the original texts as objective, not to fool the ELECTRA model like in a traditional GAN setting) then the ELECTRA model is trained for a few steps.
|
||||
- The ELECTRA checkpoints saved using [Google Research's implementation](https://github.com/google-research/electra)
|
||||
contain both the generator and discriminator. The conversion script requires the user to name which model to export
|
||||
into the correct architecture. Once converted to the HuggingFace format, these checkpoints may be loaded into all
|
||||
available ELECTRA models, however. This means that the discriminator may be loaded in the
|
||||
[`ElectraForMaskedLM`] model, and the generator may be loaded in the
|
||||
[`ElectraForPreTraining`] model (the classification head will be randomly initialized as it
|
||||
doesn't exist in the generator).
|
||||
<hfoptions id="usage">
|
||||
<hfoption id="Pipeline">
|
||||
|
||||
## Resources
|
||||
```py
|
||||
import torch
|
||||
from transformers import pipeline
|
||||
|
||||
- [Text classification task guide](../tasks/sequence_classification)
|
||||
- [Token classification task guide](../tasks/token_classification)
|
||||
- [Question answering task guide](../tasks/question_answering)
|
||||
- [Causal language modeling task guide](../tasks/language_modeling)
|
||||
- [Masked language modeling task guide](../tasks/masked_language_modeling)
|
||||
- [Multiple choice task guide](../tasks/multiple_choice)
|
||||
classifier = pipeline(
|
||||
task="text-classification",
|
||||
model="bhadresh-savani/electra-base-emotion",
|
||||
torch_dtype=torch.float16,
|
||||
device=0
|
||||
)
|
||||
classifier("This restaurant has amazing food!")
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="AutoModel">
|
||||
|
||||
```py
|
||||
import torch
|
||||
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"bhadresh-savani/electra-base-emotion",
|
||||
)
|
||||
model = AutoModelForSequenceClassification.from_pretrained(
|
||||
"bhadresh-savani/electra-base-emotion",
|
||||
torch_dtype=torch.float16
|
||||
)
|
||||
inputs = tokenizer("ELECTRA is more efficient than BERT", return_tensors="pt")
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
logits = outputs.logits
|
||||
predicted_class_id = logits.argmax(dim=-1).item()
|
||||
predicted_label = model.config.id2label[predicted_class_id]
|
||||
print(f"Predicted label: {predicted_label}")
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="transformers-cli">
|
||||
|
||||
```bash
|
||||
echo -e "This restaurant has amazing food." | transformers-cli run --task text-classification --model bhadresh-savani/electra-base-emotion --device 0
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
## Notes
|
||||
|
||||
- ELECTRA consists of two transformer models, a generator (G) and a discriminator (D). For most downstream tasks, use the discriminator model (as indicated by `*-discriminator` in the name) rather than the generator.
|
||||
- ELECTRA comes in three sizes: small (14M parameters), base (110M parameters), and large (335M parameters).
|
||||
- ELECTRA can use a smaller embedding size than the hidden size for efficiency. When `embedding_size` is smaller than `hidden_size` in the configuration, a projection layer connects them.
|
||||
- When using batched inputs with padding, make sure to use attention masks to prevent the model from attending to padding tokens.
|
||||
|
||||
```py
|
||||
# Example of properly handling padding with attention masks
|
||||
inputs = tokenizer(["Short text", "This is a much longer text that needs padding"],
|
||||
padding=True,
|
||||
return_tensors="pt")
|
||||
outputs = model(**inputs) # automatically uses the attention_mask
|
||||
```
|
||||
|
||||
- When using the discriminator for a downstream task, you can load it into any of the ELECTRA model classes ([`ElectraForSequenceClassification`], [`ElectraForTokenClassification`], etc.).
|
||||
|
||||
## ElectraConfig
|
||||
|
||||
|
||||
@ -14,48 +14,113 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# Falcon
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<div style="float: right;">
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
</div>
|
||||
|
||||
## Overview
|
||||
# Falcon
|
||||
|
||||
Falcon is a class of causal decoder-only models built by [TII](https://www.tii.ae/). The largest Falcon checkpoints
|
||||
have been trained on >=1T tokens of text, with a particular emphasis on the [RefinedWeb](https://arxiv.org/abs/2306.01116)
|
||||
corpus. They are made available under the Apache 2.0 license.
|
||||
[Falcon](https://huggingface.co/papers/2311.16867) is a family of large language models, available in 7B, 40B, and 180B parameters, as pretrained and instruction tuned variants. This model focuses on scaling pretraining over three categories, performance, data, and hardware. Falcon uses multigroup attention to significantly reduce inference memory requirements and rotary positional embeddings (RoPE). These models are pretrained on [RefinedWeb](https://huggingface.co/datasets/tiiuae/falcon-refinedweb), a high-quality and deduplicated 5T token dataset.
|
||||
|
||||
You can find all the original Falcon checkpoints under the [Falcon](https://huggingface.co/collections/tiiuae/falcon-64fb432660017eeec9837b5a) collection.
|
||||
|
||||
Falcon's architecture is modern and optimized for inference, with multi-query attention and support for efficient
|
||||
attention variants like `FlashAttention`. Both 'base' models trained only as causal language models as well as
|
||||
'instruct' models that have received further fine-tuning are available.
|
||||
> [!TIP]
|
||||
> Click on the Falcon models in the right sidebar for more examples of how to apply Falcon to different language tasks.
|
||||
|
||||
The example below demonstrates how to generate text with [`Pipeline`], [`AutoModel`], and from the command line.
|
||||
|
||||
Falcon models are (as of 2023) some of the largest and most powerful open-source language models,
|
||||
and consistently rank highly in the [OpenLLM leaderboard](https://huggingface.co/spaces/HuggingFaceH4/open_llm_leaderboard).
|
||||
<hfoptions id="usage">
|
||||
<hfoption id="Pipeline">
|
||||
|
||||
## Converting custom checkpoints
|
||||
```py
|
||||
import torch
|
||||
from transformers import pipeline
|
||||
|
||||
<Tip>
|
||||
pipeline = pipeline(
|
||||
task="text-generation",
|
||||
model="tiiuae/falcon-7b-instruct",
|
||||
torch_dtype=torch.bfloat16,
|
||||
device=0
|
||||
)
|
||||
pipeline(
|
||||
"Write a short poem about coding",
|
||||
max_length=100,
|
||||
do_sample=True,
|
||||
temperature=0.7
|
||||
)
|
||||
```
|
||||
|
||||
Falcon models were initially added to the Hugging Face Hub as custom code checkpoints. However, Falcon is now fully
|
||||
supported in the Transformers library. If you fine-tuned a model from a custom code checkpoint, we recommend converting
|
||||
your checkpoint to the new in-library format, as this should give significant improvements to stability and
|
||||
performance, especially for generation, as well as removing the need to use `trust_remote_code=True`!
|
||||
</hfoption>
|
||||
<hfoption id="AutoModel">
|
||||
|
||||
</Tip>
|
||||
```py
|
||||
import torch
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
|
||||
You can convert custom code checkpoints to full Transformers checkpoints using the `convert_custom_code_checkpoint.py`
|
||||
script located in the
|
||||
[Falcon model directory](https://github.com/huggingface/transformers/tree/main/src/transformers/models/falcon)
|
||||
of the Transformers library. To use this script, simply call it with
|
||||
`python convert_custom_code_checkpoint.py --checkpoint_dir my_model`. This will convert your checkpoint in-place, and
|
||||
you can immediately load it from the directory afterwards with e.g. `from_pretrained()`. If your model hasn't been
|
||||
uploaded to the Hub, we recommend making a backup before attempting the conversion, just in case!
|
||||
tokenizer = AutoTokenizer.from_pretrained("tiiuae/falcon-7b-instruct")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"tiiuae/falcon-7b-instruct",
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
attn_implementation="sdpa",
|
||||
)
|
||||
|
||||
input_ids = tokenizer("Write a short poem about coding", return_tensors="pt").to("cuda")
|
||||
|
||||
output = model.generate(**input_ids)
|
||||
print(tokenizer.decode(output[0], skip_special_tokens=True))
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="transformers-cli">
|
||||
|
||||
```bash
|
||||
# pip install -U flash-attn --no-build-isolation
|
||||
transformers-cli chat --model_name_or_path tiiuae/falcon-7b-instruct --torch_dtype auto --attn_implementation flash_attention_2 --device 0
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
Quantization reduces the memory burden of large models by representing the weights in a lower precision. Refer to the [Quantization](../quantization/overview) overview for more available quantization backends.
|
||||
|
||||
The example below uses [bitsandbytes](../quantization/bitsandbytes) to only quantize the weights to 4-bits.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
|
||||
|
||||
quantization_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_compute_dtype=torch.bfloat16,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
bnb_4bit_use_double_quant=True,
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("tiiuae/falcon-7b")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"tiiuae/falcon-7b",
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
quantization_config=quantization_config,
|
||||
)
|
||||
|
||||
inputs = tokenizer("In quantum physics, entanglement means", return_tensors="pt").to("cuda")
|
||||
outputs = model.generate(**inputs, max_new_tokens=100)
|
||||
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
||||
```
|
||||
|
||||
## Notes
|
||||
|
||||
- If you're upgrading from an older custom code checkpoint, remember to convert it to the official Transformers format for better stability and performance using the conversion script located in the [Falcon model directory](https://github.com/huggingface/transformers/tree/main/src/transformers/models/falcon).
|
||||
|
||||
```bash
|
||||
python convert_custom_code_checkpoint.py --checkpoint_dir my_model
|
||||
```
|
||||
|
||||
## FalconConfig
|
||||
|
||||
@ -85,6 +150,4 @@ uploaded to the Hub, we recommend making a backup before attempting the conversi
|
||||
## FalconForQuestionAnswering
|
||||
|
||||
[[autodoc]] FalconForQuestionAnswering
|
||||
- forward
|
||||
|
||||
|
||||
- forward
|
||||
@ -15,74 +15,63 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# Gemma3
|
||||
<div style="float: right;">
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
</div>
|
||||
|
||||
## Overview
|
||||
# Gemma 3
|
||||
|
||||
The Gemma 3 model was proposed in the [Gemma 3 Techncial Report](https://goo.gle/Gemma3Report) by Google. It is a vision-language model composed by a [SigLIP](siglip) vision encoder and a [Gemma 2](gemma_2) language decoder, linked by a multimodal linear projection. It cuts an image into a fixed number of tokens, in the same way as SigLIP, as long as the image does not exceed certain aspect ratio. For images that exceed the given aspect ratio, it crops the image into multiple smaller patches and concatenates them with the base image embedding. One particularity is that the model uses bidirectional attention on all the image tokens. In addition, the model interleaves sliding window local attention with full causal attention in the language backbone, where each sixth layer is a full causal attention layer.
|
||||
[Gemma 3](https://goo.gle/Gemma3Report) is a multimodal model with pretrained and instruction-tuned variants, available in 1B, 13B, and 27B parameters. The architecture is mostly the same as the previous Gemma versions. The key differences are alternating 5 local sliding window self-attention layers for every global self-attention layer, support for a longer context length of 128K tokens, and a [SigLip](./siglip) encoder that can "pan & scan" high-resolution images to prevent information from disappearing in high resolution images or images with non-square aspect ratios.
|
||||
|
||||
This model was contributed by [Ryan Mullins](https://huggingface.co/RyanMullins), [Raushan Turganbay](https://huggingface.co/RaushanTurganbay) [Arthur Zucker](https://huggingface.co/ArthurZ), and [Pedro Cuenca](https://huggingface.co/pcuenq).
|
||||
The instruction-tuned variant was post-trained with knowledge distillation and reinforcement learning.
|
||||
|
||||
You can find all the original Gemma 3 checkpoints under the [Gemma 3](https://huggingface.co/collections/meta-llama/llama-2-family-661da1f90a9d678b6f55773b) release.
|
||||
|
||||
## Usage tips
|
||||
> [!TIP]
|
||||
> Click on the Gemma 3 models in the right sidebar for more examples of how to apply Gemma to different vision and language tasks.
|
||||
|
||||
The example below demonstrates how to generate text based on an image with [`Pipeline`] or the [`AutoModel`] class.
|
||||
|
||||
- For image+text and image-only inputs use `Gemma3ForConditionalGeneration`.
|
||||
- For text-only inputs use `Gemma3ForCausalLM` for generation to avoid loading the vision tower.
|
||||
- Each sample can contain multiple images, and the number of images can vary between samples. However, make sure to pass correctly batched images to the processor, where each batch is a list of one or more images.
|
||||
- The text passed to the processor should have a `<start_of_image>` token wherever an image should be inserted.
|
||||
- The processor has its own `apply_chat_template` method to convert chat messages to model inputs. See the examples below for more details on how to use it.
|
||||
<hfoptions id="usage">
|
||||
<hfoption id="Pipeline">
|
||||
|
||||
```py
|
||||
import torch
|
||||
from transformers import pipeline
|
||||
|
||||
### Image cropping for high resolution images
|
||||
|
||||
The model supports cropping images into smaller patches when the image aspect ratio exceeds a certain value. By default the images are not cropped and only the base image is forwarded to the model. Users can set `do_pan_and_scan=True` to obtain several crops per image along with the base image to improve the quality in DocVQA or similar tasks requiring higher resolution images.
|
||||
|
||||
Pan and scan is an inference time optimization to handle images with skewed aspect ratios. When enabled, it improves performance on tasks related to document understanding, infographics, OCR, etc.
|
||||
|
||||
```python
|
||||
|
||||
processor = AutoProcessor.from_pretrained("google/gemma-3-4b-it", padding_side="left")
|
||||
|
||||
url = "https://media.istockphoto.com/id/1192867753/photo/cow-in-berchida-beach-siniscola.jpg?s=612x612&w=0&k=20&c=v0hjjniwsMNfJSuKWZuIn8pssmD5h5bSN1peBd1CmH4="
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": [
|
||||
{"type": "text", "text": "You are a helpful assistant."}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "user", "content": [
|
||||
{"type": "image", "url": url},
|
||||
{"type": "text", "text": "What is shown in this image?"},
|
||||
]
|
||||
},
|
||||
]
|
||||
inputs = processor.apply_chat_template(
|
||||
messages,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
add_generation_prompt=True,
|
||||
do_pan_and_scan=True,
|
||||
).to(model.device)
|
||||
|
||||
pipeline = pipeline(
|
||||
task="image-text-to-text",
|
||||
model="google/gemma-3-4b-pt",
|
||||
device=0,
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
pipeline(
|
||||
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg",
|
||||
text="<start_of_image> What is shown in this image?"
|
||||
)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="AutoModel">
|
||||
|
||||
## Usage Example
|
||||
|
||||
### Single-image Inference
|
||||
|
||||
```python
|
||||
```py
|
||||
import torch
|
||||
from transformers import AutoProcessor, Gemma3ForConditionalGeneration
|
||||
|
||||
model_id = "google/gemma-3-4b-it"
|
||||
model = Gemma3ForConditionalGeneration.from_pretrained(model_id, device_map="auto")
|
||||
processor = AutoProcessor.from_pretrained(model_id, padding_side="left")
|
||||
model = Gemma3ForConditionalGeneration.from_pretrained(
|
||||
"google/gemma-3-4b-it",
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
attn_implementation="sdpa"
|
||||
)
|
||||
processor = AutoProcessor.from_pretrained(
|
||||
"google/gemma-3-4b-it",
|
||||
padding_side="left"
|
||||
)
|
||||
|
||||
url = "https://media.istockphoto.com/id/1192867753/photo/cow-in-berchida-beach-siniscola.jpg?s=612x612&w=0&k=20&c=v0hjjniwsMNfJSuKWZuIn8pssmD5h5bSN1peBd1CmH4="
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
@ -92,7 +81,7 @@ messages = [
|
||||
},
|
||||
{
|
||||
"role": "user", "content": [
|
||||
{"type": "image", "url": url},
|
||||
{"type": "image", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"},
|
||||
{"type": "text", "text": "What is shown in this image?"},
|
||||
]
|
||||
},
|
||||
@ -103,21 +92,43 @@ inputs = processor.apply_chat_template(
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
add_generation_prompt=True,
|
||||
).to(model.device)
|
||||
).to("cuda")
|
||||
|
||||
output = model.generate(**inputs, max_new_tokens=50)
|
||||
print(processor.decode(output[0], skip_special_tokens=True)[inputs.input_ids.shape[1]: ])
|
||||
output = model.generate(**inputs, max_new_tokens=50, cache_implementation="static")
|
||||
print(processor.decode(output[0], skip_special_tokens=True))
|
||||
```
|
||||
|
||||
### Multi-image Inference
|
||||
</hfoption>
|
||||
<hfoption id="transformers-cli">
|
||||
|
||||
```python
|
||||
model_id = "google/gemma-3-4b-it"
|
||||
model = Gemma3ForConditionalGeneration.from_pretrained(model_id, device_map="auto")
|
||||
processor = AutoProcessor.from_pretrained(model_id, padding_side="left")
|
||||
```bash
|
||||
echo -e "Plants create energy through a process known as" | transformers-cli run --task text-generation --model google/gemma-3-1b-pt --device 0
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
Quantization reduces the memory burden of large models by representing the weights in a lower precision. Refer to the [Quantization](../quantization/overview) overview for more available quantization backends.
|
||||
|
||||
The example below uses [torchao](../quantization/torchao) to only quantize the weights to int4.
|
||||
|
||||
```py
|
||||
# pip install torchao
|
||||
import torch
|
||||
from transformers import TorchAoConfig, Gemma3ForConditionalGeneration, AutoProcessor
|
||||
|
||||
quantization_config = TorchAoConfig("int4_weight_only", group_size=128)
|
||||
model = Gemma3ForConditionalGeneration.from_pretrained(
|
||||
"google/gemma-3-27b-it",
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
quantization_config=quantization_config
|
||||
)
|
||||
processor = AutoProcessor.from_pretrained(
|
||||
"google/gemma-3-27b-it",
|
||||
padding_side="left"
|
||||
)
|
||||
|
||||
url_cow = "https://media.istockphoto.com/id/1192867753/photo/cow-in-berchida-beach-siniscola.jpg?s=612x612&w=0&k=20&c=v0hjjniwsMNfJSuKWZuIn8pssmD5h5bSN1peBd1CmH4="
|
||||
url_stop = "https://www.ilankelman.org/stopsigns/australia.jpg"
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
@ -127,9 +138,8 @@ messages = [
|
||||
},
|
||||
{
|
||||
"role": "user", "content": [
|
||||
{"type": "image", "url": url_cow},
|
||||
{"type": "image", "url": url_stop},
|
||||
{"type": "text", "text": "Are these two images identical?"},
|
||||
{"type": "image", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"},
|
||||
{"type": "text", "text": "What is shown in this image?"},
|
||||
]
|
||||
},
|
||||
]
|
||||
@ -139,33 +149,85 @@ inputs = processor.apply_chat_template(
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
add_generation_prompt=True,
|
||||
).to(model.device)
|
||||
|
||||
output = model.generate(**inputs, max_new_tokens=50)
|
||||
print(processor.decode(output[0], skip_special_tokens=True)[inputs.input_ids.shape[1]: ])
|
||||
).to("cuda")
|
||||
|
||||
output = model.generate(**inputs, max_new_tokens=50, cache_implementation="static")
|
||||
print(processor.decode(output[0], skip_special_tokens=True))
|
||||
```
|
||||
|
||||
### Text-only inference
|
||||
Use the [AttentionMaskVisualizer](https://github.com/huggingface/transformers/blob/beb9b5b02246b9b7ee81ddf938f93f44cfeaad19/src/transformers/utils/attention_visualizer.py#L139) to better understand what tokens the model can and cannot attend to.
|
||||
|
||||
You can use the VLMs for text-only generation by omitting images in your input. However, you can also load the models in text-only mode as shown below. This will skip loading the vision tower and will save resources when you just need the LLM capabilities.
|
||||
```python
|
||||
from transformers import AutoTokenizer, Gemma3ForCausalLM
|
||||
|
||||
model_id = "google/gemma-3-1b-it"
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
model = Gemma3ForCausalLM.from_pretrained(model_id, device_map="auto")
|
||||
|
||||
input_ids = tokenizer("Write me a poem about Machine Learning.", return_tensors="pt").to(model.device)
|
||||
|
||||
outputs = model.generate(**input_ids, max_new_tokens=100)
|
||||
text = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
|
||||
print(text)
|
||||
```py
|
||||
from transformers.utils.attention_visualizer import AttentionMaskVisualizer
|
||||
|
||||
visualizer = AttentionMaskVisualizer("google/gemma-3-4b-it")
|
||||
visualizer("<img>What is shown in this image?")
|
||||
```
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/gemma-3-attn-mask.png"/>
|
||||
</div>
|
||||
|
||||
## Notes
|
||||
|
||||
- Use [`Gemma3ForConditionalGeneration`] for image-and-text and image-only inputs.
|
||||
- Gemma 3 supports multiple input images, but make sure the images are correctly batched before passing them to the processor. Each batch should be a list of one or more images.
|
||||
|
||||
```py
|
||||
url_cow = "https://media.istockphoto.com/id/1192867753/photo/cow-in-berchida-beach-siniscola.jpg?s=612x612&w=0&k=20&c=v0hjjniwsMNfJSuKWZuIn8pssmD5h5bSN1peBd1CmH4="
|
||||
url_cat = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
|
||||
|
||||
messages =[
|
||||
{
|
||||
"role": "system",
|
||||
"content": [
|
||||
{"type": "text", "text": "You are a helpful assistant."}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image", "url": url_cow},
|
||||
{"type": "image", "url": url_cat},
|
||||
{"type": "text", "text": "Which image is cuter?"},
|
||||
]
|
||||
},
|
||||
]
|
||||
```
|
||||
- Text passed to the processor should have a `<start_of_image>` token wherever an image should be inserted.
|
||||
- The processor has its own [`~ProcessorMixin.apply_chat_template`] method to convert chat messages to model inputs.
|
||||
- By default, images aren't cropped and only the base image is forwarded to the model. In high resolution images or images with non-square aspect ratios, artifacts can result because the vision encoder uses a fixed resolution of 896x896. To prevent these artifacts and improve performance during inference, set `do_pan_and_scan=True` to crop the image into multiple smaller patches and concatenate them with the base image embedding. You can disable pan and scan for faster inference.
|
||||
|
||||
```diff
|
||||
inputs = processor.apply_chat_template(
|
||||
messages,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
add_generation_prompt=True,
|
||||
+ do_pan_and_scan=True,
|
||||
).to("cuda")
|
||||
```
|
||||
- For Gemma-3 1B checkpoint trained in text-only mode, use [`AutoModelForCausalLM`] instead.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"google/gemma-3-1b-pt",
|
||||
)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"google/gemma-3-1b-pt",
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
attn_implementation="sdpa"
|
||||
)
|
||||
input_ids = tokenizer("Plants create energy through a process known as", return_tensors="pt").to("cuda")
|
||||
|
||||
output = model.generate(**input_ids, cache_implementation="static")
|
||||
print(tokenizer.decode(output[0], skip_special_tokens=True))
|
||||
```
|
||||
|
||||
## Gemma3ImageProcessor
|
||||
|
||||
|
||||
@ -71,9 +71,10 @@ pip install -U flash-attn --no-build-isolation
|
||||
Below is an expected speedup diagram comparing the pure inference time between the native implementation in transformers of `facebook/hubert-large-ls960-ft`, the flash-attention-2 and the sdpa (scale-dot-product-attention) version. We show the average speedup obtained on the `librispeech_asr` `clean` validation split:
|
||||
|
||||
```python
|
||||
>>> from transformers import Wav2Vec2Model
|
||||
>>> from transformers import HubertModel
|
||||
>>> import torch
|
||||
|
||||
model = Wav2Vec2Model.from_pretrained("facebook/hubert-large-ls960-ft", torch_dtype=torch.float16, attn_implementation="flash_attention_2").to(device)
|
||||
>>> model = HubertModel.from_pretrained("facebook/hubert-large-ls960-ft", torch_dtype=torch.float16, attn_implementation="flash_attention_2").to("cuda")
|
||||
...
|
||||
```
|
||||
|
||||
|
||||
@ -18,6 +18,7 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
|
||||
|
||||
@ -14,79 +14,115 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# LLaMA
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="Flax" src="https://img.shields.io/badge/Flax-29a79b.svg?style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAC0AAAAtCAMAAAANxBKoAAAC7lBMVEUAAADg5vYHPVgAoJH+/v76+v39/f9JbLP///9+AIgAnY3///+mcqzt8fXy9fgkXa3Ax9709fr+///9/f8qXq49qp5AaLGMwrv8/P0eW60VWawxYq8yqJzG2dytt9Wyu9elzci519Lf3O3S2efY3OrY0+Xp7PT///////+dqNCexMc6Z7AGpJeGvbenstPZ5ejQ1OfJzOLa7ejh4+/r8fT29vpccbklWK8PVa0AS6ghW63O498vYa+lsdKz1NDRt9Kw1c672tbD3tnAxt7R6OHp5vDe7OrDyuDn6vLl6/EAQKak0MgATakkppo3ZK/Bz9y8w9yzu9jey97axdvHzeG21NHH4trTwthKZrVGZLSUSpuPQJiGAI+GAI8SWKydycLL4d7f2OTi1+S9xNzL0ePT6OLGzeEAo5U0qJw/aLEAo5JFa7JBabEAp5Y4qZ2QxLyKmsm3kL2xoMOehrRNb7RIbbOZgrGre68AUqwAqZqNN5aKJ5N/lMq+qsd8kMa4pcWzh7muhLMEV69juq2kbKqgUaOTR5uMMZWLLZSGAI5VAIdEAH+ovNDHuNCnxcy3qcaYx8K8msGplrx+wLahjbYdXrV6vbMvYK9DrZ8QrZ8tqJuFms+Sos6sw8ecy8RffsNVeMCvmb43aLltv7Q4Y7EZWK4QWa1gt6meZKUdr6GOAZVeA4xPAISyveLUwtivxtKTpNJ2jcqfvcltiMiwwcfAoMVxhL+Kx7xjdrqTe60tsaNQs6KaRKACrJ6UTZwkqpqTL5pkHY4AloSgsd2ptNXPvNOOncuxxsqFl8lmg8apt8FJcr9EbryGxLqlkrkrY7dRa7ZGZLQ5t6iXUZ6PPpgVpZeJCJFKAIGareTa0+KJod3H0deY2M+esM25usmYu8d2zsJOdcBVvrCLbqcAOaaHaKQAMaScWqKBXqCXMJ2RHpiLF5NmJZAdAHN2kta11dKu1M+DkcZLdb+Mcql3TppyRJdzQ5ZtNZNlIY+DF4+voCOQAAAAZ3RSTlMABAT+MEEJ/RH+/TP+Zlv+pUo6Ifz8+fco/fz6+evr39S9nJmOilQaF/7+/f38+smmoYp6b1T+/v7++vj189zU0tDJxsGzsrKSfv34+Pf27dDOysG9t6+n/vv6+vr59uzr1tG+tZ6Qg9Ym3QAABR5JREFUSMeNlVVUG1EQhpcuxEspXqS0SKEtxQp1d3d332STTRpIQhIISQgJhODu7lAoDoUCpe7u7u7+1puGpqnCPOyZvffbOXPm/PsP9JfQgyCC+tmTABTOcbxDz/heENS7/1F+9nhvkHePG0wNDLbGWwdXL+rbLWvpmZHXD8+gMfBjTh+aSe6Gnn7lwQIOTR0c8wfX3PWgv7avbdKwf/ZoBp1Gp/PvuvXW3vw5ib7emnTW4OR+3D4jB9vjNJ/7gNvfWWeH/TO/JyYrsiKCRjVEZA3UB+96kON+DxOQ/NLE8PE5iUYgIXjFnCOlxEQMaSGVxjg4gxOnEycGz8bptuNjVx08LscIgrzH3umcn+KKtiBIyvzOO2O99aAdR8cF19oZalnCtvREUw79tCd5sow1g1UKM6kXqUx4T8wsi3sTjJ3yzDmmhenLXLpo8u45eG5y4Vvbk6kkC4LLtJMowkSQxmk4ggVJEG+7c6QpHT8vvW9X7/o7+3ELmiJi2mEzZJiz8cT6TBlanBk70cB5GGIGC1gRDdZ00yADLW1FL6gqhtvNXNG5S9gdSrk4M1qu7JAsmYshzDS4peoMrU/gT7qQdqYGZaYhxZmVbGJAm/CS/HloWyhRUlknQ9KYcExTwS80d3VNOxUZJpITYyspl0LbhArhpZCD9cRWEQuhYkNGMHToQ/2Cs6swJlb39CsllxdXX6IUKh/H5jbnSsPKjgmoaFQ1f8wRLR0UnGE/RcDEjj2jXG1WVTwUs8+zxfcrVO+vSsuOpVKxCfYZiQ0/aPKuxQbQ8lIz+DClxC8u+snlcJ7Yr1z1JPqUH0V+GDXbOwAib931Y4Imaq0NTIXPXY+N5L18GJ37SVWu+hwXff8l72Ds9XuwYIBaXPq6Shm4l+Vl/5QiOlV+uTk6YR9PxKsI9xNJny31ygK1e+nIRC1N97EGkFPI+jCpiHe5PCEy7oWqWSwRrpOvhFzcbTWMbm3ZJAOn1rUKpYIt/lDhW/5RHHteeWFN60qo98YJuoq1nK3uW5AabyspC1BcIEpOhft+SZAShYoLSvnmSfnYADUERP5jJn2h5XtsgCRuhYQqAvwTwn33+YWEKUI72HX5AtfSAZDe8F2DtPPm77afhl0EkthzuCQU0BWApgQIH9+KB0JhopMM7bJrdTRoleM2JAVNMyPF+wdoaz+XJpGoVAQ7WXUkcV7gT3oUZyi/ISIJAVKhgNp+4b4veCFhYVJw4locdSjZCp9cPUhLF9EZ3KKzURepMEtCDPP3VcWFx4UIiZIklIpFNfHpdEafIF2aRmOcrUmjohbT2WUllbmRvgfbythbQO3222fpDJoufaQPncYYuqoGtUEsCJZL6/3PR5b4syeSjZMQG/T2maGANlXT2v8S4AULWaUkCxfLyW8iW4kdka+nEMjxpL2NCwsYNBp+Q61PF43zyDg9Bm9+3NNySn78jMZUUkumqE4Gp7JmFOdP1vc8PpRrzj9+wPinCy8K1PiJ4aYbnTYpCCbDkBSbzhu2QJ1Gd82t8jI8TH51+OzvXoWbnXUOBkNW+0mWFwGcGOUVpU81/n3TOHb5oMt2FgYGjzau0Nif0Ss7Q3XB33hjjQHjHA5E5aOyIQc8CBrLdQSs3j92VG+3nNEjbkbdbBr9zm04ruvw37vh0QKOdeGIkckc80fX3KH/h7PT4BOjgCty8VZ5ux1MoO5Cf5naca2LAsEgehI+drX8o/0Nu+W0m6K/I9gGPd/dfx/EN/wN62AhsBWuAAAAAElFTkSuQmCC
|
||||
">
|
||||
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<div style="float: right;">
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="Flax" src="https://img.shields.io/badge/Flax-29a79b.svg?style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAC0AAAAtCAMAAAANxBKoAAAC7lBMVEUAAADg5vYHPVgAoJH+/v76+v39/f9JbLP///9+AIgAnY3///+mcqzt8fXy9fgkXa3Ax9709fr+///9/f8qXq49qp5AaLGMwrv8/P0eW60VWawxYq8yqJzG2dytt9Wyu9elzci519Lf3O3S2efY3OrY0+Xp7PT///////+dqNCexMc6Z7AGpJeGvbenstPZ5ejQ1OfJzOLa7ejh4+/r8fT29vpccbklWK8PVa0AS6ghW63O498vYa+lsdKz1NDRt9Kw1c672tbD3tnAxt7R6OHp5vDe7OrDyuDn6vLl6/EAQKak0MgATakkppo3ZK/Bz9y8w9yzu9jey97axdvHzeG21NHH4trTwthKZrVGZLSUSpuPQJiGAI+GAI8SWKydycLL4d7f2OTi1+S9xNzL0ePT6OLGzeEAo5U0qJw/aLEAo5JFa7JBabEAp5Y4qZ2QxLyKmsm3kL2xoMOehrRNb7RIbbOZgrGre68AUqwAqZqNN5aKJ5N/lMq+qsd8kMa4pcWzh7muhLMEV69juq2kbKqgUaOTR5uMMZWLLZSGAI5VAIdEAH+ovNDHuNCnxcy3qcaYx8K8msGplrx+wLahjbYdXrV6vbMvYK9DrZ8QrZ8tqJuFms+Sos6sw8ecy8RffsNVeMCvmb43aLltv7Q4Y7EZWK4QWa1gt6meZKUdr6GOAZVeA4xPAISyveLUwtivxtKTpNJ2jcqfvcltiMiwwcfAoMVxhL+Kx7xjdrqTe60tsaNQs6KaRKACrJ6UTZwkqpqTL5pkHY4AloSgsd2ptNXPvNOOncuxxsqFl8lmg8apt8FJcr9EbryGxLqlkrkrY7dRa7ZGZLQ5t6iXUZ6PPpgVpZeJCJFKAIGareTa0+KJod3H0deY2M+esM25usmYu8d2zsJOdcBVvrCLbqcAOaaHaKQAMaScWqKBXqCXMJ2RHpiLF5NmJZAdAHN2kta11dKu1M+DkcZLdb+Mcql3TppyRJdzQ5ZtNZNlIY+DF4+voCOQAAAAZ3RSTlMABAT+MEEJ/RH+/TP+Zlv+pUo6Ifz8+fco/fz6+evr39S9nJmOilQaF/7+/f38+smmoYp6b1T+/v7++vj189zU0tDJxsGzsrKSfv34+Pf27dDOysG9t6+n/vv6+vr59uzr1tG+tZ6Qg9Ym3QAABR5JREFUSMeNlVVUG1EQhpcuxEspXqS0SKEtxQp1d3d332STTRpIQhIISQgJhODu7lAoDoUCpe7u7u7+1puGpqnCPOyZvffbOXPm/PsP9JfQgyCC+tmTABTOcbxDz/heENS7/1F+9nhvkHePG0wNDLbGWwdXL+rbLWvpmZHXD8+gMfBjTh+aSe6Gnn7lwQIOTR0c8wfX3PWgv7avbdKwf/ZoBp1Gp/PvuvXW3vw5ib7emnTW4OR+3D4jB9vjNJ/7gNvfWWeH/TO/JyYrsiKCRjVEZA3UB+96kON+DxOQ/NLE8PE5iUYgIXjFnCOlxEQMaSGVxjg4gxOnEycGz8bptuNjVx08LscIgrzH3umcn+KKtiBIyvzOO2O99aAdR8cF19oZalnCtvREUw79tCd5sow1g1UKM6kXqUx4T8wsi3sTjJ3yzDmmhenLXLpo8u45eG5y4Vvbk6kkC4LLtJMowkSQxmk4ggVJEG+7c6QpHT8vvW9X7/o7+3ELmiJi2mEzZJiz8cT6TBlanBk70cB5GGIGC1gRDdZ00yADLW1FL6gqhtvNXNG5S9gdSrk4M1qu7JAsmYshzDS4peoMrU/gT7qQdqYGZaYhxZmVbGJAm/CS/HloWyhRUlknQ9KYcExTwS80d3VNOxUZJpITYyspl0LbhArhpZCD9cRWEQuhYkNGMHToQ/2Cs6swJlb39CsllxdXX6IUKh/H5jbnSsPKjgmoaFQ1f8wRLR0UnGE/RcDEjj2jXG1WVTwUs8+zxfcrVO+vSsuOpVKxCfYZiQ0/aPKuxQbQ8lIz+DClxC8u+snlcJ7Yr1z1JPqUH0V+GDXbOwAib931Y4Imaq0NTIXPXY+N5L18GJ37SVWu+hwXff8l72Ds9XuwYIBaXPq6Shm4l+Vl/5QiOlV+uTk6YR9PxKsI9xNJny31ygK1e+nIRC1N97EGkFPI+jCpiHe5PCEy7oWqWSwRrpOvhFzcbTWMbm3ZJAOn1rUKpYIt/lDhW/5RHHteeWFN60qo98YJuoq1nK3uW5AabyspC1BcIEpOhft+SZAShYoLSvnmSfnYADUERP5jJn2h5XtsgCRuhYQqAvwTwn33+YWEKUI72HX5AtfSAZDe8F2DtPPm77afhl0EkthzuCQU0BWApgQIH9+KB0JhopMM7bJrdTRoleM2JAVNMyPF+wdoaz+XJpGoVAQ7WXUkcV7gT3oUZyi/ISIJAVKhgNp+4b4veCFhYVJw4locdSjZCp9cPUhLF9EZ3KKzURepMEtCDPP3VcWFx4UIiZIklIpFNfHpdEafIF2aRmOcrUmjohbT2WUllbmRvgfbythbQO3222fpDJoufaQPncYYuqoGtUEsCJZL6/3PR5b4syeSjZMQG/T2maGANlXT2v8S4AULWaUkCxfLyW8iW4kdka+nEMjxpL2NCwsYNBp+Q61PF43zyDg9Bm9+3NNySn78jMZUUkumqE4Gp7JmFOdP1vc8PpRrzj9+wPinCy8K1PiJ4aYbnTYpCCbDkBSbzhu2QJ1Gd82t8jI8TH51+OzvXoWbnXUOBkNW+0mWFwGcGOUVpU81/n3TOHb5oMt2FgYGjzau0Nif0Ss7Q3XB33hjjQHjHA5E5aOyIQc8CBrLdQSs3j92VG+3nNEjbkbdbBr9zm04ruvw37vh0QKOdeGIkckc80fX3KH/h7PT4BOjgCty8VZ5ux1MoO5Cf5naca2LAsEgehI+drX8o/0Nu+W0m6K/I9gGPd/dfx/EN/wN62AhsBWuAAAAAElFTkSuQmCC
|
||||
">
|
||||
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
</div>
|
||||
|
||||
## Overview
|
||||
# Llama
|
||||
|
||||
The LLaMA model was proposed in [LLaMA: Open and Efficient Foundation Language Models](https://arxiv.org/abs/2302.13971) by Hugo Touvron, Thibaut Lavril, Gautier Izacard, Xavier Martinet, Marie-Anne Lachaux, Timothée Lacroix, Baptiste Rozière, Naman Goyal, Eric Hambro, Faisal Azhar, Aurelien Rodriguez, Armand Joulin, Edouard Grave, Guillaume Lample. It is a collection of foundation language models ranging from 7B to 65B parameters.
|
||||
[Llama](https://huggingface.co/papers/2302.13971) is a family of large language models ranging from 7B to 65B parameters. These models are focused on efficient inference (important for serving language models) by training a smaller model on more tokens rather than training a larger model on fewer tokens. The Llama model is based on the GPT architecture, but it uses pre-normalization to improve training stability, replaces ReLU with SwiGLU to improve performance, and replaces absolute positional embeddings with rotary positional embeddings (RoPE) to better handle longer sequence lengths.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
You can find all the original Llama checkpoints under the [Huggy Llama](https://huggingface.co/huggyllama) organization.
|
||||
|
||||
*We introduce LLaMA, a collection of foundation language models ranging from 7B to 65B parameters. We train our models on trillions of tokens, and show that it is possible to train state-of-the-art models using publicly available datasets exclusively, without resorting to proprietary and inaccessible datasets. In particular, LLaMA-13B outperforms GPT-3 (175B) on most benchmarks, and LLaMA-65B is competitive with the best models, Chinchilla-70B and PaLM-540B. We release all our models to the research community. *
|
||||
> [!TIP]
|
||||
> Click on the Llama models in the right sidebar for more examples of how to apply Llama to different language tasks.
|
||||
|
||||
This model was contributed by [zphang](https://huggingface.co/zphang) with contributions from [BlackSamorez](https://huggingface.co/BlackSamorez). The code of the implementation in Hugging Face is based on GPT-NeoX [here](https://github.com/EleutherAI/gpt-neox). The original code of the authors can be found [here](https://github.com/facebookresearch/llama).
|
||||
The example below demonstrates how to generate text with [`Pipeline`] or the [`AutoModel`], and from the command line.
|
||||
|
||||
## Usage tips
|
||||
<hfoptions id="usage">
|
||||
<hfoption id="Pipeline">
|
||||
|
||||
- Weights for the LLaMA models can be obtained from by filling out [this form](https://docs.google.com/forms/d/e/1FAIpQLSfqNECQnMkycAp2jP4Z9TFX0cGR4uf7b_fBxjY_OjhJILlKGA/viewform?usp=send_form)
|
||||
- After downloading the weights, they will need to be converted to the Hugging Face Transformers format using the [conversion script](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/convert_llama_weights_to_hf.py). The script can be called with the following (example) command:
|
||||
```py
|
||||
import torch
|
||||
from transformers import pipeline
|
||||
|
||||
pipeline = pipeline(
|
||||
task="text-generation",
|
||||
model="huggyllama/llama-7b",
|
||||
torch_dtype=torch.float16,
|
||||
device=0
|
||||
)
|
||||
pipeline("Plants create energy through a process known as")
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="AutoModel">
|
||||
|
||||
```py
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"huggyllama/llama-7b",
|
||||
)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"huggyllama/llama-7b",
|
||||
torch_dtype=torch.float16,
|
||||
device_map="auto",
|
||||
attn_implementation="sdpa"
|
||||
)
|
||||
input_ids = tokenizer("Plants create energy through a process known as", return_tensors="pt").to("cuda")
|
||||
|
||||
output = model.generate(**input_ids, cache_implementation="static")
|
||||
print(tokenizer.decode(output[0], skip_special_tokens=True))
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="transformers-cli">
|
||||
|
||||
```bash
|
||||
python src/transformers/models/llama/convert_llama_weights_to_hf.py \
|
||||
--input_dir /path/to/downloaded/llama/weights --model_size 7B --output_dir /output/path
|
||||
echo -e "Plants create energy through a process known as" | transformers-cli run --task text-generation --model huggyllama/llama-7b --device 0
|
||||
```
|
||||
|
||||
- After conversion, the model and tokenizer can be loaded via:
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
```python
|
||||
from transformers import LlamaForCausalLM, LlamaTokenizer
|
||||
Quantization reduces the memory burden of large models by representing the weights in a lower precision. Refer to the [Quantization](../quantization/overview) overview for more available quantization backends.
|
||||
|
||||
tokenizer = LlamaTokenizer.from_pretrained("/output/path")
|
||||
model = LlamaForCausalLM.from_pretrained("/output/path")
|
||||
The example below uses [torchao](../quantization/torchao) to only quantize the weights to int4.
|
||||
|
||||
```py
|
||||
# pip install torchao
|
||||
import torch
|
||||
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
quantization_config = TorchAoConfig("int4_weight_only", group_size=128)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"huggyllama/llama-30b",
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
quantization_config=quantization_config
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-30b")
|
||||
input_ids = tokenizer("Plants create energy through a process known as", return_tensors="pt").to("cuda")
|
||||
|
||||
output = model.generate(**input_ids, cache_implementation="static")
|
||||
print(tokenizer.decode(output[0], skip_special_tokens=True))
|
||||
```
|
||||
|
||||
Note that executing the script requires enough CPU RAM to host the whole model in float16 precision (even if the biggest versions
|
||||
come in several checkpoints they each contain a part of each weight of the model, so we need to load them all in RAM). For the 65B model, it's thus 130GB of RAM needed.
|
||||
Use the [AttentionMaskVisualizer](https://github.com/huggingface/transformers/blob/beb9b5b02246b9b7ee81ddf938f93f44cfeaad19/src/transformers/utils/attention_visualizer.py#L139) to better understand what tokens the model can and cannot attend to.
|
||||
|
||||
- The LLaMA tokenizer is a BPE model based on [sentencepiece](https://github.com/google/sentencepiece). One quirk of sentencepiece is that when decoding a sequence, if the first token is the start of the word (e.g. "Banana"), the tokenizer does not prepend the prefix space to the string.
|
||||
```py
|
||||
from transformers.utils.attention_visualizer import AttentionMaskVisualizer
|
||||
|
||||
This model was contributed by [zphang](https://huggingface.co/zphang) with contributions from [BlackSamorez](https://huggingface.co/BlackSamorez). The code of the implementation in Hugging Face is based on GPT-NeoX [here](https://github.com/EleutherAI/gpt-neox). The original code of the authors can be found [here](https://github.com/facebookresearch/llama). The Flax version of the implementation was contributed by [afmck](https://huggingface.co/afmck) with the code in the implementation based on Hugging Face's Flax GPT-Neo.
|
||||
visualizer = AttentionMaskVisualizer("huggyllama/llama-7b")
|
||||
visualizer("Plants create energy through a process known as")
|
||||
```
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/llama-attn-mask.png"/>
|
||||
</div>
|
||||
|
||||
Based on the original LLaMA model, Meta AI has released some follow-up works:
|
||||
## Notes
|
||||
|
||||
- **Llama2**: Llama2 is an improved version of Llama with some architectural tweaks (Grouped Query Attention), and is pre-trained on 2Trillion tokens. Refer to the documentation of Llama2 which can be found [here](llama2).
|
||||
|
||||
## Resources
|
||||
|
||||
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with LLaMA. If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we'll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource.
|
||||
|
||||
<PipelineTag pipeline="text-classification"/>
|
||||
|
||||
- A [notebook](https://colab.research.google.com/github/bigscience-workshop/petals/blob/main/examples/prompt-tuning-sst2.ipynb#scrollTo=f04ba4d2) on how to use prompt tuning to adapt the LLaMA model for text classification task. 🌎
|
||||
|
||||
<PipelineTag pipeline="question-answering"/>
|
||||
|
||||
- [StackLLaMA: A hands-on guide to train LLaMA with RLHF](https://huggingface.co/blog/stackllama#stackllama-a-hands-on-guide-to-train-llama-with-rlhf), a blog post about how to train LLaMA to answer questions on [Stack Exchange](https://stackexchange.com/) with RLHF.
|
||||
|
||||
⚗️ Optimization
|
||||
- A [notebook](https://colab.research.google.com/drive/1SQUXq1AMZPSLD4mk3A3swUIc6Y2dclme?usp=sharing) on how to fine-tune LLaMA model using xturing library on GPU which has limited memory. 🌎
|
||||
|
||||
⚡️ Inference
|
||||
- A [notebook](https://colab.research.google.com/github/DominguesM/alpaca-lora-ptbr-7b/blob/main/notebooks/02%20-%20Evaluate.ipynb) on how to run the LLaMA Model using PeftModel from the 🤗 PEFT library. 🌎
|
||||
- A [notebook](https://colab.research.google.com/drive/1l2GiSSPbajVyp2Nk3CFT4t3uH6-5TiBe?usp=sharing) on how to load a PEFT adapter LLaMA model with LangChain. 🌎
|
||||
|
||||
🚀 Deploy
|
||||
- A [notebook](https://colab.research.google.com/github/lxe/simple-llama-finetuner/blob/master/Simple_LLaMA_FineTuner.ipynb#scrollTo=3PM_DilAZD8T) on how to fine-tune LLaMA model using LoRA method via the 🤗 PEFT library with intuitive UI. 🌎
|
||||
- A [notebook](https://github.com/aws/amazon-sagemaker-examples/blob/main/introduction_to_amazon_algorithms/jumpstart-foundation-models/text-generation-open-llama.ipynb) on how to deploy Open-LLaMA model for text generation on Amazon SageMaker. 🌎
|
||||
- The tokenizer is a byte-pair encoding model based on [SentencePiece](https://github.com/google/sentencepiece). During decoding, if the first token is the start of the word (for example, "Banana"), the tokenizer doesn't prepend the prefix space to the string.
|
||||
|
||||
## LlamaConfig
|
||||
|
||||
|
||||
@ -14,97 +14,129 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# Llama2
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="Flax" src="https://img.shields.io/badge/Flax-29a79b.svg?style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAC0AAAAtCAMAAAANxBKoAAAC7lBMVEUAAADg5vYHPVgAoJH+/v76+v39/f9JbLP///9+AIgAnY3///+mcqzt8fXy9fgkXa3Ax9709fr+///9/f8qXq49qp5AaLGMwrv8/P0eW60VWawxYq8yqJzG2dytt9Wyu9elzci519Lf3O3S2efY3OrY0+Xp7PT///////+dqNCexMc6Z7AGpJeGvbenstPZ5ejQ1OfJzOLa7ejh4+/r8fT29vpccbklWK8PVa0AS6ghW63O498vYa+lsdKz1NDRt9Kw1c672tbD3tnAxt7R6OHp5vDe7OrDyuDn6vLl6/EAQKak0MgATakkppo3ZK/Bz9y8w9yzu9jey97axdvHzeG21NHH4trTwthKZrVGZLSUSpuPQJiGAI+GAI8SWKydycLL4d7f2OTi1+S9xNzL0ePT6OLGzeEAo5U0qJw/aLEAo5JFa7JBabEAp5Y4qZ2QxLyKmsm3kL2xoMOehrRNb7RIbbOZgrGre68AUqwAqZqNN5aKJ5N/lMq+qsd8kMa4pcWzh7muhLMEV69juq2kbKqgUaOTR5uMMZWLLZSGAI5VAIdEAH+ovNDHuNCnxcy3qcaYx8K8msGplrx+wLahjbYdXrV6vbMvYK9DrZ8QrZ8tqJuFms+Sos6sw8ecy8RffsNVeMCvmb43aLltv7Q4Y7EZWK4QWa1gt6meZKUdr6GOAZVeA4xPAISyveLUwtivxtKTpNJ2jcqfvcltiMiwwcfAoMVxhL+Kx7xjdrqTe60tsaNQs6KaRKACrJ6UTZwkqpqTL5pkHY4AloSgsd2ptNXPvNOOncuxxsqFl8lmg8apt8FJcr9EbryGxLqlkrkrY7dRa7ZGZLQ5t6iXUZ6PPpgVpZeJCJFKAIGareTa0+KJod3H0deY2M+esM25usmYu8d2zsJOdcBVvrCLbqcAOaaHaKQAMaScWqKBXqCXMJ2RHpiLF5NmJZAdAHN2kta11dKu1M+DkcZLdb+Mcql3TppyRJdzQ5ZtNZNlIY+DF4+voCOQAAAAZ3RSTlMABAT+MEEJ/RH+/TP+Zlv+pUo6Ifz8+fco/fz6+evr39S9nJmOilQaF/7+/f38+smmoYp6b1T+/v7++vj189zU0tDJxsGzsrKSfv34+Pf27dDOysG9t6+n/vv6+vr59uzr1tG+tZ6Qg9Ym3QAABR5JREFUSMeNlVVUG1EQhpcuxEspXqS0SKEtxQp1d3d332STTRpIQhIISQgJhODu7lAoDoUCpe7u7u7+1puGpqnCPOyZvffbOXPm/PsP9JfQgyCC+tmTABTOcbxDz/heENS7/1F+9nhvkHePG0wNDLbGWwdXL+rbLWvpmZHXD8+gMfBjTh+aSe6Gnn7lwQIOTR0c8wfX3PWgv7avbdKwf/ZoBp1Gp/PvuvXW3vw5ib7emnTW4OR+3D4jB9vjNJ/7gNvfWWeH/TO/JyYrsiKCRjVEZA3UB+96kON+DxOQ/NLE8PE5iUYgIXjFnCOlxEQMaSGVxjg4gxOnEycGz8bptuNjVx08LscIgrzH3umcn+KKtiBIyvzOO2O99aAdR8cF19oZalnCtvREUw79tCd5sow1g1UKM6kXqUx4T8wsi3sTjJ3yzDmmhenLXLpo8u45eG5y4Vvbk6kkC4LLtJMowkSQxmk4ggVJEG+7c6QpHT8vvW9X7/o7+3ELmiJi2mEzZJiz8cT6TBlanBk70cB5GGIGC1gRDdZ00yADLW1FL6gqhtvNXNG5S9gdSrk4M1qu7JAsmYshzDS4peoMrU/gT7qQdqYGZaYhxZmVbGJAm/CS/HloWyhRUlknQ9KYcExTwS80d3VNOxUZJpITYyspl0LbhArhpZCD9cRWEQuhYkNGMHToQ/2Cs6swJlb39CsllxdXX6IUKh/H5jbnSsPKjgmoaFQ1f8wRLR0UnGE/RcDEjj2jXG1WVTwUs8+zxfcrVO+vSsuOpVKxCfYZiQ0/aPKuxQbQ8lIz+DClxC8u+snlcJ7Yr1z1JPqUH0V+GDXbOwAib931Y4Imaq0NTIXPXY+N5L18GJ37SVWu+hwXff8l72Ds9XuwYIBaXPq6Shm4l+Vl/5QiOlV+uTk6YR9PxKsI9xNJny31ygK1e+nIRC1N97EGkFPI+jCpiHe5PCEy7oWqWSwRrpOvhFzcbTWMbm3ZJAOn1rUKpYIt/lDhW/5RHHteeWFN60qo98YJuoq1nK3uW5AabyspC1BcIEpOhft+SZAShYoLSvnmSfnYADUERP5jJn2h5XtsgCRuhYQqAvwTwn33+YWEKUI72HX5AtfSAZDe8F2DtPPm77afhl0EkthzuCQU0BWApgQIH9+KB0JhopMM7bJrdTRoleM2JAVNMyPF+wdoaz+XJpGoVAQ7WXUkcV7gT3oUZyi/ISIJAVKhgNp+4b4veCFhYVJw4locdSjZCp9cPUhLF9EZ3KKzURepMEtCDPP3VcWFx4UIiZIklIpFNfHpdEafIF2aRmOcrUmjohbT2WUllbmRvgfbythbQO3222fpDJoufaQPncYYuqoGtUEsCJZL6/3PR5b4syeSjZMQG/T2maGANlXT2v8S4AULWaUkCxfLyW8iW4kdka+nEMjxpL2NCwsYNBp+Q61PF43zyDg9Bm9+3NNySn78jMZUUkumqE4Gp7JmFOdP1vc8PpRrzj9+wPinCy8K1PiJ4aYbnTYpCCbDkBSbzhu2QJ1Gd82t8jI8TH51+OzvXoWbnXUOBkNW+0mWFwGcGOUVpU81/n3TOHb5oMt2FgYGjzau0Nif0Ss7Q3XB33hjjQHjHA5E5aOyIQc8CBrLdQSs3j92VG+3nNEjbkbdbBr9zm04ruvw37vh0QKOdeGIkckc80fX3KH/h7PT4BOjgCty8VZ5ux1MoO5Cf5naca2LAsEgehI+drX8o/0Nu+W0m6K/I9gGPd/dfx/EN/wN62AhsBWuAAAAAElFTkSuQmCC
|
||||
">
|
||||
<div style="float: right;">
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="Flax" src="https://img.shields.io/badge/Flax-29a79b.svg?style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAC0AAAAtCAMAAAANxBKoAAAC7lBMVEUAAADg5vYHPVgAoJH+/v76+v39/f9JbLP///9+AIgAnY3///+mcqzt8fXy9fgkXa3Ax9709fr+///9/f8qXq49qp5AaLGMwrv8/P0eW60VWawxYq8yqJzG2dytt9Wyu9elzci519Lf3O3S2efY3OrY0+Xp7PT///////+dqNCexMc6Z7AGpJeGvbenstPZ5ejQ1OfJzOLa7ejh4+/r8fT29vpccbklWK8PVa0AS6ghW63O498vYa+lsdKz1NDRt9Kw1c672tbD3tnAxt7R6OHp5vDe7OrDyuDn6vLl6/EAQKak0MgATakkppo3ZK/Bz9y8w9yzu9jey97axdvHzeG21NHH4trTwthKZrVGZLSUSpuPQJiGAI+GAI8SWKydycLL4d7f2OTi1+S9xNzL0ePT6OLGzeEAo5U0qJw/aLEAo5JFa7JBabEAp5Y4qZ2QxLyKmsm3kL2xoMOehrRNb7RIbbOZgrGre68AUqwAqZqNN5aKJ5N/lMq+qsd8kMa4pcWzh7muhLMEV69juq2kbKqgUaOTR5uMMZWLLZSGAI5VAIdEAH+ovNDHuNCnxcy3qcaYx8K8msGplrx+wLahjbYdXrV6vbMvYK9DrZ8QrZ8tqJuFms+Sos6sw8ecy8RffsNVeMCvmb43aLltv7Q4Y7EZWK4QWa1gt6meZKUdr6GOAZVeA4xPAISyveLUwtivxtKTpNJ2jcqfvcltiMiwwcfAoMVxhL+Kx7xjdrqTe60tsaNQs6KaRKACrJ6UTZwkqpqTL5pkHY4AloSgsd2ptNXPvNOOncuxxsqFl8lmg8apt8FJcr9EbryGxLqlkrkrY7dRa7ZGZLQ5t6iXUZ6PPpgVpZeJCJFKAIGareTa0+KJod3H0deY2M+esM25usmYu8d2zsJOdcBVvrCLbqcAOaaHaKQAMaScWqKBXqCXMJ2RHpiLF5NmJZAdAHN2kta11dKu1M+DkcZLdb+Mcql3TppyRJdzQ5ZtNZNlIY+DF4+voCOQAAAAZ3RSTlMABAT+MEEJ/RH+/TP+Zlv+pUo6Ifz8+fco/fz6+evr39S9nJmOilQaF/7+/f38+smmoYp6b1T+/v7++vj189zU0tDJxsGzsrKSfv34+Pf27dDOysG9t6+n/vv6+vr59uzr1tG+tZ6Qg9Ym3QAABR5JREFUSMeNlVVUG1EQhpcuxEspXqS0SKEtxQp1d3d332STTRpIQhIISQgJhODu7lAoDoUCpe7u7u7+1puGpqnCPOyZvffbOXPm/PsP9JfQgyCC+tmTABTOcbxDz/heENS7/1F+9nhvkHePG0wNDLbGWwdXL+rbLWvpmZHXD8+gMfBjTh+aSe6Gnn7lwQIOTR0c8wfX3PWgv7avbdKwf/ZoBp1Gp/PvuvXW3vw5ib7emnTW4OR+3D4jB9vjNJ/7gNvfWWeH/TO/JyYrsiKCRjVEZA3UB+96kON+DxOQ/NLE8PE5iUYgIXjFnCOlxEQMaSGVxjg4gxOnEycGz8bptuNjVx08LscIgrzH3umcn+KKtiBIyvzOO2O99aAdR8cF19oZalnCtvREUw79tCd5sow1g1UKM6kXqUx4T8wsi3sTjJ3yzDmmhenLXLpo8u45eG5y4Vvbk6kkC4LLtJMowkSQxmk4ggVJEG+7c6QpHT8vvW9X7/o7+3ELmiJi2mEzZJiz8cT6TBlanBk70cB5GGIGC1gRDdZ00yADLW1FL6gqhtvNXNG5S9gdSrk4M1qu7JAsmYshzDS4peoMrU/gT7qQdqYGZaYhxZmVbGJAm/CS/HloWyhRUlknQ9KYcExTwS80d3VNOxUZJpITYyspl0LbhArhpZCD9cRWEQuhYkNGMHToQ/2Cs6swJlb39CsllxdXX6IUKh/H5jbnSsPKjgmoaFQ1f8wRLR0UnGE/RcDEjj2jXG1WVTwUs8+zxfcrVO+vSsuOpVKxCfYZiQ0/aPKuxQbQ8lIz+DClxC8u+snlcJ7Yr1z1JPqUH0V+GDXbOwAib931Y4Imaq0NTIXPXY+N5L18GJ37SVWu+hwXff8l72Ds9XuwYIBaXPq6Shm4l+Vl/5QiOlV+uTk6YR9PxKsI9xNJny31ygK1e+nIRC1N97EGkFPI+jCpiHe5PCEy7oWqWSwRrpOvhFzcbTWMbm3ZJAOn1rUKpYIt/lDhW/5RHHteeWFN60qo98YJuoq1nK3uW5AabyspC1BcIEpOhft+SZAShYoLSvnmSfnYADUERP5jJn2h5XtsgCRuhYQqAvwTwn33+YWEKUI72HX5AtfSAZDe8F2DtPPm77afhl0EkthzuCQU0BWApgQIH9+KB0JhopMM7bJrdTRoleM2JAVNMyPF+wdoaz+XJpGoVAQ7WXUkcV7gT3oUZyi/ISIJAVKhgNp+4b4veCFhYVJw4locdSjZCp9cPUhLF9EZ3KKzURepMEtCDPP3VcWFx4UIiZIklIpFNfHpdEafIF2aRmOcrUmjohbT2WUllbmRvgfbythbQO3222fpDJoufaQPncYYuqoGtUEsCJZL6/3PR5b4syeSjZMQG/T2maGANlXT2v8S4AULWaUkCxfLyW8iW4kdka+nEMjxpL2NCwsYNBp+Q61PF43zyDg9Bm9+3NNySn78jMZUUkumqE4Gp7JmFOdP1vc8PpRrzj9+wPinCy8K1PiJ4aYbnTYpCCbDkBSbzhu2QJ1Gd82t8jI8TH51+OzvXoWbnXUOBkNW+0mWFwGcGOUVpU81/n3TOHb5oMt2FgYGjzau0Nif0Ss7Q3XB33hjjQHjHA5E5aOyIQc8CBrLdQSs3j92VG+3nNEjbkbdbBr9zm04ruvw37vh0QKOdeGIkckc80fX3KH/h7PT4BOjgCty8VZ5ux1MoO5Cf5naca2LAsEgehI+drX8o/0Nu+W0m6K/I9gGPd/dfx/EN/wN62AhsBWuAAAAAElFTkSuQmCC
|
||||
">
|
||||
</div>
|
||||
</div>
|
||||
|
||||
## Overview
|
||||
# Llama 2
|
||||
|
||||
The Llama2 model was proposed in [LLaMA: Open Foundation and Fine-Tuned Chat Models](https://ai.meta.com/research/publications/llama-2-open-foundation-and-fine-tuned-chat-models/) by Hugo Touvron, Louis Martin, Kevin Stone, Peter Albert, Amjad Almahairi, Yasmine Babaei, Nikolay Bashlykov, Soumya Batra, Prajjwal Bhargava, Shruti Bhosale, Dan Bikel, Lukas Blecher, Cristian Canton Ferrer, Moya Chen, Guillem Cucurull, David Esiobu, Jude Fernandes, Jeremy Fu, Wenyin Fu, Brian Fuller, Cynthia Gao, Vedanuj Goswami, Naman Goyal, Anthony Hartshorn, Saghar Hosseini, Rui Hou, Hakan Inan, Marcin Kardas, Viktor Kerkez Madian Khabsa, Isabel Kloumann, Artem Korenev, Punit Singh Koura, Marie-Anne Lachaux, Thibaut Lavril, Jenya Lee, Diana Liskovich, Yinghai Lu, Yuning Mao, Xavier Martinet, Todor Mihaylov, Pushka rMishra, Igor Molybog, Yixin Nie, Andrew Poulton, Jeremy Reizenstein, Rashi Rungta, Kalyan Saladi, Alan Schelten, Ruan Silva, Eric Michael Smith, Ranjan Subramanian, Xiaoqing EllenTan, Binh Tang, Ross Taylor, Adina Williams, Jian Xiang Kuan, Puxin Xu, Zheng Yan, Iliyan Zarov, Yuchen Zhang, Angela Fan, Melanie Kambadur, Sharan Narang, Aurelien Rodriguez, Robert Stojnic, Sergey Edunov, Thomas Scialom. It is a collection of foundation language models ranging from 7B to 70B parameters, with checkpoints finetuned for chat application!
|
||||
[Llama 2](https://huggingface.co/papers/2307.09288) is a family of large language models, Llama 2 and Llama 2-Chat, available in 7B, 13B, and 70B parameters. The Llama 2 model mostly keeps the same architecture as [Llama](./llama), but it is pretrained on more tokens, doubles the context length, and uses grouped-query attention (GQA) in the 70B model to improve inference.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
Llama 2-Chat is trained with supervised fine-tuning (SFT), and reinforcement learning with human feedback (RLHF) - rejection sampling and proximal policy optimization (PPO) - is applied to the fine-tuned model to align the chat model with human preferences.
|
||||
|
||||
*In this work, we develop and release Llama 2, a collection of pretrained and fine-tuned large language models (LLMs) ranging in scale from 7 billion to 70 billion parameters. Our fine-tuned LLMs, called Llama 2-Chat, are optimized for dialogue use cases. Our models outperform open-source chat models on most benchmarks we tested, and based on our human evaluations for helpfulness and safety, may be a suitable substitute for closed-source models. We provide a detailed description of our approach to fine-tuning and safety improvements of Llama 2-Chat in order to enable the community to build on our work and contribute to the responsible development of LLMs.*
|
||||
You can find all the original Llama 2 checkpoints under the [Llama 2 Family](https://huggingface.co/collections/meta-llama/llama-2-family-661da1f90a9d678b6f55773b) collection.
|
||||
|
||||
Checkout all Llama2 model checkpoints [here](https://huggingface.co/models?search=llama2).
|
||||
This model was contributed by [Arthur Zucker](https://huggingface.co/ArthurZ) with contributions from [Lysandre Debut](https://huggingface.co/lysandre). The code of the implementation in Hugging Face is based on GPT-NeoX [here](https://github.com/EleutherAI/gpt-neox). The original code of the authors can be found [here](https://github.com/facebookresearch/llama).
|
||||
> [!TIP]
|
||||
> Click on the Llama 2 models in the right sidebar for more examples of how to apply Llama to different language tasks.
|
||||
|
||||
## Usage tips
|
||||
The example below demonstrates how to generate text with [`Pipeline`], [`AutoModel`], and how to chat with Llama 2-Chat from the command line.
|
||||
|
||||
<Tip warning={true}>
|
||||
<hfoptions id="usage">
|
||||
<hfoption id="Pipeline">
|
||||
|
||||
The `Llama2` models were trained using `bfloat16`, but the original inference uses `float16`. The checkpoints uploaded on the Hub use `torch_dtype = 'float16'`, which will be
|
||||
used by the `AutoModel` API to cast the checkpoints from `torch.float32` to `torch.float16`.
|
||||
```py
|
||||
import torch
|
||||
from transformers import pipeline
|
||||
|
||||
The `dtype` of the online weights is mostly irrelevant unless you are using `torch_dtype="auto"` when initializing a model using `model = AutoModelForCausalLM.from_pretrained("path", torch_dtype = "auto")`. The reason is that the model will first be downloaded ( using the `dtype` of the checkpoints online), then it will be casted to the default `dtype` of `torch` (becomes `torch.float32`), and finally, if there is a `torch_dtype` provided in the config, it will be used.
|
||||
pipeline = pipeline(
|
||||
task="text-generation",
|
||||
model="meta-llama/Llama-2-7b-hf",
|
||||
torch_dtype=torch.float16,
|
||||
device=0
|
||||
)
|
||||
pipeline("Plants create energy through a process known as")
|
||||
```
|
||||
|
||||
Training the model in `float16` is not recommended and is known to produce `nan`; as such, the model should be trained in `bfloat16`.
|
||||
</hfoption>
|
||||
<hfoption id="AutoModel">
|
||||
|
||||
</Tip>
|
||||
```py
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
Tips:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"meta-llama/Llama-2-7b-hf",
|
||||
)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"meta-llama/Llama-2-7b-hf",
|
||||
torch_dtype=torch.float16,
|
||||
device_map="auto",
|
||||
attn_implementation="sdpa"
|
||||
)
|
||||
input_ids = tokenizer("Plants create energy through a process known as", return_tensors="pt").to("cuda")
|
||||
|
||||
- Weights for the Llama2 models can be obtained by filling out [this form](https://ai.meta.com/resources/models-and-libraries/llama-downloads/)
|
||||
- The architecture is very similar to the first Llama, with the addition of Grouped Query Attention (GQA) following this [paper](https://arxiv.org/pdf/2305.13245.pdf)
|
||||
- Setting `config.pretraining_tp` to a value different than 1 will activate the more accurate but slower computation of the linear layers, which should better match the original logits.
|
||||
- The original model uses `pad_id = -1` which means that there is no padding token. We can't have the same logic, make sure to add a padding token using `tokenizer.add_special_tokens({"pad_token":"<pad>"})` and resize the token embedding accordingly. You should also set the `model.config.pad_token_id`. The `embed_tokens` layer of the model is initialized with `self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.config.padding_idx)`, which makes sure that encoding the padding token will output zeros, so passing it when initializing is recommended.
|
||||
- After filling out the form and gaining access to the model checkpoints, you should be able to use the already converted checkpoints. Otherwise, if you are converting your own model, feel free to use the [conversion script](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/convert_llama_weights_to_hf.py). The script can be called with the following (example) command:
|
||||
output = model.generate(**input_ids, cache_implementation="static")
|
||||
print(tokenizer.decode(output[0], skip_special_tokens=True))
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="transformers-cli">
|
||||
|
||||
```bash
|
||||
python src/transformers/models/llama/convert_llama_weights_to_hf.py \
|
||||
--input_dir /path/to/downloaded/llama/weights --model_size 7B --output_dir /output/path
|
||||
transformers-cli chat --model_name_or_path meta-llama/Llama-2-7b-chat-hf --torch_dtype auto --attn_implementation flash_attention_2
|
||||
```
|
||||
|
||||
- After conversion, the model and tokenizer can be loaded via:
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
```python
|
||||
from transformers import LlamaForCausalLM, LlamaTokenizer
|
||||
Quantization reduces the memory burden of large models by representing the weights in a lower precision. Refer to the [Quantization](../quantization/overview) overview for more available quantization backends.
|
||||
|
||||
tokenizer = LlamaTokenizer.from_pretrained("/output/path")
|
||||
model = LlamaForCausalLM.from_pretrained("/output/path")
|
||||
The example below uses [torchao](../quantization/torchao) to only quantize the weights to int4.
|
||||
|
||||
```py
|
||||
# pip install torchao
|
||||
import torch
|
||||
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
quantization_config = TorchAoConfig("int4_weight_only", group_size=128)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"meta-llama/Llama-2-13b-hf",
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
quantization_config=quantization_config
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-13b-hf")
|
||||
input_ids = tokenizer("Plants create energy through a process known as", return_tensors="pt").to("cuda")
|
||||
|
||||
output = model.generate(**input_ids, cache_implementation="static")
|
||||
print(tokenizer.decode(output[0], skip_special_tokens=True))
|
||||
```
|
||||
|
||||
Note that executing the script requires enough CPU RAM to host the whole model in float16 precision (even if the biggest versions
|
||||
come in several checkpoints they each contain a part of each weight of the model, so we need to load them all in RAM). For the 75B model, it's thus 145GB of RAM needed.
|
||||
Use the [AttentionMaskVisualizer](https://github.com/huggingface/transformers/blob/beb9b5b02246b9b7ee81ddf938f93f44cfeaad19/src/transformers/utils/attention_visualizer.py#L139) to better understand what tokens the model can and cannot attend to.
|
||||
|
||||
- The LLaMA tokenizer is a BPE model based on [sentencepiece](https://github.com/google/sentencepiece). One quirk of sentencepiece is that when decoding a sequence, if the first token is the start of the word (e.g. "Banana"), the tokenizer does not prepend the prefix space to the string.
|
||||
```py
|
||||
from transformers.utils.attention_visualizer import AttentionMaskVisualizer
|
||||
|
||||
- When using Flash Attention 2 via `attn_implementation="flash_attention_2"`, don't pass `torch_dtype` to the `from_pretrained` class method and use Automatic Mixed-Precision training. When using `Trainer`, it is simply specifying either `fp16` or `bf16` to `True`. Otherwise, make sure you are using `torch.autocast`. This is required because the Flash Attention only support `fp16` and `bf16` data type.
|
||||
visualizer = AttentionMaskVisualizer("meta-llama/Llama-2-7b-hf")
|
||||
visualizer("Plants create energy through a process known as")
|
||||
```
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/llama-2-attn-mask.png"/>
|
||||
</div>
|
||||
|
||||
## Resources
|
||||
## Notes
|
||||
|
||||
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with LLaMA2. If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we'll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource.
|
||||
- Setting `config.pretraining_tp` to a value besides `1` activates a more accurate but slower computation of the linear layers. This matches the original logits better.
|
||||
- The original model uses `pad_id = -1` to indicate a padding token. The Transformers implementation requires adding a padding token and resizing the token embedding accordingly.
|
||||
|
||||
- [Llama 2 is here - get it on Hugging Face](https://huggingface.co/blog/llama2), a blog post about Llama 2 and how to use it with 🤗 Transformers and 🤗 PEFT.
|
||||
- [LLaMA 2 - Every Resource you need](https://www.philschmid.de/llama-2), a compilation of relevant resources to learn about LLaMA 2 and how to get started quickly.
|
||||
|
||||
<PipelineTag pipeline="text-generation"/>
|
||||
|
||||
- A [notebook](https://colab.research.google.com/drive/1PEQyJO1-f6j0S_XJ8DV50NkpzasXkrzd?usp=sharing) on how to fine-tune Llama 2 in Google Colab using QLoRA and 4-bit precision. 🌎
|
||||
- A [notebook](https://colab.research.google.com/drive/134o_cXcMe_lsvl15ZE_4Y75Kstepsntu?usp=sharing) on how to fine-tune the "Llama-v2-7b-guanaco" model with 4-bit QLoRA and generate Q&A datasets from PDFs. 🌎
|
||||
|
||||
<PipelineTag pipeline="text-classification"/>
|
||||
|
||||
- A [notebook](https://colab.research.google.com/drive/1ggaa2oRFphdBmqIjSEbnb_HGkcIRC2ZB?usp=sharing) on how to fine-tune the Llama 2 model with QLoRa, TRL, and Korean text classification dataset. 🌎🇰🇷
|
||||
|
||||
⚗️ Optimization
|
||||
- [Fine-tune Llama 2 with DPO](https://huggingface.co/blog/dpo-trl), a guide to using the TRL library's DPO method to fine tune Llama 2 on a specific dataset.
|
||||
- [Extended Guide: Instruction-tune Llama 2](https://www.philschmid.de/instruction-tune-llama-2), a guide to training Llama 2 to generate instructions from inputs, transforming the model from instruction-following to instruction-giving.
|
||||
- A [notebook](https://colab.research.google.com/drive/1SYpgFpcmtIUzdE7pxqknrM4ArCASfkFQ?usp=sharing) on how to fine-tune the Llama 2 model on a personal computer using QLoRa and TRL. 🌎
|
||||
|
||||
⚡️ Inference
|
||||
- A [notebook](https://colab.research.google.com/drive/1TC56ArKerXUpbgRy5vM3woRsbTEVNq7h?usp=sharing) on how to quantize the Llama 2 model using GPTQ from the AutoGPTQ library. 🌎
|
||||
- A [notebook](https://colab.research.google.com/drive/1X1z9Q6domMKl2CnEM0QGHNwidLfR4dW2?usp=sharing) on how to run the Llama 2 Chat Model with 4-bit quantization on a local computer or Google Colab. 🌎
|
||||
|
||||
🚀 Deploy
|
||||
- [Fine-tune LLaMA 2 (7-70B) on Amazon SageMaker](https://www.philschmid.de/sagemaker-llama2-qlora), a complete guide from setup to QLoRA fine-tuning and deployment on Amazon SageMaker.
|
||||
- [Deploy Llama 2 7B/13B/70B on Amazon SageMaker](https://www.philschmid.de/sagemaker-llama-llm), a guide on using Hugging Face's LLM DLC container for secure and scalable deployment.
|
||||
```py
|
||||
tokenizer.add_special_tokens({"pad_token":"<pad>"})
|
||||
# update model config with padding token
|
||||
model.config.pad_token_id
|
||||
```
|
||||
- It is recommended to initialize the `embed_tokens` layer with the following code to ensure encoding the padding token outputs zeros.
|
||||
|
||||
```py
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.config.padding_idx)
|
||||
```
|
||||
- The tokenizer is a byte-pair encoding model based on [SentencePiece](https://github.com/google/sentencepiece). During decoding, if the first token is the start of the word (for example, "Banana"), the tokenizer doesn't prepend the prefix space to the string.
|
||||
- Don't use the `torch_dtype` parameter in [`~AutoModel.from_pretrained`] if you're using FlashAttention-2 because it only supports fp16 or bf16. You should use [Automatic Mixed Precision](https://pytorch.org/tutorials/recipes/recipes/amp_recipe.html), set fp16 or bf16 to `True` if using [`Trainer`], or use [torch.autocast](https://pytorch.org/docs/stable/amp.html#torch.autocast).
|
||||
|
||||
## LlamaConfig
|
||||
|
||||
|
||||
442
docs/source/en/model_doc/llama4.md
Normal file
442
docs/source/en/model_doc/llama4.md
Normal file
@ -0,0 +1,442 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# Llama4
|
||||
|
||||
|
||||
<div style="float: right;">
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
|
||||
</div>
|
||||
</div>
|
||||
|
||||
Llama 4, developed by Meta, introduces a new auto-regressive Mixture-of-Experts (MoE) architecture.
|
||||
This generation includes two models:
|
||||
- The highly capable Llama 4 Maverick with 17B active parameters out of ~400B total, with 128 experts.
|
||||
- The efficient Llama 4 Scout also has 17B active parameters out of ~109B total, using just 16 experts.
|
||||
-
|
||||
Both models leverage early fusion for native multimodality, enabling them to process text and image inputs.
|
||||
Maverick and Scout are both trained on up to 40 trillion tokens on data encompassing 200 languages
|
||||
(with specific fine-tuning support for 12 languages including Arabic, Spanish, German, and Hindi).
|
||||
|
||||
For deployment, Llama 4 Scout is designed for accessibility, fitting on a single server-grade GPU via
|
||||
on-the-fly 4-bit or 8-bitint4 quantization, while Maverick is available in BF16 and FP8 formats.
|
||||
These models are released under the custom Llama 4 Community License Agreement, available on the model repositories.
|
||||
|
||||
You can find all the original Llama checkpoints under the [meta-llama](https://huggingface.co/meta-llama) organization.
|
||||
|
||||
> [!TIP]
|
||||
> The Llama 4 family of models comes in two flavors: 109B, and 402B parameters. Both of these flavors are extremely
|
||||
> large and won't fit on your run-of-the-mill device. See below for some examples to reduce the memory usage of the
|
||||
> model.
|
||||
>
|
||||
> For the download to be faster and more resilient, we recommend installing the `hf_xet` dependency as followed:
|
||||
> `pip install transformers[hf_xet]`
|
||||
|
||||
The examples below demonstrates how to generate with [`Pipeline`] or the [`AutoModel`]. We additionally add an example
|
||||
showcasing how to toggle the right attributes to enable very long-context generations, as some flavors of Llama 4
|
||||
have context lengths going up to 10 million tokens.
|
||||
|
||||
|
||||
<hfoptions id="usage">
|
||||
<hfoption id="Pipeline">
|
||||
|
||||
```py
|
||||
from transformers import pipeline
|
||||
import torch
|
||||
|
||||
model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "what is the recipe of mayonnaise?"},
|
||||
]
|
||||
|
||||
pipe = pipeline(
|
||||
"text-generation",
|
||||
model=model_id,
|
||||
device_map="auto",
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
output = pipe(messages, do_sample=False, max_new_tokens=200)
|
||||
print(output[0]["generated_text"][-1]["content"])
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="AutoModel - Text only">
|
||||
|
||||
```py
|
||||
from transformers import AutoTokenizer, Llama4ForConditionalGeneration
|
||||
import torch
|
||||
|
||||
model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "Who are you?"},
|
||||
]
|
||||
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt", return_dict=True)
|
||||
|
||||
model = Llama4ForConditionalGeneration.from_pretrained(
|
||||
model_id,
|
||||
device_map="auto",
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
outputs = model.generate(**inputs.to(model.device), max_new_tokens=100)
|
||||
outputs = tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:])
|
||||
print(outputs[0])
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="AutoModel - Multimodal">
|
||||
|
||||
```py
|
||||
from transformers import AutoProcessor, Llama4ForConditionalGeneration
|
||||
import torch
|
||||
|
||||
model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
|
||||
|
||||
processor = AutoProcessor.from_pretrained(model_id)
|
||||
model = Llama4ForConditionalGeneration.from_pretrained(
|
||||
model_id,
|
||||
device_map="auto",
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg"
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image", "url": img_url},
|
||||
{"type": "text", "text": "Describe this image in two sentences."},
|
||||
]
|
||||
},
|
||||
]
|
||||
|
||||
inputs = processor.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
).to(model.device)
|
||||
|
||||
outputs = model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=256,
|
||||
)
|
||||
|
||||
response = processor.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:])[0]
|
||||
print(response)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="AutoModel - Multimodal with multiple images">
|
||||
|
||||
```py
|
||||
from transformers import AutoProcessor, Llama4ForConditionalGeneration
|
||||
import torch
|
||||
|
||||
model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
|
||||
|
||||
processor = AutoProcessor.from_pretrained(model_id)
|
||||
model = Llama4ForConditionalGeneration.from_pretrained(
|
||||
model_id,
|
||||
device_map="auto",
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
url1 = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg"
|
||||
url2 = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png"
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image", "url": url1},
|
||||
{"type": "image", "url": url2},
|
||||
{"type": "text", "text": "Can you describe how these two images are similar, and how they differ?"},
|
||||
]
|
||||
},
|
||||
]
|
||||
|
||||
inputs = processor.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
).to(model.device)
|
||||
|
||||
outputs = model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=256,
|
||||
)
|
||||
|
||||
response = processor.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:])[0]
|
||||
print(response)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="AutoModel - Long context">
|
||||
|
||||
Beware: the example below uses both `device_map="auto"` and flex-attention.
|
||||
Please use `torchrun` to run this example in tensor-parallel mode.
|
||||
|
||||
We will work to enable running with `device_map="auto"` and flex-attention without
|
||||
tensor-parallel in the future.
|
||||
|
||||
```py
|
||||
from transformers import Llama4ForConditionalGeneration, AutoTokenizer
|
||||
import torch
|
||||
import time
|
||||
|
||||
file = "very_long_context_prompt.txt"
|
||||
model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
|
||||
|
||||
with open(file, "r") as f:
|
||||
very_long_text = "\n".join(f.readlines())
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
model = Llama4ForConditionalGeneration.from_pretrained(
|
||||
model_id,
|
||||
device_map="auto",
|
||||
attn_implementation="flex_attention",
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": f"Look at the following texts: [{very_long_text}]\n\n\n\nWhat are the books, and who wrote them? Make me a nice list."},
|
||||
]
|
||||
input_ids = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt")
|
||||
|
||||
torch.cuda.synchronize()
|
||||
start = time.time()
|
||||
out = model.generate(
|
||||
input_ids.to(model.device),
|
||||
prefill_chunk_size=2048*8,
|
||||
max_new_tokens=300,
|
||||
cache_implementation="hybrid",
|
||||
)
|
||||
print(time.time()-start)
|
||||
print(tokenizer.batch_decode(out[:, input_ids.shape[-1]:]))
|
||||
print(f"{torch.cuda.max_memory_allocated(model.device) / 1024**3:.2f} GiB")
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
## Efficiency; how to get the best out of llama 4
|
||||
|
||||
### The Attention methods
|
||||
|
||||
Updating the default attention function can significantly improve compute performance as well as memory usage. Refer to the [Attention Interface](../attention_interface) overview for an in-depth explanation of our interface.
|
||||
|
||||
As of release, the Llama 4 model supports the following attention methods: `eager`, `flex_attention`, `sdpa`. We recommend using `flex_attention` for best results.
|
||||
Switching attention mechanism is done at the model initialization step:
|
||||
|
||||
|
||||
<hfoptions id="Attention">
|
||||
<hfoption id="Flex Attention">
|
||||
|
||||
Setting Flex Attention ensures the best results with the very long context the model can handle.
|
||||
|
||||
> [!TIP] Beware: the example below uses both `device_map="auto"` and flex-attention.
|
||||
> Please use `torchrun` to run this example in tensor-parallel mode.
|
||||
>
|
||||
> We will work to enable running with `device_map="auto"` and flex-attention without
|
||||
> tensor-parallel in the future.
|
||||
|
||||
```py
|
||||
from transformers import Llama4ForConditionalGeneration
|
||||
import torch
|
||||
|
||||
model = Llama4ForConditionalGeneration.from_pretrained(
|
||||
model_id,
|
||||
attn_implementation="flex_attention",
|
||||
device_map="auto",
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
```
|
||||
</hfoption>
|
||||
<hfoption id="SDPA">
|
||||
The `sdpa` attention method is generally more compute-efficient than the `eager` method.
|
||||
|
||||
```py
|
||||
from transformers import Llama4ForConditionalGeneration
|
||||
import torch
|
||||
|
||||
model = Llama4ForConditionalGeneration.from_pretrained(
|
||||
model_id,
|
||||
attn_implementation="sdpa",
|
||||
device_map="auto",
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
```
|
||||
</hfoption>
|
||||
<hfoption id="Eager">
|
||||
The `eager` attention method is set by default, so no need for anything different when loading the model:
|
||||
|
||||
```py
|
||||
from transformers import Llama4ForConditionalGeneration
|
||||
import torch
|
||||
|
||||
model = Llama4ForConditionalGeneration.from_pretrained(
|
||||
model_id,
|
||||
device_map="auto",
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
```
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
|
||||
### Quantization
|
||||
|
||||
Quantization reduces the memory burden of large models by representing the weights in a lower precision. Refer to the [Quantization](../quantization/overview) overview for available quantization backends.
|
||||
At time of release, both FBGEMM and LLM-Compressor are supported; more quantization methods will be supported in the days that follow the release.
|
||||
|
||||
See below for examples using both:
|
||||
|
||||
|
||||
|
||||
Here is an example loading an BF16 model in FP8 using the FBGEMM approach:
|
||||
|
||||
<hfoptions id="Quantization">
|
||||
<hfoption id="FBGEMM">
|
||||
|
||||
```python
|
||||
from transformers import AutoTokenizer, Llama4ForConditionalGeneration, FbgemmFp8Config
|
||||
import torch
|
||||
|
||||
model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "Who are you?"},
|
||||
]
|
||||
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt", return_dict=True)
|
||||
|
||||
model = Llama4ForConditionalGeneration.from_pretrained(
|
||||
model_id,
|
||||
device_map="auto",
|
||||
torch_dtype=torch.bfloat16,
|
||||
quantization_config=FbgemmFp8Config()
|
||||
)
|
||||
|
||||
outputs = model.generate(**inputs.to(model.device), max_new_tokens=100)
|
||||
outputs = tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:])
|
||||
print(outputs[0])
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="LLM-Compressor">
|
||||
|
||||
To use the LLM-Compressor technique, we recommend leveraging the pre-quantized FP8 checkpoint available with the release:
|
||||
|
||||
```python
|
||||
from transformers import AutoTokenizer, Llama4ForConditionalGeneration
|
||||
import torch
|
||||
|
||||
model_id = "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8"
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "Who are you?"},
|
||||
]
|
||||
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt", return_dict=True)
|
||||
|
||||
model = Llama4ForConditionalGeneration.from_pretrained(
|
||||
model_id,
|
||||
tp_plan="auto",
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
outputs = model.generate(**inputs.to(model.device), max_new_tokens=100)
|
||||
outputs = tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:])
|
||||
print(outputs[0])
|
||||
```
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
### Offloading
|
||||
|
||||
Enabling CPU-offloading means that components of the model might be moved to CPU instead of GPU in case the GPU-memory available isn't sufficient to load the entire model.
|
||||
At inference, different components will be loaded/unloaded from/to the GPU on the fly. This ensures that the model can be loaded on smaller machines as long as the CPU-memory is sufficient.
|
||||
However, this also slows down inference as it adds communication overhead.
|
||||
|
||||
In order to enable CPU-offloading, you simply need to specify the `device_map` to `auto` at model load:
|
||||
|
||||
```py
|
||||
from transformers import Llama4ForConditionalGeneration
|
||||
import torch
|
||||
|
||||
model = Llama4ForConditionalGeneration.from_pretrained(
|
||||
model_id,
|
||||
device_map="auto",
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
```
|
||||
|
||||
## Llama4Config
|
||||
|
||||
[[autodoc]] Llama4Config
|
||||
|
||||
## Llama4TextConfig
|
||||
|
||||
[[autodoc]] Llama4TextConfig
|
||||
|
||||
## Llama4VisionConfig
|
||||
|
||||
[[autodoc]] Llama4VisionConfig
|
||||
|
||||
## Llama4Processor
|
||||
|
||||
[[autodoc]] Llama4Processor
|
||||
|
||||
## Llama4ImageProcessorFast
|
||||
|
||||
[[autodoc]] Llama4ImageProcessorFast
|
||||
|
||||
## Llama4ForConditionalGeneration
|
||||
|
||||
[[autodoc]] Llama4ForConditionalGeneration
|
||||
- forward
|
||||
|
||||
## Llama4ForCausalLM
|
||||
|
||||
[[autodoc]] Llama4ForCausalLM
|
||||
- forward
|
||||
|
||||
## Llama4TextModel
|
||||
|
||||
[[autodoc]] Llama4TextModel
|
||||
- forward
|
||||
|
||||
## Llama4ForCausalLM
|
||||
|
||||
[[autodoc]] Llama4ForCausalLM
|
||||
- forward
|
||||
|
||||
## Llama4VisionModel
|
||||
|
||||
[[autodoc]] Llama4VisionModel
|
||||
- forward
|
||||
234
docs/source/en/model_doc/mistral3.md
Normal file
234
docs/source/en/model_doc/mistral3.md
Normal file
@ -0,0 +1,234 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# Mistral3
|
||||
|
||||
## Overview
|
||||
|
||||
Building upon Mistral Small 3 (2501), Mistral Small 3.1 (2503) adds state-of-the-art vision understanding and enhances long context capabilities up to 128k tokens without compromising text performance. With 24 billion parameters, this model achieves top-tier capabilities in both text and vision tasks.
|
||||
|
||||
It is ideal for:
|
||||
- Fast-response conversational agents.
|
||||
- Low-latency function calling.
|
||||
- Subject matter experts via fine-tuning.
|
||||
- Local inference for hobbyists and organizations handling sensitive data.
|
||||
- Programming and math reasoning.
|
||||
- Long document understanding.
|
||||
- Visual understanding.
|
||||
|
||||
This model was contributed by [cyrilvallez](https://huggingface.co/cyrilvallez) and [yonigozlan](https://huggingface.co/yonigozlan).
|
||||
|
||||
The original code can be found [here](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/pixtral.py) and [here](https://github.com/mistralai/mistral-common).
|
||||
|
||||
## Usage example
|
||||
|
||||
### Inference with Pipeline
|
||||
|
||||
Here is how you can use the `image-text-to-text` pipeline to perform inference with the `Mistral3` models in just a few lines of code:
|
||||
```python
|
||||
>>> from transformers import pipeline
|
||||
|
||||
>>> messages = [
|
||||
... {
|
||||
... "role": "user",
|
||||
... "content": [
|
||||
... {
|
||||
... "type": "image",
|
||||
... "image": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg",
|
||||
... },
|
||||
... {"type": "text", "text": "Describe this image."},
|
||||
... ],
|
||||
... },
|
||||
... ]
|
||||
|
||||
>>> pipe = pipeline("image-text-to-text", model="mistralai/Mistral-Small-3.1-24B-Instruct-2503", torch_dtype=torch.bfloat16)
|
||||
>>> outputs = pipe(text=messages, max_new_tokens=50, return_full_text=False)
|
||||
>>> outputs[0]["generated_text"]
|
||||
'The image depicts a vibrant and lush garden scene featuring a variety of wildflowers and plants. The central focus is on a large, pinkish-purple flower, likely a Greater Celandine (Chelidonium majus), with a'
|
||||
```
|
||||
### Inference on a single image
|
||||
|
||||
This example demonstrates how to perform inference on a single image with the Mistral3 models using chat templates.
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoProcessor, AutoModelForImageTextToText
|
||||
>>> import torch
|
||||
|
||||
>>> torch_device = "cuda"
|
||||
>>> model_checkpoint = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
|
||||
>>> processor = AutoProcessor.from_pretrained(model_checkpoint)
|
||||
>>> model = AutoModelForImageTextToText.from_pretrained(model_checkpoint, device_map=torch_device, torch_dtype=torch.bfloat16)
|
||||
|
||||
>>> messages = [
|
||||
... {
|
||||
... "role": "user",
|
||||
... "content": [
|
||||
... {"type": "image", "url": "http://images.cocodataset.org/val2017/000000039769.jpg"},
|
||||
... {"type": "text", "text": "Describe this image"},
|
||||
... ],
|
||||
... }
|
||||
... ]
|
||||
|
||||
>>> inputs = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt").to(model.device, dtype=torch.bfloat16)
|
||||
|
||||
>>> generate_ids = model.generate(**inputs, max_new_tokens=20)
|
||||
>>> decoded_output = processor.decode(generate_ids[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True)
|
||||
|
||||
>>> decoded_output
|
||||
"The image depicts two cats lying on a pink blanket. The larger cat, which appears to be an"...
|
||||
```
|
||||
|
||||
### Text-only generation
|
||||
This example shows how to generate text using the Mistral3 model without providing any image input.
|
||||
|
||||
|
||||
````python
|
||||
>>> from transformers import AutoProcessor, AutoModelForImageTextToText
|
||||
>>> import torch
|
||||
|
||||
>>> torch_device = "cuda"
|
||||
>>> model_checkpoint = ".mistralai/Mistral-Small-3.1-24B-Instruct-2503"
|
||||
>>> processor = AutoProcessor.from_pretrained(model_checkpoint)
|
||||
>>> model = AutoModelForImageTextToText.from_pretrained(model_checkpoint, device_map=torch_device, torch_dtype=torch.bfloat16)
|
||||
|
||||
>>> SYSTEM_PROMPT = "You are a conversational agent that always answers straight to the point, always end your accurate response with an ASCII drawing of a cat."
|
||||
>>> user_prompt = "Give me 5 non-formal ways to say 'See you later' in French."
|
||||
|
||||
>>> messages = [
|
||||
... {"role": "system", "content": SYSTEM_PROMPT},
|
||||
... {"role": "user", "content": user_prompt},
|
||||
... ]
|
||||
|
||||
>>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||
>>> inputs = processor(text=text, return_tensors="pt").to(0, dtype=torch.float16)
|
||||
>>> generate_ids = model.generate(**inputs, max_new_tokens=50, do_sample=False)
|
||||
>>> decoded_output = processor.batch_decode(generate_ids[:, inputs["input_ids"].shape[1] :], skip_special_tokens=True)[0]
|
||||
|
||||
>>> print(decoded_output)
|
||||
"1. À plus tard!
|
||||
2. Salut, à plus!
|
||||
3. À toute!
|
||||
4. À la prochaine!
|
||||
5. Je me casse, à plus!
|
||||
|
||||
```
|
||||
/\_/\
|
||||
( o.o )
|
||||
> ^ <
|
||||
```"
|
||||
````
|
||||
|
||||
### Batched image and text inputs
|
||||
Mistral3 models also support batched image and text inputs.
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoProcessor, AutoModelForImageTextToText
|
||||
>>> import torch
|
||||
|
||||
>>> torch_device = "cuda"
|
||||
>>> model_checkpoint = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
|
||||
>>> processor = AutoProcessor.from_pretrained(model_checkpoint)
|
||||
>>> model = AutoModelForImageTextToText.from_pretrained(model_checkpoint, device_map=torch_device, torch_dtype=torch.bfloat16)
|
||||
|
||||
>>> messages = [
|
||||
... [
|
||||
... {
|
||||
... "role": "user",
|
||||
... "content": [
|
||||
... {"type": "image", "url": "https://llava-vl.github.io/static/images/view.jpg"},
|
||||
... {"type": "text", "text": "Write a haiku for this image"},
|
||||
... ],
|
||||
... },
|
||||
... ],
|
||||
... [
|
||||
... {
|
||||
... "role": "user",
|
||||
... "content": [
|
||||
... {"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"},
|
||||
... {"type": "text", "text": "Describe this image"},
|
||||
... ],
|
||||
... },
|
||||
... ],
|
||||
... ]
|
||||
|
||||
|
||||
>>> inputs = processor.apply_chat_template(messages, padding=True, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt").to(model.device, dtype=torch.bfloat16)
|
||||
|
||||
>>> output = model.generate(**inputs, max_new_tokens=25)
|
||||
|
||||
>>> decoded_outputs = processor.batch_decode(output, skip_special_tokens=True)
|
||||
>>> decoded_outputs
|
||||
["Write a haiku for this imageCalm waters reflect\nWhispers of the forest's breath\nPeace on wooden path"
|
||||
, "Describe this imageThe image depicts a vibrant street scene in what appears to be a Chinatown district. The focal point is a traditional Chinese"]
|
||||
```
|
||||
|
||||
### Batched multi-image input and quantization with BitsAndBytes
|
||||
This implementation of the Mistral3 models supports batched text-images inputs with different number of images for each text.
|
||||
This example also how to use `BitsAndBytes` to load the model in 4bit quantization.
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig
|
||||
>>> import torch
|
||||
|
||||
>>> torch_device = "cuda"
|
||||
>>> model_checkpoint = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
|
||||
>>> processor = AutoProcessor.from_pretrained(model_checkpoint)
|
||||
>>> quantization_config = BitsAndBytesConfig(load_in_4bit=True)
|
||||
>>> model = AutoModelForImageTextToText.from_pretrained(
|
||||
... model_checkpoint, quantization_config=quantization_config
|
||||
... )
|
||||
|
||||
>>> messages = [
|
||||
... [
|
||||
... {
|
||||
... "role": "user",
|
||||
... "content": [
|
||||
... {"type": "image", "url": "https://llava-vl.github.io/static/images/view.jpg"},
|
||||
... {"type": "text", "text": "Write a haiku for this image"},
|
||||
... ],
|
||||
... },
|
||||
... ],
|
||||
... [
|
||||
... {
|
||||
... "role": "user",
|
||||
... "content": [
|
||||
... {"type": "image", "url": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg"},
|
||||
... {"type": "image", "url": "https://thumbs.dreamstime.com/b/golden-gate-bridge-san-francisco-purple-flowers-california-echium-candicans-36805947.jpg"},
|
||||
... {"type": "text", "text": "These images depict two different landmarks. Can you identify them?"},
|
||||
... ],
|
||||
... },
|
||||
... ],
|
||||
>>> ]
|
||||
|
||||
>>> inputs = processor.apply_chat_template(messages, padding=True, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt").to(model.device, dtype=torch.bfloat16)
|
||||
|
||||
>>> output = model.generate(**inputs, max_new_tokens=25)
|
||||
|
||||
>>> decoded_outputs = processor.batch_decode(output, skip_special_tokens=True)
|
||||
>>> decoded_outputs
|
||||
["Write a haiku for this imageSure, here is a haiku inspired by the image:\n\nCalm lake's wooden path\nSilent forest stands guard\n", "These images depict two different landmarks. Can you identify them? Certainly! The images depict two iconic landmarks:\n\n1. The first image shows the Statue of Liberty in New York City."]
|
||||
```
|
||||
|
||||
|
||||
## Mistral3Config
|
||||
|
||||
[[autodoc]] Mistral3Config
|
||||
|
||||
|
||||
## Mistral3ForConditionalGeneration
|
||||
|
||||
[[autodoc]] Mistral3ForConditionalGeneration
|
||||
- forward
|
||||
@ -14,52 +14,81 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# MobileBERT
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="TensorFlow" src="https://img.shields.io/badge/TensorFlow-FF6F00?style=flat&logo=tensorflow&logoColor=white">
|
||||
<div style="float: right;">
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="TensorFlow" src="https://img.shields.io/badge/TensorFlow-FF6F00?style=flat&logo=tensorflow&logoColor=white">
|
||||
</div>
|
||||
</div>
|
||||
|
||||
## Overview
|
||||
# MobileBERT
|
||||
|
||||
The MobileBERT model was proposed in [MobileBERT: a Compact Task-Agnostic BERT for Resource-Limited Devices](https://arxiv.org/abs/2004.02984) by Zhiqing Sun, Hongkun Yu, Xiaodan Song, Renjie Liu, Yiming Yang, and Denny
|
||||
Zhou. It's a bidirectional transformer based on the BERT model, which is compressed and accelerated using several
|
||||
approaches.
|
||||
[MobileBERT](https://huggingface.co/papers/2004.02984) is a lightweight and efficient variant of BERT, specifically designed for resource-limited devices such as mobile phones. It retains BERT's architecture but significantly reduces model size and inference latency while maintaining strong performance on NLP tasks. MobileBERT achieves this through a bottleneck structure and carefully balanced self-attention and feedforward networks. The model is trained by knowledge transfer from a large BERT model with an inverted bottleneck structure.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
You can find the original MobileBERT checkpoint under the [Google](https://huggingface.co/google/mobilebert-uncased) organization.
|
||||
> [!TIP]
|
||||
> Click on the MobileBERT models in the right sidebar for more examples of how to apply MobileBERT to different language tasks.
|
||||
|
||||
*Natural Language Processing (NLP) has recently achieved great success by using huge pre-trained models with hundreds
|
||||
of millions of parameters. However, these models suffer from heavy model sizes and high latency such that they cannot
|
||||
be deployed to resource-limited mobile devices. In this paper, we propose MobileBERT for compressing and accelerating
|
||||
the popular BERT model. Like the original BERT, MobileBERT is task-agnostic, that is, it can be generically applied to
|
||||
various downstream NLP tasks via simple fine-tuning. Basically, MobileBERT is a thin version of BERT_LARGE, while
|
||||
equipped with bottleneck structures and a carefully designed balance between self-attentions and feed-forward networks.
|
||||
To train MobileBERT, we first train a specially designed teacher model, an inverted-bottleneck incorporated BERT_LARGE
|
||||
model. Then, we conduct knowledge transfer from this teacher to MobileBERT. Empirical studies show that MobileBERT is
|
||||
4.3x smaller and 5.5x faster than BERT_BASE while achieving competitive results on well-known benchmarks. On the
|
||||
natural language inference tasks of GLUE, MobileBERT achieves a GLUEscore o 77.7 (0.6 lower than BERT_BASE), and 62 ms
|
||||
latency on a Pixel 4 phone. On the SQuAD v1.1/v2.0 question answering task, MobileBERT achieves a dev F1 score of
|
||||
90.0/79.2 (1.5/2.1 higher than BERT_BASE).*
|
||||
The example below demonstrates how to predict the `[MASK]` token with [`Pipeline`], [`AutoModel`], and from the command line.
|
||||
|
||||
This model was contributed by [vshampor](https://huggingface.co/vshampor). The original code can be found [here](https://github.com/google-research/google-research/tree/master/mobilebert).
|
||||
<hfoptions id="usage">
|
||||
<hfoption id="Pipeline">
|
||||
|
||||
## Usage tips
|
||||
```py
|
||||
import torch
|
||||
from transformers import pipeline
|
||||
|
||||
- MobileBERT is a model with absolute position embeddings so it's usually advised to pad the inputs on the right rather
|
||||
than the left.
|
||||
- MobileBERT is similar to BERT and therefore relies on the masked language modeling (MLM) objective. It is therefore
|
||||
efficient at predicting masked tokens and at NLU in general, but is not optimal for text generation. Models trained
|
||||
with a causal language modeling (CLM) objective are better in that regard.
|
||||
pipeline = pipeline(
|
||||
task="fill-mask",
|
||||
model="google/mobilebert-uncased",
|
||||
torch_dtype=torch.float16,
|
||||
device=0
|
||||
)
|
||||
pipeline("The capital of France is [MASK].")
|
||||
```
|
||||
</hfoption>
|
||||
<hfoption id="AutoModel">
|
||||
|
||||
```py
|
||||
import torch
|
||||
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"google/mobilebert-uncased",
|
||||
)
|
||||
model = AutoModelForMaskedLM.from_pretrained(
|
||||
"google/mobilebert-uncased",
|
||||
torch_dtype=torch.float16,
|
||||
device_map="auto",
|
||||
)
|
||||
inputs = tokenizer("The capital of France is [MASK].", return_tensors="pt").to("cuda")
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
predictions = outputs.logits
|
||||
|
||||
masked_index = torch.where(inputs['input_ids'] == tokenizer.mask_token_id)[1]
|
||||
predicted_token_id = predictions[0, masked_index].argmax(dim=-1)
|
||||
predicted_token = tokenizer.decode(predicted_token_id)
|
||||
|
||||
print(f"The predicted token is: {predicted_token}")
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="transformers-cli">
|
||||
|
||||
```bash
|
||||
echo -e "The capital of France is [MASK]." | transformers-cli run --task fill-mask --model google/mobilebert-uncased --device 0
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
|
||||
## Resources
|
||||
## Notes
|
||||
|
||||
- [Text classification task guide](../tasks/sequence_classification)
|
||||
- [Token classification task guide](../tasks/token_classification)
|
||||
- [Question answering task guide](../tasks/question_answering)
|
||||
- [Masked language modeling task guide](../tasks/masked_language_modeling)
|
||||
- [Multiple choice task guide](../tasks/multiple_choice)
|
||||
- Inputs should be padded on the right because BERT uses absolute position embeddings.
|
||||
|
||||
## MobileBertConfig
|
||||
|
||||
|
||||
@ -14,52 +14,79 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# ModernBERT
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<div style="float: right;">
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
</div>
|
||||
|
||||
## Overview
|
||||
# ModernBERT
|
||||
|
||||
The ModernBERT model was proposed in [Smarter, Better, Faster, Longer: A Modern Bidirectional Encoder for Fast, Memory Efficient, and Long Context Finetuning and Inference](https://arxiv.org/abs/2412.13663) by Benjamin Warner, Antoine Chaffin, Benjamin Clavié, Orion Weller, Oskar Hallström, Said Taghadouini, Alexis Galalgher, Raja Bisas, Faisal Ladhak, Tom Aarsen, Nathan Cooper, Grifin Adams, Jeremy Howard and Iacopo Poli.
|
||||
[ModernBERT](https://huggingface.co/papers/2412.13663) is a modernized version of [`BERT`] trained on 2T tokens. It brings many improvements to the original architecture such as rotary positional embeddings to support sequences of up to 8192 tokens, unpadding to avoid wasting compute on padding tokens, GeGLU layers, and alternating attention.
|
||||
|
||||
It is a refresh of the traditional encoder architecture, as used in previous models such as [BERT](https://huggingface.co/docs/transformers/en/model_doc/bert) and [RoBERTa](https://huggingface.co/docs/transformers/en/model_doc/roberta).
|
||||
You can find all the original ModernBERT checkpoints under the [ModernBERT](https://huggingface.co/collections/answerdotai/modernbert-67627ad707a4acbf33c41deb) collection.
|
||||
|
||||
It builds on BERT and implements many modern architectural improvements which have been developed since its original release, such as:
|
||||
- [Rotary Positional Embeddings](https://huggingface.co/blog/designing-positional-encoding) to support sequences of up to 8192 tokens.
|
||||
- [Unpadding](https://arxiv.org/abs/2208.08124) to ensure no compute is wasted on padding tokens, speeding up processing time for batches with mixed-length sequences.
|
||||
- [GeGLU](https://arxiv.org/abs/2002.05202) Replacing the original MLP layers with GeGLU layers, shown to improve performance.
|
||||
- [Alternating Attention](https://arxiv.org/abs/2004.05150v2) where most attention layers employ a sliding window of 128 tokens, with Global Attention only used every 3 layers.
|
||||
- [Flash Attention](https://github.com/Dao-AILab/flash-attention) to speed up processing.
|
||||
- A model designed following recent [The Case for Co-Designing Model Architectures with Hardware](https://arxiv.org/abs/2401.14489), ensuring maximum efficiency across inference GPUs.
|
||||
- Modern training data scales (2 trillion tokens) and mixtures (including code ande math data)
|
||||
> [!TIP]
|
||||
> Click on the ModernBERT models in the right sidebar for more examples of how to apply ModernBERT to different language tasks.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
The example below demonstrates how to predict the `[MASK]` token with [`Pipeline`], [`AutoModel`], and from the command line.
|
||||
|
||||
*Encoder-only transformer models such as BERT offer a great performance-size tradeoff for retrieval and classification tasks with respect to larger decoder-only models. Despite being the workhorse of numerous production pipelines, there have been limited Pareto improvements to BERT since its release. In this paper, we introduce ModernBERT, bringing modern model optimizations to encoder-only models and representing a major Pareto improvement over older encoders. Trained on 2 trillion tokens with a native 8192 sequence length, ModernBERT models exhibit state-of-the-art results on a large pool of evaluations encompassing diverse classification tasks and both single and multi-vector retrieval on different domains (including code). In addition to strong downstream performance, ModernBERT is also the most speed and memory efficient encoder and is designed for inference on common GPUs.*
|
||||
<hfoptions id="usage">
|
||||
<hfoption id="Pipeline">
|
||||
|
||||
The original code can be found [here](https://github.com/answerdotai/modernbert).
|
||||
```py
|
||||
import torch
|
||||
from transformers import pipeline
|
||||
|
||||
## Resources
|
||||
pipeline = pipeline(
|
||||
task="fill-mask",
|
||||
model="answerdotai/ModernBERT-base",
|
||||
torch_dtype=torch.float16,
|
||||
device=0
|
||||
)
|
||||
pipeline("Plants create [MASK] through a process known as photosynthesis.")
|
||||
```
|
||||
|
||||
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with ModernBert.
|
||||
</hfoption>
|
||||
<hfoption id="AutoModel">
|
||||
|
||||
<PipelineTag pipeline="text-classification"/>
|
||||
```py
|
||||
import torch
|
||||
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
||||
|
||||
- A notebook on how to [finetune for General Language Understanding Evaluation (GLUE) with Transformers](https://github.com/AnswerDotAI/ModernBERT/blob/main/examples/finetune_modernbert_on_glue.ipynb), also available as a Google Colab [notebook](https://colab.research.google.com/github/AnswerDotAI/ModernBERT/blob/main/examples/finetune_modernbert_on_glue.ipynb). 🌎
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"answerdotai/ModernBERT-base",
|
||||
)
|
||||
model = AutoModelForMaskedLM.from_pretrained(
|
||||
"answerdotai/ModernBERT-base",
|
||||
torch_dtype=torch.float16,
|
||||
device_map="auto",
|
||||
attn_implementation="sdpa"
|
||||
)
|
||||
inputs = tokenizer("Plants create [MASK] through a process known as photosynthesis.", return_tensors="pt").to("cuda")
|
||||
|
||||
<PipelineTag pipeline="sentence-similarity"/>
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
predictions = outputs.logits
|
||||
|
||||
- A script on how to [finetune for text similarity or information retrieval with Sentence Transformers](https://github.com/AnswerDotAI/ModernBERT/blob/main/examples/train_st.py). 🌎
|
||||
- A script on how to [finetune for information retrieval with PyLate](https://github.com/AnswerDotAI/ModernBERT/blob/main/examples/train_pylate.py). 🌎
|
||||
masked_index = torch.where(inputs['input_ids'] == tokenizer.mask_token_id)[1]
|
||||
predicted_token_id = predictions[0, masked_index].argmax(dim=-1)
|
||||
predicted_token = tokenizer.decode(predicted_token_id)
|
||||
|
||||
<PipelineTag pipeline="fill-mask"/>
|
||||
print(f"The predicted token is: {predicted_token}")
|
||||
```
|
||||
|
||||
- [Masked language modeling task guide](../tasks/masked_language_modeling)
|
||||
</hfoption>
|
||||
<hfoption id="transformers-cli">
|
||||
|
||||
```bash
|
||||
echo -e "Plants create [MASK] through a process known as photosynthesis." | transformers-cli run --task fill-mask --model answerdotai/ModernBERT-base --device 0
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
## ModernBertConfig
|
||||
|
||||
@ -88,5 +115,15 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h
|
||||
[[autodoc]] ModernBertForTokenClassification
|
||||
- forward
|
||||
|
||||
## ModernBertForQuestionAnswering
|
||||
|
||||
[[autodoc]] ModernBertForQuestionAnswering
|
||||
- forward
|
||||
|
||||
### Usage tips
|
||||
|
||||
The ModernBert model can be fine-tuned using the HuggingFace Transformers library with its [official script](https://github.com/huggingface/transformers/blob/main/examples/pytorch/question-answering/run_qa.py) for question-answering tasks.
|
||||
|
||||
|
||||
</pt>
|
||||
</frameworkcontent>
|
||||
|
||||
@ -14,154 +14,123 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# OpenAI GPT
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="TensorFlow" src="https://img.shields.io/badge/TensorFlow-FF6F00?style=flat&logo=tensorflow&logoColor=white">
|
||||
<img alt="Flax" src="https://img.shields.io/badge/Flax-29a79b.svg?style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAC0AAAAtCAMAAAANxBKoAAAC7lBMVEUAAADg5vYHPVgAoJH+/v76+v39/f9JbLP///9+AIgAnY3///+mcqzt8fXy9fgkXa3Ax9709fr+///9/f8qXq49qp5AaLGMwrv8/P0eW60VWawxYq8yqJzG2dytt9Wyu9elzci519Lf3O3S2efY3OrY0+Xp7PT///////+dqNCexMc6Z7AGpJeGvbenstPZ5ejQ1OfJzOLa7ejh4+/r8fT29vpccbklWK8PVa0AS6ghW63O498vYa+lsdKz1NDRt9Kw1c672tbD3tnAxt7R6OHp5vDe7OrDyuDn6vLl6/EAQKak0MgATakkppo3ZK/Bz9y8w9yzu9jey97axdvHzeG21NHH4trTwthKZrVGZLSUSpuPQJiGAI+GAI8SWKydycLL4d7f2OTi1+S9xNzL0ePT6OLGzeEAo5U0qJw/aLEAo5JFa7JBabEAp5Y4qZ2QxLyKmsm3kL2xoMOehrRNb7RIbbOZgrGre68AUqwAqZqNN5aKJ5N/lMq+qsd8kMa4pcWzh7muhLMEV69juq2kbKqgUaOTR5uMMZWLLZSGAI5VAIdEAH+ovNDHuNCnxcy3qcaYx8K8msGplrx+wLahjbYdXrV6vbMvYK9DrZ8QrZ8tqJuFms+Sos6sw8ecy8RffsNVeMCvmb43aLltv7Q4Y7EZWK4QWa1gt6meZKUdr6GOAZVeA4xPAISyveLUwtivxtKTpNJ2jcqfvcltiMiwwcfAoMVxhL+Kx7xjdrqTe60tsaNQs6KaRKACrJ6UTZwkqpqTL5pkHY4AloSgsd2ptNXPvNOOncuxxsqFl8lmg8apt8FJcr9EbryGxLqlkrkrY7dRa7ZGZLQ5t6iXUZ6PPpgVpZeJCJFKAIGareTa0+KJod3H0deY2M+esM25usmYu8d2zsJOdcBVvrCLbqcAOaaHaKQAMaScWqKBXqCXMJ2RHpiLF5NmJZAdAHN2kta11dKu1M+DkcZLdb+Mcql3TppyRJdzQ5ZtNZNlIY+DF4+voCOQAAAAZ3RSTlMABAT+MEEJ/RH+/TP+Zlv+pUo6Ifz8+fco/fz6+evr39S9nJmOilQaF/7+/f38+smmoYp6b1T+/v7++vj189zU0tDJxsGzsrKSfv34+Pf27dDOysG9t6+n/vv6+vr59uzr1tG+tZ6Qg9Ym3QAABR5JREFUSMeNlVVUG1EQhpcuxEspXqS0SKEtxQp1d3d332STTRpIQhIISQgJhODu7lAoDoUCpe7u7u7+1puGpqnCPOyZvffbOXPm/PsP9JfQgyCC+tmTABTOcbxDz/heENS7/1F+9nhvkHePG0wNDLbGWwdXL+rbLWvpmZHXD8+gMfBjTh+aSe6Gnn7lwQIOTR0c8wfX3PWgv7avbdKwf/ZoBp1Gp/PvuvXW3vw5ib7emnTW4OR+3D4jB9vjNJ/7gNvfWWeH/TO/JyYrsiKCRjVEZA3UB+96kON+DxOQ/NLE8PE5iUYgIXjFnCOlxEQMaSGVxjg4gxOnEycGz8bptuNjVx08LscIgrzH3umcn+KKtiBIyvzOO2O99aAdR8cF19oZalnCtvREUw79tCd5sow1g1UKM6kXqUx4T8wsi3sTjJ3yzDmmhenLXLpo8u45eG5y4Vvbk6kkC4LLtJMowkSQxmk4ggVJEG+7c6QpHT8vvW9X7/o7+3ELmiJi2mEzZJiz8cT6TBlanBk70cB5GGIGC1gRDdZ00yADLW1FL6gqhtvNXNG5S9gdSrk4M1qu7JAsmYshzDS4peoMrU/gT7qQdqYGZaYhxZmVbGJAm/CS/HloWyhRUlknQ9KYcExTwS80d3VNOxUZJpITYyspl0LbhArhpZCD9cRWEQuhYkNGMHToQ/2Cs6swJlb39CsllxdXX6IUKh/H5jbnSsPKjgmoaFQ1f8wRLR0UnGE/RcDEjj2jXG1WVTwUs8+zxfcrVO+vSsuOpVKxCfYZiQ0/aPKuxQbQ8lIz+DClxC8u+snlcJ7Yr1z1JPqUH0V+GDXbOwAib931Y4Imaq0NTIXPXY+N5L18GJ37SVWu+hwXff8l72Ds9XuwYIBaXPq6Shm4l+Vl/5QiOlV+uTk6YR9PxKsI9xNJny31ygK1e+nIRC1N97EGkFPI+jCpiHe5PCEy7oWqWSwRrpOvhFzcbTWMbm3ZJAOn1rUKpYIt/lDhW/5RHHteeWFN60qo98YJuoq1nK3uW5AabyspC1BcIEpOhft+SZAShYoLSvnmSfnYADUERP5jJn2h5XtsgCRuhYQqAvwTwn33+YWEKUI72HX5AtfSAZDe8F2DtPPm77afhl0EkthzuCQU0BWApgQIH9+KB0JhopMM7bJrdTRoleM2JAVNMyPF+wdoaz+XJpGoVAQ7WXUkcV7gT3oUZyi/ISIJAVKhgNp+4b4veCFhYVJw4locdSjZCp9cPUhLF9EZ3KKzURepMEtCDPP3VcWFx4UIiZIklIpFNfHpdEafIF2aRmOcrUmjohbT2WUllbmRvgfbythbQO3222fpDJoufaQPncYYuqoGtUEsCJZL6/3PR5b4syeSjZMQG/T2maGANlXT2v8S4AULWaUkCxfLyW8iW4kdka+nEMjxpL2NCwsYNBp+Q61PF43zyDg9Bm9+3NNySn78jMZUUkumqE4Gp7JmFOdP1vc8PpRrzj9+wPinCy8K1PiJ4aYbnTYpCCbDkBSbzhu2QJ1Gd82t8jI8TH51+OzvXoWbnXUOBkNW+0mWFwGcGOUVpU81/n3TOHb5oMt2FgYGjzau0Nif0Ss7Q3XB33hjjQHjHA5E5aOyIQc8CBrLdQSs3j92VG+3nNEjbkbdbBr9zm04ruvw37vh0QKOdeGIkckc80fX3KH/h7PT4BOjgCty8VZ5ux1MoO5Cf5naca2LAsEgehI+drX8o/0Nu+W0m6K/I9gGPd/dfx/EN/wN62AhsBWuAAAAAElFTkSuQmCC
|
||||
">
|
||||
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<div style="float: right;">
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="TensorFlow" src="https://img.shields.io/badge/TensorFlow-FF6F00?style=flat&logo=tensorflow&logoColor=white">
|
||||
<img alt="Flax" src="https://img.shields.io/badge/Flax-29a79b.svg?style=flat&logo=data:image/png;base64,...">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
|
||||
</div>
|
||||
</div>
|
||||
|
||||
## Overview
|
||||
|
||||
OpenAI GPT model was proposed in [Improving Language Understanding by Generative Pre-Training](https://s3-us-west-2.amazonaws.com/openai-assets/research-covers/language-unsupervised/language_understanding_paper.pdf)
|
||||
by Alec Radford, Karthik Narasimhan, Tim Salimans and Ilya Sutskever. It's a causal (unidirectional) transformer
|
||||
pre-trained using language modeling on a large corpus with long range dependencies, the Toronto Book Corpus.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
*Natural language understanding comprises a wide range of diverse tasks such as textual entailment, question answering,
|
||||
semantic similarity assessment, and document classification. Although large unlabeled text corpora are abundant,
|
||||
labeled data for learning these specific tasks is scarce, making it challenging for discriminatively trained models to
|
||||
perform adequately. We demonstrate that large gains on these tasks can be realized by generative pretraining of a
|
||||
language model on a diverse corpus of unlabeled text, followed by discriminative fine-tuning on each specific task. In
|
||||
contrast to previous approaches, we make use of task-aware input transformations during fine-tuning to achieve
|
||||
effective transfer while requiring minimal changes to the model architecture. We demonstrate the effectiveness of our
|
||||
approach on a wide range of benchmarks for natural language understanding. Our general task-agnostic model outperforms
|
||||
discriminatively trained models that use architectures specifically crafted for each task, significantly improving upon
|
||||
the state of the art in 9 out of the 12 tasks studied.*
|
||||
|
||||
[Write With Transformer](https://transformer.huggingface.co/doc/gpt) is a webapp created and hosted by Hugging Face
|
||||
showcasing the generative capabilities of several models. GPT is one of them.
|
||||
|
||||
This model was contributed by [thomwolf](https://huggingface.co/thomwolf). The original code can be found [here](https://github.com/openai/finetune-transformer-lm).
|
||||
|
||||
## Usage tips
|
||||
|
||||
- GPT is a model with absolute position embeddings so it's usually advised to pad the inputs on the right rather than
|
||||
the left.
|
||||
- GPT was trained with a causal language modeling (CLM) objective and is therefore powerful at predicting the next
|
||||
token in a sequence. Leveraging this feature allows GPT-2 to generate syntactically coherent text as it can be
|
||||
observed in the *run_generation.py* example script.
|
||||
|
||||
|
||||
Note:
|
||||
# GPT
|
||||
|
||||
If you want to reproduce the original tokenization process of the *OpenAI GPT* paper, you will need to install `ftfy`
|
||||
and `SpaCy`:
|
||||
[GPT (Generative Pre-trained Transformer)](https://cdn.openai.com/research-covers/language-unsupervised/language_understanding_paper.pdf) focuses on effectively learning text representations and transferring them to tasks. This model trains the Transformer decoder to predict the next word, and then fine-tuned on labeled data.
|
||||
|
||||
```bash
|
||||
pip install spacy ftfy==4.4.3
|
||||
python -m spacy download en
|
||||
GPT can generate high-quality text, making it well-suited for a variety of natural language understanding tasks such as textual entailment, question answering, semantic similarity, and document classification.
|
||||
|
||||
You can find all the original GPT checkpoints under the [OpenAI community](https://huggingface.co/openai-community/openai-gpt) organization.
|
||||
|
||||
> [!TIP]
|
||||
> Click on the GPT models in the right sidebar for more examples of how to apply GPT to different language tasks.
|
||||
|
||||
The example below demonstrates how to generate text with [`Pipeline`], [`AutoModel`], and from the command line.
|
||||
|
||||
|
||||
|
||||
<hfoptions id="usage">
|
||||
<hfoption id="Pipeline">
|
||||
|
||||
|
||||
```python
|
||||
import torch
|
||||
from transformers import pipeline
|
||||
|
||||
generator = pipeline(task="text-generation", model="openai-community/gpt", torch_dtype=torch.float16, device=0)
|
||||
output = generator("The future of AI is", max_length=50, do_sample=True)
|
||||
print(output[0]["generated_text"])
|
||||
```
|
||||
|
||||
If you don't install `ftfy` and `SpaCy`, the [`OpenAIGPTTokenizer`] will default to tokenize
|
||||
using BERT's `BasicTokenizer` followed by Byte-Pair Encoding (which should be fine for most usage, don't worry).
|
||||
</hfoption>
|
||||
<hfoption id="AutoModel">
|
||||
|
||||
## Resources
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with OpenAI GPT. If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we'll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource.
|
||||
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt")
|
||||
model = AutoModelForCausalLM.from_pretrained("openai-community/openai-gpt", torch_dtype=torch.float16)
|
||||
|
||||
<PipelineTag pipeline="text-classification"/>
|
||||
inputs = tokenizer("The future of AI is", return_tensors="pt")
|
||||
outputs = model.generate(**inputs, max_length=50)
|
||||
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
||||
```
|
||||
|
||||
- A blog post on [outperforming OpenAI GPT-3 with SetFit for text-classification](https://www.philschmid.de/getting-started-setfit).
|
||||
- See also: [Text classification task guide](../tasks/sequence_classification)
|
||||
</hfoption>
|
||||
<hfoption id="transformers-cli">
|
||||
|
||||
<PipelineTag pipeline="text-generation"/>
|
||||
```bash
|
||||
echo -e "The future of AI is" | transformers-cli run --task text-generation --model openai-community/openai-gpt --device 0
|
||||
|
||||
- A blog on how to [Finetune a non-English GPT-2 Model with Hugging Face](https://www.philschmid.de/fine-tune-a-non-english-gpt-2-model-with-huggingface).
|
||||
- A blog on [How to generate text: using different decoding methods for language generation with Transformers](https://huggingface.co/blog/how-to-generate) with GPT-2.
|
||||
- A blog on [Training CodeParrot 🦜 from Scratch](https://huggingface.co/blog/codeparrot), a large GPT-2 model.
|
||||
- A blog on [Faster Text Generation with TensorFlow and XLA](https://huggingface.co/blog/tf-xla-generate) with GPT-2.
|
||||
- A blog on [How to train a Language Model with Megatron-LM](https://huggingface.co/blog/megatron-training) with a GPT-2 model.
|
||||
- A notebook on how to [finetune GPT2 to generate lyrics in the style of your favorite artist](https://colab.research.google.com/github/AlekseyKorshuk/huggingartists/blob/master/huggingartists-demo.ipynb). 🌎
|
||||
- A notebook on how to [finetune GPT2 to generate tweets in the style of your favorite Twitter user](https://colab.research.google.com/github/borisdayma/huggingtweets/blob/master/huggingtweets-demo.ipynb). 🌎
|
||||
- [Causal language modeling](https://huggingface.co/course/en/chapter7/6?fw=pt#training-a-causal-language-model-from-scratch) chapter of the 🤗 Hugging Face Course.
|
||||
- [`OpenAIGPTLMHeadModel`] is supported by this [causal language modeling example script](https://github.com/huggingface/transformers/tree/main/examples/pytorch/language-modeling#gpt-2gpt-and-causal-language-modeling), [text generation example script](https://github.com/huggingface/transformers/blob/main/examples/pytorch/text-generation/run_generation.py) and [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/language_modeling.ipynb).
|
||||
- [`TFOpenAIGPTLMHeadModel`] is supported by this [causal language modeling example script](https://github.com/huggingface/transformers/tree/main/examples/tensorflow/language-modeling#run_clmpy) and [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/language_modeling-tf.ipynb).
|
||||
- See also: [Causal language modeling task guide](../tasks/language_modeling)
|
||||
```
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
<PipelineTag pipeline="token-classification"/>
|
||||
## Notes
|
||||
|
||||
- A course material on [Byte-Pair Encoding tokenization](https://huggingface.co/course/en/chapter6/5).
|
||||
- Inputs should be padded on the right because GPT uses absolute position embeddings.
|
||||
|
||||
## OpenAIGPTConfig
|
||||
|
||||
[[autodoc]] OpenAIGPTConfig
|
||||
|
||||
## OpenAIGPTModel
|
||||
|
||||
[[autodoc]] OpenAIGPTModel
|
||||
- forward
|
||||
|
||||
## OpenAIGPTLMHeadModel
|
||||
|
||||
[[autodoc]] OpenAIGPTLMHeadModel
|
||||
- forward
|
||||
|
||||
## OpenAIGPTDoubleHeadsModel
|
||||
|
||||
[[autodoc]] OpenAIGPTDoubleHeadsModel
|
||||
- forward
|
||||
|
||||
## OpenAIGPTForSequenceClassification
|
||||
|
||||
[[autodoc]] OpenAIGPTForSequenceClassification
|
||||
- forward
|
||||
|
||||
## OpenAIGPTTokenizer
|
||||
|
||||
[[autodoc]] OpenAIGPTTokenizer
|
||||
- save_vocabulary
|
||||
|
||||
## OpenAIGPTTokenizerFast
|
||||
|
||||
[[autodoc]] OpenAIGPTTokenizerFast
|
||||
|
||||
## OpenAI specific outputs
|
||||
|
||||
[[autodoc]] models.openai.modeling_openai.OpenAIGPTDoubleHeadsModelOutput
|
||||
|
||||
[[autodoc]] models.openai.modeling_tf_openai.TFOpenAIGPTDoubleHeadsModelOutput
|
||||
|
||||
<frameworkcontent>
|
||||
<pt>
|
||||
|
||||
## OpenAIGPTModel
|
||||
|
||||
[[autodoc]] OpenAIGPTModel
|
||||
- forward
|
||||
|
||||
## OpenAIGPTLMHeadModel
|
||||
|
||||
[[autodoc]] OpenAIGPTLMHeadModel
|
||||
- forward
|
||||
|
||||
## OpenAIGPTDoubleHeadsModel
|
||||
|
||||
[[autodoc]] OpenAIGPTDoubleHeadsModel
|
||||
- forward
|
||||
|
||||
## OpenAIGPTForSequenceClassification
|
||||
|
||||
[[autodoc]] OpenAIGPTForSequenceClassification
|
||||
- forward
|
||||
|
||||
</pt>
|
||||
<tf>
|
||||
|
||||
## TFOpenAIGPTModel
|
||||
|
||||
[[autodoc]] TFOpenAIGPTModel
|
||||
- call
|
||||
- call
|
||||
|
||||
## TFOpenAIGPTLMHeadModel
|
||||
|
||||
[[autodoc]] TFOpenAIGPTLMHeadModel
|
||||
- call
|
||||
- call
|
||||
|
||||
## TFOpenAIGPTDoubleHeadsModel
|
||||
|
||||
[[autodoc]] TFOpenAIGPTDoubleHeadsModel
|
||||
- call
|
||||
- call
|
||||
|
||||
## TFOpenAIGPTForSequenceClassification
|
||||
|
||||
[[autodoc]] TFOpenAIGPTForSequenceClassification
|
||||
- call
|
||||
|
||||
</tf>
|
||||
</frameworkcontent>
|
||||
- call
|
||||
|
||||
@ -14,89 +14,157 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# PaliGemma
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<div style="float: right;">
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
</div>
|
||||
|
||||
## Overview
|
||||
# PaliGemma
|
||||
|
||||
The PaliGemma model was proposed in [PaliGemma – Google's Cutting-Edge Open Vision Language Model](https://huggingface.co/blog/paligemma) by Google. It is a 3B vision-language model composed by a [SigLIP](siglip) vision encoder and a [Gemma](gemma) language decoder linked by a multimodal linear projection. It cuts an image into a fixed number of VIT tokens and prepends it to an optional prompt. One particularity is that the model uses full block attention on all the image tokens plus the input text tokens. It comes in 3 resolutions, 224x224, 448x448 and 896x896 with 3 base models, with 55 fine-tuned versions for different tasks, and 2 mix models.
|
||||
[PaliGemma](https://huggingface.co/papers/2407.07726) is a family of vision-language models (VLMs), combining [SigLIP](./siglip) with the [Gemma](./gemma) 2B model. PaliGemma is available in 3B, 10B, and 28B parameters. The main purpose of PaliGemma is to provide an adaptable base VLM that is easy to transfer to other tasks. The SigLIP vision encoder is a "shape optimized" contrastively pretrained [ViT](./vit) that converts an image into a sequence of tokens and prepended to an optional prompt. The Gemma 2B model is used as the decoder. PaliGemma uses full attention on all image and text tokens to maximize its capacity.
|
||||
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/paligemma/paligemma_arch.png"
|
||||
alt="drawing" width="600"/>
|
||||
[PaliGemma 2](https://huggingface.co/papers/2412.03555) improves on the first model by using Gemma 2 (2B, 9B, and 27B parameter variants) as the decoder. These are available as **pt** or **mix** variants. The **pt** checkpoints are intended for further fine-tuning and the **mix** checkpoints are ready for use out of the box.
|
||||
|
||||
<small> PaliGemma architecture. Taken from the <a href="https://huggingface.co/blog/paligemma">blog post.</a> </small>
|
||||
You can find all the original PaliGemma checkpoints under the [PaliGemma](https://huggingface.co/collections/google/paligemma-release-6643a9ffbf57de2ae0448dda), [PaliGemma 2](https://huggingface.co/collections/google/paligemma-2-release-67500e1e1dbfdd4dee27ba48), and [PaliGemma 2 Mix](https://huggingface.co/collections/google/paligemma-2-mix-67ac6a251aaf3ee73679dcc4) collections.
|
||||
|
||||
This model was contributed by [Molbap](https://huggingface.co/Molbap).
|
||||
> [!TIP]
|
||||
> Click on the PaliGemma models in the right sidebar for more examples of how to apply PaliGemma to different vision and language tasks.
|
||||
|
||||
## Usage tips
|
||||
The example below demonstrates how to generate text based on an image with [`Pipeline`] or the [`AutoModel`] class.
|
||||
|
||||
- PaliGemma is not meant for conversational use, and it works best when fine-tuning to a specific use case. Some downstream tasks on which PaliGemma can be fine-tuned include image captioning, visual question answering (VQA), object detection, referring expression segmentation and document understanding.
|
||||
- One can use `PaliGemmaProcessor` to prepare images, text and optional labels for the model. When fine-tuning a PaliGemma model, the `suffix` argument can be passed to the processor which creates the `labels` for the model:
|
||||
<hfoptions id="usage">
|
||||
<hfoption id="Pipeline">
|
||||
|
||||
```python
|
||||
prompt = "What is on the flower?"
|
||||
answer = "a bee"
|
||||
inputs = processor(images=raw_image, text=prompt, suffix=answer, return_tensors="pt")
|
||||
```py
|
||||
import torch
|
||||
from transformers import pipeline
|
||||
|
||||
pipeline = pipeline(
|
||||
task="image-text-to-text",
|
||||
model="google/paligemma2-3b-mix-224",
|
||||
device=0,
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
pipeline(
|
||||
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg",
|
||||
text="What is in this image?"
|
||||
)
|
||||
```
|
||||
|
||||
## Usage Example
|
||||
</hfoption>
|
||||
<hfoption id="AutoModel">
|
||||
|
||||
The model can accept a single or multiple images. According to the [paper](https://arxiv.org/abs/2407.07726v1), the checkpoint PaliGemma can transfer to tasks which take multiple images as input. NLVR2 is one such task, which asks one question about two images, and requires looking at both to give the correct answer. Here's an example code for single and multi image inference.
|
||||
|
||||
### Single-image Inference
|
||||
|
||||
```python
|
||||
```py
|
||||
import torch
|
||||
import requests
|
||||
from PIL import Image
|
||||
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
|
||||
|
||||
model_id = "google/paligemma-3b-mix-224"
|
||||
model = PaliGemmaForConditionalGeneration.from_pretrained(model_id)
|
||||
processor = AutoProcessor.from_pretrained(model_id)
|
||||
|
||||
prompt = "What is on the flower?"
|
||||
image_file = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg?download=true"
|
||||
raw_image = Image.open(requests.get(image_file, stream=True).raw)
|
||||
inputs = processor(raw_image, prompt, return_tensors="pt")
|
||||
output = model.generate(**inputs, max_new_tokens=20)
|
||||
|
||||
print(processor.decode(output[0], skip_special_tokens=True)[inputs.input_ids.shape[1]: ])
|
||||
```
|
||||
|
||||
### Multi-image Inference
|
||||
|
||||
```python
|
||||
model_id = "google/paligemma-3b-ft-nlvr2-448" # checkpoint tuned for multiple images
|
||||
model = PaliGemmaForConditionalGeneration.from_pretrained(model_id)
|
||||
processor = PaliGemmaProcessor.from_pretrained(model_id)
|
||||
|
||||
prompt = "answer en Which of the two pictures shows a snowman, first or second?"
|
||||
stop_sign_image = Image.open(
|
||||
requests.get("https://www.ilankelman.org/stopsigns/australia.jpg", stream=True).raw
|
||||
model = PaliGemmaForConditionalGeneration.from_pretrained(
|
||||
"google/paligemma2-3b-mix-224",
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
attn_implementation="sdpa"
|
||||
)
|
||||
snow_image = Image.open(
|
||||
requests.get(
|
||||
"https://huggingface.co/microsoft/kosmos-2-patch14-224/resolve/main/snowman.jpg", stream=True
|
||||
).raw
|
||||
processor = AutoProcessor.from_pretrained(
|
||||
"google/paligemma2-3b-mix-224",
|
||||
)
|
||||
|
||||
inputs = processor(images=[[snow_image, stop_sign_image]], text=prompt, return_tensors="pt")
|
||||
|
||||
output = model.generate(**inputs, max_new_tokens=20)
|
||||
print(processor.decode(output[0], skip_special_tokens=True)[inputs.input_ids.shape[1]: ])
|
||||
prompt = "What is in this image?"
|
||||
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
|
||||
image = Image.open(requests.get(url, stream=True).raw)
|
||||
inputs = processor(image, prompt, return_tensors="pt").to("cuda")
|
||||
|
||||
output = model.generate(**inputs, max_new_tokens=50, cache_implementation="static")
|
||||
print(processor.decode(output[0], skip_special_tokens=True))
|
||||
```
|
||||
|
||||
## Resources
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with PaliGemma. If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we'll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource.
|
||||
Quantization reduces the memory burden of large models by representing the weights in a lower precision. Refer to the [Quantization](../quantization/overview) overview for more available quantization backends.
|
||||
|
||||
- A blog post introducing all the features of PaliGemma can be found [here](https://huggingface.co/blog/paligemma).
|
||||
- Demo notebooks on how to fine-tune PaliGemma for VQA with the Trainer API along with inference can be found [here](https://github.com/huggingface/notebooks/tree/main/examples/paligemma).
|
||||
- Demo notebooks on how to fine-tune PaliGemma on a custom dataset (receipt image -> JSON) along with inference can be found [here](https://github.com/NielsRogge/Transformers-Tutorials/tree/master/PaliGemma). 🌎
|
||||
The example below uses [torchao](../quantization/torchao) to only quantize the weights to int4.
|
||||
|
||||
```py
|
||||
# pip install torchao
|
||||
import torch
|
||||
import requests
|
||||
from PIL import Image
|
||||
from transformers import TorchAoConfig, AutoProcessor, PaliGemmaForConditionalGeneration
|
||||
|
||||
quantization_config = TorchAoConfig("int4_weight_only", group_size=128)
|
||||
model = PaliGemmaForConditionalGeneration.from_pretrained(
|
||||
"google/paligemma2-28b-mix-224",
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
quantization_config=quantization_config
|
||||
)
|
||||
processor = AutoProcessor.from_pretrained(
|
||||
"google/paligemma2-28b-mix-224",
|
||||
)
|
||||
|
||||
prompt = "What is in this image?"
|
||||
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
|
||||
image = Image.open(requests.get(url, stream=True).raw)
|
||||
inputs = processor(image, prompt, return_tensors="pt").to("cuda")
|
||||
|
||||
output = model.generate(**inputs, max_new_tokens=50, cache_implementation="static")
|
||||
print(processor.decode(output[0], skip_special_tokens=True))
|
||||
```
|
||||
|
||||
Use the [AttentionMaskVisualizer](https://github.com/huggingface/transformers/blob/beb9b5b02246b9b7ee81ddf938f93f44cfeaad19/src/transformers/utils/attention_visualizer.py#L139) to better understand what tokens the model can and cannot attend to.
|
||||
|
||||
```py
|
||||
from transformers.utils.attention_visualizer import AttentionMaskVisualizer
|
||||
|
||||
visualizer = AttentionMaskVisualizer("google/paligemma2-3b-mix-224")
|
||||
visualizer("<img> What is in this image?")
|
||||
```
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/paligemma2-attn-mask.png"/>
|
||||
</div>
|
||||
|
||||
## Notes
|
||||
|
||||
- PaliGemma is not a conversational model and works best when fine-tuned for specific downstream tasks such as image captioning, visual question answering (VQA), object detection, and document understanding.
|
||||
- [`PaliGemmaProcessor`] can prepare images, text, and optional labels for the model. Pass the `suffix` parameter to the processor to create labels for the model during fine-tuning.
|
||||
|
||||
```py
|
||||
prompt = "What is in this image?"
|
||||
answer = "a pallas cat"
|
||||
inputs = processor(images=image, text=prompt, suffix=answer, return_tensors="pt")
|
||||
```
|
||||
- PaliGemma can support multiple input images if it is fine-tuned to accept multiple images. For example, the [NLVR2](https://huggingface.co/google/paligemma-3b-ft-nlvr2-448) checkpoint supports multiple images. Pass the images as a list to the processor.
|
||||
|
||||
```py
|
||||
import torch
|
||||
import requests
|
||||
from PIL import Image
|
||||
from transformers import TorchAoConfig, AutoProcessor, PaliGemmaForConditionalGeneration
|
||||
|
||||
model = PaliGemmaForConditionalGeneration.from_pretrained("google/paligemma-3b-ft-nlvr2-448")
|
||||
processor = AutoProcessor.from_pretrained("google/paligemma-3b-ft-nlvr2-448")
|
||||
|
||||
prompt = "Are these two images the same?"
|
||||
cat_image = Image.open(
|
||||
requests.get("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg", stream=True).raw
|
||||
)
|
||||
cow_image = Image.open(
|
||||
requests.get(
|
||||
"https://media.istockphoto.com/id/1192867753/photo/cow-in-berchida-beach-siniscola.jpg?s=612x612&w=0&k=20&c=v0hjjniwsMNfJSuKWZuIn8pssmD5h5bSN1peBd1CmH4=", stream=True
|
||||
).raw
|
||||
)
|
||||
|
||||
inputs = processor(images=[[cat_image, cow_image]], text=prompt, return_tensors="pt")
|
||||
|
||||
output = model.generate(**inputs, max_new_tokens=20, cache_implementation="static")
|
||||
print(processor.decode(output[0], skip_special_tokens=True))
|
||||
```
|
||||
|
||||
## PaliGemmaConfig
|
||||
|
||||
|
||||
156
docs/source/en/model_doc/phi4_multimodal.md
Normal file
156
docs/source/en/model_doc/phi4_multimodal.md
Normal file
@ -0,0 +1,156 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
-->
|
||||
|
||||
# Phi4 Multimodal
|
||||
|
||||
## Overview
|
||||
|
||||
Phi4 Multimodal is a lightweight open multimodal foundation model that leverages the language, vision, and speech research and datasets used for Phi-3.5 and 4.0 models. The model processes text, image, and audio inputs, generating text outputs, and comes with 128K token context length. The model underwent an enhancement process, incorporating both supervised fine-tuning, direct preference optimization and RLHF (Reinforcement Learning from Human Feedback) to support precise instruction adherence and safety measures. The languages that each modal supports are the following:
|
||||
|
||||
- Text: Arabic, Chinese, Czech, Danish, Dutch, English, Finnish, French, German, Hebrew, Hungarian, Italian, Japanese, Korean, Norwegian, Polish, Portuguese, Russian, Spanish, Swedish, Thai, Turkish, Ukrainian
|
||||
- Vision: English
|
||||
- Audio: English, Chinese, German, French, Italian, Japanese, Spanish, Portuguese
|
||||
|
||||
This model was contributed by [Cyril Vallez](https://huggingface.co/cyrilvallez). The most recent code can be
|
||||
found [here](https://github.com/huggingface/transformers/blob/main/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py).
|
||||
|
||||
|
||||
## Usage tips
|
||||
|
||||
`Phi4-multimodal-instruct` can be found on the [Huggingface Hub](https://huggingface.co/microsoft/Phi-4-multimodal-instruct)
|
||||
|
||||
In the following, we demonstrate how to use it for inference depending on the input modalities (text, image, audio).
|
||||
|
||||
```python
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
|
||||
|
||||
|
||||
# Define model path
|
||||
model_path = "microsoft/Phi-4-multimodal-instruct"
|
||||
device = "cuda:0"
|
||||
|
||||
# Load model and processor
|
||||
processor = AutoProcessor.from_pretrained(model_path)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_path, device_map=device, torch_dtype=torch.float16)
|
||||
|
||||
# Optional: load the adapters (note that without them, the base model will very likely not work well)
|
||||
model.load_adapter(model_path, adapter_name="speech", device_map=device, adapter_kwargs={"subfolder": 'speech-lora'})
|
||||
model.load_adapter(model_path, adapter_name="vision", device_map=device, adapter_kwargs={"subfolder": 'vision-lora'})
|
||||
|
||||
# Part : Image Processing
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"},
|
||||
{"type": "text", "text": "What is shown in this image?"},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
model.set_adapter("vision") # if loaded, activate the vision adapter
|
||||
inputs = processor.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
).to(device, torch.float16)
|
||||
|
||||
# Generate response
|
||||
generate_ids = model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=1000,
|
||||
do_sample=False,
|
||||
)
|
||||
generate_ids = generate_ids[:, inputs['input_ids'].shape[1]:]
|
||||
response = processor.batch_decode(
|
||||
generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||||
)[0]
|
||||
print(f'>>> Response\n{response}')
|
||||
|
||||
|
||||
# Part 2: Audio Processing
|
||||
model.set_adapter("speech") # if loaded, activate the speech adapter
|
||||
audio_url = "https://upload.wikimedia.org/wikipedia/commons/b/b0/Barbara_Sahakian_BBC_Radio4_The_Life_Scientific_29_May_2012_b01j5j24.flac"
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "audio", "url": audio_url},
|
||||
{"type": "text", "text": "Transcribe the audio to text, and then translate the audio to French. Use <sep> as a separator between the origina transcript and the translation."},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
inputs = processor.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
sample_rate=sample_rate,
|
||||
).to(device, torch.float16)
|
||||
|
||||
generate_ids = model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=1000,
|
||||
do_sample=False,
|
||||
)
|
||||
generate_ids = generate_ids[:, inputs['input_ids'].shape[1]:]
|
||||
response = processor.batch_decode(
|
||||
generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||||
)[0]
|
||||
print(f'>>> Response\n{response}')
|
||||
```
|
||||
|
||||
## Phi4MultimodalFeatureExtractor
|
||||
|
||||
[[autodoc]] Phi4MultimodalFeatureExtractor
|
||||
|
||||
## Phi4MultimodalImageProcessorFast
|
||||
|
||||
[[autodoc]] Phi4MultimodalImageProcessorFast
|
||||
|
||||
## Phi4MultimodalProcessor
|
||||
|
||||
[[autodoc]] Phi4MultimodalProcessor
|
||||
|
||||
## Phi4MultimodalAudioConfig
|
||||
|
||||
[[autodoc]] Phi4MultimodalAudioConfig
|
||||
|
||||
## Phi4MultimodalVisionConfig
|
||||
|
||||
[[autodoc]] Phi4MultimodalVisionConfig
|
||||
|
||||
## Phi4MultimodalConfig
|
||||
|
||||
[[autodoc]] Phi4MultimodalConfig
|
||||
|
||||
## Phi4MultimodalAudioModel
|
||||
|
||||
[[autodoc]] Phi4MultimodalAudioModel
|
||||
|
||||
## Phi4MultimodalVisionModel
|
||||
|
||||
[[autodoc]] Phi4MultimodalVisionModel
|
||||
|
||||
## Phi4MultimodalModel
|
||||
|
||||
[[autodoc]] Phi4MultimodalModel
|
||||
- forward
|
||||
|
||||
## Phi4MultimodalForCausalLM
|
||||
|
||||
[[autodoc]] Phi4MultimodalForCausalLM
|
||||
- forward
|
||||
96
docs/source/en/model_doc/prompt_depth_anything.md
Normal file
96
docs/source/en/model_doc/prompt_depth_anything.md
Normal file
@ -0,0 +1,96 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# Prompt Depth Anything
|
||||
|
||||
## Overview
|
||||
|
||||
The Prompt Depth Anything model was introduced in [Prompting Depth Anything for 4K Resolution Accurate Metric Depth Estimation](https://arxiv.org/abs/2412.14015) by Haotong Lin, Sida Peng, Jingxiao Chen, Songyou Peng, Jiaming Sun, Minghuan Liu, Hujun Bao, Jiashi Feng, Xiaowei Zhou, Bingyi Kang.
|
||||
|
||||
|
||||
The abstract from the paper is as follows:
|
||||
|
||||
*Prompts play a critical role in unleashing the power of language and vision foundation models for specific tasks. For the first time, we introduce prompting into depth foundation models, creating a new paradigm for metric depth estimation termed Prompt Depth Anything. Specifically, we use a low-cost LiDAR as the prompt to guide the Depth Anything model for accurate metric depth output, achieving up to 4K resolution. Our approach centers on a concise prompt fusion design that integrates the LiDAR at multiple scales within the depth decoder. To address training challenges posed by limited datasets containing both LiDAR depth and precise GT depth, we propose a scalable data pipeline that includes synthetic data LiDAR simulation and real data pseudo GT depth generation. Our approach sets new state-of-the-arts on the ARKitScenes and ScanNet++ datasets and benefits downstream applications, including 3D reconstruction and generalized robotic grasping.*
|
||||
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/prompt_depth_anything_architecture.jpg"
|
||||
alt="drawing" width="600"/>
|
||||
|
||||
<small> Prompt Depth Anything overview. Taken from the <a href="https://arxiv.org/pdf/2412.14015">original paper</a>.</small>
|
||||
|
||||
## Usage example
|
||||
|
||||
The Transformers library allows you to use the model with just a few lines of code:
|
||||
|
||||
```python
|
||||
>>> import torch
|
||||
>>> import requests
|
||||
>>> import numpy as np
|
||||
|
||||
>>> from PIL import Image
|
||||
>>> from transformers import AutoImageProcessor, AutoModelForDepthEstimation
|
||||
|
||||
>>> url = "https://github.com/DepthAnything/PromptDA/blob/main/assets/example_images/image.jpg?raw=true"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
>>> image_processor = AutoImageProcessor.from_pretrained("depth-anything/prompt-depth-anything-vits-hf")
|
||||
>>> model = AutoModelForDepthEstimation.from_pretrained("depth-anything/prompt-depth-anything-vits-hf")
|
||||
|
||||
>>> prompt_depth_url = "https://github.com/DepthAnything/PromptDA/blob/main/assets/example_images/arkit_depth.png?raw=true"
|
||||
>>> prompt_depth = Image.open(requests.get(prompt_depth_url, stream=True).raw)
|
||||
>>> # the prompt depth can be None, and the model will output a monocular relative depth.
|
||||
|
||||
>>> # prepare image for the model
|
||||
>>> inputs = image_processor(images=image, return_tensors="pt", prompt_depth=prompt_depth)
|
||||
|
||||
>>> with torch.no_grad():
|
||||
... outputs = model(**inputs)
|
||||
|
||||
>>> # interpolate to original size
|
||||
>>> post_processed_output = image_processor.post_process_depth_estimation(
|
||||
... outputs,
|
||||
... target_sizes=[(image.height, image.width)],
|
||||
... )
|
||||
|
||||
>>> # visualize the prediction
|
||||
>>> predicted_depth = post_processed_output[0]["predicted_depth"]
|
||||
>>> depth = predicted_depth * 1000
|
||||
>>> depth = depth.detach().cpu().numpy()
|
||||
>>> depth = Image.fromarray(depth.astype("uint16")) # mm
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with Prompt Depth Anything.
|
||||
|
||||
- [Prompt Depth Anything Demo](https://huggingface.co/spaces/depth-anything/PromptDA)
|
||||
- [Prompt Depth Anything Interactive Results](https://promptda.github.io/interactive.html)
|
||||
|
||||
If you are interested in submitting a resource to be included here, please feel free to open a Pull Request and we'll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource.
|
||||
|
||||
## PromptDepthAnythingConfig
|
||||
|
||||
[[autodoc]] PromptDepthAnythingConfig
|
||||
|
||||
## PromptDepthAnythingForDepthEstimation
|
||||
|
||||
[[autodoc]] PromptDepthAnythingForDepthEstimation
|
||||
- forward
|
||||
|
||||
## PromptDepthAnythingImageProcessor
|
||||
|
||||
[[autodoc]] PromptDepthAnythingImageProcessor
|
||||
- preprocess
|
||||
- post_process_depth_estimation
|
||||
@ -14,51 +14,137 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# Qwen2
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<div style="float: right;">
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
</div>
|
||||
|
||||
## Overview
|
||||
# Qwen2
|
||||
|
||||
Qwen2 is the new model series of large language models from the Qwen team. Previously, we released the Qwen series, including Qwen2-0.5B, Qwen2-1.5B, Qwen2-7B, Qwen2-57B-A14B, Qwen2-72B, Qwen2-Audio, etc.
|
||||
[Qwen2](https://huggingface.co/papers/2407.10671) is a family of large language models (pretrained, instruction-tuned and mixture-of-experts) available in sizes from 0.5B to 72B parameters. The models are built on the Transformer architecture featuring enhancements like group query attention (GQA), rotary positional embeddings (RoPE), a mix of sliding window and full attention, and dual chunk attention with YARN for training stability. Qwen2 models support multiple languages and context lengths up to 131,072 tokens.
|
||||
|
||||
### Model Details
|
||||
You can find all the official Qwen2 checkpoints under the [Qwen2](https://huggingface.co/collections/Qwen/qwen2-6659360b33528ced941e557f) collection.
|
||||
|
||||
Qwen2 is a language model series including decoder language models of different model sizes. For each size, we release the base language model and the aligned chat model. It is based on the Transformer architecture with SwiGLU activation, attention QKV bias, group query attention, mixture of sliding window attention and full attention, etc. Additionally, we have an improved tokenizer adaptive to multiple natural languages and codes.
|
||||
> [!TIP]
|
||||
> Click on the Qwen2 models in the right sidebar for more examples of how to apply Qwen2 to different language tasks.
|
||||
|
||||
The example below demonstrates how to generate text with [`Pipeline`], [`AutoModel`], and from the command line using the instruction-tuned models.
|
||||
|
||||
## Usage tips
|
||||
|
||||
`Qwen2-7B` and `Qwen2-7B-Instruct` can be found on the [Huggingface Hub](https://huggingface.co/Qwen)
|
||||
|
||||
In the following, we demonstrate how to use `Qwen2-7B-Instruct` for the inference. Note that we have used the ChatML format for dialog, in this demo we show how to leverage `apply_chat_template` for this purpose.
|
||||
<hfoptions id="usage">
|
||||
<hfoption id="Pipeline">
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
>>> device = "cuda" # the device to load the model onto
|
||||
import torch
|
||||
from transformers import pipeline
|
||||
|
||||
>>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-7B-Instruct", device_map="auto")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-7B-Instruct")
|
||||
pipe = pipeline(
|
||||
task="text-generation",
|
||||
model="Qwen/Qwen2-1.5B-Instruct",
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map=0
|
||||
)
|
||||
|
||||
>>> prompt = "Give me a short introduction to large language model."
|
||||
|
||||
>>> messages = [{"role": "user", "content": prompt}]
|
||||
|
||||
>>> text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||
|
||||
>>> model_inputs = tokenizer([text], return_tensors="pt").to(device)
|
||||
|
||||
>>> generated_ids = model.generate(model_inputs.input_ids, max_new_tokens=512, do_sample=True)
|
||||
|
||||
>>> generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)]
|
||||
|
||||
>>> response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Tell me about the Qwen2 model family."},
|
||||
]
|
||||
outputs = pipe(messages, max_new_tokens=256, do_sample=True, temperature=0.7, top_k=50, top_p=0.95)
|
||||
print(outputs[0]["generated_text"][-1]['content'])
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="AutoModel">
|
||||
|
||||
```python
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"Qwen/Qwen2-1.5B-Instruct",
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
attn_implementation="sdpa"
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-1.5B-Instruct")
|
||||
|
||||
prompt = "Give me a short introduction to large language models."
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
text = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True
|
||||
)
|
||||
model_inputs = tokenizer([text], return_tensors="pt").to("cuda")
|
||||
|
||||
generated_ids = model.generate(
|
||||
model_inputs.input_ids,
|
||||
cache_implementation="static",
|
||||
max_new_tokens=512,
|
||||
do_sample=True,
|
||||
temperature=0.7,
|
||||
top_k=50,
|
||||
top_p=0.95
|
||||
)
|
||||
generated_ids = [
|
||||
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
|
||||
]
|
||||
|
||||
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||
print(response)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="transformers-cli">
|
||||
|
||||
```bash
|
||||
# pip install -U flash-attn --no-build-isolation
|
||||
transformers-cli chat --model_name_or_path Qwen/Qwen2-7B-Instruct --torch_dtype auto --attn_implementation flash_attention_2 --device 0
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
Quantization reduces the memory burden of large models by representing the weights in a lower precision. Refer to the [Quantization](../quantization/overview) overview for more available quantization backends.
|
||||
|
||||
The example below uses [bitsandbytes](../quantization/bitsandbytes) to quantize the weights to 4-bits.
|
||||
|
||||
```python
|
||||
# pip install -U flash-attn --no-build-isolation
|
||||
import torch
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
|
||||
|
||||
quantization_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_compute_dtype=torch.bfloat16,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
bnb_4bit_use_double_quant=True,
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-7B")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"Qwen/Qwen2-7B",
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
quantization_config=quantization_config,
|
||||
attn_implementation="flash_attention_2"
|
||||
)
|
||||
|
||||
inputs = tokenizer("The Qwen2 model family is", return_tensors="pt").to("cuda")
|
||||
outputs = model.generate(**inputs, max_new_tokens=100)
|
||||
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
||||
```
|
||||
|
||||
|
||||
## Notes
|
||||
|
||||
- Ensure your Transformers library version is up-to-date. Qwen2 requires Transformers>=4.37.0 for full support.
|
||||
|
||||
## Qwen2Config
|
||||
|
||||
[[autodoc]] Qwen2Config
|
||||
|
||||
@ -14,45 +14,74 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# Qwen2.5-VL
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<div style="float: right;">
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white"> </div>
|
||||
</div>
|
||||
|
||||
## Overview
|
||||
# Qwen2.5-VL
|
||||
|
||||
The [Qwen2.5-VL](https://qwenlm.github.io/blog/qwen2_5-vl/) model is an update to [Qwen2-VL](https://arxiv.org/abs/2409.12191) from Qwen team, Alibaba Group.
|
||||
[Qwen2.5-VL](https://huggingface.co/papers/2502.13923) is a multimodal vision-language model, available in 3B, 7B, and 72B parameters, pretrained on 4.1T tokens. The model introduces window attention in the ViT encoder to accelerate training and inference, dynamic FPS sampling on the spatial and temporal dimensions for better video understanding across different sampling rates, and an upgraded MRoPE (multi-resolutional rotary positional encoding) mechanism to better capture and learn temporal dynamics.
|
||||
|
||||
The abstract from this update is the following:
|
||||
|
||||
*Qwen2.5-VL marks a major step forward from Qwen2-VL, built upon the latest Qwen2.5 LLM. We've accelerated training and testing through the strategic implementation of window attention within the ViT. The ViT architecture itself has been refined with SwiGLU and RMSNorm, aligning it more closely with the LLM's structure. A key innovation is the expansion of native dynamic resolution to encompass the temporal dimension, in addition to spatial aspects. Furthermore, we've upgraded MRoPE, incorporating absolute time alignment on the time axis to allow the model to effectively capture temporal dynamics, regardless of frame rate, leading to superior video understanding.*
|
||||
You can find all the original Qwen2.5-VL checkpoints under the [Qwen2.5-VL](https://huggingface.co/collections/Qwen/qwen25-vl-6795ffac22b334a837c0f9a5) collection.
|
||||
|
||||
## Usage example
|
||||
> [!TIP]
|
||||
> Click on the Qwen2.5-VL models in the right sidebar for more examples of how to apply Qwen2.5-VL to different vision and language tasks.
|
||||
|
||||
### Single Media inference
|
||||
The example below demonstrates how to generate text based on an image with [`Pipeline`] or the [`AutoModel`] class.
|
||||
|
||||
The model can accept both images and videos as input. Here's an example code for inference.
|
||||
|
||||
```python
|
||||
<hfoptions id="usage">
|
||||
<hfoption id="Pipeline">
|
||||
|
||||
```py
|
||||
import torch
|
||||
from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor
|
||||
from transformers import pipeline
|
||||
pipe = pipeline(
|
||||
task="image-text-to-text",
|
||||
model="Qwen/Qwen2.5-VL-7B-Instruct",
|
||||
device=0,
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
"url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg",
|
||||
},
|
||||
{ "type": "text", "text": "Describe this image."},
|
||||
]
|
||||
}
|
||||
]
|
||||
pipe(text=messages,max_new_tokens=20, return_full_text=False)
|
||||
|
||||
# Load the model in half-precision on the available device(s)
|
||||
model = Qwen2_5_VLForConditionalGeneration.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct", device_map="auto")
|
||||
```
|
||||
</hfoption>
|
||||
|
||||
<hfoption id="AutoModel">
|
||||
|
||||
```py
|
||||
import torch
|
||||
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
|
||||
|
||||
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||
"Qwen/Qwen2.5-VL-7B-Instruct",
|
||||
torch_dtype=torch.float16,
|
||||
device_map="auto",
|
||||
attn_implementation="sdpa"
|
||||
)
|
||||
processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
|
||||
|
||||
|
||||
conversation = [
|
||||
messages = [
|
||||
{
|
||||
"role":"user",
|
||||
"content":[
|
||||
{
|
||||
"type":"image",
|
||||
"url": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg"
|
||||
"url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
|
||||
},
|
||||
{
|
||||
"type":"text",
|
||||
@ -60,212 +89,145 @@ conversation = [
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
]
|
||||
|
||||
inputs = processor.apply_chat_template(
|
||||
conversation,
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt"
|
||||
).to(model.device)
|
||||
).to("cuda")
|
||||
|
||||
|
||||
# Inference: Generation of the output
|
||||
output_ids = model.generate(**inputs, max_new_tokens=128)
|
||||
generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, output_ids)]
|
||||
output_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
||||
print(output_text)
|
||||
|
||||
# Video
|
||||
conversation = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "video", "path": "/path/to/video.mp4"},
|
||||
{"type": "text", "text": "What happened in the video?"},
|
||||
],
|
||||
}
|
||||
generated_ids = model.generate(**inputs, max_new_tokens=128)
|
||||
generated_ids_trimmed = [
|
||||
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
||||
]
|
||||
|
||||
inputs = processor.apply_chat_template(
|
||||
conversation,
|
||||
video_fps=1,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt"
|
||||
).to(model.device)
|
||||
|
||||
# Inference: Generation of the output
|
||||
output_ids = model.generate(**inputs, max_new_tokens=128)
|
||||
generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, output_ids)]
|
||||
output_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
||||
print(output_text)
|
||||
```
|
||||
|
||||
### Batch Mixed Media Inference
|
||||
|
||||
The model can batch inputs composed of mixed samples of various types such as images, videos, and text. Here is an example.
|
||||
|
||||
```python
|
||||
# Conversation for the first image
|
||||
conversation1 = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image", "path": "/path/to/image1.jpg"},
|
||||
{"type": "text", "text": "Describe this image."}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
# Conversation with two images
|
||||
conversation2 = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image", "path": "/path/to/image2.jpg"},
|
||||
{"type": "image", "path": "/path/to/image3.jpg"},
|
||||
{"type": "text", "text": "What is written in the pictures?"}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
# Conversation with pure text
|
||||
conversation3 = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "who are you?"
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
# Conversation with mixed midia
|
||||
conversation4 = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image", "path": "/path/to/image3.jpg"},
|
||||
{"type": "image", "path": "/path/to/image4.jpg"},
|
||||
{"type": "video", "path": "/path/to/video.jpg"},
|
||||
{"type": "text", "text": "What are the common elements in these medias?"},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
conversations = [conversation1, conversation2, conversation3, conversation4]
|
||||
# Preparation for batch inference
|
||||
ipnuts = processor.apply_chat_template(
|
||||
conversations,
|
||||
video_fps=1,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt"
|
||||
).to(model.device)
|
||||
|
||||
|
||||
# Batch Inference
|
||||
output_ids = model.generate(**inputs, max_new_tokens=128)
|
||||
generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, output_ids)]
|
||||
output_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
||||
print(output_text)
|
||||
```
|
||||
|
||||
### Usage Tips
|
||||
|
||||
#### Image Resolution trade-off
|
||||
|
||||
The model supports a wide range of resolution inputs. By default, it uses the native resolution for input, but higher resolutions can enhance performance at the cost of more computation. Users can set the minimum and maximum number of pixels to achieve an optimal configuration for their needs.
|
||||
|
||||
```python
|
||||
min_pixels = 224*224
|
||||
max_pixels = 2048*2048
|
||||
processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels)
|
||||
```
|
||||
|
||||
In case of limited GPU RAM, one can reduce the resolution as follows:
|
||||
|
||||
```python
|
||||
min_pixels = 256*28*28
|
||||
max_pixels = 1024*28*28
|
||||
processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels)
|
||||
```
|
||||
This ensures each image gets encoded using a number between 256-1024 tokens. The 28 comes from the fact that the model uses a patch size of 14 and a temporal patch size of 2 (14 x 2 = 28).
|
||||
|
||||
#### Multiple Image Inputs
|
||||
|
||||
By default, images and video content are directly included in the conversation. When handling multiple images, it's helpful to add labels to the images and videos for better reference. Users can control this behavior with the following settings:
|
||||
|
||||
```python
|
||||
conversation = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image"},
|
||||
{"type": "text", "text": "Hello, how are you?"}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "I'm doing well, thank you for asking. How can I assist you today?"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "Can you describe these images and video?"},
|
||||
{"type": "image"},
|
||||
{"type": "image"},
|
||||
{"type": "video"},
|
||||
{"type": "text", "text": "These are from my vacation."}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "I'd be happy to describe the images and video for you. Could you please provide more context about your vacation?"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "It was a trip to the mountains. Can you see the details in the images and video?"
|
||||
}
|
||||
]
|
||||
|
||||
# default:
|
||||
prompt_without_id = processor.apply_chat_template(conversation, add_generation_prompt=True)
|
||||
# Excepted output: '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Hello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing well, thank you for asking. How can I assist you today?<|im_end|>\n<|im_start|>user\nCan you describe these images and video?<|vision_start|><|image_pad|><|vision_end|><|vision_start|><|image_pad|><|vision_end|><|vision_start|><|video_pad|><|vision_end|>These are from my vacation.<|im_end|>\n<|im_start|>assistant\nI'd be happy to describe the images and video for you. Could you please provide more context about your vacation?<|im_end|>\n<|im_start|>user\nIt was a trip to the mountains. Can you see the details in the images and video?<|im_end|>\n<|im_start|>assistant\n'
|
||||
|
||||
|
||||
# add ids
|
||||
prompt_with_id = processor.apply_chat_template(conversation, add_generation_prompt=True, add_vision_id=True)
|
||||
# Excepted output: '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nPicture 1: <|vision_start|><|image_pad|><|vision_end|>Hello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing well, thank you for asking. How can I assist you today?<|im_end|>\n<|im_start|>user\nCan you describe these images and video?Picture 2: <|vision_start|><|image_pad|><|vision_end|>Picture 3: <|vision_start|><|image_pad|><|vision_end|>Video 1: <|vision_start|><|video_pad|><|vision_end|>These are from my vacation.<|im_end|>\n<|im_start|>assistant\nI'd be happy to describe the images and video for you. Could you please provide more context about your vacation?<|im_end|>\n<|im_start|>user\nIt was a trip to the mountains. Can you see the details in the images and video?<|im_end|>\n<|im_start|>assistant\n'
|
||||
|
||||
```
|
||||
|
||||
#### Flash-Attention 2 to speed up generation
|
||||
|
||||
First, make sure to install the latest version of Flash Attention 2:
|
||||
|
||||
```bash
|
||||
pip install -U flash-attn --no-build-isolation
|
||||
```
|
||||
|
||||
Also, you should have hardware that is compatible with FlashAttention 2. Read more about it in the official documentation of the [flash attention repository](https://github.com/Dao-AILab/flash-attention). FlashAttention-2 can only be used when a model is loaded in `torch.float16` or `torch.bfloat16`.
|
||||
|
||||
To load and run a model using FlashAttention-2, add `attn_implementation="flash_attention_2"` when loading the model:
|
||||
|
||||
```python
|
||||
from transformers import Qwen2_5_VLForConditionalGeneration
|
||||
|
||||
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||
"Qwen/Qwen2.5-VL-7B-Instruct",
|
||||
torch_dtype=torch.bfloat16,
|
||||
attn_implementation="flash_attention_2",
|
||||
output_text = processor.batch_decode(
|
||||
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||||
)
|
||||
print(output_text)
|
||||
```
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
Quantization reduces the memory burden of large models by representing the weights in a lower precision. Refer to the [Quantization](../quantization/overview) overview for more available quantization backends.
|
||||
|
||||
The example below uses [torchao](../quantization/torchao) to only quantize the weights to int4.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from transformers import TorchAoConfig, Gemma3ForConditionalGeneration, AutoProcessor
|
||||
|
||||
quantization_config = TorchAoConfig("int4_weight_only", group_size=128)
|
||||
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||
"Qwen/Qwen2.5-VL-7B-Instruct",
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
quantization_config=quantization_config
|
||||
)
|
||||
|
||||
```
|
||||
### Notes
|
||||
|
||||
- Use Qwen2.5-VL for video inputs by setting `"type": "video"` as shown below.
|
||||
```python
|
||||
conversation = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "video", "path": "/path/to/video.mp4"},
|
||||
{"type": "text", "text": "What happened in the video?"},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
inputs = processor.apply_chat_template(
|
||||
conversation,
|
||||
video_fps=1,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt"
|
||||
).to(model.device)
|
||||
|
||||
# Inference: Generation of the output
|
||||
output_ids = model.generate(**inputs, max_new_tokens=128)
|
||||
generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, output_ids)]
|
||||
output_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
||||
print(output_text)
|
||||
```
|
||||
- Use Qwen2.5-VL for a mixed batch of inputs (images, videos, text). Add labels when handling multiple images or videos for better reference
|
||||
as show below.
|
||||
```python
|
||||
import torch
|
||||
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
|
||||
|
||||
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||
"Qwen/Qwen2.5-VL-7B-Instruct",
|
||||
torch_dtype=torch.float16,
|
||||
device_map="auto",
|
||||
attn_implementation="sdpa"
|
||||
)
|
||||
processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
|
||||
conversation = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image"},
|
||||
{"type": "text", "text": "Hello, how are you?"}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "I'm doing well, thank you for asking. How can I assist you today?"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "Can you describe these images and video?"},
|
||||
{"type": "image"},
|
||||
{"type": "image"},
|
||||
{"type": "video"},
|
||||
{"type": "text", "text": "These are from my vacation."}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "I'd be happy to describe the images and video for you. Could you please provide more context about your vacation?"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "It was a trip to the mountains. Can you see the details in the images and video?"
|
||||
}
|
||||
]
|
||||
|
||||
# default:
|
||||
prompt_without_id = processor.apply_chat_template(conversation, add_generation_prompt=True)
|
||||
# Excepted output: '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Hello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing well, thank you for asking. How can I assist you today?<|im_end|>\n<|im_start|>user\nCan you describe these images and video?<|vision_start|><|image_pad|><|vision_end|><|vision_start|><|image_pad|><|vision_end|><|vision_start|><|video_pad|><|vision_end|>These are from my vacation.<|im_end|>\n<|im_start|>assistant\nI'd be happy to describe the images and video for you. Could you please provide more context about your vacation?<|im_end|>\n<|im_start|>user\nIt was a trip to the mountains. Can you see the details in the images and video?<|im_end|>\n<|im_start|>assistant\n'
|
||||
|
||||
|
||||
# add ids
|
||||
prompt_with_id = processor.apply_chat_template(conversation, add_generation_prompt=True, add_vision_id=True)
|
||||
# Excepted output: '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nPicture 1: <|vision_start|><|image_pad|><|vision_end|>Hello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing well, thank you for asking. How can I assist you today?<|im_end|>\n<|im_start|>user\nCan you describe these images and video?Picture 2: <|vision_start|><|image_pad|><|vision_end|>Picture 3: <|vision_start|><|image_pad|><|vision_end|>Video 1: <|vision_start|><|video_pad|><|vision_end|>These are from my vacation.<|im_end|>\n<|im_start|>assistant\nI'd be happy to describe the images and video for you. Could you please provide more context about your vacation?<|im_end|>\n<|im_start|>user\nIt was a trip to the mountains. Can you see the details in the images and video?<|im_end|>\n<|im_start|>assistant\n'
|
||||
```
|
||||
|
||||
- Use the `min_pixels` and `max_pixels` parameters in [`AutoProcessor`] to set the resolution.
|
||||
|
||||
```python
|
||||
min_pixels = 224*224
|
||||
max_pixels = 2048*2048
|
||||
processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels)
|
||||
```
|
||||
|
||||
Higher resolution can require more compute whereas reducing the resolution can save memory as follows:
|
||||
|
||||
```python
|
||||
min_pixels = 256*28*28
|
||||
max_pixels = 1024*28*28
|
||||
processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels)
|
||||
```
|
||||
## Qwen2_5_VLConfig
|
||||
|
||||
[[autodoc]] Qwen2_5_VLConfig
|
||||
|
||||
@ -29,7 +29,7 @@ The Qwen2-Audio is the new model series of large audio-language models from the
|
||||
* voice chat: users can freely engage in voice interactions with Qwen2-Audio without text input
|
||||
* audio analysis: users could provide audio and text instructions for analysis during the interaction
|
||||
|
||||
It was proposed in [Qwen2-Audio Technical Report](https://arxiv.org/abs/2407.10759) by Yunfei Chu, Jin Xu, Qian Yang, Haojie Wei, Xipin Wei, Zhifang Guo, Yichong Leng, Yuanjun Lv, Jinzheng He, Junyang Lin, Chang Zhou, Jingren Zhou.
|
||||
It was proposed in [Qwen2-Audio Technical Report](https://arxiv.org/abs/2407.10759) by Yunfei Chu, Jin Xu, Qian Yang, Haojie Wei, Xipin Wei, Zhifang Guo, Yichong Leng, Yuanjun Lv, Jinzheng He, Junyang Lin, Chang Zhou, Jingren Zhou.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
@ -100,7 +100,7 @@ for message in conversation:
|
||||
for ele in message["content"]:
|
||||
if ele["type"] == "audio":
|
||||
audios.append(librosa.load(
|
||||
BytesIO(urlopen(ele['audio_url']).read()),
|
||||
BytesIO(urlopen(ele['audio_url']).read()),
|
||||
sr=processor.feature_extractor.sampling_rate)[0]
|
||||
)
|
||||
|
||||
@ -125,7 +125,7 @@ processor = AutoProcessor.from_pretrained("Qwen/Qwen2-Audio-7B-Instruct")
|
||||
model = Qwen2AudioForConditionalGeneration.from_pretrained("Qwen/Qwen2-Audio-7B-Instruct", device_map="auto")
|
||||
|
||||
conversation = [
|
||||
{'role': 'system', 'content': 'You are a helpful assistant.'},
|
||||
{'role': 'system', 'content': 'You are a helpful assistant.'},
|
||||
{"role": "user", "content": [
|
||||
{"type": "audio", "audio_url": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/glass-breaking-151256.mp3"},
|
||||
{"type": "text", "text": "What's that sound?"},
|
||||
@ -148,7 +148,7 @@ for message in conversation:
|
||||
if ele["type"] == "audio":
|
||||
audios.append(
|
||||
librosa.load(
|
||||
BytesIO(urlopen(ele['audio_url']).read()),
|
||||
BytesIO(urlopen(ele['audio_url']).read()),
|
||||
sr=processor.feature_extractor.sampling_rate)[0]
|
||||
)
|
||||
|
||||
@ -203,7 +203,7 @@ for conversation in conversations:
|
||||
if ele["type"] == "audio":
|
||||
audios.append(
|
||||
librosa.load(
|
||||
BytesIO(urlopen(ele['audio_url']).read()),
|
||||
BytesIO(urlopen(ele['audio_url']).read()),
|
||||
sr=processor.feature_extractor.sampling_rate)[0]
|
||||
)
|
||||
|
||||
@ -221,7 +221,7 @@ response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_
|
||||
|
||||
[[autodoc]] Qwen2AudioConfig
|
||||
|
||||
## Qwen2AudioConfig
|
||||
## Qwen2AudioEncoderConfig
|
||||
|
||||
[[autodoc]] Qwen2AudioEncoderConfig
|
||||
|
||||
@ -229,6 +229,11 @@ response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_
|
||||
|
||||
[[autodoc]] Qwen2AudioProcessor
|
||||
|
||||
## Qwen2AudioEncoder
|
||||
|
||||
[[autodoc]] Qwen2AudioEncoder
|
||||
- forward
|
||||
|
||||
## Qwen2AudioForConditionalGeneration
|
||||
|
||||
[[autodoc]] Qwen2AudioForConditionalGeneration
|
||||
|
||||
59
docs/source/en/model_doc/qwen3.md
Normal file
59
docs/source/en/model_doc/qwen3.md
Normal file
@ -0,0 +1,59 @@
|
||||
<!--Copyright 2024 The Qwen Team and The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# Qwen3
|
||||
|
||||
## Overview
|
||||
|
||||
To be released with the official model launch.
|
||||
|
||||
### Model Details
|
||||
|
||||
To be released with the official model launch.
|
||||
|
||||
|
||||
## Usage tips
|
||||
|
||||
To be released with the official model launch.
|
||||
|
||||
## Qwen3Config
|
||||
|
||||
[[autodoc]] Qwen3Config
|
||||
|
||||
## Qwen3Model
|
||||
|
||||
[[autodoc]] Qwen3Model
|
||||
- forward
|
||||
|
||||
## Qwen3ForCausalLM
|
||||
|
||||
[[autodoc]] Qwen3ForCausalLM
|
||||
- forward
|
||||
|
||||
## Qwen3ForSequenceClassification
|
||||
|
||||
[[autodoc]] Qwen3ForSequenceClassification
|
||||
- forward
|
||||
|
||||
## Qwen3ForTokenClassification
|
||||
|
||||
[[autodoc]] Qwen3ForTokenClassification
|
||||
- forward
|
||||
|
||||
## Qwen3ForQuestionAnswering
|
||||
|
||||
[[autodoc]] Qwen3ForQuestionAnswering
|
||||
- forward
|
||||
58
docs/source/en/model_doc/qwen3_moe.md
Normal file
58
docs/source/en/model_doc/qwen3_moe.md
Normal file
@ -0,0 +1,58 @@
|
||||
<!--Copyright 2024 The Qwen Team and The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# Qwen3MoE
|
||||
|
||||
## Overview
|
||||
|
||||
To be released with the official model launch.
|
||||
|
||||
### Model Details
|
||||
|
||||
To be released with the official model launch.
|
||||
|
||||
## Usage tips
|
||||
|
||||
To be released with the official model launch.
|
||||
|
||||
## Qwen3MoeConfig
|
||||
|
||||
[[autodoc]] Qwen3MoeConfig
|
||||
|
||||
## Qwen3MoeModel
|
||||
|
||||
[[autodoc]] Qwen3MoeModel
|
||||
- forward
|
||||
|
||||
## Qwen3MoeForCausalLM
|
||||
|
||||
[[autodoc]] Qwen3MoeForCausalLM
|
||||
- forward
|
||||
|
||||
## Qwen3MoeForSequenceClassification
|
||||
|
||||
[[autodoc]] Qwen3MoeForSequenceClassification
|
||||
- forward
|
||||
|
||||
## Qwen3MoeForTokenClassification
|
||||
|
||||
[[autodoc]] Qwen3MoeForTokenClassification
|
||||
- forward
|
||||
|
||||
## Qwen3MoeForQuestionAnswering
|
||||
|
||||
[[autodoc]] Qwen3MoeForQuestionAnswering
|
||||
- forward
|
||||
@ -149,12 +149,24 @@ alt="drawing" width="900"/>
|
||||
[[autodoc]] SamImageProcessor
|
||||
|
||||
|
||||
## SamVisionModel
|
||||
|
||||
[[autodoc]] SamVisionModel
|
||||
- forward
|
||||
|
||||
|
||||
## SamModel
|
||||
|
||||
[[autodoc]] SamModel
|
||||
- forward
|
||||
|
||||
|
||||
## TFSamVisionModel
|
||||
|
||||
[[autodoc]] TFSamVisionModel
|
||||
- call
|
||||
|
||||
|
||||
## TFSamModel
|
||||
|
||||
[[autodoc]] TFSamModel
|
||||
|
||||
100
docs/source/en/model_doc/shieldgemma2.md
Normal file
100
docs/source/en/model_doc/shieldgemma2.md
Normal file
@ -0,0 +1,100 @@
|
||||
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# ShieldGemma 2
|
||||
|
||||
## Overview
|
||||
|
||||
The ShieldGemma 2 model was proposed in a [technical report](https://arxiv.org/abs/2504.01081) by Google. ShieldGemma 2, built on [Gemma 3](https://ai.google.dev/gemma/docs/core/model_card_3), is a 4 billion (4B) parameter model that checks the safety of both synthetic and natural images against key categories to help you build robust datasets and models. With this addition to the Gemma family of models, researchers and developers can now easily minimize the risk of harmful content in their models across key areas of harm as defined below:
|
||||
|
||||
- No Sexually Explicit content: The image shall not contain content that depicts explicit or graphic sexual acts (e.g., pornography, erotic nudity, depictions of rape or sexual assault).
|
||||
- No Dangerous Content: The image shall not contain content that facilitates or encourages activities that could cause real-world harm (e.g., building firearms and explosive devices, promotion of terrorism, instructions for suicide).
|
||||
- No Violence/Gore content: The image shall not contain content that depicts shocking, sensational, or gratuitous violence (e.g., excessive blood and gore, gratuitous violence against animals, extreme injury or moment of death).
|
||||
|
||||
We recommend using ShieldGemma 2 as an input filter to vision language models, or as an output filter of image generation systems. To train a robust image safety model, we curated training datasets of natural and synthetic images and instruction-tuned Gemma 3 to demonstrate strong performance.
|
||||
|
||||
This model was contributed by [Ryan Mullins](https://huggingface.co/RyanMullins).
|
||||
|
||||
## Usage Example
|
||||
|
||||
- ShieldGemma 2 provides a Processor that accepts a list of `images` and an optional list of `policies` as input, and constructs a batch of prompts as the product of these two lists using the provided chat template.
|
||||
- You can extend ShieldGemma's built-in in policies with the `custom_policies` argument to the Processor. Using the same key as one of the built-in policies will overwrite that policy with your custom defintion.
|
||||
- ShieldGemma 2 does not support the image cropping capabilities used by Gemma 3.
|
||||
|
||||
### Classification against Built-in Policies
|
||||
|
||||
```python
|
||||
from PIL import Image
|
||||
import requests
|
||||
from transformers import AutoProcessor, ShieldGemma2ForImageClassification
|
||||
|
||||
model_id = "google/shieldgemma-2-4b-it"
|
||||
model = ShieldGemma2ForImageClassification.from_pretrained(model_id, device_map="auto")
|
||||
processor = AutoProcessor.from_pretrained(model_id)
|
||||
|
||||
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg"
|
||||
image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
inputs = processor(images=[image], return_tensors="pt").to(model.device)
|
||||
|
||||
output = model(**inputs)
|
||||
print(output.probabilities)
|
||||
```
|
||||
|
||||
### Classification against Custom Policies
|
||||
|
||||
```python
|
||||
from PIL import Image
|
||||
import requests
|
||||
from transformers import AutoProcessor, ShieldGemma2ForImageClassification
|
||||
|
||||
model_id = "google/shieldgemma-2-4b-it"
|
||||
model = ShieldGemma2ForImageClassification.from_pretrained(model_id, device_map="auto")
|
||||
processor = AutoProcessor.from_pretrained(model_id)
|
||||
|
||||
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg"
|
||||
image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
custom_policies = {
|
||||
"key_a": "descrition_a",
|
||||
"key_b": "descrition_b",
|
||||
}
|
||||
|
||||
inputs = processor(
|
||||
images=[image],
|
||||
custom_policies=custom_policies,
|
||||
policies=["dangerous", "key_a", "key_b"],
|
||||
return_tensors="pt",
|
||||
).to(model.device)
|
||||
|
||||
output = model(**inputs)
|
||||
print(output.probabilities)
|
||||
```
|
||||
|
||||
|
||||
## ShieldGemma2Processor
|
||||
|
||||
[[autodoc]] ShieldGemma2Processor
|
||||
|
||||
## ShieldGemma2Config
|
||||
|
||||
[[autodoc]] ShieldGemma2Config
|
||||
|
||||
## ShieldGemma2ForImageClassification
|
||||
|
||||
[[autodoc]] ShieldGemma2ForImageClassification
|
||||
- forward
|
||||
@ -14,349 +14,104 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# T5
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="TensorFlow" src="https://img.shields.io/badge/TensorFlow-FF6F00?style=flat&logo=tensorflow&logoColor=white">
|
||||
<img alt="Flax" src="https://img.shields.io/badge/Flax-29a79b.svg?style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAC0AAAAtCAMAAAANxBKoAAAC7lBMVEUAAADg5vYHPVgAoJH+/v76+v39/f9JbLP///9+AIgAnY3///+mcqzt8fXy9fgkXa3Ax9709fr+///9/f8qXq49qp5AaLGMwrv8/P0eW60VWawxYq8yqJzG2dytt9Wyu9elzci519Lf3O3S2efY3OrY0+Xp7PT///////+dqNCexMc6Z7AGpJeGvbenstPZ5ejQ1OfJzOLa7ejh4+/r8fT29vpccbklWK8PVa0AS6ghW63O498vYa+lsdKz1NDRt9Kw1c672tbD3tnAxt7R6OHp5vDe7OrDyuDn6vLl6/EAQKak0MgATakkppo3ZK/Bz9y8w9yzu9jey97axdvHzeG21NHH4trTwthKZrVGZLSUSpuPQJiGAI+GAI8SWKydycLL4d7f2OTi1+S9xNzL0ePT6OLGzeEAo5U0qJw/aLEAo5JFa7JBabEAp5Y4qZ2QxLyKmsm3kL2xoMOehrRNb7RIbbOZgrGre68AUqwAqZqNN5aKJ5N/lMq+qsd8kMa4pcWzh7muhLMEV69juq2kbKqgUaOTR5uMMZWLLZSGAI5VAIdEAH+ovNDHuNCnxcy3qcaYx8K8msGplrx+wLahjbYdXrV6vbMvYK9DrZ8QrZ8tqJuFms+Sos6sw8ecy8RffsNVeMCvmb43aLltv7Q4Y7EZWK4QWa1gt6meZKUdr6GOAZVeA4xPAISyveLUwtivxtKTpNJ2jcqfvcltiMiwwcfAoMVxhL+Kx7xjdrqTe60tsaNQs6KaRKACrJ6UTZwkqpqTL5pkHY4AloSgsd2ptNXPvNOOncuxxsqFl8lmg8apt8FJcr9EbryGxLqlkrkrY7dRa7ZGZLQ5t6iXUZ6PPpgVpZeJCJFKAIGareTa0+KJod3H0deY2M+esM25usmYu8d2zsJOdcBVvrCLbqcAOaaHaKQAMaScWqKBXqCXMJ2RHpiLF5NmJZAdAHN2kta11dKu1M+DkcZLdb+Mcql3TppyRJdzQ5ZtNZNlIY+DF4+voCOQAAAAZ3RSTlMABAT+MEEJ/RH+/TP+Zlv+pUo6Ifz8+fco/fz6+evr39S9nJmOilQaF/7+/f38+smmoYp6b1T+/v7++vj189zU0tDJxsGzsrKSfv34+Pf27dDOysG9t6+n/vv6+vr59uzr1tG+tZ6Qg9Ym3QAABR5JREFUSMeNlVVUG1EQhpcuxEspXqS0SKEtxQp1d3d332STTRpIQhIISQgJhODu7lAoDoUCpe7u7u7+1puGpqnCPOyZvffbOXPm/PsP9JfQgyCC+tmTABTOcbxDz/heENS7/1F+9nhvkHePG0wNDLbGWwdXL+rbLWvpmZHXD8+gMfBjTh+aSe6Gnn7lwQIOTR0c8wfX3PWgv7avbdKwf/ZoBp1Gp/PvuvXW3vw5ib7emnTW4OR+3D4jB9vjNJ/7gNvfWWeH/TO/JyYrsiKCRjVEZA3UB+96kON+DxOQ/NLE8PE5iUYgIXjFnCOlxEQMaSGVxjg4gxOnEycGz8bptuNjVx08LscIgrzH3umcn+KKtiBIyvzOO2O99aAdR8cF19oZalnCtvREUw79tCd5sow1g1UKM6kXqUx4T8wsi3sTjJ3yzDmmhenLXLpo8u45eG5y4Vvbk6kkC4LLtJMowkSQxmk4ggVJEG+7c6QpHT8vvW9X7/o7+3ELmiJi2mEzZJiz8cT6TBlanBk70cB5GGIGC1gRDdZ00yADLW1FL6gqhtvNXNG5S9gdSrk4M1qu7JAsmYshzDS4peoMrU/gT7qQdqYGZaYhxZmVbGJAm/CS/HloWyhRUlknQ9KYcExTwS80d3VNOxUZJpITYyspl0LbhArhpZCD9cRWEQuhYkNGMHToQ/2Cs6swJlb39CsllxdXX6IUKh/H5jbnSsPKjgmoaFQ1f8wRLR0UnGE/RcDEjj2jXG1WVTwUs8+zxfcrVO+vSsuOpVKxCfYZiQ0/aPKuxQbQ8lIz+DClxC8u+snlcJ7Yr1z1JPqUH0V+GDXbOwAib931Y4Imaq0NTIXPXY+N5L18GJ37SVWu+hwXff8l72Ds9XuwYIBaXPq6Shm4l+Vl/5QiOlV+uTk6YR9PxKsI9xNJny31ygK1e+nIRC1N97EGkFPI+jCpiHe5PCEy7oWqWSwRrpOvhFzcbTWMbm3ZJAOn1rUKpYIt/lDhW/5RHHteeWFN60qo98YJuoq1nK3uW5AabyspC1BcIEpOhft+SZAShYoLSvnmSfnYADUERP5jJn2h5XtsgCRuhYQqAvwTwn33+YWEKUI72HX5AtfSAZDe8F2DtPPm77afhl0EkthzuCQU0BWApgQIH9+KB0JhopMM7bJrdTRoleM2JAVNMyPF+wdoaz+XJpGoVAQ7WXUkcV7gT3oUZyi/ISIJAVKhgNp+4b4veCFhYVJw4locdSjZCp9cPUhLF9EZ3KKzURepMEtCDPP3VcWFx4UIiZIklIpFNfHpdEafIF2aRmOcrUmjohbT2WUllbmRvgfbythbQO3222fpDJoufaQPncYYuqoGtUEsCJZL6/3PR5b4syeSjZMQG/T2maGANlXT2v8S4AULWaUkCxfLyW8iW4kdka+nEMjxpL2NCwsYNBp+Q61PF43zyDg9Bm9+3NNySn78jMZUUkumqE4Gp7JmFOdP1vc8PpRrzj9+wPinCy8K1PiJ4aYbnTYpCCbDkBSbzhu2QJ1Gd82t8jI8TH51+OzvXoWbnXUOBkNW+0mWFwGcGOUVpU81/n3TOHb5oMt2FgYGjzau0Nif0Ss7Q3XB33hjjQHjHA5E5aOyIQc8CBrLdQSs3j92VG+3nNEjbkbdbBr9zm04ruvw37vh0QKOdeGIkckc80fX3KH/h7PT4BOjgCty8VZ5ux1MoO5Cf5naca2LAsEgehI+drX8o/0Nu+W0m6K/I9gGPd/dfx/EN/wN62AhsBWuAAAAAElFTkSuQmCC
|
||||
">
|
||||
<div style="float: right;">
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="TensorFlow" src="https://img.shields.io/badge/TensorFlow-FF6F00?style=flat&logo=tensorflow&logoColor=white">
|
||||
<img alt="Flax" src="https://img.shields.io/badge/Flax-29a79b.svg?style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAC0AAAAtCAMAAAANxBKoAAAC7lBMVEUAAADg5vYHPVgAoJH+/v76+v39/f9JbLP///9+AIgAnY3///+mcqzt8fXy9fgkXa3Ax9709fr+///9/f8qXq49qp5AaLGMwrv8/P0eW60VWawxYq8yqJzG2dytt9Wyu9elzci519Lf3O3S2efY3OrY0+Xp7PT///////+dqNCexMc6Z7AGpJeGvbenstPZ5ejQ1OfJzOLa7ejh4+/r8fT29vpccbklWK8PVa0AS6ghW63O498vYa+lsdKz1NDRt9Kw1c672tbD3tnAxt7R6OHp5vDe7OrDyuDn6vLl6/EAQKak0MgATakkppo3ZK/Bz9y8w9yzu9jey97axdvHzeG21NHH4trTwthKZrVGZLSUSpuPQJiGAI+GAI8SWKydycLL4d7f2OTi1+S9xNzL0ePT6OLGzeEAo5U0qJw/aLEAo5JFa7JBabEAp5Y4qZ2QxLyKmsm3kL2xoMOehrRNb7RIbbOZgrGre68AUqwAqZqNN5aKJ5N/lMq+qsd8kMa4pcWzh7muhLMEV69juq2kbKqgUaOTR5uMMZWLLZSGAI5VAIdEAH+ovNDHuNCnxcy3qcaYx8K8msGplrx+wLahjbYdXrV6vbMvYK9DrZ8QrZ8tqJuFms+Sos6sw8ecy8RffsNVeMCvmb43aLltv7Q4Y7EZWK4QWa1gt6meZKUdr6GOAZVeA4xPAISyveLUwtivxtKTpNJ2jcqfvcltiMiwwcfAoMVxhL+Kx7xjdrqTe60tsaNQs6KaRKACrJ6UTZwkqpqTL5pkHY4AloSgsd2ptNXPvNOOncuxxsqFl8lmg8apt8FJcr9EbryGxLqlkrkrY7dRa7ZGZLQ5t6iXUZ6PPpgVpZeJCJFKAIGareTa0+KJod3H0deY2M+esM25usmYu8d2zsJOdcBVvrCLbqcAOaaHaKQAMaScWqKBXqCXMJ2RHpiLF5NmJZAdAHN2kta11dKu1M+DkcZLdb+Mcql3TppyRJdzQ5ZtNZNlIY+DF4+voCOQAAAAZ3RSTlMABAT+MEEJ/RH+/TP+Zlv+pUo6Ifz8+fco/fz6+evr39S9nJmOilQaF/7+/f38+smmoYp6b1T+/v7++vj189zU0tDJxsGzsrKSfv34+Pf27dDOysG9t6+n/vv6+vr59uzr1tG+tZ6Qg9Ym3QAABR5JREFUSMeNlVVUG1EQhpcuxEspXqS0SKEtxQp1d3d332STTRpIQhIISQgJhODu7lAoDoUCpe7u7u7+1puGpqnCPOyZvffbOXPm/PsP9JfQgyCC+tmTABTOcbxDz/heENS7/1F+9nhvkHePG0wNDLbGWwdXL+rbLWvpmZHXD8+gMfBjTh+aSe6Gnn7lwQIOTR0c8wfX3PWgv7avbdKwf/ZoBp1Gp/PvuvXW3vw5ib7emnTW4OR+3D4jB9vjNJ/7gNvfWWeH/TO/JyYrsiKCRjVEZA3UB+96kON+DxOQ/NLE8PE5iUYgIXjFnCOlxEQMaSGVxjg4gxOnEycGz8bptuNjVx08LscIgrzH3umcn+KKtiBIyvzOO2O99aAdR8cF19oZalnCtvREUw79tCd5sow1g1UKM6kXqUx4T8wsi3sTjJ3yzDmmhenLXLpo8u45eG5y4Vvbk6kkC4LLtJMowkSQxmk4ggVJEG+7c6QpHT8vvW9X7/o7+3ELmiJi2mEzZJiz8cT6TBlanBk70cB5GGIGC1gRDdZ00yADLW1FL6gqhtvNXNG5S9gdSrk4M1qu7JAsmYshzDS4peoMrU/gT7qQdqYGZaYhxZmVbGJAm/CS/HloWyhRUlknQ9KYcExTwS80d3VNOxUZJpITYyspl0LbhArhpZCD9cRWEQuhYkNGMHToQ/2Cs6swJlb39CsllxdXX6IUKh/H5jbnSsPKjgmoaFQ1f8wRLR0UnGE/RcDEjj2jXG1WVTwUs8+zxfcrVO+vSsuOpVKxCfYZiQ0/aPKuxQbQ8lIz+DClxC8u+snlcJ7Yr1z1JPqUH0V+GDXbOwAib931Y4Imaq0NTIXPXY+N5L18GJ37SVWu+hwXff8l72Ds9XuwYIBaXPq6Shm4l+Vl/5QiOlV+uTk6YR9PxKsI9xNJny31ygK1e+nIRC1N97EGkFPI+jCpiHe5PCEy7oWqWSwRrpOvhFzcbTWMbm3ZJAOn1rUKpYIt/lDhW/5RHHteeWFN60qo98YJuoq1nK3uW5AabyspC1BcIEpOhft+SZAShYoLSvnmSfnYADUERP5jJn2h5XtsgCRuhYQqAvwTwn33+YWEKUI72HX5AtfSAZDe8F2DtPPm77afhl0EkthzuCQU0BWApgQIH9+KB0JhopMM7bJrdTRoleM2JAVNMyPF+wdoaz+XJpGoVAQ7WXUkcV7gT3oUZyi/ISIJAVKhgNp+4b4veCFhYVJw4locdSjZCp9cPUhLF9EZ3KKzURepMEtCDPP3VcWFx4UIiZIklIpFNfHpdEafIF2aRmOcrUmjohbT2WUllbmRvgfbythbQO3222fpDJoufaQPncYYuqoGtUEsCJZL6/3PR5b4syeSjZMQG/T2maGANlXT2v8S4AULWaUkCxfLyW8iW4kdka+nEMjxpL2NCwsYNBp+Q61PF43zyDg9Bm9+3NNySn78jMZUUkumqE4Gp7JmFOdP1vc8PpRrzj9+wPinCy8K1PiJ4aYbnTYpCCbDkBSbzhu2QJ1Gd82t8jI8TH51+OzvXoWbnXUOBkNW+0mWFwGcGOUVpU81/n3TOHb5oMt2FgYGjzau0Nif0Ss7Q3XB33hjjQHjHA5E5aOyIQc8CBrLdQSs3j92VG+3nNEjbkbdbBr9zm04ruvw37vh0QKOdeGIkckc80fX3KH/h7PT4BOjgCty8VZ5ux1MoO5Cf5naca2LAsEgehI+drX8o/0Nu+W0m6K/I9gGPd/dfx/EN/wN62AhsBWuAAAAAElFTkSuQmCC
|
||||
">
|
||||
</div>
|
||||
</div>
|
||||
|
||||
## Overview
|
||||
# T5
|
||||
|
||||
The T5 model was presented in [Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer](https://arxiv.org/pdf/1910.10683.pdf) by [Colin Raffel](https://huggingface.co/craffel), Noam Shazeer, [Adam Roberts](https://huggingface.co/adarob), Katherine Lee, Sharan Narang,
|
||||
Michael Matena, Yanqi Zhou, Wei Li, [Peter J. Liu](https://huggingface.co/peterjliu).
|
||||
[T5](https://huggingface.co/papers/1910.10683) is a encoder-decoder transformer available in a range of sizes from 60M to 11B parameters. It is designed to handle a wide range of NLP tasks by treating them all as text-to-text problems. This eliminates the need for task-specific architectures because T5 converts every NLP task into a text generation task.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
To formulate every task as text generation, each task is prepended with a task-specific prefix (e.g., translate English to German: ..., summarize: ...). This enables T5 to handle tasks like translation, summarization, question answering, and more.
|
||||
|
||||
*Transfer learning, where a model is first pre-trained on a data-rich task before being fine-tuned on a downstream
|
||||
task, has emerged as a powerful technique in natural language processing (NLP). The effectiveness of transfer learning
|
||||
has given rise to a diversity of approaches, methodology, and practice. In this paper, we explore the landscape of
|
||||
transfer learning techniques for NLP by introducing a unified framework that converts every language problem into a
|
||||
text-to-text format. Our systematic study compares pretraining objectives, architectures, unlabeled datasets, transfer
|
||||
approaches, and other factors on dozens of language understanding tasks. By combining the insights from our exploration
|
||||
with scale and our new "Colossal Clean Crawled Corpus", we achieve state-of-the-art results on many benchmarks covering
|
||||
summarization, question answering, text classification, and more. To facilitate future work on transfer learning for
|
||||
NLP, we release our dataset, pre-trained models, and code.*
|
||||
You can find all official T5 checkpoints under the [T5](https://huggingface.co/collections/google/t5-release-65005e7c520f8d7b4d037918) collection.
|
||||
|
||||
All checkpoints can be found on the [hub](https://huggingface.co/models?search=t5).
|
||||
> [!TIP]
|
||||
> Click on the T5 models in the right sidebar for more examples of how to apply T5 to different language tasks.
|
||||
|
||||
This model was contributed by [thomwolf](https://huggingface.co/thomwolf). The original code can be found [here](https://github.com/google-research/text-to-text-transfer-transformer).
|
||||
The example below demonstrates how to generate text with [`Pipeline`], [`AutoModel`], and how to translate with T5 from the command line.
|
||||
|
||||
## Usage tips
|
||||
<hfoptions id="usage">
|
||||
<hfoption id="Pipeline">
|
||||
|
||||
- T5 is an encoder-decoder model pre-trained on a multi-task mixture of unsupervised and supervised tasks and for which
|
||||
each task is converted into a text-to-text format. T5 works well on a variety of tasks out-of-the-box by prepending a
|
||||
different prefix to the input corresponding to each task, e.g., for translation: *translate English to German: ...*,
|
||||
for summarization: *summarize: ...*.
|
||||
- The pretraining includes both supervised and self-supervised training. Supervised training is conducted on downstream tasks provided by the GLUE and SuperGLUE benchmarks (converting them into text-to-text tasks as explained above).
|
||||
- Self-supervised training uses corrupted tokens, by randomly removing 15% of the tokens and replacing them with individual sentinel tokens (if several consecutive tokens are marked for removal, the whole group is replaced with a single sentinel token). The input of the encoder is the corrupted sentence, the input of the decoder is the original sentence and the target is then the dropped out tokens delimited by their sentinel tokens.
|
||||
```py
|
||||
import torch
|
||||
from transformers import pipeline
|
||||
|
||||
- T5 uses relative scalar embeddings. Encoder input padding can be done on the left and on the right.
|
||||
|
||||
- See the [training](#training), [inference](#inference) and [resources](#resources) sections below for all details regarding usage.
|
||||
|
||||
T5 comes in different sizes:
|
||||
|
||||
- [google-t5/t5-small](https://huggingface.co/google-t5/t5-small)
|
||||
|
||||
- [google-t5/t5-base](https://huggingface.co/google-t5/t5-base)
|
||||
|
||||
- [google-t5/t5-large](https://huggingface.co/google-t5/t5-large)
|
||||
|
||||
- [google-t5/t5-3b](https://huggingface.co/google-t5/t5-3b)
|
||||
|
||||
- [google-t5/t5-11b](https://huggingface.co/google-t5/t5-11b).
|
||||
|
||||
Based on the original T5 model, Google has released some follow-up works:
|
||||
|
||||
- **T5v1.1**: T5v1.1 is an improved version of T5 with some architectural tweaks, and is pre-trained on C4 only without
|
||||
mixing in the supervised tasks. Refer to the documentation of T5v1.1 which can be found [here](t5v1.1).
|
||||
|
||||
- **mT5**: mT5 is a multilingual T5 model. It is pre-trained on the mC4 corpus, which includes 101 languages. Refer to
|
||||
the documentation of mT5 which can be found [here](mt5).
|
||||
|
||||
- **byT5**: byT5 is a T5 model pre-trained on byte sequences rather than SentencePiece subword token sequences. Refer
|
||||
to the documentation of byT5 which can be found [here](byt5).
|
||||
|
||||
- **UL2**: UL2 is a T5 like model pretrained on various denoising objectives
|
||||
|
||||
- **Flan-T5**: Flan is a pretraining methods that is based on prompting. The Flan-T5 are T5 models trained on the Flan collection of
|
||||
datasets which include: `taskmaster2`, `djaym7/wiki_dialog`, `deepmind/code_contests`, `lambada`, `gsm8k`, `aqua_rat`, `esnli`, `quasc` and `qed`.
|
||||
|
||||
- **FLan-UL2** : the UL2 model finetuned using the "Flan" prompt tuning and dataset collection.
|
||||
|
||||
- **UMT5**: UmT5 is a multilingual T5 model trained on an improved and refreshed mC4 multilingual corpus, 29 trillion characters across 107 language, using a new sampling method, UniMax. Refer to
|
||||
the documentation of mT5 which can be found [here](umt5).
|
||||
|
||||
## Training
|
||||
|
||||
T5 is an encoder-decoder model and converts all NLP problems into a text-to-text format. It is trained using teacher
|
||||
forcing. This means that for training, we always need an input sequence and a corresponding target sequence. The input
|
||||
sequence is fed to the model using `input_ids`. The target sequence is shifted to the right, i.e., prepended by a
|
||||
start-sequence token and fed to the decoder using the `decoder_input_ids`. In teacher-forcing style, the target
|
||||
sequence is then appended by the EOS token and corresponds to the `labels`. The PAD token is hereby used as the
|
||||
start-sequence token. T5 can be trained / fine-tuned both in a supervised and unsupervised fashion.
|
||||
|
||||
One can use [`T5ForConditionalGeneration`] (or the Tensorflow/Flax variant), which includes the
|
||||
language modeling head on top of the decoder.
|
||||
|
||||
- Unsupervised denoising training
|
||||
|
||||
In this setup, spans of the input sequence are masked by so-called sentinel tokens (*a.k.a* unique mask tokens) and
|
||||
the output sequence is formed as a concatenation of the same sentinel tokens and the *real* masked tokens. Each
|
||||
sentinel token represents a unique mask token for this sentence and should start with `<extra_id_0>`,
|
||||
`<extra_id_1>`, ... up to `<extra_id_99>`. As a default, 100 sentinel tokens are available in
|
||||
[`T5Tokenizer`].
|
||||
|
||||
For instance, the sentence "The cute dog walks in the park" with the masks put on "cute dog" and "the" should be
|
||||
processed as follows:
|
||||
|
||||
```python
|
||||
>>> from transformers import T5Tokenizer, T5ForConditionalGeneration
|
||||
|
||||
>>> tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-small")
|
||||
>>> model = T5ForConditionalGeneration.from_pretrained("google-t5/t5-small")
|
||||
|
||||
>>> input_ids = tokenizer("The <extra_id_0> walks in <extra_id_1> park", return_tensors="pt").input_ids
|
||||
>>> labels = tokenizer("<extra_id_0> cute dog <extra_id_1> the <extra_id_2>", return_tensors="pt").input_ids
|
||||
|
||||
>>> # the forward function automatically creates the correct decoder_input_ids
|
||||
>>> loss = model(input_ids=input_ids, labels=labels).loss
|
||||
>>> loss.item()
|
||||
3.7837
|
||||
pipeline = pipeline(
|
||||
task="text2text-generation",
|
||||
model="google-t5/t5-base",
|
||||
torch_dtype=torch.float16,
|
||||
device=0
|
||||
)
|
||||
pipeline("translate English to French: The weather is nice today.")
|
||||
```
|
||||
|
||||
If you're interested in pre-training T5 on a new corpus, check out the [run_t5_mlm_flax.py](https://github.com/huggingface/transformers/tree/main/examples/flax/language-modeling) script in the Examples
|
||||
directory.
|
||||
</hfoption>
|
||||
<hfoption id="AutoModel">
|
||||
|
||||
- Supervised training
|
||||
```py
|
||||
import torch
|
||||
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
||||
|
||||
In this setup, the input sequence and output sequence are a standard sequence-to-sequence input-output mapping.
|
||||
Suppose that we want to fine-tune the model for translation for example, and we have a training example: the input
|
||||
sequence "The house is wonderful." and output sequence "Das Haus ist wunderbar.", then they should be prepared for
|
||||
the model as follows:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"google-t5/t5-base"
|
||||
)
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained(
|
||||
"google-t5/t5-base",
|
||||
torch_dtype=torch.float16,
|
||||
device_map="auto"
|
||||
)
|
||||
|
||||
```python
|
||||
>>> from transformers import T5Tokenizer, T5ForConditionalGeneration
|
||||
input_ids = tokenizer("translate English to French: The weather is nice today.", return_tensors="pt").to("cuda")
|
||||
|
||||
>>> tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-small")
|
||||
>>> model = T5ForConditionalGeneration.from_pretrained("google-t5/t5-small")
|
||||
|
||||
>>> input_ids = tokenizer("translate English to German: The house is wonderful.", return_tensors="pt").input_ids
|
||||
>>> labels = tokenizer("Das Haus ist wunderbar.", return_tensors="pt").input_ids
|
||||
|
||||
>>> # the forward function automatically creates the correct decoder_input_ids
|
||||
>>> loss = model(input_ids=input_ids, labels=labels).loss
|
||||
>>> loss.item()
|
||||
0.2542
|
||||
output = model.generate(**input_ids, cache_implementation="static")
|
||||
print(tokenizer.decode(output[0], skip_special_tokens=True))
|
||||
```
|
||||
|
||||
As you can see, only 2 inputs are required for the model in order to compute a loss: `input_ids` (which are the
|
||||
`input_ids` of the encoded input sequence) and `labels` (which are the `input_ids` of the encoded
|
||||
target sequence). The model will automatically create the `decoder_input_ids` based on the `labels`, by
|
||||
shifting them one position to the right and prepending the `config.decoder_start_token_id`, which for T5 is
|
||||
equal to 0 (i.e. the id of the pad token). Also note the task prefix: we prepend the input sequence with 'translate
|
||||
English to German: ' before encoding it. This will help in improving the performance, as this task prefix was used
|
||||
during T5's pre-training.
|
||||
</hfoption>
|
||||
<hfoption id="transformers-cli">
|
||||
|
||||
However, the example above only shows a single training example. In practice, one trains deep learning models in
|
||||
batches. This entails that we must pad/truncate examples to the same length. For encoder-decoder models, one
|
||||
typically defines a `max_source_length` and `max_target_length`, which determine the maximum length of the
|
||||
input and output sequences respectively (otherwise they are truncated). These should be carefully set depending on
|
||||
the task.
|
||||
|
||||
In addition, we must make sure that padding token id's of the `labels` are not taken into account by the loss
|
||||
function. In PyTorch and Tensorflow, this can be done by replacing them with -100, which is the `ignore_index`
|
||||
of the `CrossEntropyLoss`. In Flax, one can use the `decoder_attention_mask` to ignore padded tokens from
|
||||
the loss (see the [Flax summarization script](https://github.com/huggingface/transformers/tree/main/examples/flax/summarization) for details). We also pass
|
||||
`attention_mask` as additional input to the model, which makes sure that padding tokens of the inputs are
|
||||
ignored. The code example below illustrates all of this.
|
||||
|
||||
```python
|
||||
>>> from transformers import T5Tokenizer, T5ForConditionalGeneration
|
||||
>>> import torch
|
||||
|
||||
>>> tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-small")
|
||||
>>> model = T5ForConditionalGeneration.from_pretrained("google-t5/t5-small")
|
||||
|
||||
>>> # the following 2 hyperparameters are task-specific
|
||||
>>> max_source_length = 512
|
||||
>>> max_target_length = 128
|
||||
|
||||
>>> # Suppose we have the following 2 training examples:
|
||||
>>> input_sequence_1 = "Welcome to NYC"
|
||||
>>> output_sequence_1 = "Bienvenue à NYC"
|
||||
|
||||
>>> input_sequence_2 = "HuggingFace is a company"
|
||||
>>> output_sequence_2 = "HuggingFace est une entreprise"
|
||||
|
||||
>>> # encode the inputs
|
||||
>>> task_prefix = "translate English to French: "
|
||||
>>> input_sequences = [input_sequence_1, input_sequence_2]
|
||||
|
||||
>>> encoding = tokenizer(
|
||||
... [task_prefix + sequence for sequence in input_sequences],
|
||||
... padding="longest",
|
||||
... max_length=max_source_length,
|
||||
... truncation=True,
|
||||
... return_tensors="pt",
|
||||
... )
|
||||
|
||||
>>> input_ids, attention_mask = encoding.input_ids, encoding.attention_mask
|
||||
|
||||
>>> # encode the targets
|
||||
>>> target_encoding = tokenizer(
|
||||
... [output_sequence_1, output_sequence_2],
|
||||
... padding="longest",
|
||||
... max_length=max_target_length,
|
||||
... truncation=True,
|
||||
... return_tensors="pt",
|
||||
... )
|
||||
>>> labels = target_encoding.input_ids
|
||||
|
||||
>>> # replace padding token id's of the labels by -100 so it's ignored by the loss
|
||||
>>> labels[labels == tokenizer.pad_token_id] = -100
|
||||
|
||||
>>> # forward pass
|
||||
>>> loss = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels).loss
|
||||
>>> loss.item()
|
||||
0.188
|
||||
```bash
|
||||
echo -e "translate English to French: The weather is nice today." | transformers-cli run --task text2text-generation --model google-t5/t5-base --device 0
|
||||
```
|
||||
|
||||
Additional training tips:
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
- T5 models need a slightly higher learning rate than the default one set in the `Trainer` when using the AdamW
|
||||
optimizer. Typically, 1e-4 and 3e-4 work well for most problems (classification, summarization, translation, question
|
||||
answering, question generation). Note that T5 was pre-trained using the AdaFactor optimizer.
|
||||
Quantization reduces the memory burden of large models by representing the weights in a lower precision. Refer to the [Quantization](../quantization/overview) overview for more available quantization backends.
|
||||
|
||||
According to [this forum post](https://discuss.huggingface.co/t/t5-finetuning-tips/684), task prefixes matter when
|
||||
(1) doing multi-task training (2) your task is similar or related to one of the supervised tasks used in T5's
|
||||
pre-training mixture (see Appendix D of the [paper](https://arxiv.org/pdf/1910.10683.pdf) for the task prefixes
|
||||
used).
|
||||
The example below uses [torchao](../quantization/torchao) to only quantize the weights to int4.
|
||||
|
||||
If training on TPU, it is recommended to pad all examples of the dataset to the same length or make use of
|
||||
*pad_to_multiple_of* to have a small number of predefined bucket sizes to fit all examples in. Dynamically padding
|
||||
batches to the longest example is not recommended on TPU as it triggers a recompilation for every batch shape that is
|
||||
encountered during training thus significantly slowing down the training. only padding up to the longest example in a
|
||||
batch) leads to very slow training on TPU.
|
||||
```py
|
||||
# pip install torchao
|
||||
import torch
|
||||
from transformers import TorchAoConfig, AutoModelForSeq2SeqLM, AutoTokenizer
|
||||
|
||||
## Inference
|
||||
quantization_config = TorchAoConfig("int4_weight_only", group_size=128)
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained(
|
||||
"google/t5-v1_1-xl",
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
quantization_config=quantization_config
|
||||
)
|
||||
|
||||
At inference time, it is recommended to use [`~generation.GenerationMixin.generate`]. This
|
||||
method takes care of encoding the input and feeding the encoded hidden states via cross-attention layers to the decoder
|
||||
and auto-regressively generates the decoder output. Check out [this blog post](https://huggingface.co/blog/how-to-generate) to know all the details about generating text with Transformers.
|
||||
There's also [this blog post](https://huggingface.co/blog/encoder-decoder#encoder-decoder) which explains how
|
||||
generation works in general in encoder-decoder models.
|
||||
tokenizer = AutoTokenizer.from_pretrained("google/t5-v1_1-xl")
|
||||
input_ids = tokenizer("translate English to French: The weather is nice today.", return_tensors="pt").to("cuda")
|
||||
|
||||
```python
|
||||
>>> from transformers import T5Tokenizer, T5ForConditionalGeneration
|
||||
|
||||
>>> tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-small")
|
||||
>>> model = T5ForConditionalGeneration.from_pretrained("google-t5/t5-small")
|
||||
|
||||
>>> input_ids = tokenizer("translate English to German: The house is wonderful.", return_tensors="pt").input_ids
|
||||
>>> outputs = model.generate(input_ids)
|
||||
>>> print(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
||||
Das Haus ist wunderbar.
|
||||
output = model.generate(**input_ids, cache_implementation="static")
|
||||
print(tokenizer.decode(output[0], skip_special_tokens=True))
|
||||
```
|
||||
|
||||
Note that T5 uses the `pad_token_id` as the `decoder_start_token_id`, so when doing generation without using
|
||||
[`~generation.GenerationMixin.generate`], make sure you start it with the `pad_token_id`.
|
||||
## Notes
|
||||
|
||||
The example above only shows a single example. You can also do batched inference, like so:
|
||||
|
||||
```python
|
||||
>>> from transformers import T5Tokenizer, T5ForConditionalGeneration
|
||||
|
||||
>>> tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-small")
|
||||
>>> model = T5ForConditionalGeneration.from_pretrained("google-t5/t5-small")
|
||||
|
||||
>>> task_prefix = "translate English to German: "
|
||||
>>> # use different length sentences to test batching
|
||||
>>> sentences = ["The house is wonderful.", "I like to work in NYC."]
|
||||
|
||||
>>> inputs = tokenizer([task_prefix + sentence for sentence in sentences], return_tensors="pt", padding=True)
|
||||
|
||||
>>> output_sequences = model.generate(
|
||||
... input_ids=inputs["input_ids"],
|
||||
... attention_mask=inputs["attention_mask"],
|
||||
... do_sample=False, # disable sampling to test if batching affects output
|
||||
... )
|
||||
|
||||
>>> print(tokenizer.batch_decode(output_sequences, skip_special_tokens=True))
|
||||
['Das Haus ist wunderbar.', 'Ich arbeite gerne in NYC.']
|
||||
```
|
||||
|
||||
Because T5 has been trained with the span-mask denoising objective,
|
||||
it can be used to predict the sentinel (masked-out) tokens during inference.
|
||||
The predicted tokens will then be placed between the sentinel tokens.
|
||||
|
||||
```python
|
||||
>>> from transformers import T5Tokenizer, T5ForConditionalGeneration
|
||||
|
||||
>>> tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-small")
|
||||
>>> model = T5ForConditionalGeneration.from_pretrained("google-t5/t5-small")
|
||||
|
||||
>>> input_ids = tokenizer("The <extra_id_0> walks in <extra_id_1> park", return_tensors="pt").input_ids
|
||||
|
||||
>>> sequence_ids = model.generate(input_ids)
|
||||
>>> sequences = tokenizer.batch_decode(sequence_ids)
|
||||
>>> sequences
|
||||
['<pad> <extra_id_0> park offers <extra_id_1> the <extra_id_2> park.</s>']
|
||||
```
|
||||
|
||||
## Performance
|
||||
|
||||
If you'd like a faster training and inference performance, install [NVIDIA APEX](https://github.com/NVIDIA/apex#quick-start) for NVIDIA GPUs, or [ROCm APEX](https://github.com/ROCmSoftwarePlatform/apex) for AMD GPUs and then the model will automatically use `apex.normalization.FusedRMSNorm` instead of `T5LayerNorm`. The former uses an optimized fused kernel which is several times faster than the latter.
|
||||
|
||||
|
||||
## Resources
|
||||
|
||||
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with T5. If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we'll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource.
|
||||
|
||||
<PipelineTag pipeline="text-classification"/>
|
||||
|
||||
- A notebook for how to [finetune T5 for classification and multiple choice](https://colab.research.google.com/github/patil-suraj/exploring-T5/blob/master/t5_fine_tuning.ipynb).
|
||||
- A notebook for how to [finetune T5 for sentiment span extraction](https://colab.research.google.com/github/enzoampil/t5-intro/blob/master/t5_qa_training_pytorch_span_extraction.ipynb). 🌎
|
||||
|
||||
<PipelineTag pipeline="token-classification"/>
|
||||
|
||||
- A notebook for how to [finetune T5 for named entity recognition](https://colab.research.google.com/drive/1obr78FY_cBmWY5ODViCmzdY6O1KB65Vc?usp=sharing). 🌎
|
||||
|
||||
<PipelineTag pipeline="text-generation"/>
|
||||
|
||||
- A notebook for [Finetuning CodeT5 for generating docstrings from Ruby code](https://colab.research.google.com/github/NielsRogge/Transformers-Tutorials/blob/master/T5/Fine_tune_CodeT5_for_generating_docstrings_from_Ruby_code.ipynb).
|
||||
|
||||
<PipelineTag pipeline="summarization"/>
|
||||
|
||||
- A notebook to [Finetune T5-base-dutch to perform Dutch abstractive summarization on a TPU](https://colab.research.google.com/github/NielsRogge/Transformers-Tutorials/blob/master/T5/Fine_tuning_Dutch_T5_base_on_CNN_Daily_Mail_for_summarization_(on_TPU_using_HuggingFace_Accelerate).ipynb).
|
||||
- A notebook for how to [finetune T5 for summarization in PyTorch and track experiments with WandB](https://colab.research.google.com/github/abhimishra91/transformers-tutorials/blob/master/transformers_summarization_wandb.ipynb#scrollTo=OKRpFvYhBauC). 🌎
|
||||
- A blog post on [Distributed Training: Train BART/T5 for Summarization using 🤗 Transformers and Amazon SageMaker](https://huggingface.co/blog/sagemaker-distributed-training-seq2seq).
|
||||
- [`T5ForConditionalGeneration`] is supported by this [example script](https://github.com/huggingface/transformers/tree/main/examples/pytorch/summarization) and [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/summarization.ipynb).
|
||||
- [`TFT5ForConditionalGeneration`] is supported by this [example script](https://github.com/huggingface/transformers/tree/main/examples/tensorflow/summarization) and [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/summarization-tf.ipynb).
|
||||
- [`FlaxT5ForConditionalGeneration`] is supported by this [example script](https://github.com/huggingface/transformers/tree/main/examples/flax/summarization).
|
||||
- [Summarization](https://huggingface.co/course/chapter7/5?fw=pt#summarization) chapter of the 🤗 Hugging Face course.
|
||||
- [Summarization task guide](../tasks/summarization)
|
||||
|
||||
<PipelineTag pipeline="fill-mask"/>
|
||||
|
||||
- [`FlaxT5ForConditionalGeneration`] is supported by this [example script](https://github.com/huggingface/transformers/tree/main/examples/flax/language-modeling#t5-like-span-masked-language-modeling) for training T5 with a span-masked language model objective. The script also shows how to train a T5 tokenizer. [`FlaxT5ForConditionalGeneration`] is also supported by this [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/masked_language_modeling_flax.ipynb).
|
||||
|
||||
<PipelineTag pipeline="translation"/>
|
||||
|
||||
- [`T5ForConditionalGeneration`] is supported by this [example script](https://github.com/huggingface/transformers/tree/main/examples/pytorch/translation) and [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/translation.ipynb).
|
||||
- [`TFT5ForConditionalGeneration`] is supported by this [example script](https://github.com/huggingface/transformers/tree/main/examples/tensorflow/translation) and [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/translation-tf.ipynb).
|
||||
- [Translation task guide](../tasks/translation)
|
||||
|
||||
<PipelineTag pipeline="question-answering"/>
|
||||
|
||||
- A notebook on how to [finetune T5 for question answering with TensorFlow 2](https://colab.research.google.com/github/snapthat/TF-T5-text-to-text/blob/master/snapthatT5/notebooks/TF-T5-Datasets%20Training.ipynb). 🌎
|
||||
- A notebook on how to [finetune T5 for question answering on a TPU](https://colab.research.google.com/github/patil-suraj/exploring-T5/blob/master/T5_on_TPU.ipynb#scrollTo=QLGiFCDqvuil).
|
||||
|
||||
🚀 **Deploy**
|
||||
- A blog post on how to deploy [T5 11B for inference for less than $500](https://www.philschmid.de/deploy-t5-11b).
|
||||
- You can pad the encoder inputs on the left or right because T5 uses relative scalar embeddings.
|
||||
- T5 models need a slightly higher learning rate than the default used in [`Trainer`]. Typically, values of `1e-4` and `3e-4` work well for most tasks.
|
||||
|
||||
## T5Config
|
||||
|
||||
|
||||
@ -18,6 +18,7 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
|
||||
|
||||
@ -14,143 +14,83 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# Vision Transformer (ViT)
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="TensorFlow" src="https://img.shields.io/badge/TensorFlow-FF6F00?style=flat&logo=tensorflow&logoColor=white">
|
||||
<img alt="Flax" src="https://img.shields.io/badge/Flax-29a79b.svg?style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAC0AAAAtCAMAAAANxBKoAAAC7lBMVEUAAADg5vYHPVgAoJH+/v76+v39/f9JbLP///9+AIgAnY3///+mcqzt8fXy9fgkXa3Ax9709fr+///9/f8qXq49qp5AaLGMwrv8/P0eW60VWawxYq8yqJzG2dytt9Wyu9elzci519Lf3O3S2efY3OrY0+Xp7PT///////+dqNCexMc6Z7AGpJeGvbenstPZ5ejQ1OfJzOLa7ejh4+/r8fT29vpccbklWK8PVa0AS6ghW63O498vYa+lsdKz1NDRt9Kw1c672tbD3tnAxt7R6OHp5vDe7OrDyuDn6vLl6/EAQKak0MgATakkppo3ZK/Bz9y8w9yzu9jey97axdvHzeG21NHH4trTwthKZrVGZLSUSpuPQJiGAI+GAI8SWKydycLL4d7f2OTi1+S9xNzL0ePT6OLGzeEAo5U0qJw/aLEAo5JFa7JBabEAp5Y4qZ2QxLyKmsm3kL2xoMOehrRNb7RIbbOZgrGre68AUqwAqZqNN5aKJ5N/lMq+qsd8kMa4pcWzh7muhLMEV69juq2kbKqgUaOTR5uMMZWLLZSGAI5VAIdEAH+ovNDHuNCnxcy3qcaYx8K8msGplrx+wLahjbYdXrV6vbMvYK9DrZ8QrZ8tqJuFms+Sos6sw8ecy8RffsNVeMCvmb43aLltv7Q4Y7EZWK4QWa1gt6meZKUdr6GOAZVeA4xPAISyveLUwtivxtKTpNJ2jcqfvcltiMiwwcfAoMVxhL+Kx7xjdrqTe60tsaNQs6KaRKACrJ6UTZwkqpqTL5pkHY4AloSgsd2ptNXPvNOOncuxxsqFl8lmg8apt8FJcr9EbryGxLqlkrkrY7dRa7ZGZLQ5t6iXUZ6PPpgVpZeJCJFKAIGareTa0+KJod3H0deY2M+esM25usmYu8d2zsJOdcBVvrCLbqcAOaaHaKQAMaScWqKBXqCXMJ2RHpiLF5NmJZAdAHN2kta11dKu1M+DkcZLdb+Mcql3TppyRJdzQ5ZtNZNlIY+DF4+voCOQAAAAZ3RSTlMABAT+MEEJ/RH+/TP+Zlv+pUo6Ifz8+fco/fz6+evr39S9nJmOilQaF/7+/f38+smmoYp6b1T+/v7++vj189zU0tDJxsGzsrKSfv34+Pf27dDOysG9t6+n/vv6+vr59uzr1tG+tZ6Qg9Ym3QAABR5JREFUSMeNlVVUG1EQhpcuxEspXqS0SKEtxQp1d3d332STTRpIQhIISQgJhODu7lAoDoUCpe7u7u7+1puGpqnCPOyZvffbOXPm/PsP9JfQgyCC+tmTABTOcbxDz/heENS7/1F+9nhvkHePG0wNDLbGWwdXL+rbLWvpmZHXD8+gMfBjTh+aSe6Gnn7lwQIOTR0c8wfX3PWgv7avbdKwf/ZoBp1Gp/PvuvXW3vw5ib7emnTW4OR+3D4jB9vjNJ/7gNvfWWeH/TO/JyYrsiKCRjVEZA3UB+96kON+DxOQ/NLE8PE5iUYgIXjFnCOlxEQMaSGVxjg4gxOnEycGz8bptuNjVx08LscIgrzH3umcn+KKtiBIyvzOO2O99aAdR8cF19oZalnCtvREUw79tCd5sow1g1UKM6kXqUx4T8wsi3sTjJ3yzDmmhenLXLpo8u45eG5y4Vvbk6kkC4LLtJMowkSQxmk4ggVJEG+7c6QpHT8vvW9X7/o7+3ELmiJi2mEzZJiz8cT6TBlanBk70cB5GGIGC1gRDdZ00yADLW1FL6gqhtvNXNG5S9gdSrk4M1qu7JAsmYshzDS4peoMrU/gT7qQdqYGZaYhxZmVbGJAm/CS/HloWyhRUlknQ9KYcExTwS80d3VNOxUZJpITYyspl0LbhArhpZCD9cRWEQuhYkNGMHToQ/2Cs6swJlb39CsllxdXX6IUKh/H5jbnSsPKjgmoaFQ1f8wRLR0UnGE/RcDEjj2jXG1WVTwUs8+zxfcrVO+vSsuOpVKxCfYZiQ0/aPKuxQbQ8lIz+DClxC8u+snlcJ7Yr1z1JPqUH0V+GDXbOwAib931Y4Imaq0NTIXPXY+N5L18GJ37SVWu+hwXff8l72Ds9XuwYIBaXPq6Shm4l+Vl/5QiOlV+uTk6YR9PxKsI9xNJny31ygK1e+nIRC1N97EGkFPI+jCpiHe5PCEy7oWqWSwRrpOvhFzcbTWMbm3ZJAOn1rUKpYIt/lDhW/5RHHteeWFN60qo98YJuoq1nK3uW5AabyspC1BcIEpOhft+SZAShYoLSvnmSfnYADUERP5jJn2h5XtsgCRuhYQqAvwTwn33+YWEKUI72HX5AtfSAZDe8F2DtPPm77afhl0EkthzuCQU0BWApgQIH9+KB0JhopMM7bJrdTRoleM2JAVNMyPF+wdoaz+XJpGoVAQ7WXUkcV7gT3oUZyi/ISIJAVKhgNp+4b4veCFhYVJw4locdSjZCp9cPUhLF9EZ3KKzURepMEtCDPP3VcWFx4UIiZIklIpFNfHpdEafIF2aRmOcrUmjohbT2WUllbmRvgfbythbQO3222fpDJoufaQPncYYuqoGtUEsCJZL6/3PR5b4syeSjZMQG/T2maGANlXT2v8S4AULWaUkCxfLyW8iW4kdka+nEMjxpL2NCwsYNBp+Q61PF43zyDg9Bm9+3NNySn78jMZUUkumqE4Gp7JmFOdP1vc8PpRrzj9+wPinCy8K1PiJ4aYbnTYpCCbDkBSbzhu2QJ1Gd82t8jI8TH51+OzvXoWbnXUOBkNW+0mWFwGcGOUVpU81/n3TOHb5oMt2FgYGjzau0Nif0Ss7Q3XB33hjjQHjHA5E5aOyIQc8CBrLdQSs3j92VG+3nNEjbkbdbBr9zm04ruvw37vh0QKOdeGIkckc80fX3KH/h7PT4BOjgCty8VZ5ux1MoO5Cf5naca2LAsEgehI+drX8o/0Nu+W0m6K/I9gGPd/dfx/EN/wN62AhsBWuAAAAAElFTkSuQmCC
|
||||
">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<div style="float: right;">
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="TensorFlow" src="https://img.shields.io/badge/TensorFlow-FF6F00?style=flat&logo=tensorflow&logoColor=white">
|
||||
<img alt="Flax" src="https://img.shields.io/badge/Flax-29a79b.svg?style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAC0AAAAtCAMAAAANxBKoAAAC7lBMVEUAAADg5vYHPVgAoJH+/v76+v39/f9JbLP///9+AIgAnY3///+mcqzt8fXy9fgkXa3Ax9709fr+///9/f8qXq49qp5AaLGMwrv8/P0eW60VWawxYq8yqJzG2dytt9Wyu9elzci519Lf3O3S2efY3OrY0+Xp7PT///////+dqNCexMc6Z7AGpJeGvbenstPZ5ejQ1OfJzOLa7ejh4+/r8fT29vpccbklWK8PVa0AS6ghW63O498vYa+lsdKz1NDRt9Kw1c672tbD3tnAxt7R6OHp5vDe7OrDyuDn6vLl6/EAQKak0MgATakkppo3ZK/Bz9y8w9yzu9jey97axdvHzeG21NHH4trTwthKZrVGZLSUSpuPQJiGAI+GAI8SWKydycLL4d7f2OTi1+S9xNzL0ePT6OLGzeEAo5U0qJw/aLEAo5JFa7JBabEAp5Y4qZ2QxLyKmsm3kL2xoMOehrRNb7RIbbOZgrGre68AUqwAqZqNN5aKJ5N/lMq+qsd8kMa4pcWzh7muhLMEV69juq2kbKqgUaOTR5uMMZWLLZSGAI5VAIdEAH+ovNDHuNCnxcy3qcaYx8K8msGplrx+wLahjbYdXrV6vbMvYK9DrZ8QrZ8tqJuFms+Sos6sw8ecy8RffsNVeMCvmb43aLltv7Q4Y7EZWK4QWa1gt6meZKUdr6GOAZVeA4xPAISyveLUwtivxtKTpNJ2jcqfvcltiMiwwcfAoMVxhL+Kx7xjdrqTe60tsaNQs6KaRKACrJ6UTZwkqpqTL5pkHY4AloSgsd2ptNXPvNOOncuxxsqFl8lmg8apt8FJcr9EbryGxLqlkrkrY7dRa7ZGZLQ5t6iXUZ6PPpgVpZeJCJFKAIGareTa0+KJod3H0deY2M+esM25usmYu8d2zsJOdcBVvrCLbqcAOaaHaKQAMaScWqKBXqCXMJ2RHpiLF5NmJZAdAHN2kta11dKu1M+DkcZLdb+Mcql3TppyRJdzQ5ZtNZNlIY+DF4+voCOQAAAAZ3RSTlMABAT+MEEJ/RH+/TP+Zlv+pUo6Ifz8+fco/fz6+evr39S9nJmOilQaF/7+/f38+smmoYp6b1T+/v7++vj189zU0tDJxsGzsrKSfv34+Pf27dDOysG9t6+n/vv6+vr59uzr1tG+tZ6Qg9Ym3QAABR5JREFUSMeNlVVUG1EQhpcuxEspXqS0SKEtxQp1d3d332STTRpIQhIISQgJhODu7lAoDoUCpe7u7u7+1puGpqnCPOyZvffbOXPm/PsP9JfQgyCC+tmTABTOcbxDz/heENS7/1F+9nhvkHePG0wNDLbGWwdXL+rbLWvpmZHXD8+gMfBjTh+aSe6Gnn7lwQIOTR0c8wfX3PWgv7avbdKwf/ZoBp1Gp/PvuvXW3vw5ib7emnTW4OR+3D4jB9vjNJ/7gNvfWWeH/TO/JyYrsiKCRjVEZA3UB+96kON+DxOQ/NLE8PE5iUYgIXjFnCOlxEQMaSGVxjg4gxOnEycGz8bptuNjVx08LscIgrzH3umcn+KKtiBIyvzOO2O99aAdR8cF19oZalnCtvREUw79tCd5sow1g1UKM6kXqUx4T8wsi3sTjJ3yzDmmhenLXLpo8u45eG5y4Vvbk6kkC4LLtJMowkSQxmk4ggVJEG+7c6QpHT8vvW9X7/o7+3ELmiJi2mEzZJiz8cT6TBlanBk70cB5GGIGC1gRDdZ00yADLW1FL6gqhtvNXNG5S9gdSrk4M1qu7JAsmYshzDS4peoMrU/gT7qQdqYGZaYhxZmVbGJAm/CS/HloWyhRUlknQ9KYcExTwS80d3VNOxUZJpITYyspl0LbhArhpZCD9cRWEQuhYkNGMHToQ/2Cs6swJlb39CsllxdXX6IUKh/H5jbnSsPKjgmoaFQ1f8wRLR0UnGE/RcDEjj2jXG1WVTwUs8+zxfcrVO+vSsuOpVKxCfYZiQ0/aPKuxQbQ8lIz+DClxC8u+snlcJ7Yr1z1JPqUH0V+GDXbOwAib931Y4Imaq0NTIXPXY+N5L18GJ37SVWu+hwXff8l72Ds9XuwYIBaXPq6Shm4l+Vl/5QiOlV+uTk6YR9PxKsI9xNJny31ygK1e+nIRC1N97EGkFPI+jCpiHe5PCEy7oWqWSwRrpOvhFzcbTWMbm3ZJAOn1rUKpYIt/lDhW/5RHHteeWFN60qo98YJuoq1nK3uW5AabyspC1BcIEpOhft+SZAShYoLSvnmSfnYADUERP5jJn2h5XtsgCRuhYQqAvwTwn33+YWEKUI72HX5AtfSAZDe8F2DtPPm77afhl0EkthzuCQU0BWApgQIH9+KB0JhopMM7bJrdTRoleM2JAVNMyPF+wdoaz+XJpGoVAQ7WXUkcV7gT3oUZyi/ISIJAVKhgNp+4b4veCFhYVJw4locdSjZCp9cPUhLF9EZ3KKzURepMEtCDPP3VcWFx4UIiZIklIpFNfHpdEafIF2aRmOcrUmjohbT2WUllbmRvgfbythbQO3222fpDJoufaQPncYYuqoGtUEsCJZL6/3PR5b4syeSjZMQG/T2maGANlXT2v8S4AULWaUkCxfLyW8iW4kdka+nEMjxpL2NCwsYNBp+Q61PF43zyDg9Bm9+3NNySn78jMZUUkumqE4Gp7JmFOdP1vc8PpRrzj9+wPinCy8K1PiJ4aYbnTYpCCbDkBSbzhu2QJ1Gd82t8jI8TH51+OzvXoWbnXUOBkNW+0mWFwGcGOUVpU81/n3TOHb5oMt2FgYGjzau0Nif0Ss7Q3XB33hjjQHjHA5E5aOyIQc8CBrLdQSs3j92VG+3nNEjbkbdbBr9zm04ruvw37vh0QKOdeGIkckc80fX3KH/h7PT4BOjgCty8VZ5ux1MoO5Cf5naca2LAsEgehI+drX8o/0Nu+W0m6K/I9gGPd/dfx/EN/wN62AhsBWuAAAAAElFTkSuQmCC
|
||||
">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
</div>
|
||||
|
||||
## Overview
|
||||
# Vision Transformer (ViT)
|
||||
|
||||
The Vision Transformer (ViT) model was proposed in [An Image is Worth 16x16 Words: Transformers for Image Recognition
|
||||
at Scale](https://arxiv.org/abs/2010.11929) by Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk
|
||||
Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob
|
||||
Uszkoreit, Neil Houlsby. It's the first paper that successfully trains a Transformer encoder on ImageNet, attaining
|
||||
very good results compared to familiar convolutional architectures.
|
||||
[Vision Transformer (ViT)](https://huggingface.co/papers/2010.11929) is a transformer adapted for computer vision tasks. An image is split into smaller fixed-sized patches which are treated as a sequence of tokens, similar to words for NLP tasks. ViT requires less resources to pretrain compared to convolutional architectures and its performance on large datasets can be transferred to smaller downstream tasks.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
You can find all the original ViT checkpoints under the [Google](https://huggingface.co/google?search_models=vit) organization.
|
||||
|
||||
*While the Transformer architecture has become the de-facto standard for natural language processing tasks, its
|
||||
applications to computer vision remain limited. In vision, attention is either applied in conjunction with
|
||||
convolutional networks, or used to replace certain components of convolutional networks while keeping their overall
|
||||
structure in place. We show that this reliance on CNNs is not necessary and a pure transformer applied directly to
|
||||
sequences of image patches can perform very well on image classification tasks. When pre-trained on large amounts of
|
||||
data and transferred to multiple mid-sized or small image recognition benchmarks (ImageNet, CIFAR-100, VTAB, etc.),
|
||||
Vision Transformer (ViT) attains excellent results compared to state-of-the-art convolutional networks while requiring
|
||||
substantially fewer computational resources to train.*
|
||||
> [!TIP]
|
||||
> Click on the ViT models in the right sidebar for more examples of how to apply ViT to different computer vision tasks.
|
||||
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/vit_architecture.jpg"
|
||||
alt="drawing" width="600"/>
|
||||
The example below demonstrates how to classify an image with [`Pipeline`] or the [`AutoModel`] class.
|
||||
|
||||
<small> ViT architecture. Taken from the <a href="https://arxiv.org/abs/2010.11929">original paper.</a> </small>
|
||||
<hfoptions id="usage">
|
||||
<hfoption id="Pipeline">
|
||||
|
||||
Following the original Vision Transformer, some follow-up works have been made:
|
||||
```py
|
||||
import torch
|
||||
from transformers import pipeline
|
||||
|
||||
- [DeiT](deit) (Data-efficient Image Transformers) by Facebook AI. DeiT models are distilled vision transformers.
|
||||
The authors of DeiT also released more efficiently trained ViT models, which you can directly plug into [`ViTModel`] or
|
||||
[`ViTForImageClassification`]. There are 4 variants available (in 3 different sizes): *facebook/deit-tiny-patch16-224*,
|
||||
*facebook/deit-small-patch16-224*, *facebook/deit-base-patch16-224* and *facebook/deit-base-patch16-384*. Note that one should
|
||||
use [`DeiTImageProcessor`] in order to prepare images for the model.
|
||||
|
||||
- [BEiT](beit) (BERT pre-training of Image Transformers) by Microsoft Research. BEiT models outperform supervised pre-trained
|
||||
vision transformers using a self-supervised method inspired by BERT (masked image modeling) and based on a VQ-VAE.
|
||||
|
||||
- DINO (a method for self-supervised training of Vision Transformers) by Facebook AI. Vision Transformers trained using
|
||||
the DINO method show very interesting properties not seen with convolutional models. They are capable of segmenting
|
||||
objects, without having ever been trained to do so. DINO checkpoints can be found on the [hub](https://huggingface.co/models?other=dino).
|
||||
|
||||
- [MAE](vit_mae) (Masked Autoencoders) by Facebook AI. By pre-training Vision Transformers to reconstruct pixel values for a high portion
|
||||
(75%) of masked patches (using an asymmetric encoder-decoder architecture), the authors show that this simple method outperforms
|
||||
supervised pre-training after fine-tuning.
|
||||
|
||||
This model was contributed by [nielsr](https://huggingface.co/nielsr). The original code (written in JAX) can be
|
||||
found [here](https://github.com/google-research/vision_transformer).
|
||||
|
||||
Note that we converted the weights from Ross Wightman's [timm library](https://github.com/rwightman/pytorch-image-models),
|
||||
who already converted the weights from JAX to PyTorch. Credits go to him!
|
||||
|
||||
## Usage tips
|
||||
|
||||
- To feed images to the Transformer encoder, each image is split into a sequence of fixed-size non-overlapping patches,
|
||||
which are then linearly embedded. A [CLS] token is added to serve as representation of an entire image, which can be
|
||||
used for classification. The authors also add absolute position embeddings, and feed the resulting sequence of
|
||||
vectors to a standard Transformer encoder.
|
||||
- As the Vision Transformer expects each image to be of the same size (resolution), one can use
|
||||
[`ViTImageProcessor`] to resize (or rescale) and normalize images for the model.
|
||||
- Both the patch resolution and image resolution used during pre-training or fine-tuning are reflected in the name of
|
||||
each checkpoint. For example, `google/vit-base-patch16-224` refers to a base-sized architecture with patch
|
||||
resolution of 16x16 and fine-tuning resolution of 224x224. All checkpoints can be found on the [hub](https://huggingface.co/models?search=vit).
|
||||
- The available checkpoints are either (1) pre-trained on [ImageNet-21k](http://www.image-net.org/) (a collection of
|
||||
14 million images and 21k classes) only, or (2) also fine-tuned on [ImageNet](http://www.image-net.org/challenges/LSVRC/2012/) (also referred to as ILSVRC 2012, a collection of 1.3 million
|
||||
images and 1,000 classes).
|
||||
- The Vision Transformer was pre-trained using a resolution of 224x224. During fine-tuning, it is often beneficial to
|
||||
use a higher resolution than pre-training [(Touvron et al., 2019)](https://arxiv.org/abs/1906.06423), [(Kolesnikov
|
||||
et al., 2020)](https://arxiv.org/abs/1912.11370). In order to fine-tune at higher resolution, the authors perform
|
||||
2D interpolation of the pre-trained position embeddings, according to their location in the original image.
|
||||
- The best results are obtained with supervised pre-training, which is not the case in NLP. The authors also performed
|
||||
an experiment with a self-supervised pre-training objective, namely masked patched prediction (inspired by masked
|
||||
language modeling). With this approach, the smaller ViT-B/16 model achieves 79.9% accuracy on ImageNet, a significant
|
||||
improvement of 2% to training from scratch, but still 4% behind supervised pre-training.
|
||||
|
||||
### Using Scaled Dot Product Attention (SDPA)
|
||||
|
||||
PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function
|
||||
encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the
|
||||
[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
|
||||
or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention)
|
||||
page for more information.
|
||||
|
||||
SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set
|
||||
`attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used.
|
||||
|
||||
```
|
||||
from transformers import ViTForImageClassification
|
||||
model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224", attn_implementation="sdpa", torch_dtype=torch.float16)
|
||||
...
|
||||
pipeline = pipeline(
|
||||
task="image-classification",
|
||||
model="google/vit-base-patch16-224",
|
||||
torch_dtype=torch.float16,
|
||||
device=0
|
||||
)
|
||||
pipeline(images="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg")
|
||||
```
|
||||
|
||||
For the best speedups, we recommend loading the model in half-precision (e.g. `torch.float16` or `torch.bfloat16`).
|
||||
</hfoption>
|
||||
<hfoption id="AutoModel">
|
||||
|
||||
On a local benchmark (A100-40GB, PyTorch 2.3.0, OS Ubuntu 22.04) with `float32` and `google/vit-base-patch16-224` model, we saw the following speedups during inference.
|
||||
```py
|
||||
import torch
|
||||
import requests
|
||||
from PIL import Image
|
||||
from transformers import AutoModelForImageClassification, AutoImageProcessor
|
||||
|
||||
| Batch size | Average inference time (ms), eager mode | Average inference time (ms), sdpa model | Speed up, Sdpa / Eager (x) |
|
||||
|--------------|-------------------------------------------|-------------------------------------------|------------------------------|
|
||||
| 1 | 7 | 6 | 1.17 |
|
||||
| 2 | 8 | 6 | 1.33 |
|
||||
| 4 | 8 | 6 | 1.33 |
|
||||
| 8 | 8 | 6 | 1.33 |
|
||||
image_processor = AutoImageProcessor.from_pretrained(
|
||||
"google/vit-base-patch16-224",
|
||||
use_fast=True,
|
||||
)
|
||||
model = AutoModelForImageClassification.from_pretrained(
|
||||
"google/vit-base-patch16-224",
|
||||
torch_dtype=torch.float16,
|
||||
device_map="auto",
|
||||
attn_implementation="sdpa"
|
||||
)
|
||||
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
|
||||
image = Image.open(requests.get(url, stream=True).raw)
|
||||
inputs = image_processor(image, return_tensors="pt").to("cuda")
|
||||
|
||||
## Resources
|
||||
with torch.no_grad():
|
||||
logits = model(**inputs).logits
|
||||
predicted_class_id = logits.argmax(dim=-1).item()
|
||||
|
||||
Demo notebooks regarding inference as well as fine-tuning ViT on custom data can be found [here](https://github.com/NielsRogge/Transformers-Tutorials/tree/master/VisionTransformer).
|
||||
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with ViT. If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we'll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource.
|
||||
class_labels = model.config.id2label
|
||||
predicted_class_label = class_labels[predicted_class_id]
|
||||
print(f"The predicted class label is: {predicted_class_label}")
|
||||
```
|
||||
|
||||
`ViTForImageClassification` is supported by:
|
||||
<PipelineTag pipeline="image-classification"/>
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
- A blog post on how to [Fine-Tune ViT for Image Classification with Hugging Face Transformers](https://huggingface.co/blog/fine-tune-vit)
|
||||
- A blog post on [Image Classification with Hugging Face Transformers and `Keras`](https://www.philschmid.de/image-classification-huggingface-transformers-keras)
|
||||
- A notebook on [Fine-tuning for Image Classification with Hugging Face Transformers](https://github.com/huggingface/notebooks/blob/main/examples/image_classification.ipynb)
|
||||
- A notebook on how to [Fine-tune the Vision Transformer on CIFAR-10 with the Hugging Face Trainer](https://github.com/NielsRogge/Transformers-Tutorials/blob/master/VisionTransformer/Fine_tuning_the_Vision_Transformer_on_CIFAR_10_with_the_%F0%9F%A4%97_Trainer.ipynb)
|
||||
- A notebook on how to [Fine-tune the Vision Transformer on CIFAR-10 with PyTorch Lightning](https://github.com/NielsRogge/Transformers-Tutorials/blob/master/VisionTransformer/Fine_tuning_the_Vision_Transformer_on_CIFAR_10_with_PyTorch_Lightning.ipynb)
|
||||
## Notes
|
||||
|
||||
⚗️ Optimization
|
||||
|
||||
- A blog post on how to [Accelerate Vision Transformer (ViT) with Quantization using Optimum](https://www.philschmid.de/optimizing-vision-transformer)
|
||||
|
||||
⚡️ Inference
|
||||
|
||||
- A notebook on [Quick demo: Vision Transformer (ViT) by Google Brain](https://github.com/NielsRogge/Transformers-Tutorials/blob/master/VisionTransformer/Quick_demo_of_HuggingFace_version_of_Vision_Transformer_inference.ipynb)
|
||||
|
||||
🚀 Deploy
|
||||
|
||||
- A blog post on [Deploying Tensorflow Vision Models in Hugging Face with TF Serving](https://huggingface.co/blog/tf-serving-vision)
|
||||
- A blog post on [Deploying Hugging Face ViT on Vertex AI](https://huggingface.co/blog/deploy-vertex-ai)
|
||||
- A blog post on [Deploying Hugging Face ViT on Kubernetes with TF Serving](https://huggingface.co/blog/deploy-tfserving-kubernetes)
|
||||
- The best results are obtained with supervised pretraining, and during fine-tuning, it may be better to use images with a resolution higher than 224x224.
|
||||
- Use [`ViTImageProcessorFast`] to resize (or rescale) and normalize images to the expected size.
|
||||
- The patch and image resolution are reflected in the checkpoint name. For example, google/vit-base-patch16-224, is the **base-sized** architecture with a patch resolution of 16x16 and fine-tuning resolution of 224x224.
|
||||
|
||||
## ViTConfig
|
||||
|
||||
@ -171,9 +111,6 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h
|
||||
[[autodoc]] ViTImageProcessorFast
|
||||
- preprocess
|
||||
|
||||
<frameworkcontent>
|
||||
<pt>
|
||||
|
||||
## ViTModel
|
||||
|
||||
[[autodoc]] ViTModel
|
||||
@ -189,9 +126,6 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h
|
||||
[[autodoc]] ViTForImageClassification
|
||||
- forward
|
||||
|
||||
</pt>
|
||||
<tf>
|
||||
|
||||
## TFViTModel
|
||||
|
||||
[[autodoc]] TFViTModel
|
||||
@ -202,9 +136,6 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h
|
||||
[[autodoc]] TFViTForImageClassification
|
||||
- call
|
||||
|
||||
</tf>
|
||||
<jax>
|
||||
|
||||
## FlaxVitModel
|
||||
|
||||
[[autodoc]] FlaxViTModel
|
||||
@ -214,6 +145,3 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h
|
||||
|
||||
[[autodoc]] FlaxViTForImageClassification
|
||||
- __call__
|
||||
|
||||
</jax>
|
||||
</frameworkcontent>
|
||||
|
||||
@ -19,6 +19,7 @@ rendered properly in your Markdown viewer.
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="TensorFlow" src="https://img.shields.io/badge/TensorFlow-FF6F00?style=flat&logo=tensorflow&logoColor=white">
|
||||
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
|
||||
|
||||
@ -18,6 +18,7 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
|
||||
|
||||
@ -14,6 +14,7 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
|
||||
|
||||
@ -14,152 +14,86 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# Whisper
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="TensorFlow" src="https://img.shields.io/badge/TensorFlow-FF6F00?style=flat&logo=tensorflow&logoColor=white">
|
||||
<img alt="Flax" src="https://img.shields.io/badge/Flax-29a79b.svg?style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAC0AAAAtCAMAAAANxBKoAAAC7lBMVEUAAADg5vYHPVgAoJH+/v76+v39/f9JbLP///9+AIgAnY3///+mcqzt8fXy9fgkXa3Ax9709fr+///9/f8qXq49qp5AaLGMwrv8/P0eW60VWawxYq8yqJzG2dytt9Wyu9elzci519Lf3O3S2efY3OrY0+Xp7PT///////+dqNCexMc6Z7AGpJeGvbenstPZ5ejQ1OfJzOLa7ejh4+/r8fT29vpccbklWK8PVa0AS6ghW63O498vYa+lsdKz1NDRt9Kw1c672tbD3tnAxt7R6OHp5vDe7OrDyuDn6vLl6/EAQKak0MgATakkppo3ZK/Bz9y8w9yzu9jey97axdvHzeG21NHH4trTwthKZrVGZLSUSpuPQJiGAI+GAI8SWKydycLL4d7f2OTi1+S9xNzL0ePT6OLGzeEAo5U0qJw/aLEAo5JFa7JBabEAp5Y4qZ2QxLyKmsm3kL2xoMOehrRNb7RIbbOZgrGre68AUqwAqZqNN5aKJ5N/lMq+qsd8kMa4pcWzh7muhLMEV69juq2kbKqgUaOTR5uMMZWLLZSGAI5VAIdEAH+ovNDHuNCnxcy3qcaYx8K8msGplrx+wLahjbYdXrV6vbMvYK9DrZ8QrZ8tqJuFms+Sos6sw8ecy8RffsNVeMCvmb43aLltv7Q4Y7EZWK4QWa1gt6meZKUdr6GOAZVeA4xPAISyveLUwtivxtKTpNJ2jcqfvcltiMiwwcfAoMVxhL+Kx7xjdrqTe60tsaNQs6KaRKACrJ6UTZwkqpqTL5pkHY4AloSgsd2ptNXPvNOOncuxxsqFl8lmg8apt8FJcr9EbryGxLqlkrkrY7dRa7ZGZLQ5t6iXUZ6PPpgVpZeJCJFKAIGareTa0+KJod3H0deY2M+esM25usmYu8d2zsJOdcBVvrCLbqcAOaaHaKQAMaScWqKBXqCXMJ2RHpiLF5NmJZAdAHN2kta11dKu1M+DkcZLdb+Mcql3TppyRJdzQ5ZtNZNlIY+DF4+voCOQAAAAZ3RSTlMABAT+MEEJ/RH+/TP+Zlv+pUo6Ifz8+fco/fz6+evr39S9nJmOilQaF/7+/f38+smmoYp6b1T+/v7++vj189zU0tDJxsGzsrKSfv34+Pf27dDOysG9t6+n/vv6+vr59uzr1tG+tZ6Qg9Ym3QAABR5JREFUSMeNlVVUG1EQhpcuxEspXqS0SKEtxQp1d3d332STTRpIQhIISQgJhODu7lAoDoUCpe7u7u7+1puGpqnCPOyZvffbOXPm/PsP9JfQgyCC+tmTABTOcbxDz/heENS7/1F+9nhvkHePG0wNDLbGWwdXL+rbLWvpmZHXD8+gMfBjTh+aSe6Gnn7lwQIOTR0c8wfX3PWgv7avbdKwf/ZoBp1Gp/PvuvXW3vw5ib7emnTW4OR+3D4jB9vjNJ/7gNvfWWeH/TO/JyYrsiKCRjVEZA3UB+96kON+DxOQ/NLE8PE5iUYgIXjFnCOlxEQMaSGVxjg4gxOnEycGz8bptuNjVx08LscIgrzH3umcn+KKtiBIyvzOO2O99aAdR8cF19oZalnCtvREUw79tCd5sow1g1UKM6kXqUx4T8wsi3sTjJ3yzDmmhenLXLpo8u45eG5y4Vvbk6kkC4LLtJMowkSQxmk4ggVJEG+7c6QpHT8vvW9X7/o7+3ELmiJi2mEzZJiz8cT6TBlanBk70cB5GGIGC1gRDdZ00yADLW1FL6gqhtvNXNG5S9gdSrk4M1qu7JAsmYshzDS4peoMrU/gT7qQdqYGZaYhxZmVbGJAm/CS/HloWyhRUlknQ9KYcExTwS80d3VNOxUZJpITYyspl0LbhArhpZCD9cRWEQuhYkNGMHToQ/2Cs6swJlb39CsllxdXX6IUKh/H5jbnSsPKjgmoaFQ1f8wRLR0UnGE/RcDEjj2jXG1WVTwUs8+zxfcrVO+vSsuOpVKxCfYZiQ0/aPKuxQbQ8lIz+DClxC8u+snlcJ7Yr1z1JPqUH0V+GDXbOwAib931Y4Imaq0NTIXPXY+N5L18GJ37SVWu+hwXff8l72Ds9XuwYIBaXPq6Shm4l+Vl/5QiOlV+uTk6YR9PxKsI9xNJny31ygK1e+nIRC1N97EGkFPI+jCpiHe5PCEy7oWqWSwRrpOvhFzcbTWMbm3ZJAOn1rUKpYIt/lDhW/5RHHteeWFN60qo98YJuoq1nK3uW5AabyspC1BcIEpOhft+SZAShYoLSvnmSfnYADUERP5jJn2h5XtsgCRuhYQqAvwTwn33+YWEKUI72HX5AtfSAZDe8F2DtPPm77afhl0EkthzuCQU0BWApgQIH9+KB0JhopMM7bJrdTRoleM2JAVNMyPF+wdoaz+XJpGoVAQ7WXUkcV7gT3oUZyi/ISIJAVKhgNp+4b4veCFhYVJw4locdSjZCp9cPUhLF9EZ3KKzURepMEtCDPP3VcWFx4UIiZIklIpFNfHpdEafIF2aRmOcrUmjohbT2WUllbmRvgfbythbQO3222fpDJoufaQPncYYuqoGtUEsCJZL6/3PR5b4syeSjZMQG/T2maGANlXT2v8S4AULWaUkCxfLyW8iW4kdka+nEMjxpL2NCwsYNBp+Q61PF43zyDg9Bm9+3NNySn78jMZUUkumqE4Gp7JmFOdP1vc8PpRrzj9+wPinCy8K1PiJ4aYbnTYpCCbDkBSbzhu2QJ1Gd82t8jI8TH51+OzvXoWbnXUOBkNW+0mWFwGcGOUVpU81/n3TOHb5oMt2FgYGjzau0Nif0Ss7Q3XB33hjjQHjHA5E5aOyIQc8CBrLdQSs3j92VG+3nNEjbkbdbBr9zm04ruvw37vh0QKOdeGIkckc80fX3KH/h7PT4BOjgCty8VZ5ux1MoO5Cf5naca2LAsEgehI+drX8o/0Nu+W0m6K/I9gGPd/dfx/EN/wN62AhsBWuAAAAAElFTkSuQmCC
|
||||
">
|
||||
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<div style="float: right;">
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="TensorFlow" src="https://img.shields.io/badge/TensorFlow-FF6F00?style=flat&logo=tensorflow&logoColor=white">
|
||||
<img alt="Flax" src="https://img.shields.io/badge/Flax-29a79b.svg?style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAC0AAAAtCAMAAAANxBKoAAAC7lBMVEUAAADg5vYHPVgAoJH+/v76+v39/f9JbLP///9+AIgAnY3///+mcqzt8fXy9fgkXa3Ax9709fr+///9/f8qXq49qp5AaLGMwrv8/P0eW60VWawxYq8yqJzG2dytt9Wyu9elzci519Lf3O3S2efY3OrY0+Xp7PT///////+dqNCexMc6Z7AGpJeGvbenstPZ5ejQ1OfJzOLa7ejh4+/r8fT29vpccbklWK8PVa0AS6ghW63O498vYa+lsdKz1NDRt9Kw1c672tbD3tnAxt7R6OHp5vDe7OrDyuDn6vLl6/EAQKak0MgATakkppo3ZK/Bz9y8w9yzu9jey97axdvHzeG21NHH4trTwthKZrVGZLSUSpuPQJiGAI+GAI8SWKydycLL4d7f2OTi1+S9xNzL0ePT6OLGzeEAo5U0qJw/aLEAo5JFa7JBabEAp5Y4qZ2QxLyKmsm3kL2xoMOehrRNb7RIbbOZgrGre68AUqwAqZqNN5aKJ5N/lMq+qsd8kMa4pcWzh7muhLMEV69juq2kbKqgUaOTR5uMMZWLLZSGAI5VAIdEAH+ovNDHuNCnxcy3qcaYx8K8msGplrx+wLahjbYdXrV6vbMvYK9DrZ8QrZ8tqJuFms+Sos6sw8ecy8RffsNVeMCvmb43aLltv7Q4Y7EZWK4QWa1gt6meZKUdr6GOAZVeA4xPAISyveLUwtivxtKTpNJ2jcqfvcltiMiwwcfAoMVxhL+Kx7xjdrqTe60tsaNQs6KaRKACrJ6UTZwkqpqTL5pkHY4AloSgsd2ptNXPvNOOncuxxsqFl8lmg8apt8FJcr9EbryGxLqlkrkrY7dRa7ZGZLQ5t6iXUZ6PPpgVpZeJCJFKAIGareTa0+KJod3H0deY2M+esM25usmYu8d2zsJOdcBVvrCLbqcAOaaHaKQAMaScWqKBXqCXMJ2RHpiLF5NmJZAdAHN2kta11dKu1M+DkcZLdb+Mcql3TppyRJdzQ5ZtNZNlIY+DF4+voCOQAAAAZ3RSTlMABAT+MEEJ/RH+/TP+Zlv+pUo6Ifz8+fco/fz6+evr39S9nJmOilQaF/7+/f38+smmoYp6b1T+/v7++vj189zU0tDJxsGzsrKSfv34+Pf27dDOysG9t6+n/vv6+vr59uzr1tG+tZ6Qg9Ym3QAABR5JREFUSMeNlVVUG1EQhpcuxEspXqS0SKEtxQp1d3d332STTRpIQhIISQgJhODu7lAoDoUCpe7u7u7+1puGpqnCPOyZvffbOXPm/PsP9JfQgyCC+tmTABTOcbxDz/heENS7/1F+9nhvkHePG0wNDLbGWwdXL+rbLWvpmZHXD8+gMfBjTh+aSe6Gnn7lwQIOTR0c8wfX3PWgv7avbdKwf/ZoBp1Gp/PvuvXW3vw5ib7emnTW4OR+3D4jB9vjNJ/7gNvfWWeH/TO/JyYrsiKCRjVEZA3UB+96kON+DxOQ/NLE8PE5iUYgIXjFnCOlxEQMaSGVxjg4gxOnEycGz8bptuNjVx08LscIgrzH3umcn+KKtiBIyvzOO2O99aAdR8cF19oZalnCtvREUw79tCd5sow1g1UKM6kXqUx4T8wsi3sTjJ3yzDmmhenLXLpo8u45eG5y4Vvbk6kkC4LLtJMowkSQxmk4ggVJEG+7c6QpHT8vvW9X7/o7+3ELmiJi2mEzZJiz8cT6TBlanBk70cB5GGIGC1gRDdZ00yADLW1FL6gqhtvNXNG5S9gdSrk4M1qu7JAsmYshzDS4peoMrU/gT7qQdqYGZaYhxZmVbGJAm/CS/HloWyhRUlknQ9KYcExTwS80d3VNOxUZJpITYyspl0LbhArhpZCD9cRWEQuhYkNGMHToQ/2Cs6swJlb39CsllxdXX6IUKh/H5jbnSsPKjgmoaFQ1f8wRLR0UnGE/RcDEjj2jXG1WVTwUs8+zxfcrVO+vSsuOpVKxCfYZiQ0/aPKuxQbQ8lIz+DClxC8u+snlcJ7Yr1z1JPqUH0V+GDXbOwAib931Y4Imaq0NTIXPXY+N5L18GJ37SVWu+hwXff8l72Ds9XuwYIBaXPq6Shm4l+Vl/5QiOlV+uTk6YR9PxKsI9xNJny31ygK1e+nIRC1N97EGkFPI+jCpiHe5PCEy7oWqWSwRrpOvhFzcbTWMbm3ZJAOn1rUKpYIt/lDhW/5RHHteeWFN60qo98YJuoq1nK3uW5AabyspC1BcIEpOhft+SZAShYoLSvnmSfnYADUERP5jJn2h5XtsgCRuhYQqAvwTwn33+YWEKUI72HX5AtfSAZDe8F2DtPPm77afhl0EkthzuCQU0BWApgQIH9+KB0JhopMM7bJrdTRoleM2JAVNMyPF+wdoaz+XJpGoVAQ7WXUkcV7gT3oUZyi/ISIJAVKhgNp+4b4veCFhYVJw4locdSjZCp9cPUhLF9EZ3KKzURepMEtCDPP3VcWFx4UIiZIklIpFNfHpdEafIF2aRmOcrUmjohbT2WUllbmRvgfbythbQO3222fpDJoufaQPncYYuqoGtUEsCJZL6/3PR5b4syeSjZMQG/T2maGANlXT2v8S4AULWaUkCxfLyW8iW4kdka+nEMjxpL2NCwsYNBp+Q61PF43zyDg9Bm9+3NNySn78jMZUUkumqE4Gp7JmFOdP1vc8PpRrzj9+wPinCy8K1PiJ4aYbnTYpCCbDkBSbzhu2QJ1Gd82t8jI8TH51+OzvXoWbnXUOBkNW+0mWFwGcGOUVpU81/n3TOHb5oMt2FgYGjzau0Nif0Ss7Q3XB33hjjQHjHA5E5aOyIQc8CBrLdQSs3j92VG+3nNEjbkbdbBr9zm04ruvw37vh0QKOdeGIkckc80fX3KH/h7PT4BOjgCty8VZ5ux1MoO5Cf5naca2LAsEgehI+drX8o/0Nu+W0m6K/I9gGPd/dfx/EN/wN62AhsBWuAAAAAElFTkSuQmCC
|
||||
">
|
||||
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
</div>
|
||||
|
||||
## Overview
|
||||
# Whisper
|
||||
|
||||
The Whisper model was proposed in [Robust Speech Recognition via Large-Scale Weak Supervision](https://cdn.openai.com/papers/whisper.pdf) by Alec Radford, Jong Wook Kim, Tao Xu, Greg Brockman, Christine McLeavey, Ilya Sutskever.
|
||||
[Whisper](https://hf.co/papers/2212.04356) is a encoder-decoder (sequence-to-sequence) transformer pretrained on 680,000 hours of labeled audio data. This amount of pretraining data enables zero-shot performance on audio tasks in English and many other languages. The decoder allows Whisper to map the encoders learned speech representations to useful outputs, such as text, without additional fine-tuning. Whisper just works out of the box.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
You can find all the original Whisper checkpoints under the [Whisper](https://huggingface.co/collections/openai/whisper-release-6501bba2cf999715fd953013) collection.
|
||||
|
||||
*We study the capabilities of speech processing systems trained simply to predict large amounts of transcripts of audio on the internet. When scaled to 680,000 hours of multilingual and multitask supervision, the resulting models generalize well to standard benchmarks and are often competitive with prior fully supervised results but in a zeroshot transfer setting without the need for any finetuning. When compared to humans, the models approach their accuracy and robustness. We are releasing models and inference code to serve as a foundation for further work on robust speech processing.*
|
||||
> [!TIP]
|
||||
> Click on the Whisper models in the right sidebar for more examples of how to apply Whisper to different audio tasks.
|
||||
|
||||
This model was contributed by [Arthur Zucker](https://huggingface.co/ArthurZ). The Tensorflow version of this model was contributed by [amyeroberts](https://huggingface.co/amyeroberts).
|
||||
The original code can be found [here](https://github.com/openai/whisper).
|
||||
The example below demonstrates how to automatically transcribe speech into text with [`Pipeline`] or the [`AutoModel`] class.
|
||||
|
||||
## Quick usage
|
||||
|
||||
You can run Whisper in less than 4 lines of code and transcribe in less than a minute!
|
||||
|
||||
```python
|
||||
# pip install transformers torch
|
||||
<hfoptions id="usage">
|
||||
<hfoption id="Pipeline">
|
||||
|
||||
```py
|
||||
import torch
|
||||
from transformers import pipeline
|
||||
|
||||
whisper = pipeline("automatic-speech-recognition", "openai/whisper-large-v3", torch_dtype=torch.float16, device="cuda:0")
|
||||
|
||||
transcription = whisper("<audio_file.mp3>")
|
||||
|
||||
print(transcription["text"])
|
||||
pipeline = pipeline(
|
||||
task="automatic-speech-recognition",
|
||||
model="openai/whisper-large-v3-turbo",
|
||||
torch_dtype=torch.float16,
|
||||
device=0
|
||||
)
|
||||
pipeline("https://huggingface.co/datasets/Narsil/asr_dummy/resolve/main/mlk.flac")
|
||||
```
|
||||
|
||||
Voila! You can swap the model with any [Whisper checkpoints](https://huggingface.co/models?other=whisper&sort=downloads) on the Hugging Face Hub with the same pipeline based on your needs.
|
||||
</hfoption>
|
||||
<hfoption id="AutoModel">
|
||||
|
||||
Bonus: You can replace `"cuda"` with `"mps"` to make it seamlessly work on Macs.
|
||||
```py
|
||||
# pip install datasets
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoProcessor, WhisperForConditionalGeneration
|
||||
|
||||
## Usage tips
|
||||
processor = AutoProcessor.from_pretrained(
|
||||
"openai/whisper-large-v3-turbo",
|
||||
)
|
||||
model = WhisperForConditionalGeneration.from_pretrained(
|
||||
"openai/whisper-large-v3-turbo",
|
||||
torch_dtype=torch.float16,
|
||||
device_map="auto",
|
||||
attn_implementation="sdpa"
|
||||
).to("cuda")
|
||||
|
||||
- The model usually performs well without requiring any finetuning.
|
||||
- The architecture follows a classic encoder-decoder architecture, which means that it relies on the [`~generation.GenerationMixin.generate`] function for inference.
|
||||
- One can use [`WhisperProcessor`] to prepare audio for the model, and decode the predicted ID's back into text.
|
||||
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||
audio_sample = ds[0]["audio"]
|
||||
|
||||
- To convert the model and the processor, we recommend using the following:
|
||||
input_features = processor(
|
||||
audio_sample["array"],
|
||||
sampling_rate=audio_sample["sampling_rate"],
|
||||
return_tensors="pt"
|
||||
).input_features
|
||||
input_features = input_features.to("cuda", dtype=torch.float16)
|
||||
|
||||
```bash
|
||||
python src/transformers/models/whisper/convert_openai_to_hf.py --checkpoint_path "" --pytorch_dump_folder_path "Arthur/whisper-3" --convert_preprocessor True
|
||||
```
|
||||
The script will automatically determine all necessary parameters from the OpenAI checkpoint. A `tiktoken` library needs to be installed
|
||||
to perform the conversion of the OpenAI tokenizer to the `tokenizers` version.
|
||||
|
||||
## Inference
|
||||
|
||||
Here is a step-by-step guide to transcribing an audio sample using a pre-trained Whisper model:
|
||||
|
||||
```python
|
||||
>>> from datasets import load_dataset
|
||||
>>> from transformers import WhisperProcessor, WhisperForConditionalGeneration
|
||||
|
||||
>>> # Select an audio file and read it:
|
||||
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||
>>> audio_sample = ds[0]["audio"]
|
||||
|
||||
>>> # Load the Whisper model in Hugging Face format:
|
||||
>>> processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
|
||||
>>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
|
||||
|
||||
>>> # Use the model and processor to transcribe the audio:
|
||||
>>> input_features = processor(
|
||||
... audio_sample["array"], sampling_rate=audio_sample["sampling_rate"], return_tensors="pt"
|
||||
... ).input_features
|
||||
|
||||
>>> # Generate token ids
|
||||
>>> predicted_ids = model.generate(input_features)
|
||||
|
||||
>>> # Decode token ids to text
|
||||
>>> transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
|
||||
|
||||
>>> transcription[0]
|
||||
' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.'
|
||||
predicted_ids = model.generate(input_features, cache_implementation="static")
|
||||
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
|
||||
transcription[0]
|
||||
```
|
||||
|
||||
Whisper is compatible with the following optimisations for both short and long-form generation:
|
||||
- [PyTorch Scaled Dot Product Attention (SDPA)](../perf_infer_gpu_one#pytorch-scaled-dot-product-attention): flash attention and memory-efficient attention kernels. Enabled by default for `torch>=2.1.1`.
|
||||
- [Flash Attention 2](../perf_infer_gpu_one#flashattention-2): improved implementation of flash attention through better parallelism and work partitioning.
|
||||
- [torch.compile](../llm_optims#static-kv-cache-and-torchcompile): JIT-compile the forward pass to dispatch to efficient fused kernels.
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
As an example, the following codesnippet enables SDPA and `torch.compile` for up to 5x faster inference:
|
||||
## Notes
|
||||
|
||||
```python
|
||||
>>> from datasets import load_dataset
|
||||
>>> from transformers import WhisperProcessor, WhisperForConditionalGeneration
|
||||
|
||||
>>> # Select an audio file and read it:
|
||||
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||
>>> audio_sample = ds[0]["audio"]
|
||||
|
||||
>>> # Load the Whisper model with SDPA attention
|
||||
>>> processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
|
||||
>>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en", attn_implementation="sdpa")
|
||||
|
||||
>>> # Enable static cache and compile the forward pass
|
||||
>>> model.generation_config.cache_implementation = "static"
|
||||
>>> model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
|
||||
|
||||
>>> # Use the model and processor to transcribe the audio:
|
||||
>>> input_features = processor(
|
||||
... audio_sample["array"], sampling_rate=audio_sample["sampling_rate"], return_tensors="pt"
|
||||
... ).input_features
|
||||
|
||||
>>> # Compile the forward pass
|
||||
>>> for _ in range(2):
|
||||
>>> model.generate(input_features)
|
||||
|
||||
>>> # Generate token ids using compiled graph (fast!)
|
||||
>>> predicted_ids = model.generate(input_features)
|
||||
|
||||
>>> # Decode token ids to text
|
||||
>>> transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
|
||||
|
||||
>>> transcription[0]
|
||||
' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.'
|
||||
```
|
||||
|
||||
For more details on each optimisation, refer to the documentation linked above.
|
||||
|
||||
## Resources
|
||||
|
||||
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with Whisper. If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we'll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource.
|
||||
|
||||
- [Fine-tune Whisper](https://huggingface.co/blog/fine-tune-whisper) on your own dataset for better downstream performance.
|
||||
- [Distil-Whisper](https://huggingface.co/distil-whisper): Upto 6x faster, 2x smaller distilled Whisper models for English. We release the [model checkpoints](https://huggingface.co/distil-whisper), and [distillation code](https://github.com/huggingface/distil-whisper).
|
||||
- A fork with a script to [convert a Whisper model in Hugging Face format to OpenAI format](https://github.com/zuazo-forks/transformers/blob/convert_hf_to_openai/src/transformers/models/whisper/convert_hf_to_openai.py). 🌎
|
||||
Usage example:
|
||||
```bash
|
||||
pip install -U openai-whisper
|
||||
python convert_hf_to_openai.py \
|
||||
--checkpoint openai/whisper-tiny \
|
||||
--whisper_dump_path whisper-tiny-openai.pt
|
||||
```
|
||||
- Whisper relies on [`~GenerationMixin.generate`] for inference.
|
||||
- The [`WhisperProcessor`] can be used for preparing audio and decoding predicted ids back into text.
|
||||
|
||||
## WhisperConfig
|
||||
|
||||
@ -205,9 +139,6 @@ python convert_hf_to_openai.py \
|
||||
- batch_decode
|
||||
- decode
|
||||
|
||||
<frameworkcontent>
|
||||
<pt>
|
||||
|
||||
## WhisperModel
|
||||
|
||||
[[autodoc]] WhisperModel
|
||||
@ -230,9 +161,6 @@ python convert_hf_to_openai.py \
|
||||
[[autodoc]] WhisperForAudioClassification
|
||||
- forward
|
||||
|
||||
</pt>
|
||||
<tf>
|
||||
|
||||
## TFWhisperModel
|
||||
|
||||
[[autodoc]] TFWhisperModel
|
||||
@ -243,9 +171,6 @@ python convert_hf_to_openai.py \
|
||||
[[autodoc]] TFWhisperForConditionalGeneration
|
||||
- call
|
||||
|
||||
</tf>
|
||||
<jax>
|
||||
|
||||
## FlaxWhisperModel
|
||||
|
||||
[[autodoc]] FlaxWhisperModel
|
||||
@ -260,7 +185,3 @@ python convert_hf_to_openai.py \
|
||||
|
||||
[[autodoc]] FlaxWhisperForAudioClassification
|
||||
- __call__
|
||||
|
||||
</jax>
|
||||
</frameworkcontent>
|
||||
|
||||
|
||||
@ -18,6 +18,7 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
|
||||
|
||||
@ -78,7 +78,7 @@ class RobertaModel(BertModel):
|
||||
super().__init__(config)
|
||||
self.embeddings = RobertaEmbeddings(config)
|
||||
|
||||
|
||||
|
||||
# The model heads now only need to redefine the model inside to `RobertaModel`
|
||||
class RobertaForMaskedLM(BertForMaskedLM):
|
||||
def __init__(self, config):
|
||||
@ -546,7 +546,7 @@ This makes it very easy to switch decorators and makes it explicit that the only
|
||||
|
||||
## Docstring variables
|
||||
|
||||
If an object defined in both the modular and modeling file from which it inherits, the modular definition has precedence unless for assignments containing the pattern `DOCSTRING`. These variables are typically used in `MODEL_START_DOCSTRING` and `MODEL_INPUT_DOCSTRING` in the modeling files. They are big blocks of docstrings and the linter rewrites the names everywhere. For this reason, assignments containing the `DOCSTRING` variable always uses the definition found in the source file instead of the modular file.
|
||||
If an object defined in both the modular and modeling file from which it inherits, the modular definition has precedence unless for assignments containing the pattern `DOCSTRING`. These variables are typically used in `MODEL_START_DOCSTRING` and `MODEL_INPUT_DOCSTRING` in the modeling files. They are big blocks of docstrings and the linter rewrites the names everywhere. For this reason, assignments containing the `DOCSTRING` variable can use the definition found in the source file without copying the whole docstring, by simply setting the variable to `None` in the modular file.
|
||||
|
||||
This is very useful if you need the variable reference somewhere but you don't want to clutter the modular file with docstrings which are always the same. The example code below allows you to automatically use the same docstrings from [Mistral](./model_doc/mistral) in [Starcoder2](./model_doc/starcoder2).
|
||||
|
||||
@ -561,6 +561,8 @@ class Starcoder2Model(MistralModel):
|
||||
...
|
||||
```
|
||||
|
||||
Setting the variable to anything other than `None` will override the docstring, so that you can customize the docstrings if needed.
|
||||
|
||||
## Special naming
|
||||
|
||||
The linter automatically renames everything when inheriting from a class. For consistency, you should always use the same class name prefix when inheriting from different classes from the same file.
|
||||
@ -586,7 +588,7 @@ We detected multiple prefix names when inheriting from transformers.models.llama
|
||||
If there are automatic dependencies with a prefix, but you want another one, explicitly rename the classes locally with a `pass` class as shown in the following.
|
||||
|
||||
```py
|
||||
class Emu3TextMLP(LlamaMLP):
|
||||
class Emu3TextMLP(LlamaMLP):
|
||||
pass
|
||||
```
|
||||
|
||||
|
||||
@ -44,11 +44,6 @@ import os
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
# initialize distributed environment
|
||||
rank = int(os.environ["RANK"])
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
torch.cuda.set_device(device)
|
||||
torch.distributed.init_process_group("nccl", device_id=device)
|
||||
|
||||
# enable tensor parallelism
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
@ -59,7 +54,7 @@ model = AutoModelForCausalLM.from_pretrained(
|
||||
# prepare input tokens
|
||||
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
|
||||
prompt = "Can I help"
|
||||
inputs = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
|
||||
inputs = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)
|
||||
|
||||
# distributed run
|
||||
outputs = model(inputs)
|
||||
@ -71,6 +66,13 @@ Launch the inference script above on [torchrun](https://pytorch.org/docs/stable/
|
||||
torchrun --nproc-per-node 4 demo.py
|
||||
```
|
||||
|
||||
For CPU, please binding different socket on each rank. For example, if you are using Intel 4th Gen Xeon:
|
||||
```bash
|
||||
export OMP_NUM_THREADS=56
|
||||
numactl -C 0-55 -m 0 torchrun --nnodes=2 --node_rank=0 --master_addr="127.0.0.1" --master_port=29500 --nproc-per-node 1 demo.py & numactl -C 56-111 -m 1 torchrun --nnodes=2 --node_rank=1 --master_addr="127.0.0.1" --master_port=29500 --nproc-per-node 1 demo.py & wait
|
||||
```
|
||||
The CPU benchmark data will be released soon.
|
||||
|
||||
You can benefit from considerable speed ups for inference, especially for inputs with large batch size or long sequences.
|
||||
|
||||
For a single forward pass on [Llama](./model_doc/llama) with a sequence length of 512 and various batch sizes, you can expect the following speed ups.
|
||||
|
||||
@ -29,8 +29,8 @@ import requests
|
||||
|
||||
processor = AutoProcessor.from_pretrained("google/paligemma-3b-pt-224")
|
||||
|
||||
prompt = "answer en Where is the cow standing?"
|
||||
url = "https://huggingface.co/gv-hf/PaliGemma-test-224px-hf/resolve/main/cow_beach_1.png"
|
||||
prompt = "answer en Where is the cat standing?"
|
||||
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
|
||||
image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
inputs = processor(text=prompt, images=image, return_tensors="pt")
|
||||
|
||||
@ -20,7 +20,10 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
[LLM.int8()](https://hf.co/papers/2208.07339) is a quantization method that aims to make large language model inference more accessible without significant degradation. Unlike naive 8-bit quantization, which can result in loss of critical information and accuracy, LLM.int8() dynamically adapts to ensure sensitive components of the computation retain higher precision when needed.
|
||||
|
||||
QLoRA, or 4-bit quantization, compresses a model even further to 4-bits and inserts a small set of trainable low-rank adaptation (LoRA) weights to allowing training.
|
||||
QLoRA, or 4-bit quantization, compresses a model even further to 4-bits and inserts a small set of trainable low-rank adaptation (LoRA) weights to allowing training.
|
||||
|
||||
> **Note:** For a user-friendly quantization experience, you can use the `bitsandbytes` [community space](https://huggingface.co/spaces/bnb-community/bnb-my-repo).
|
||||
|
||||
|
||||
Run the command below to install bitsandbytes.
|
||||
|
||||
|
||||
@ -40,10 +40,20 @@ Use the Space below to help you pick a quantization method depending on your har
|
||||
| [VPTQ](./vptq) | 🔴 | 🔴 | 🟢 | 🟡 | 🔴 | 🔴 | 🟢 | 1/8 | 🔴 | 🟢 | 🟢 | https://github.com/microsoft/VPTQ |
|
||||
| [FINEGRAINED_FP8](./finegrained_fp8) | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | 🔴 | 8 | 🔴 | 🟢 | 🟢 | |
|
||||
| [SpQR](./spqr) | 🔴 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | 🟢 | 3 | 🔴 | 🟢 | 🟢 | https://github.com/Vahe1994/SpQR/ |
|
||||
| [Quark](./quark) | 🔴 | 🟢 | 🟢 | 🟢 | 🟢 | 🟢 | ? | 2/4/6/8/9/16 | 🔴 | 🔴 | 🟢 | https://quark.docs.amd.com/latest/ |
|
||||
|
||||
## Resources
|
||||
|
||||
If you are new to quantization, we recommend checking out these beginner-friendly quantization courses in collaboration with DeepLearning.AI.
|
||||
|
||||
* [Quantization Fundamentals with Hugging Face](https://www.deeplearning.ai/short-courses/quantization-fundamentals-with-hugging-face/)
|
||||
* [Quantization in Depth](https://www.deeplearning.ai/short-courses/quantization-in-depth
|
||||
* [Quantization in Depth](https://www.deeplearning.ai/short-courses/quantization-in-depth)
|
||||
|
||||
## User-Friendly Quantization Tools
|
||||
|
||||
If you are looking for a user-friendly quantization experience, you can use the following community spaces and notebooks:
|
||||
|
||||
* [Bitsandbytes Space](https://huggingface.co/spaces/bnb-community/bnb-my-repo)
|
||||
* [GGUF Space](https://huggingface.co/spaces/ggml-org/gguf-my-repo)
|
||||
* [MLX Space](https://huggingface.co/spaces/mlx-community/mlx-my-repo)
|
||||
* [AuoQuant Notebook](https://colab.research.google.com/drive/1b6nqC7UZVt8bx4MksX7s656GXPM-eWw4?usp=sharing#scrollTo=ZC9Nsr9u5WhN)
|
||||
|
||||
84
docs/source/en/quantization/quark.md
Normal file
84
docs/source/en/quantization/quark.md
Normal file
@ -0,0 +1,84 @@
|
||||
<!--Copyright 2025 Advanced Micro Devices, Inc. and The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# Quark
|
||||
|
||||
[Quark](https://quark.docs.amd.com/latest/) is a deep learning quantization toolkit designed to be agnostic to specific data types, algorithms, and hardware. Different pre-processing strategies, algorithms and data-types can be combined in Quark.
|
||||
|
||||
The PyTorch support integrated through 🤗 Transformers primarily targets AMD CPUs and GPUs, and is primarily meant to be used for evaluation purposes. For example, it is possible to use [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness) with 🤗 Transformers backend and evaluate a wide range of models quantized through Quark seamlessly.
|
||||
|
||||
Users interested in Quark can refer to its [documentation](https://quark.docs.amd.com/latest/) to get started quantizing models and using them in supported open-source libraries!
|
||||
|
||||
Although Quark has its own checkpoint / [configuration format](https://huggingface.co/amd/Llama-3.1-8B-Instruct-FP8-KV-Quark-test/blob/main/config.json#L26), the library also supports producing models with a serialization layout compliant with other quantization/runtime implementations ([AutoAWQ](https://huggingface.co/docs/transformers/quantization/awq), [native fp8 in 🤗 Transformers](https://huggingface.co/docs/transformers/quantization/finegrained_fp8)).
|
||||
|
||||
To be able to load Quark quantized models in Transformers, the library first needs to be installed:
|
||||
|
||||
```bash
|
||||
pip install amd-quark
|
||||
```
|
||||
|
||||
## Support matrix
|
||||
|
||||
Models quantized through Quark support a large range of features, that can be combined together. All quantized models independently of their configuration can seamlessly be reloaded through `PretrainedModel.from_pretrained`.
|
||||
|
||||
The table below shows a few features supported by Quark:
|
||||
|
||||
| **Feature** | **Supported subset in Quark** | |
|
||||
|---------------------------------|-----------------------------------------------------------------------------------------------------------|---|
|
||||
| Data types | int8, int4, int2, bfloat16, float16, fp8_e5m2, fp8_e4m3, fp6_e3m2, fp6_e2m3, fp4, OCP MX, MX6, MX9, bfp16 | |
|
||||
| Pre-quantization transformation | SmoothQuant, QuaRot, SpinQuant, AWQ | |
|
||||
| Quantization algorithm | GPTQ | |
|
||||
| Supported operators | ``nn.Linear``, ``nn.Conv2d``, ``nn.ConvTranspose2d``, ``nn.Embedding``, ``nn.EmbeddingBag`` | |
|
||||
| Granularity | per-tensor, per-channel, per-block, per-layer, per-layer type | |
|
||||
| KV cache | fp8 | |
|
||||
| Activation calibration | MinMax / Percentile / MSE | |
|
||||
| Quantization strategy | weight-only, static, dynamic, with or without output quantization | |
|
||||
|
||||
## Models on Hugging Face Hub
|
||||
|
||||
Public models using Quark native serialization can be found at https://huggingface.co/models?other=quark.
|
||||
|
||||
Although Quark also supports [models using `quant_method="fp8"`](https://huggingface.co/models?other=fp8) and [models using `quant_method="awq"`](https://huggingface.co/models?other=awq), Transformers loads these models rather through [AutoAWQ](https://huggingface.co/docs/transformers/quantization/awq) or uses the [native fp8 support in 🤗 Transformers](https://huggingface.co/docs/transformers/quantization/finegrained_fp8).
|
||||
|
||||
## Using Quark models in Transformers
|
||||
|
||||
Here is an example of how one can load a Quark model in Transformers:
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
model_id = "EmbeddedLLM/Llama-3.1-8B-Instruct-w_fp8_per_channel_sym"
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id)
|
||||
model = model.to("cuda")
|
||||
|
||||
print(model.model.layers[0].self_attn.q_proj)
|
||||
# QParamsLinear(
|
||||
# (weight_quantizer): ScaledRealQuantizer()
|
||||
# (input_quantizer): ScaledRealQuantizer()
|
||||
# (output_quantizer): ScaledRealQuantizer()
|
||||
# )
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
inp = tokenizer("Where is a good place to cycle around Tokyo?", return_tensors="pt")
|
||||
inp = inp.to("cuda")
|
||||
|
||||
res = model.generate(**inp, min_new_tokens=50, max_new_tokens=100)
|
||||
|
||||
print(tokenizer.batch_decode(res)[0])
|
||||
# <|begin_of_text|>Where is a good place to cycle around Tokyo? There are several places in Tokyo that are suitable for cycling, depending on your skill level and interests. Here are a few suggestions:
|
||||
# 1. Yoyogi Park: This park is a popular spot for cycling and has a wide, flat path that's perfect for beginners. You can also visit the Meiji Shrine, a famous Shinto shrine located in the park.
|
||||
# 2. Imperial Palace East Garden: This beautiful garden has a large, flat path that's perfect for cycling. You can also visit the
|
||||
```
|
||||
@ -20,18 +20,95 @@ Install torchao with the following command.
|
||||
pip install --upgrade torch torchao transformers
|
||||
```
|
||||
|
||||
torchao supports many quantization types for different data types (int4, float8, weight only, etc.), but the Transformers integration only currently supports int8 weight quantization and int8 dynamic quantization of weights.
|
||||
torchao supports many quantization types for different data types (int4, float8, weight only, etc.).
|
||||
Starting with version 0.10.0, torchao provides enhanced flexibility through the `AOBaseConfig` API, allowing for more customized quantization configurations.
|
||||
And full access to the techniques offered in the torchao library.
|
||||
|
||||
You can manually choose the quantization types and settings or automatically select the quantization types.
|
||||
|
||||
<hfoptions id="torchao">
|
||||
<hfoption id="manual">
|
||||
|
||||
|
||||
Create a [`TorchAoConfig`] and specify the quantization type and `group_size` of the weights to quantize. Set the `cache_implementation` to `"static"` to automatically [torch.compile](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html) the forward method.
|
||||
|
||||
> [!TIP]
|
||||
> Run the quantized model on a CPU by changing `device_map` to `"cpu"` and `layout` to `Int4CPULayout()`. This is only available in torchao 0.8.0+.
|
||||
|
||||
In torchao 0.10.0+, you can use the more flexible `AOBaseConfig` approach instead of string identifiers:
|
||||
|
||||
```py
|
||||
import torch
|
||||
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
|
||||
from torchao.quantization import Int4WeightOnlyConfig
|
||||
|
||||
# Using AOBaseConfig instance (torchao >= 0.10.0)
|
||||
quant_config = Int4WeightOnlyConfig(group_size=128)
|
||||
quantization_config = TorchAoConfig(quant_type=quant_config)
|
||||
|
||||
# Load and quantize the model
|
||||
quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||
"meta-llama/Meta-Llama-3-8B",
|
||||
torch_dtype="auto",
|
||||
device_map="auto",
|
||||
quantization_config=quantization_config
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B")
|
||||
input_text = "What are we having for dinner?"
|
||||
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
|
||||
|
||||
# auto-compile the quantized model with `cache_implementation="static"` to get speed up
|
||||
output = quantized_model.generate(**input_ids, max_new_tokens=10, cache_implementation="static")
|
||||
print(tokenizer.decode(output[0], skip_special_tokens=True))
|
||||
```
|
||||
|
||||
## Available Quantization Schemes
|
||||
|
||||
TorchAO provides a variety of quantization configurations:
|
||||
|
||||
- `Int4WeightOnlyConfig`
|
||||
- `Int8WeightOnlyConfig`
|
||||
- `Int8DynamicActivationInt8WeightConfig`
|
||||
- `Float8WeightOnlyConfig`
|
||||
|
||||
Each configuration can be further customized with parameters such as `group_size`, `scheme`, and `layout` to optimize for specific hardware and model architectures.
|
||||
|
||||
For a complete list of available configurations, see our [quantization API documentation](https://github.com/pytorch/ao/blob/main/torchao/quantization/quant_api.py).
|
||||
|
||||
> **⚠️ DEPRECATION WARNING**
|
||||
>
|
||||
> Starting with version 0.10.0, the string-based API for quantization configuration (e.g., `TorchAoConfig("int4_weight_only", group_size=128)`) is **deprecated** and will be removed in a future release.
|
||||
>
|
||||
> Please use the new `AOBaseConfig`-based approach instead:
|
||||
>
|
||||
> ```python
|
||||
> # Old way (deprecated)
|
||||
> quantization_config = TorchAoConfig("int4_weight_only", group_size=128)
|
||||
>
|
||||
> # New way (recommended)
|
||||
> from torchao.quantization import Int4WeightOnlyConfig
|
||||
> quant_config = Int4WeightOnlyConfig(group_size=128)
|
||||
> quantization_config = TorchAoConfig(quant_type=quant_config)
|
||||
> ```
|
||||
>
|
||||
> The new API offers greater flexibility, better type safety, and access to the full range of features available in torchao.
|
||||
>
|
||||
> ## Migration Guide
|
||||
>
|
||||
> Here's how to migrate from common string identifiers to their `AOBaseConfig` equivalents:
|
||||
>
|
||||
> | Old String API | New `AOBaseConfig` API |
|
||||
> |----------------|------------------------|
|
||||
> | `"int4_weight_only"` | `Int4WeightOnlyConfig()` |
|
||||
> | `"int8_weight_only"` | `Int8WeightOnlyConfig()` |
|
||||
> | `"int8_dynamic_activation_int8_weight"` | `Int8DynamicActivationInt8WeightConfig()` |
|
||||
>
|
||||
> All configuration objects accept parameters for customization (e.g., `group_size`, `scheme`, `layout`).
|
||||
|
||||
|
||||
Below is the API for for torchao < `0.9.0`
|
||||
|
||||
```py
|
||||
import torch
|
||||
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
|
||||
@ -73,12 +150,15 @@ output = bf16_model.generate(**input_ids, max_new_tokens=10, cache_implementatio
|
||||
print("bf16 model:", benchmark_fn(bf16_model.generate, **input_ids, max_new_tokens=MAX_NEW_TOKENS, cache_implementation="static"))
|
||||
```
|
||||
|
||||
> [!TIP]
|
||||
> For best performance, you can use recommended settings by calling `torchao.quantization.utils.recommended_inductor_config_setter()`
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="automatic">
|
||||
|
||||
The [autoquant](https://pytorch.org/ao/stable/generated/torchao.quantization.autoquant.html#torchao.quantization.autoquant) API automatically chooses a quantization type for quantizable layers (`nn.Linear`) by micro-benchmarking on input type and shape and compiling a single linear layer.
|
||||
|
||||
Create a [`TorchAoConfig`] and set to `"autoquant"`. Set the `cache_implementation` to `"static"` to automatically [torch.compile](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html) the forward method. Finally, call `finalize_autoquant` on the quantized model to finalize the quantization and log the input shapes.
|
||||
Create a [`TorchAoConfig`] and set to `"autoquant"`. Set the `cache_implementation` to `"static"` to automatically [torch.compile](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html) the forward method. Finally, call `finalize_autoquant` on the quantized model to finalize the quantization and log the input shapes.
|
||||
|
||||
> [!TIP]
|
||||
> Run the quantized model on a CPU by changing `device_map` to `"cpu"` and `layout` to `Int4CPULayout()`. This is only available in torchao 0.8.0+.
|
||||
@ -131,7 +211,7 @@ print("bf16 model:", benchmark_fn(bf16_model.generate, **input_ids, max_new_toke
|
||||
|
||||
## Serialization
|
||||
|
||||
torchao implements [torch.Tensor subclasses](https://pytorch.org/docs/stable/notes/extending.html#subclassing-torch-tensor) for maximum flexibility in supporting new quantized torch.Tensor formats. [Safetensors](https://huggingface.co/docs/safetensors/en/index) serialization and deserialization does not work with torchaco.
|
||||
torchao implements [torch.Tensor subclasses](https://pytorch.org/docs/stable/notes/extending.html#subclassing-torch-tensor) for maximum flexibility in supporting new quantized torch.Tensor formats. [Safetensors](https://huggingface.co/docs/safetensors/en/index) serialization and deserialization does not work with torchao.
|
||||
|
||||
To avoid arbitrary user code execution, torchao sets `weights_only=True` in [torch.load](https://pytorch.org/docs/stable/generated/torch.load.html) to ensure only tensors are loaded. Any known user functions can be whitelisted with [add_safe_globals](https://pytorch.org/docs/stable/notes/serialization.html#torch.serialization.add_safe_globals).
|
||||
|
||||
|
||||
@ -220,7 +220,7 @@ Pasa tu texto al tokenizador:
|
||||
El tokenizador devolverá un diccionario conteniendo:
|
||||
|
||||
* [input_ids](./glossary#input-ids): representaciones numéricas de los tokens.
|
||||
* [atttention_mask](.glossary#attention-mask): indica cuáles tokens deben ser atendidos.
|
||||
* [attention_mask](.glossary#attention-mask): indica cuáles tokens deben ser atendidos.
|
||||
|
||||
Como con el [`pipeline`], el tokenizador aceptará una lista de inputs. Además, el tokenizador también puede rellenar (pad, en inglés) y truncar el texto para devolver un lote (batch, en inglés) de longitud uniforme:
|
||||
|
||||
|
||||
@ -23,7 +23,7 @@ Abbiamo integrato di recente `BetterTransformer` per fare inferenza più rapidam
|
||||
|
||||
## PyTorch JIT-mode (TorchScript)
|
||||
|
||||
TorchScript è un modo di creare modelli serializzabili e ottimizzabili da codice PyTorch. Ogni programmma TorchScript può esere salvato da un processo Python e caricato in un processo dove non ci sono dipendenze Python.
|
||||
TorchScript è un modo di creare modelli serializzabili e ottimizzabili da codice PyTorch. Ogni programma TorchScript può esere salvato da un processo Python e caricato in un processo dove non ci sono dipendenze Python.
|
||||
Comparandolo con l'eager mode di default, jit mode in PyTorch normalmente fornisce prestazioni migliori per l'inferenza del modello da parte di metodologie di ottimizzazione come la operator fusion.
|
||||
|
||||
Per una prima introduzione a TorchScript, vedi la Introduction to [PyTorch TorchScript tutorial](https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html#tracing-modules).
|
||||
|
||||
@ -22,10 +22,6 @@ rendered properly in your Markdown viewer.
|
||||
- `_LRSchedule` から継承するスケジュール オブジェクトの形式のいくつかのスケジュール:
|
||||
- 複数のバッチの勾配を累積するための勾配累積クラス
|
||||
|
||||
## AdamW (PyTorch)
|
||||
|
||||
[[autodoc]] AdamW
|
||||
|
||||
## AdaFactor (PyTorch)
|
||||
|
||||
[[autodoc]] Adafactor
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user