mirror of
https://github.com/huggingface/transformers.git
synced 2025-11-03 03:14:36 +08:00
Compare commits
195 Commits
check_env_
...
fix_requir
| Author | SHA1 | Date | |
|---|---|---|---|
| 1a67ca199c | |||
| 897874748b | |||
| 6a75528cbc | |||
| 6cef03ba66 | |||
| a563999a02 | |||
| 3c39c07939 | |||
| f797e3d98a | |||
| 442d356aa5 | |||
| 7e9b57ce62 | |||
| 54a123f068 | |||
| 931126b929 | |||
| c7064cdba1 | |||
| 371c44d0ef | |||
| 7ff896c0f2 | |||
| 94ea20ad41 | |||
| 10907e2846 | |||
| 7d76876498 | |||
| dac443414e | |||
| 6daec12d0b | |||
| 0ea1151222 | |||
| 9c0c323e12 | |||
| bde41d69b4 | |||
| 7ecc5b88c0 | |||
| 5ae9b2cac0 | |||
| d9e76656ae | |||
| 1ae8d54b04 | |||
| 10144ff116 | |||
| aa478567f8 | |||
| ae5ce22664 | |||
| 4f139f5a50 | |||
| a2c2fb0108 | |||
| 0ddad2d655 | |||
| fbb2054ed5 | |||
| 6d8b0b3378 | |||
| f5865d32a2 | |||
| e39c732644 | |||
| bc0150bb04 | |||
| 9cda4265d6 | |||
| e032d12e8a | |||
| f834ca2c19 | |||
| c5c648dd74 | |||
| 71b35387fd | |||
| ad340908e4 | |||
| 2527f71a47 | |||
| 7ae0be722e | |||
| e3eda6d188 | |||
| 1e6ff5fd55 | |||
| 6f4058aee3 | |||
| 08e3217baf | |||
| 4d0de5f73a | |||
| c15a7adb28 | |||
| 121f91d36c | |||
| 4321b0648c | |||
| aab0878327 | |||
| 35f0f5b5da | |||
| 530322ccb6 | |||
| 8064cd9b4f | |||
| cdfb018d03 | |||
| 1e6b546ea6 | |||
| 0fc683d1cd | |||
| 2515a5a290 | |||
| 2da82e432d | |||
| 794fde7b1c | |||
| b54c2f4689 | |||
| 754a370bca | |||
| 31a62c2eb8 | |||
| f830105183 | |||
| e2b0224d94 | |||
| 6cc109c354 | |||
| 8bbcdf5409 | |||
| 3a826a45ca | |||
| 5e855095a2 | |||
| 416b5a875d | |||
| f8a16805c5 | |||
| 48e179857c | |||
| 832cb684a0 | |||
| 22065bd645 | |||
| f789f960c8 | |||
| 12bf24d6ae | |||
| e7ad077012 | |||
| 99f9f1042f | |||
| 0fb8d49e88 | |||
| 08f36771b3 | |||
| 9db31ea585 | |||
| debfe904c9 | |||
| 54538ebee3 | |||
| d1b92369ca | |||
| 25b7f27234 | |||
| aa40fda346 | |||
| e94571580b | |||
| 84aa13dd85 | |||
| 0ef339ff1b | |||
| 46d73910d5 | |||
| 579135a2f6 | |||
| 8cd57eb731 | |||
| ebe47ce3e9 | |||
| 531e4fcf0e | |||
| a4e55fcff8 | |||
| 878562b68d | |||
| 8ebc435267 | |||
| ad3d157188 | |||
| 3d40bda30e | |||
| acbcb5d07d | |||
| 4ba0989eab | |||
| 352ec8ef22 | |||
| edd345b52e | |||
| b016de1ae4 | |||
| f74d7da836 | |||
| d130cd0e16 | |||
| 41b9b92b52 | |||
| 8dd0a2b89c | |||
| 15ac2b6ac5 | |||
| b552708694 | |||
| 2b84831a93 | |||
| 2d46a08b63 | |||
| 1b29409d89 | |||
| 8a828a747e | |||
| 3f6af96732 | |||
| 9a1c1fe7ed | |||
| 782d7d945d | |||
| afafb84b59 | |||
| 34ccfebf32 | |||
| f697b3f824 | |||
| 2099287a59 | |||
| a0803a9555 | |||
| 6ce238fe7a | |||
| 12048990a9 | |||
| 98601cc818 | |||
| c9302c0983 | |||
| 2056287940 | |||
| 3e96a0c32b | |||
| 199d7adf10 | |||
| 126abe3461 | |||
| 3d133cc557 | |||
| e90d55ebcc | |||
| cbfa14823b | |||
| 7613cf1a45 | |||
| 32c12aaec3 | |||
| 764ab0d46a | |||
| c94c6ed397 | |||
| e94d607c8b | |||
| adfc91cd46 | |||
| 6f5dc9c82e | |||
| a165458901 | |||
| ed95493ce0 | |||
| 211e4dc9a4 | |||
| 800510c67b | |||
| 41f5c3216c | |||
| bc2dea3f54 | |||
| 35253076f4 | |||
| bf41e54fc8 | |||
| 3249c5dc15 | |||
| 24e311f42b | |||
| 897ff9af0e | |||
| c0bd8048a5 | |||
| 60b75d99b6 | |||
| fac70ff3c0 | |||
| ae34bd75fd | |||
| 8f6b27eb5c | |||
| 737cbd2109 | |||
| 3a6ab46a0b | |||
| 4b13a02920 | |||
| 786d9c5ed9 | |||
| a1e389e637 | |||
| f304318f5f | |||
| 8805600406 | |||
| e686fed635 | |||
| a03cee7a1d | |||
| 3b07ca78bb | |||
| 475664e2c6 | |||
| 0710e9b1e8 | |||
| f99c279d20 | |||
| d1efaf0318 | |||
| 19919689b2 | |||
| d0b65bb479 | |||
| ad63d20dff | |||
| 286393fbb1 | |||
| 4705b04c74 | |||
| 2b4734bd49 | |||
| bd41b9c1ac | |||
| 6acd5aecb3 | |||
| 0d6a60fe55 | |||
| b7fc2daf8b | |||
| bab605dd04 | |||
| 9fd9476005 | |||
| 257bc670fb | |||
| 2bea6bf24e | |||
| a86dad56bc | |||
| d6064754ea | |||
| 581cf96e0c | |||
| eca74d1367 | |||
| 52cc204dd7 | |||
| aa3778afc2 | |||
| c90e6e9625 | |||
| 1fcaad6df9 |
@ -151,7 +151,6 @@ class CircleCIJob:
|
||||
"checkout",
|
||||
{"attach_workspace": {"at": "test_preparation"}},
|
||||
{"run": "apt-get update && apt-get install -y curl"},
|
||||
{"run": "echo $TRANSFORMERS_IS_CI"},
|
||||
{"run": " && ".join(self.install_steps)},
|
||||
{"run": {"name": "Download NLTK files", "command": """python -c "import nltk; nltk.download('punkt', quiet=True)" """} if "example" in self.name else "echo Skipping"},
|
||||
{"run": {
|
||||
@ -172,6 +171,7 @@ class CircleCIJob:
|
||||
"command": f"TESTS=$(circleci tests split --split-by=timings {self.job_name}_test_list.txt) && echo $TESTS > splitted_tests.txt && echo $TESTS | tr ' ' '\n'" if self.parallelism else f"awk '{{printf \"%s \", $0}}' {self.job_name}_test_list.txt > splitted_tests.txt"
|
||||
}
|
||||
},
|
||||
{"run": {"name": "fetch hub objects before pytest", "command": "python3 utils/fetch_hub_objects_for_ci.py"}},
|
||||
{"run": {
|
||||
"name": "Run tests",
|
||||
"command": f"({timeout_cmd} python3 -m pytest {marker_cmd} -n {self.pytest_num_workers} {junit_flags} {repeat_on_failure_flags} {' '.join(pytest_flags)} $(cat splitted_tests.txt) | tee tests_output.txt)"}
|
||||
|
||||
4
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
4
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
@ -48,11 +48,11 @@ body:
|
||||
- pipelines: @Rocketknight1
|
||||
- tensorflow: @gante and @Rocketknight1
|
||||
- tokenizers: @ArthurZucker and @itazap
|
||||
- trainer: @muellerzr @SunMarc
|
||||
- trainer: @zach-huggingface @SunMarc
|
||||
|
||||
Integrations:
|
||||
|
||||
- deepspeed: HF Trainer/Accelerate: @muellerzr
|
||||
- deepspeed: HF Trainer/Accelerate: @SunMarc @zach-huggingface
|
||||
- ray/raytune: @richardliaw, @amogkam
|
||||
- Big Model Inference: @SunMarc
|
||||
- quantization (bitsandbytes, autogpt): @SunMarc @MekkCyber
|
||||
|
||||
2
.github/ISSUE_TEMPLATE/i18n.md
vendored
2
.github/ISSUE_TEMPLATE/i18n.md
vendored
@ -23,7 +23,7 @@ Some notes:
|
||||
* Please translate in a gender-neutral way.
|
||||
* Add your translations to the folder called `<languageCode>` inside the [source folder](https://github.com/huggingface/transformers/tree/main/docs/source).
|
||||
* Register your translation in `<languageCode>/_toctree.yml`; please follow the order of the [English version](https://github.com/huggingface/transformers/blob/main/docs/source/en/_toctree.yml).
|
||||
* Once you're finished, open a pull request and tag this issue by including #issue-number in the description, where issue-number is the number of this issue. Please ping @stevhliu and @MKhalusova for review.
|
||||
* Once you're finished, open a pull request and tag this issue by including #issue-number in the description, where issue-number is the number of this issue. Please ping @stevhliu for review.
|
||||
* 🙋 If you'd like others to help you with the translation, you can also post in the 🤗 [forums](https://discuss.huggingface.co/).
|
||||
|
||||
## Get Started section
|
||||
|
||||
4
.github/PULL_REQUEST_TEMPLATE.md
vendored
4
.github/PULL_REQUEST_TEMPLATE.md
vendored
@ -51,12 +51,12 @@ Library:
|
||||
- pipelines: @Rocketknight1
|
||||
- tensorflow: @gante and @Rocketknight1
|
||||
- tokenizers: @ArthurZucker
|
||||
- trainer: @muellerzr and @SunMarc
|
||||
- trainer: @zach-huggingface and @SunMarc
|
||||
- chat templates: @Rocketknight1
|
||||
|
||||
Integrations:
|
||||
|
||||
- deepspeed: HF Trainer/Accelerate: @muellerzr
|
||||
- deepspeed: HF Trainer/Accelerate: @SunMarc @zach-huggingface
|
||||
- ray/raytune: @richardliaw, @amogkam
|
||||
- Big Model Inference: @SunMarc
|
||||
- quantization (bitsandbytes, autogpt): @SunMarc @MekkCyber
|
||||
|
||||
6
.github/scripts/codeowners_for_review_action
vendored
6
.github/scripts/codeowners_for_review_action
vendored
@ -14,7 +14,7 @@ docs/ @stevhliu
|
||||
# Owners of subsections of the library
|
||||
/src/transformers/generation/ @gante
|
||||
/src/transformers/pipeline/ @Rocketknight1 @yonigozlan
|
||||
/src/transformers/integrations/ @SunMarc @MekkCyber @muellerzr
|
||||
/src/transformers/integrations/ @SunMarc @MekkCyber @zach-huggingface
|
||||
/src/transformers/quantizers/ @SunMarc @MekkCyber
|
||||
tests/ @ydshieh
|
||||
tests/generation/ @gante
|
||||
@ -27,8 +27,8 @@ tests/generation/ @gante
|
||||
# Specific files come after the sections/globs, so they take priority
|
||||
/.circleci/config.yml @ArthurZucker @ydshieh
|
||||
/utils/tests_fetcher.py @ydshieh
|
||||
trainer.py @muellerzr @SunMarc
|
||||
trainer_utils.py @muellerzr @SunMarc
|
||||
trainer.py @zach-huggingface @SunMarc
|
||||
trainer_utils.py @zach-huggingface @SunMarc
|
||||
/utils/modular_model_converter.py @Cyrilvallez @ArthurZucker
|
||||
|
||||
# Owners of individual models are specific / high priority, and so they come last
|
||||
|
||||
36
.github/workflows/build-docker-images.yml
vendored
36
.github/workflows/build-docker-images.yml
vendored
@ -63,14 +63,14 @@ jobs:
|
||||
uses: huggingface/hf-workflows/.github/actions/post-slack@main
|
||||
with:
|
||||
slack_channel: ${{ secrets.CI_SLACK_CHANNEL_DOCKER }}
|
||||
title: 🤗 Results of the transformers-all-latest-gpu-push-ci docker build
|
||||
title: 🤗 Results of the transformers-all-latest-gpu-push-ci docker build
|
||||
status: ${{ job.status }}
|
||||
slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
|
||||
|
||||
latest-torch-deepspeed-docker:
|
||||
name: "Latest PyTorch + DeepSpeed"
|
||||
runs-on:
|
||||
group: aws-general-8-plus
|
||||
group: aws-g4dn-2xlarge-cache
|
||||
steps:
|
||||
-
|
||||
name: Set up Docker Buildx
|
||||
@ -99,7 +99,7 @@ jobs:
|
||||
uses: huggingface/hf-workflows/.github/actions/post-slack@main
|
||||
with:
|
||||
slack_channel: ${{ secrets.CI_SLACK_CHANNEL_DOCKER}}
|
||||
title: 🤗 Results of the transformers-pytorch-deepspeed-latest-gpu docker build
|
||||
title: 🤗 Results of the transformers-pytorch-deepspeed-latest-gpu docker build
|
||||
status: ${{ job.status }}
|
||||
slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
|
||||
|
||||
@ -140,7 +140,7 @@ jobs:
|
||||
uses: huggingface/hf-workflows/.github/actions/post-slack@main
|
||||
with:
|
||||
slack_channel: ${{ secrets.CI_SLACK_CHANNEL_DOCKER }}
|
||||
title: 🤗 Results of the transformers-pytorch-deepspeed-latest-gpu-push-ci docker build
|
||||
title: 🤗 Results of the transformers-pytorch-deepspeed-latest-gpu-push-ci docker build
|
||||
status: ${{ job.status }}
|
||||
slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
|
||||
|
||||
@ -176,7 +176,7 @@ jobs:
|
||||
uses: huggingface/hf-workflows/.github/actions/post-slack@main
|
||||
with:
|
||||
slack_channel: ${{ secrets.CI_SLACK_CHANNEL_DOCKER }}
|
||||
title: 🤗 Results of the huggingface/transformers-doc-builder docker build
|
||||
title: 🤗 Results of the huggingface/transformers-doc-builder docker build
|
||||
status: ${{ job.status }}
|
||||
slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
|
||||
|
||||
@ -214,7 +214,7 @@ jobs:
|
||||
uses: huggingface/hf-workflows/.github/actions/post-slack@main
|
||||
with:
|
||||
slack_channel: ${{ secrets.CI_SLACK_CHANNEL_DOCKER }}
|
||||
title: 🤗 Results of the huggingface/transformers-pytorch-gpudocker build
|
||||
title: 🤗 Results of the huggingface/transformers-pytorch-gpudocker build
|
||||
status: ${{ job.status }}
|
||||
slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
|
||||
|
||||
@ -223,19 +223,19 @@ jobs:
|
||||
runs-on:
|
||||
group: aws-general-8-plus
|
||||
steps:
|
||||
-
|
||||
-
|
||||
name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
-
|
||||
-
|
||||
name: Check out code
|
||||
uses: actions/checkout@v4
|
||||
-
|
||||
-
|
||||
name: Login to DockerHub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_PASSWORD }}
|
||||
-
|
||||
-
|
||||
name: Build and push
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
@ -263,7 +263,7 @@ jobs:
|
||||
uses: huggingface/hf-workflows/.github/actions/post-slack@main
|
||||
with:
|
||||
slack_channel: ${{ secrets.CI_SLACK_CHANNEL_DOCKER }}
|
||||
title: 🤗 Results of the huggingface/transformers-pytorch-amd-gpu-push-ci build
|
||||
title: 🤗 Results of the huggingface/transformers-pytorch-amd-gpu-push-ci build
|
||||
status: ${{ job.status }}
|
||||
slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
|
||||
|
||||
@ -301,7 +301,7 @@ jobs:
|
||||
uses: huggingface/hf-workflows/.github/actions/post-slack@main
|
||||
with:
|
||||
slack_channel: ${{ secrets.CI_SLACK_CHANNEL_DOCKER }}
|
||||
title: 🤗 Results of the huggingface/transformers-tensorflow-gpu build
|
||||
title: 🤗 Results of the huggingface/transformers-tensorflow-gpu build
|
||||
status: ${{ job.status }}
|
||||
slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
|
||||
|
||||
@ -310,19 +310,19 @@ jobs:
|
||||
runs-on:
|
||||
group: aws-general-8-plus
|
||||
steps:
|
||||
-
|
||||
-
|
||||
name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
-
|
||||
-
|
||||
name: Check out code
|
||||
uses: actions/checkout@v4
|
||||
-
|
||||
-
|
||||
name: Login to DockerHub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_PASSWORD }}
|
||||
-
|
||||
-
|
||||
name: Build and push
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
@ -350,7 +350,7 @@ jobs:
|
||||
uses: huggingface/hf-workflows/.github/actions/post-slack@main
|
||||
with:
|
||||
slack_channel: ${{ secrets.CI_SLACK_CHANNEL_DOCKER }}
|
||||
title: 🤗 Results of the transformers-pytorch-deepspeed-amd-gpu build
|
||||
title: 🤗 Results of the transformers-pytorch-deepspeed-amd-gpu build
|
||||
status: ${{ job.status }}
|
||||
slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
|
||||
|
||||
@ -388,6 +388,6 @@ jobs:
|
||||
uses: huggingface/hf-workflows/.github/actions/post-slack@main
|
||||
with:
|
||||
slack_channel: ${{ secrets.CI_SLACK_CHANNEL_DOCKER }}
|
||||
title: 🤗 Results of the transformers-quantization-latest-gpu build
|
||||
title: 🤗 Results of the transformers-quantization-latest-gpu build
|
||||
status: ${{ job.status }}
|
||||
slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
|
||||
|
||||
@ -42,7 +42,7 @@ jobs:
|
||||
nightly-torch-deepspeed-docker:
|
||||
name: "Nightly PyTorch + DeepSpeed"
|
||||
runs-on:
|
||||
group: aws-general-8-plus
|
||||
group: aws-g4dn-2xlarge-cache
|
||||
steps:
|
||||
-
|
||||
name: Set up Docker Buildx
|
||||
|
||||
20
.github/workflows/model_jobs.yml
vendored
20
.github/workflows/model_jobs.yml
vendored
@ -18,6 +18,10 @@ on:
|
||||
docker:
|
||||
required: true
|
||||
type: string
|
||||
report_name_prefix:
|
||||
required: false
|
||||
default: run_models_gpu
|
||||
type: string
|
||||
|
||||
env:
|
||||
HF_HOME: /mnt/cache
|
||||
@ -116,23 +120,23 @@ jobs:
|
||||
|
||||
- name: Run all tests on GPU
|
||||
working-directory: /transformers
|
||||
run: python3 -m pytest -rsfE -v --make-reports=${{ env.machine_type }}_run_models_gpu_${{ matrix.folders }}_test_reports tests/${{ matrix.folders }}
|
||||
run: python3 -m pytest -rsfE -v --make-reports=${{ env.machine_type }}_${{ inputs.report_name_prefix }}_${{ matrix.folders }}_test_reports tests/${{ matrix.folders }}
|
||||
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
continue-on-error: true
|
||||
run: cat /transformers/reports/${{ env.machine_type }}_run_models_gpu_${{ matrix.folders }}_test_reports/failures_short.txt
|
||||
run: cat /transformers/reports/${{ env.machine_type }}_${{ inputs.report_name_prefix }}_${{ matrix.folders }}_test_reports/failures_short.txt
|
||||
|
||||
- name: Run test
|
||||
shell: bash
|
||||
run: |
|
||||
mkdir -p /transformers/reports/${{ env.machine_type }}_run_models_gpu_${{ matrix.folders }}_test_reports
|
||||
echo "hello" > /transformers/reports/${{ env.machine_type }}_run_models_gpu_${{ matrix.folders }}_test_reports/hello.txt
|
||||
echo "${{ env.machine_type }}_run_models_gpu_${{ matrix.folders }}_test_reports"
|
||||
mkdir -p /transformers/reports/${{ env.machine_type }}_${{ inputs.report_name_prefix }}_${{ matrix.folders }}_test_reports
|
||||
echo "hello" > /transformers/reports/${{ env.machine_type }}_${{ inputs.report_name_prefix }}_${{ matrix.folders }}_test_reports/hello.txt
|
||||
echo "${{ env.machine_type }}_${{ inputs.report_name_prefix }}_${{ matrix.folders }}_test_reports"
|
||||
|
||||
- name: "Test suite reports artifacts: ${{ env.machine_type }}_run_models_gpu_${{ env.matrix_folders }}_test_reports"
|
||||
- name: "Test suite reports artifacts: ${{ env.machine_type }}_${{ inputs.report_name_prefix }}_${{ env.matrix_folders }}_test_reports"
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: ${{ env.machine_type }}_run_models_gpu_${{ env.matrix_folders }}_test_reports
|
||||
path: /transformers/reports/${{ env.machine_type }}_run_models_gpu_${{ matrix.folders }}_test_reports
|
||||
name: ${{ env.machine_type }}_${{ inputs.report_name_prefix }}_${{ env.matrix_folders }}_test_reports
|
||||
path: /transformers/reports/${{ env.machine_type }}_${{ inputs.report_name_prefix }}_${{ matrix.folders }}_test_reports
|
||||
|
||||
2
.github/workflows/self-comment-ci.yml
vendored
2
.github/workflows/self-comment-ci.yml
vendored
@ -29,7 +29,7 @@ jobs:
|
||||
runs-on: ubuntu-22.04
|
||||
name: Get PR number
|
||||
# For security: only allow team members to run
|
||||
if: ${{ github.event.issue.state == 'open' && contains(fromJSON('["ydshieh", "ArthurZucker", "zucchini-nlp", "qubvel", "molbap", "gante", "LysandreJik", "Cyrilvallez", "Rocketknight1", "SunMarc", "muellerzr", "eustlb"]'), github.actor) && (startsWith(github.event.comment.body, 'run-slow') || startsWith(github.event.comment.body, 'run slow') || startsWith(github.event.comment.body, 'run_slow')) }}
|
||||
if: ${{ github.event.issue.state == 'open' && contains(fromJSON('["ydshieh", "ArthurZucker", "zucchini-nlp", "qubvel", "molbap", "gante", "LysandreJik", "Cyrilvallez", "Rocketknight1", "SunMarc", "muellerzr", "eustlb", "MekkCyber"]'), github.actor) && (startsWith(github.event.comment.body, 'run-slow') || startsWith(github.event.comment.body, 'run slow') || startsWith(github.event.comment.body, 'run_slow')) }}
|
||||
outputs:
|
||||
PR_NUMBER: ${{ steps.set_pr_number.outputs.PR_NUMBER }}
|
||||
steps:
|
||||
|
||||
13
.github/workflows/self-scheduled-caller.yml
vendored
13
.github/workflows/self-scheduled-caller.yml
vendored
@ -54,12 +54,23 @@ jobs:
|
||||
ci_event: Daily CI
|
||||
secrets: inherit
|
||||
|
||||
trainer-fsdp-ci:
|
||||
name: Trainer/FSDP CI
|
||||
uses: ./.github/workflows/self-scheduled.yml
|
||||
with:
|
||||
job: run_trainer_and_fsdp_gpu
|
||||
slack_report_channel: "#transformers-ci-daily-training"
|
||||
runner: daily-ci
|
||||
docker: huggingface/transformers-all-latest-gpu
|
||||
ci_event: Daily CI
|
||||
secrets: inherit
|
||||
|
||||
deepspeed-ci:
|
||||
name: DeepSpeed CI
|
||||
uses: ./.github/workflows/self-scheduled.yml
|
||||
with:
|
||||
job: run_torch_cuda_extensions_gpu
|
||||
slack_report_channel: "#transformers-ci-daily-deepspeed"
|
||||
slack_report_channel: "#transformers-ci-daily-training"
|
||||
runner: daily-ci
|
||||
docker: huggingface/transformers-pytorch-deepspeed-latest-gpu
|
||||
ci_event: Daily CI
|
||||
|
||||
35
.github/workflows/self-scheduled.yml
vendored
35
.github/workflows/self-scheduled.yml
vendored
@ -45,7 +45,7 @@ env:
|
||||
|
||||
jobs:
|
||||
setup:
|
||||
if: contains(fromJSON('["run_models_gpu", "run_quantization_torch_gpu"]'), inputs.job)
|
||||
if: contains(fromJSON('["run_models_gpu", "run_trainer_and_fsdp_gpu", "run_quantization_torch_gpu"]'), inputs.job)
|
||||
name: Setup
|
||||
strategy:
|
||||
matrix:
|
||||
@ -77,12 +77,17 @@ jobs:
|
||||
run: pip freeze
|
||||
|
||||
- id: set-matrix
|
||||
if: ${{ inputs.job == 'run_models_gpu' }}
|
||||
if: contains(fromJSON('["run_models_gpu", "run_trainer_and_fsdp_gpu"]'), inputs.job)
|
||||
name: Identify models to test
|
||||
working-directory: /transformers/tests
|
||||
run: |
|
||||
echo "folder_slices=$(python3 ../utils/split_model_tests.py --num_splits ${{ env.NUM_SLICES }})" >> $GITHUB_OUTPUT
|
||||
echo "slice_ids=$(python3 -c 'd = list(range(${{ env.NUM_SLICES }})); print(d)')" >> $GITHUB_OUTPUT
|
||||
if [ "${{ inputs.job }}" = "run_models_gpu" ]; then
|
||||
echo "folder_slices=$(python3 ../utils/split_model_tests.py --num_splits ${{ env.NUM_SLICES }})" >> $GITHUB_OUTPUT
|
||||
echo "slice_ids=$(python3 -c 'd = list(range(${{ env.NUM_SLICES }})); print(d)')" >> $GITHUB_OUTPUT
|
||||
elif [ "${{ inputs.job }}" = "run_trainer_and_fsdp_gpu" ]; then
|
||||
echo "folder_slices=[['trainer'], ['fsdp']]" >> $GITHUB_OUTPUT
|
||||
echo "slice_ids=[0, 1]" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
- id: set-matrix-quantization
|
||||
if: ${{ inputs.job == 'run_quantization_torch_gpu' }}
|
||||
@ -113,6 +118,25 @@ jobs:
|
||||
docker: ${{ inputs.docker }}
|
||||
secrets: inherit
|
||||
|
||||
run_trainer_and_fsdp_gpu:
|
||||
if: ${{ inputs.job == 'run_trainer_and_fsdp_gpu' }}
|
||||
name: " "
|
||||
needs: setup
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
machine_type: [aws-g4dn-2xlarge-cache, aws-g4dn-12xlarge-cache]
|
||||
slice_id: [0, 1]
|
||||
uses: ./.github/workflows/model_jobs.yml
|
||||
with:
|
||||
folder_slices: ${{ needs.setup.outputs.folder_slices }}
|
||||
machine_type: ${{ matrix.machine_type }}
|
||||
slice_id: ${{ matrix.slice_id }}
|
||||
runner: ${{ inputs.runner }}
|
||||
docker: ${{ inputs.docker }}
|
||||
report_name_prefix: run_trainer_and_fsdp_gpu
|
||||
secrets: inherit
|
||||
|
||||
run_pipelines_torch_gpu:
|
||||
if: ${{ inputs.job == 'run_pipelines_torch_gpu' }}
|
||||
name: PyTorch pipelines
|
||||
@ -382,7 +406,7 @@ jobs:
|
||||
run: pip freeze
|
||||
|
||||
- name: Set `machine_type` for report and artifact names
|
||||
working-directory: /transformers
|
||||
working-directory: ${{ inputs.working-directory-prefix }}/transformers
|
||||
shell: bash
|
||||
run: |
|
||||
echo "${{ matrix.machine_type }}"
|
||||
@ -541,6 +565,7 @@ jobs:
|
||||
needs: [
|
||||
setup,
|
||||
run_models_gpu,
|
||||
run_trainer_and_fsdp_gpu,
|
||||
run_pipelines_torch_gpu,
|
||||
run_pipelines_tf_gpu,
|
||||
run_examples_gpu,
|
||||
|
||||
@ -70,7 +70,7 @@ Explore the [Hub](https://huggingface.com/) today to find a model and use Transf
|
||||
|
||||
## Installation
|
||||
|
||||
Transformers works with Python 3.9+ [PyTorch](https://pytorch.org/get-started/locally/) 2.0+, [TensorFlow](https://www.tensorflow.org/install/pip) 2.6+, and [Flax](https://flax.readthedocs.io/en/latest/) 0.4.1+.
|
||||
Transformers works with Python 3.9+ [PyTorch](https://pytorch.org/get-started/locally/) 2.1+, [TensorFlow](https://www.tensorflow.org/install/pip) 2.6+, and [Flax](https://flax.readthedocs.io/en/latest/) 0.4.1+.
|
||||
|
||||
Create and activate a virtual environment with [venv](https://docs.python.org/3/library/venv.html) or [uv](https://docs.astral.sh/uv/), a fast Rust-based Python package and project manager.
|
||||
|
||||
|
||||
@ -12,7 +12,7 @@ def run_benchmark(logger: Logger, branch: str, commit_id: str, commit_msg: str,
|
||||
|
||||
## Writing metrics to the database
|
||||
|
||||
`MetricRecorder` is thread-safe, in the sense of the python [`Thread`](https://docs.python.org/3/library/threading.html#threading.Thread). This means you can start a background thread to do the readings on the device measurements while not blocking the main thread to execute the model measurements.
|
||||
`MetricsRecorder` is thread-safe, in the sense of the python [`Thread`](https://docs.python.org/3/library/threading.html#threading.Thread). This means you can start a background thread to do the readings on the device measurements while not blocking the main thread to execute the model measurements.
|
||||
|
||||
cf [`llama.py`](./llama.py) to see an example of this in practice.
|
||||
|
||||
|
||||
@ -3,7 +3,6 @@ import importlib.util
|
||||
import logging
|
||||
import os
|
||||
from typing import Dict
|
||||
import psycopg2
|
||||
import sys
|
||||
|
||||
from psycopg2.extras import Json
|
||||
|
||||
@ -118,7 +118,7 @@ def run_benchmark(logger: Logger, branch: str, commit_id: str, commit_msg: str,
|
||||
with torch.no_grad():
|
||||
past_key_values = StaticCache(
|
||||
model.config,
|
||||
batch_size=batch_size,
|
||||
max_batch_size=batch_size,
|
||||
device=device,
|
||||
dtype=torch.float16,
|
||||
max_cache_len=seq_length + num_tokens_to_generate,
|
||||
@ -144,7 +144,7 @@ def run_benchmark(logger: Logger, branch: str, commit_id: str, commit_msg: str,
|
||||
|
||||
past_key_values = StaticCache(
|
||||
model.config,
|
||||
batch_size=batch_size,
|
||||
max_batch_size=batch_size,
|
||||
device=device,
|
||||
dtype=torch.float16,
|
||||
max_cache_len=seq_length + num_tokens_to_generate,
|
||||
@ -187,7 +187,7 @@ def run_benchmark(logger: Logger, branch: str, commit_id: str, commit_msg: str,
|
||||
# TODO use decode_one_token(model, input_id.clone(), cache_position) for verification
|
||||
past_key_values = StaticCache(
|
||||
model.config,
|
||||
batch_size=batch_size,
|
||||
max_batch_size=batch_size,
|
||||
device=device,
|
||||
dtype=torch.float16,
|
||||
max_cache_len=seq_length + num_tokens_to_generate + 10,
|
||||
@ -204,7 +204,7 @@ def run_benchmark(logger: Logger, branch: str, commit_id: str, commit_msg: str,
|
||||
time_to_first_token = end - start
|
||||
logger.info(f"completed first compile generation in: {time_to_first_token}s")
|
||||
cache_position += 1
|
||||
all_generated_tokens += next_token.clone().detach().cpu().tolist()
|
||||
all_generated_tokens += next_token.tolist()
|
||||
|
||||
cache_position = torch.tensor([seq_length], device=device)
|
||||
### First compile, decoding
|
||||
@ -215,9 +215,9 @@ def run_benchmark(logger: Logger, branch: str, commit_id: str, commit_msg: str,
|
||||
torch.cuda.synchronize()
|
||||
end = perf_counter()
|
||||
time_to_second_token = end - start
|
||||
logger.info(f"completed second compile generation in: {time_to_first_token}s")
|
||||
logger.info(f"completed second compile generation in: {time_to_second_token}s")
|
||||
cache_position += 1
|
||||
all_generated_tokens += next_token.clone().detach().cpu().tolist()
|
||||
all_generated_tokens += next_token.tolist()
|
||||
|
||||
### Second compile, decoding
|
||||
start = perf_counter()
|
||||
@ -227,15 +227,15 @@ def run_benchmark(logger: Logger, branch: str, commit_id: str, commit_msg: str,
|
||||
torch.cuda.synchronize()
|
||||
end = perf_counter()
|
||||
time_to_third_token = end - start
|
||||
logger.info(f"completed third compile forward in: {time_to_first_token}s")
|
||||
logger.info(f"completed third compile forward in: {time_to_third_token}s")
|
||||
cache_position += 1
|
||||
all_generated_tokens += next_token.clone().detach().cpu().tolist()
|
||||
all_generated_tokens += next_token.tolist()
|
||||
|
||||
### Using cuda graphs decoding
|
||||
|
||||
start = perf_counter()
|
||||
for _ in range(1, num_tokens_to_generate):
|
||||
all_generated_tokens += next_token.clone().detach().cpu().tolist()
|
||||
all_generated_tokens += next_token.tolist()
|
||||
next_token = decode_one_token(
|
||||
model, next_token.clone(), cache_position=cache_position, past_key_values=past_key_values
|
||||
)
|
||||
@ -254,7 +254,7 @@ def run_benchmark(logger: Logger, branch: str, commit_id: str, commit_msg: str,
|
||||
|
||||
past_key_values = StaticCache(
|
||||
model.config,
|
||||
batch_size=batch_size,
|
||||
max_batch_size=batch_size,
|
||||
device=device,
|
||||
dtype=torch.float16,
|
||||
max_cache_len=seq_length + 128,
|
||||
@ -271,7 +271,7 @@ def run_benchmark(logger: Logger, branch: str, commit_id: str, commit_msg: str,
|
||||
|
||||
past_key_values = StaticCache(
|
||||
model.config,
|
||||
batch_size=batch_size,
|
||||
max_batch_size=batch_size,
|
||||
device=device,
|
||||
dtype=torch.float16,
|
||||
max_cache_len=seq_length + 128,
|
||||
@ -287,7 +287,7 @@ def run_benchmark(logger: Logger, branch: str, commit_id: str, commit_msg: str,
|
||||
|
||||
past_key_values = StaticCache(
|
||||
model.config,
|
||||
batch_size=batch_size,
|
||||
max_batch_size=batch_size,
|
||||
device=device,
|
||||
dtype=torch.float16,
|
||||
max_cache_len=seq_length + 128,
|
||||
@ -298,12 +298,12 @@ def run_benchmark(logger: Logger, branch: str, commit_id: str, commit_msg: str,
|
||||
output = model.generate(**inputs, past_key_values=past_key_values)
|
||||
end = perf_counter()
|
||||
third_compile_generate_time = end - start
|
||||
logger.info(f"completed second compile generation in: {third_compile_generate_time}s")
|
||||
logger.info(f"completed third compile generation in: {third_compile_generate_time}s")
|
||||
logger.info(f"generated: {tokenizer.batch_decode(output.cpu().tolist())}")
|
||||
|
||||
past_key_values = StaticCache(
|
||||
model.config,
|
||||
batch_size=batch_size,
|
||||
max_batch_size=batch_size,
|
||||
device=device,
|
||||
dtype=torch.float16,
|
||||
max_cache_len=seq_length + 128,
|
||||
@ -313,7 +313,7 @@ def run_benchmark(logger: Logger, branch: str, commit_id: str, commit_msg: str,
|
||||
output = model.generate(**inputs, past_key_values=past_key_values)
|
||||
end = perf_counter()
|
||||
fourth_compile_generate_time = end - start
|
||||
logger.info(f"completed second compile generation in: {fourth_compile_generate_time}s")
|
||||
logger.info(f"completed fourth compile generation in: {fourth_compile_generate_time}s")
|
||||
logger.info(f"generated: {tokenizer.batch_decode(output.cpu().tolist())}")
|
||||
|
||||
metrics_recorder.collect_model_measurements(
|
||||
|
||||
@ -46,10 +46,6 @@ NOT_DEVICE_TESTS = {
|
||||
"test_keep_in_fp32_modules",
|
||||
"test_gradient_checkpointing_backward_compatibility",
|
||||
"test_gradient_checkpointing_enable_disable",
|
||||
"test_save_load_fast_init_from_base",
|
||||
"test_fast_init_context_manager",
|
||||
"test_fast_init_tied_embeddings",
|
||||
"test_save_load_fast_init_to_base",
|
||||
"test_torch_save_load",
|
||||
"test_initialization",
|
||||
"test_forward_signature",
|
||||
|
||||
@ -7,5 +7,5 @@ ENV UV_PYTHON=/usr/local/bin/python
|
||||
RUN pip --no-cache-dir install uv && uv venv && uv pip install --no-cache-dir -U pip setuptools
|
||||
RUN uv pip install --no-cache-dir 'torch' 'torchvision' 'torchaudio' --index-url https://download.pytorch.org/whl/cpu
|
||||
RUN uv pip install --no-deps timm accelerate --extra-index-url https://download.pytorch.org/whl/cpu
|
||||
RUN uv pip install --no-cache-dir librosa "git+https://github.com/huggingface/transformers.git@${REF}#egg=transformers[sklearn,sentencepiece,vision,testing,tiktoken,num2words]"
|
||||
RUN uv pip install --no-cache-dir librosa "git+https://github.com/huggingface/transformers.git@${REF}#egg=transformers[sklearn,sentencepiece,vision,testing,tiktoken,num2words,video]"
|
||||
RUN uv pip uninstall transformers
|
||||
|
||||
@ -14,6 +14,8 @@ ARG PYTORCH='2.6.0'
|
||||
ARG INTEL_TORCH_EXT='2.3.0'
|
||||
# Example: `cu102`, `cu113`, etc.
|
||||
ARG CUDA='cu121'
|
||||
# Disable kernel mapping for now until all tests pass
|
||||
ENV DISABLE_KERNEL_MAPPING=1
|
||||
|
||||
RUN apt update
|
||||
RUN apt install -y git libsndfile1-dev tesseract-ocr espeak-ng python3 python3-pip ffmpeg git-lfs
|
||||
@ -57,7 +59,8 @@ RUN python3 -m pip uninstall -y ninja
|
||||
|
||||
# For `dinat` model
|
||||
# The `XXX` part in `torchXXX` needs to match `PYTORCH` (to some extent)
|
||||
RUN python3 -m pip install --no-cache-dir natten==0.15.1+torch220$CUDA -f https://shi-labs.com/natten/wheels
|
||||
# pin `0.17.4` otherwise `cannot import name 'natten2dav' from 'natten.functional'`
|
||||
RUN python3 -m pip install --no-cache-dir natten==0.17.4+torch250cu121 -f https://shi-labs.com/natten/wheels
|
||||
|
||||
# For `nougat` tokenizer
|
||||
RUN python3 -m pip install --no-cache-dir python-Levenshtein
|
||||
|
||||
@ -1,12 +1,12 @@
|
||||
# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-23-11.html#rel-23-11
|
||||
FROM nvcr.io/nvidia/pytorch:23.11-py3
|
||||
# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-08.html
|
||||
FROM nvcr.io/nvidia/pytorch:24.08-py3
|
||||
LABEL maintainer="Hugging Face"
|
||||
|
||||
ARG DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
ARG PYTORCH='2.2.0'
|
||||
ARG PYTORCH='2.6.0'
|
||||
# Example: `cu102`, `cu113`, etc.
|
||||
ARG CUDA='cu121'
|
||||
ARG CUDA='cu126'
|
||||
|
||||
RUN apt -y update
|
||||
RUN apt install -y libaio-dev
|
||||
@ -15,7 +15,8 @@ RUN python3 -m pip install --no-cache-dir --upgrade pip
|
||||
ARG REF=main
|
||||
RUN git clone https://github.com/huggingface/transformers && cd transformers && git checkout $REF
|
||||
|
||||
RUN python3 -m pip install --no-cache-dir ./transformers[deepspeed-testing]
|
||||
# `datasets` requires pandas, pandas has some modules compiled with numpy=1.x causing errors
|
||||
RUN python3 -m pip install --no-cache-dir './transformers[deepspeed-testing]' 'pandas<2' 'numpy<2'
|
||||
|
||||
# Install latest release PyTorch
|
||||
# (PyTorch must be installed before pre-compiling any DeepSpeed c++/cuda ops.)
|
||||
|
||||
@ -1,11 +1,11 @@
|
||||
# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-23-11.html#rel-23-11
|
||||
FROM nvcr.io/nvidia/pytorch:23.11-py3
|
||||
FROM nvcr.io/nvidia/pytorch:24.08-py3
|
||||
LABEL maintainer="Hugging Face"
|
||||
|
||||
ARG DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
# Example: `cu102`, `cu113`, etc.
|
||||
ARG CUDA='cu121'
|
||||
ARG CUDA='cu126'
|
||||
|
||||
RUN apt -y update
|
||||
RUN apt install -y libaio-dev
|
||||
@ -21,7 +21,8 @@ RUN python3 -m pip uninstall -y torch torchvision torchaudio
|
||||
# (https://www.deepspeed.ai/tutorials/advanced-install/#pre-install-deepspeed-ops)
|
||||
RUN python3 -m pip install --no-cache-dir -U --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/$CUDA
|
||||
|
||||
RUN python3 -m pip install --no-cache-dir ./transformers[deepspeed-testing]
|
||||
# `datasets` requires pandas, pandas has some modules compiled with numpy=1.x causing errors
|
||||
RUN python3 -m pip install --no-cache-dir './transformers[deepspeed-testing]' 'pandas<2' 'numpy<2'
|
||||
|
||||
RUN python3 -m pip install --no-cache-dir git+https://github.com/huggingface/accelerate@main#egg=accelerate
|
||||
|
||||
|
||||
@ -12,6 +12,8 @@ SHELL ["sh", "-lc"]
|
||||
ARG PYTORCH='2.6.0'
|
||||
# Example: `cu102`, `cu113`, etc.
|
||||
ARG CUDA='cu121'
|
||||
# Disable kernel mapping for quantization tests
|
||||
ENV DISABLE_KERNEL_MAPPING=1
|
||||
|
||||
RUN apt update
|
||||
RUN apt install -y git libsndfile1-dev tesseract-ocr espeak-ng python3 python3-pip ffmpeg
|
||||
|
||||
@ -674,29 +674,7 @@ use_cpu: false
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="Tensor Parallelism with PyTorch 2">
|
||||
|
||||
```yml
|
||||
compute_environment: LOCAL_MACHINE
|
||||
tp_config:
|
||||
tp_size: 4
|
||||
distributed_type: TP
|
||||
downcast_bf16: 'no'
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: 'no'
|
||||
num_machines: 1
|
||||
num_processes: 4
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
||||
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
يُعد أمر [`accelerate_launch`](https://huggingface.co/docs/accelerate/package_reference/cli#accelerate-launch) هو الطريقة المُوصى بها لتشغيل نص البرمجى للتدريب على نظام موزع باستخدام Accelerate و [`Trainer`] مع المعلمات المحددة في `config_file.yaml`. يتم حفظ هذا الملف في مجلد ذاكرة التخزين المؤقت لـ Accelerate ويتم تحميله تلقائيًا عند تشغيل `accelerate_launch`.
|
||||
|
||||
|
||||
@ -161,6 +161,8 @@
|
||||
sections:
|
||||
- local: quantization/overview
|
||||
title: Overview
|
||||
- local: quantization/selecting
|
||||
title: Selecting a quantization method
|
||||
- local: quantization/aqlm
|
||||
title: AQLM
|
||||
- local: quantization/awq
|
||||
@ -415,6 +417,8 @@
|
||||
title: DeBERTa
|
||||
- local: model_doc/deberta-v2
|
||||
title: DeBERTa-v2
|
||||
- local: model_doc/deepseek_v3
|
||||
title: DeepSeek-V3
|
||||
- local: model_doc/dialogpt
|
||||
title: DialoGPT
|
||||
- local: model_doc/diffllama
|
||||
@ -459,6 +463,8 @@
|
||||
title: Gemma2
|
||||
- local: model_doc/glm
|
||||
title: GLM
|
||||
- local: model_doc/glm4
|
||||
title: glm4
|
||||
- local: model_doc/openai-gpt
|
||||
title: GPT
|
||||
- local: model_doc/gpt_neo
|
||||
@ -505,6 +511,8 @@
|
||||
title: Llama2
|
||||
- local: model_doc/llama3
|
||||
title: Llama3
|
||||
- local: model_doc/llama4
|
||||
title: Llama4
|
||||
- local: model_doc/longformer
|
||||
title: Longformer
|
||||
- local: model_doc/longt5
|
||||
@ -601,6 +609,10 @@
|
||||
title: Qwen2
|
||||
- local: model_doc/qwen2_moe
|
||||
title: Qwen2MoE
|
||||
- local: model_doc/qwen3
|
||||
title: Qwen3
|
||||
- local: model_doc/qwen3_moe
|
||||
title: Qwen3MoE
|
||||
- local: model_doc/rag
|
||||
title: RAG
|
||||
- local: model_doc/realm
|
||||
@ -1066,6 +1078,8 @@
|
||||
title: Utilities for Audio processing
|
||||
- local: internal/file_utils
|
||||
title: General Utilities
|
||||
- local: internal/import_utils
|
||||
title: Importing Utilities
|
||||
- local: internal/time_series_utils
|
||||
title: Utilities for Time Series
|
||||
title: Internal helpers
|
||||
|
||||
@ -23,13 +23,13 @@ supported models.
|
||||
Most recent models can now switch from one attention function used in the Attention layer to the other, thanks to a simple mapping.
|
||||
By default, we provide the implementation for [`sdpa`](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html),
|
||||
[`flash_attention_2`](https://github.com/Dao-AILab/flash-attention) and [`flex_attention`](https://pytorch.org/docs/stable/nn.attention.flex_attention.html#module-torch.nn.attention.flex_attention)
|
||||
as well as `eager`, which is simple matrix multiplication without any optimization on top.
|
||||
as well as `eager`, which is a simple matrix multiplication without any optimization on top.
|
||||
This is the setting you can usually choose when instantiating a model:
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
model_id = "meta-llama/Llama-3.2-1B
|
||||
model_id = "meta-llama/Llama-3.2-1B"
|
||||
|
||||
# Here, using flash attention as an example
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation="flash_attention_2")
|
||||
@ -43,7 +43,7 @@ from transformers import AutoModelForCausalLM, AttentionInterface
|
||||
from transformers.integrations.sdpa_attention import sdpa_attention_forward
|
||||
import torch
|
||||
|
||||
model_id = "meta-llama/Llama-3.2-1B
|
||||
model_id = "meta-llama/Llama-3.2-1B"
|
||||
|
||||
def my_new_sdpa(*args, **kwargs):
|
||||
print("I just entered the attention computation")
|
||||
@ -56,7 +56,7 @@ model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation="my_n
|
||||
model(torch.ones(1, 5, dtype=int))
|
||||
```
|
||||
|
||||
You will see it prints "I just entered the attention computation" as many times as there are layers in the model (with this example, 16 times.
|
||||
You will see it prints "I just entered the attention computation" as many times as there are layers in the model (with this example, 16 times).
|
||||
|
||||
## Dynamically switching attention function
|
||||
|
||||
@ -70,12 +70,12 @@ model(torch.ones(1, 5, dtype=int))
|
||||
```
|
||||
|
||||
and it will stop printing the statements, as it now uses the `sdpa` attention.
|
||||
This allows to quickly change attention function, without needing to reload the model!
|
||||
This allows to quickly change an attention function, without needing to reload the model!
|
||||
|
||||
## What about new args needed in my custom function?
|
||||
## What about new args needed in my custom attention function?
|
||||
|
||||
But indeed, what if the new function requires a new arg to be properly used? It's no issue! Models supporting the
|
||||
`AttentionInterface` propagates kwargs all the way to the Attention layers, and to the attention function used. That way,
|
||||
`AttentionInterface` propagate kwargs all the way to the Attention layers, and to the used attention function. That way,
|
||||
you can simply pass the arg (as a kwargs, i.e. you need to qualify the name of the arg) in the model's forward, and it will be correctly used in the attention. However, custom attention functions have some limitations. In particular, it must follow the signature and return format of other attention functions, i.e.
|
||||
|
||||
```python
|
||||
@ -103,4 +103,26 @@ model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation="cust
|
||||
model(torch.ones(1, 5, dtype=int), a_new_kwargs=..., another_new_kwargs=...)
|
||||
```
|
||||
|
||||
If in doubt about what args/kwargs a given model sends to the attention function, simply check that model's modeling code on [GitHub](https://github.com/huggingface/transformers/tree/main/src/transformers/models)!
|
||||
If in doubt about what args/kwargs a given model sends to the attention function, simply check that model's modeling code on [GitHub](https://github.com/huggingface/transformers/tree/main/src/transformers/models)!
|
||||
|
||||
## Accessing current available implementations
|
||||
|
||||
Most of the time, you will simply need to `register` a new function. If, however, you need to access an existing one,
|
||||
and/or perform a few checks, the prefered way is to use the global `ALL_ATTENTION_FUNCTIONS`. It behaves the same way you
|
||||
would expect from a usual Python dictionary:
|
||||
|
||||
```python
|
||||
>>> from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||
|
||||
>>> list(ALL_ATTENTION_FUNCTIONS.keys())
|
||||
>>> ['flash_attention_2', 'flex_attention', 'sdpa']
|
||||
|
||||
>>> ALL_ATTENTION_FUNCTIONS["sdpa"]
|
||||
>>> <function transformers.integrations.sdpa_attention.sdpa_attention_forward>
|
||||
|
||||
>>> ALL_ATTENTION_FUNCTIONS.get("sdpa", None)
|
||||
>>> <function transformers.integrations.sdpa_attention.sdpa_attention_forward>
|
||||
|
||||
# You can also globally `register` a new function directly on it
|
||||
>>> ALL_ATTENTION_FUNCTIONS.register("new_func", new_func)
|
||||
```
|
||||
@ -181,35 +181,6 @@ processed_chat = processor.apply_chat_template(
|
||||
print(processed_chat.keys())
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="custom frame sampling">
|
||||
|
||||
Some models don't sample frames *uniformly* and require more complex logic to determine which frames to use. For example, the model may have an *adaptive frame selection* or if the model prioritizes *key moments* in a video rather than evenly spaced frames.
|
||||
|
||||
If a model has a different sampling strategy, you can write a function that customizes frame selection. The function should include the following requirements.
|
||||
|
||||
- Use the `sample_indices_fn` parameter to pass a callable function for sampling.
|
||||
- If provided, this function *overrides* the standard `num_frames` and `fps` parameters.
|
||||
- The function receives all the parameters passed to `load_video` and must return valid frame indices to sample from.
|
||||
|
||||
An example function is shown below. This gives you full control over frame selection, making the model more adaptable to different video scenarios.
|
||||
|
||||
```py
|
||||
def sample_indices_fn(metadata, **kwargs):
|
||||
# samples only the first and the second frame
|
||||
return [0, 1]
|
||||
|
||||
processed_chat = processor.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
sample_indices_fn=sample_indices_fn,
|
||||
video_load_backend="decord",
|
||||
)
|
||||
print(processed_chat.keys())
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="list of image frames">
|
||||
|
||||
|
||||
@ -43,4 +43,3 @@ Transformers is designed for developers and machine learning engineers and resea
|
||||
</a>
|
||||
</div>
|
||||
|
||||
Join us on the Hugging Face [Hub](https://huggingface.co/), [Discord](https://discord.com/invite/JfAtkvEtRb), or [forum](https://discuss.huggingface.co/) to collaborate and build models, datasets, and applications together.
|
||||
|
||||
@ -20,7 +20,7 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
# Installation
|
||||
|
||||
Transformers works with [PyTorch](https://pytorch.org/get-started/locally/), [TensorFlow 2.0](https://www.tensorflow.org/install/pip), and [Flax](https://flax.readthedocs.io/en/latest/). It has been tested on Python 3.9+, PyTorch 2.0+, TensorFlow 2.6+, and Flax 0.4.1+.
|
||||
Transformers works with [PyTorch](https://pytorch.org/get-started/locally/), [TensorFlow 2.0](https://www.tensorflow.org/install/pip), and [Flax](https://flax.readthedocs.io/en/latest/). It has been tested on Python 3.9+, PyTorch 2.1+, TensorFlow 2.6+, and Flax 0.4.1+.
|
||||
|
||||
## Virtual environment
|
||||
|
||||
|
||||
91
docs/source/en/internal/import_utils.md
Normal file
91
docs/source/en/internal/import_utils.md
Normal file
@ -0,0 +1,91 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# Import Utilities
|
||||
|
||||
This page goes through the transformers utilities to enable lazy and fast object import.
|
||||
While we strive for minimal dependencies, some models have specific dependencies requirements that cannot be
|
||||
worked around. We don't want for all users of `transformers` to have to install those dependencies to use other models,
|
||||
we therefore mark those as soft dependencies rather than hard dependencies.
|
||||
|
||||
The transformers toolkit is not made to error-out on import of a model that has a specific dependency; instead, an
|
||||
object for which you are lacking a dependency will error-out when calling any method on it. As an example, if
|
||||
`torchvision` isn't installed, the fast image processors will not be available.
|
||||
|
||||
This object is still importable:
|
||||
|
||||
```python
|
||||
>>> from transformers import DetrImageProcessorFast
|
||||
>>> print(DetrImageProcessorFast)
|
||||
<class 'DetrImageProcessorFast'>
|
||||
```
|
||||
|
||||
However, no method can be called on that object:
|
||||
|
||||
```python
|
||||
>>> DetrImageProcessorFast.from_pretrained()
|
||||
ImportError:
|
||||
DetrImageProcessorFast requires the Torchvision library but it was not found in your environment. Checkout the instructions on the
|
||||
installation page: https://pytorch.org/get-started/locally/ and follow the ones that match your environment.
|
||||
Please note that you may need to restart your runtime after installation.
|
||||
```
|
||||
|
||||
Let's see how to specify specific object dependencies.
|
||||
|
||||
## Specifying Object Dependencies
|
||||
|
||||
### Filename-based
|
||||
|
||||
All objects under a given filename have an automatic dependency to the tool linked to the filename
|
||||
|
||||
**TensorFlow**: All files starting with `modeling_tf_` have an automatic TensorFlow dependency.
|
||||
|
||||
**Flax**: All files starting with `modeling_flax_` have an automatic Flax dependency
|
||||
|
||||
**PyTorch**: All files starting with `modeling_` and not valid with the above (TensorFlow and Flax) have an automatic
|
||||
PyTorch dependency
|
||||
|
||||
**Tokenizers**: All files starting with `tokenization_` and ending with `_fast` have an automatic `tokenizers` dependency
|
||||
|
||||
**Vision**: All files starting with `image_processing_` have an automatic dependency to the `vision` dependency group;
|
||||
at the time of writing, this only contains the `pillow` dependency.
|
||||
|
||||
**Vision + Torch + Torchvision**: All files starting with `image_processing_` and ending with `_fast` have an automatic
|
||||
dependency to `vision`, `torch`, and `torchvision`.
|
||||
|
||||
All of these automatic dependencies are added on top of the explicit dependencies that are detailed below.
|
||||
|
||||
### Explicit Object Dependencies
|
||||
|
||||
We add a method called `requires` that is used to explicitly specify the dependencies of a given object. As an
|
||||
example, the `Trainer` class has two hard dependencies: `torch` and `accelerate`. Here is how we specify these
|
||||
required dependencies:
|
||||
|
||||
```python
|
||||
from .utils.import_utils import requires
|
||||
|
||||
@requires(backends=("torch", "accelerate"))
|
||||
class Trainer:
|
||||
...
|
||||
```
|
||||
|
||||
Backends that can be added here are all the backends that are available in the `import_utils.py` module.
|
||||
|
||||
## Methods
|
||||
|
||||
[[autodoc]] utils.import_utils.define_import_structure
|
||||
|
||||
[[autodoc]] utils.import_utils.requires
|
||||
@ -25,6 +25,10 @@ Most of those are only useful if you are studying the code of the models in the
|
||||
[[autodoc]] AttentionInterface
|
||||
- register
|
||||
|
||||
## Rotary Position Embedding Functions
|
||||
|
||||
[[autodoc]] dynamic_rope_update
|
||||
|
||||
## Pytorch custom modules
|
||||
|
||||
[[autodoc]] pytorch_utils.Conv1D
|
||||
|
||||
@ -93,7 +93,7 @@ model.generation_config.max_new_tokens = 16
|
||||
|
||||
past_key_values = StaticCache(
|
||||
config=model.config,
|
||||
batch_size=1,
|
||||
max_batch_size=1,
|
||||
# If you plan to reuse the cache, make sure the cache length is large enough for all cases
|
||||
max_cache_len=prompt_length+(model.generation_config.max_new_tokens*2),
|
||||
device=model.device,
|
||||
@ -159,7 +159,7 @@ from torch.nn.attention import SDPBackend, sdpa_kernel
|
||||
batch_size, seq_length = inputs["input_ids"].shape
|
||||
with torch.no_grad():
|
||||
past_key_values = StaticCache(
|
||||
config=model.config, batch_size=2, max_cache_len=4096, device=torch_device, dtype=model.dtype
|
||||
config=model.config, max_batch_size=2, max_cache_len=4096, device=torch_device, dtype=model.dtype
|
||||
)
|
||||
cache_position = torch.arange(seq_length, device=torch_device)
|
||||
generated_ids = torch.zeros(
|
||||
|
||||
@ -88,6 +88,11 @@ The original code can be found [here](https://github.com/salesforce/BLIP).
|
||||
[[autodoc]] BlipTextModel
|
||||
- forward
|
||||
|
||||
## BlipTextLMHeadModel
|
||||
|
||||
[[autodoc]] BlipTextLMHeadModel
|
||||
- forward
|
||||
|
||||
## BlipVisionModel
|
||||
|
||||
[[autodoc]] BlipVisionModel
|
||||
@ -123,6 +128,11 @@ The original code can be found [here](https://github.com/salesforce/BLIP).
|
||||
[[autodoc]] TFBlipTextModel
|
||||
- call
|
||||
|
||||
## TFBlipTextLMHeadModel
|
||||
|
||||
[[autodoc]] TFBlipTextLMHeadModel
|
||||
- forward
|
||||
|
||||
## TFBlipVisionModel
|
||||
|
||||
[[autodoc]] TFBlipVisionModel
|
||||
|
||||
@ -14,221 +14,77 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# CLIP
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="TensorFlow" src="https://img.shields.io/badge/TensorFlow-FF6F00?style=flat&logo=tensorflow&logoColor=white">
|
||||
<img alt="Flax" src="https://img.shields.io/badge/Flax-29a79b.svg?style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAC0AAAAtCAMAAAANxBKoAAAC7lBMVEUAAADg5vYHPVgAoJH+/v76+v39/f9JbLP///9+AIgAnY3///+mcqzt8fXy9fgkXa3Ax9709fr+///9/f8qXq49qp5AaLGMwrv8/P0eW60VWawxYq8yqJzG2dytt9Wyu9elzci519Lf3O3S2efY3OrY0+Xp7PT///////+dqNCexMc6Z7AGpJeGvbenstPZ5ejQ1OfJzOLa7ejh4+/r8fT29vpccbklWK8PVa0AS6ghW63O498vYa+lsdKz1NDRt9Kw1c672tbD3tnAxt7R6OHp5vDe7OrDyuDn6vLl6/EAQKak0MgATakkppo3ZK/Bz9y8w9yzu9jey97axdvHzeG21NHH4trTwthKZrVGZLSUSpuPQJiGAI+GAI8SWKydycLL4d7f2OTi1+S9xNzL0ePT6OLGzeEAo5U0qJw/aLEAo5JFa7JBabEAp5Y4qZ2QxLyKmsm3kL2xoMOehrRNb7RIbbOZgrGre68AUqwAqZqNN5aKJ5N/lMq+qsd8kMa4pcWzh7muhLMEV69juq2kbKqgUaOTR5uMMZWLLZSGAI5VAIdEAH+ovNDHuNCnxcy3qcaYx8K8msGplrx+wLahjbYdXrV6vbMvYK9DrZ8QrZ8tqJuFms+Sos6sw8ecy8RffsNVeMCvmb43aLltv7Q4Y7EZWK4QWa1gt6meZKUdr6GOAZVeA4xPAISyveLUwtivxtKTpNJ2jcqfvcltiMiwwcfAoMVxhL+Kx7xjdrqTe60tsaNQs6KaRKACrJ6UTZwkqpqTL5pkHY4AloSgsd2ptNXPvNOOncuxxsqFl8lmg8apt8FJcr9EbryGxLqlkrkrY7dRa7ZGZLQ5t6iXUZ6PPpgVpZeJCJFKAIGareTa0+KJod3H0deY2M+esM25usmYu8d2zsJOdcBVvrCLbqcAOaaHaKQAMaScWqKBXqCXMJ2RHpiLF5NmJZAdAHN2kta11dKu1M+DkcZLdb+Mcql3TppyRJdzQ5ZtNZNlIY+DF4+voCOQAAAAZ3RSTlMABAT+MEEJ/RH+/TP+Zlv+pUo6Ifz8+fco/fz6+evr39S9nJmOilQaF/7+/f38+smmoYp6b1T+/v7++vj189zU0tDJxsGzsrKSfv34+Pf27dDOysG9t6+n/vv6+vr59uzr1tG+tZ6Qg9Ym3QAABR5JREFUSMeNlVVUG1EQhpcuxEspXqS0SKEtxQp1d3d332STTRpIQhIISQgJhODu7lAoDoUCpe7u7u7+1puGpqnCPOyZvffbOXPm/PsP9JfQgyCC+tmTABTOcbxDz/heENS7/1F+9nhvkHePG0wNDLbGWwdXL+rbLWvpmZHXD8+gMfBjTh+aSe6Gnn7lwQIOTR0c8wfX3PWgv7avbdKwf/ZoBp1Gp/PvuvXW3vw5ib7emnTW4OR+3D4jB9vjNJ/7gNvfWWeH/TO/JyYrsiKCRjVEZA3UB+96kON+DxOQ/NLE8PE5iUYgIXjFnCOlxEQMaSGVxjg4gxOnEycGz8bptuNjVx08LscIgrzH3umcn+KKtiBIyvzOO2O99aAdR8cF19oZalnCtvREUw79tCd5sow1g1UKM6kXqUx4T8wsi3sTjJ3yzDmmhenLXLpo8u45eG5y4Vvbk6kkC4LLtJMowkSQxmk4ggVJEG+7c6QpHT8vvW9X7/o7+3ELmiJi2mEzZJiz8cT6TBlanBk70cB5GGIGC1gRDdZ00yADLW1FL6gqhtvNXNG5S9gdSrk4M1qu7JAsmYshzDS4peoMrU/gT7qQdqYGZaYhxZmVbGJAm/CS/HloWyhRUlknQ9KYcExTwS80d3VNOxUZJpITYyspl0LbhArhpZCD9cRWEQuhYkNGMHToQ/2Cs6swJlb39CsllxdXX6IUKh/H5jbnSsPKjgmoaFQ1f8wRLR0UnGE/RcDEjj2jXG1WVTwUs8+zxfcrVO+vSsuOpVKxCfYZiQ0/aPKuxQbQ8lIz+DClxC8u+snlcJ7Yr1z1JPqUH0V+GDXbOwAib931Y4Imaq0NTIXPXY+N5L18GJ37SVWu+hwXff8l72Ds9XuwYIBaXPq6Shm4l+Vl/5QiOlV+uTk6YR9PxKsI9xNJny31ygK1e+nIRC1N97EGkFPI+jCpiHe5PCEy7oWqWSwRrpOvhFzcbTWMbm3ZJAOn1rUKpYIt/lDhW/5RHHteeWFN60qo98YJuoq1nK3uW5AabyspC1BcIEpOhft+SZAShYoLSvnmSfnYADUERP5jJn2h5XtsgCRuhYQqAvwTwn33+YWEKUI72HX5AtfSAZDe8F2DtPPm77afhl0EkthzuCQU0BWApgQIH9+KB0JhopMM7bJrdTRoleM2JAVNMyPF+wdoaz+XJpGoVAQ7WXUkcV7gT3oUZyi/ISIJAVKhgNp+4b4veCFhYVJw4locdSjZCp9cPUhLF9EZ3KKzURepMEtCDPP3VcWFx4UIiZIklIpFNfHpdEafIF2aRmOcrUmjohbT2WUllbmRvgfbythbQO3222fpDJoufaQPncYYuqoGtUEsCJZL6/3PR5b4syeSjZMQG/T2maGANlXT2v8S4AULWaUkCxfLyW8iW4kdka+nEMjxpL2NCwsYNBp+Q61PF43zyDg9Bm9+3NNySn78jMZUUkumqE4Gp7JmFOdP1vc8PpRrzj9+wPinCy8K1PiJ4aYbnTYpCCbDkBSbzhu2QJ1Gd82t8jI8TH51+OzvXoWbnXUOBkNW+0mWFwGcGOUVpU81/n3TOHb5oMt2FgYGjzau0Nif0Ss7Q3XB33hjjQHjHA5E5aOyIQc8CBrLdQSs3j92VG+3nNEjbkbdbBr9zm04ruvw37vh0QKOdeGIkckc80fX3KH/h7PT4BOjgCty8VZ5ux1MoO5Cf5naca2LAsEgehI+drX8o/0Nu+W0m6K/I9gGPd/dfx/EN/wN62AhsBWuAAAAAElFTkSuQmCC
|
||||
">
|
||||
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<div style="float: right;">
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="TensorFlow" src="https://img.shields.io/badge/TensorFlow-FF6F00?style=flat&logo=tensorflow&logoColor=white">
|
||||
<img alt="Flax" src="https://img.shields.io/badge/Flax-29a79b.svg?style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAC0AAAAtCAMAAAANxBKoAAAC7lBMVEUAAADg5vYHPVgAoJH+/v76+v39/f9JbLP///9+AIgAnY3///+mcqzt8fXy9fgkXa3Ax9709fr+///9/f8qXq49qp5AaLGMwrv8/P0eW60VWawxYq8yqJzG2dytt9Wyu9elzci519Lf3O3S2efY3OrY0+Xp7PT///////+dqNCexMc6Z7AGpJeGvbenstPZ5ejQ1OfJzOLa7ejh4+/r8fT29vpccbklWK8PVa0AS6ghW63O498vYa+lsdKz1NDRt9Kw1c672tbD3tnAxt7R6OHp5vDe7OrDyuDn6vLl6/EAQKak0MgATakkppo3ZK/Bz9y8w9yzu9jey97axdvHzeG21NHH4trTwthKZrVGZLSUSpuPQJiGAI+GAI8SWKydycLL4d7f2OTi1+S9xNzL0ePT6OLGzeEAo5U0qJw/aLEAo5JFa7JBabEAp5Y4qZ2QxLyKmsm3kL2xoMOehrRNb7RIbbOZgrGre68AUqwAqZqNN5aKJ5N/lMq+qsd8kMa4pcWzh7muhLMEV69juq2kbKqgUaOTR5uMMZWLLZSGAI5VAIdEAH+ovNDHuNCnxcy3qcaYx8K8msGplrx+wLahjbYdXrV6vbMvYK9DrZ8QrZ8tqJuFms+Sos6sw8ecy8RffsNVeMCvmb43aLltv7Q4Y7EZWK4QWa1gt6meZKUdr6GOAZVeA4xPAISyveLUwtivxtKTpNJ2jcqfvcltiMiwwcfAoMVxhL+Kx7xjdrqTe60tsaNQs6KaRKACrJ6UTZwkqpqTL5pkHY4AloSgsd2ptNXPvNOOncuxxsqFl8lmg8apt8FJcr9EbryGxLqlkrkrY7dRa7ZGZLQ5t6iXUZ6PPpgVpZeJCJFKAIGareTa0+KJod3H0deY2M+esM25usmYu8d2zsJOdcBVvrCLbqcAOaaHaKQAMaScWqKBXqCXMJ2RHpiLF5NmJZAdAHN2kta11dKu1M+DkcZLdb+Mcql3TppyRJdzQ5ZtNZNlIY+DF4+voCOQAAAAZ3RSTlMABAT+MEEJ/RH+/TP+Zlv+pUo6Ifz8+fco/fz6+evr39S9nJmOilQaF/7+/f38+smmoYp6b1T+/v7++vj189zU0tDJxsGzsrKSfv34+Pf27dDOysG9t6+n/vv6+vr59uzr1tG+tZ6Qg9Ym3QAABR5JREFUSMeNlVVUG1EQhpcuxEspXqS0SKEtxQp1d3d332STTRpIQhIISQgJhODu7lAoDoUCpe7u7u7+1puGpqnCPOyZvffbOXPm/PsP9JfQgyCC+tmTABTOcbxDz/heENS7/1F+9nhvkHePG0wNDLbGWwdXL+rbLWvpmZHXD8+gMfBjTh+aSe6Gnn7lwQIOTR0c8wfX3PWgv7avbdKwf/ZoBp1Gp/PvuvXW3vw5ib7emnTW4OR+3D4jB9vjNJ/7gNvfWWeH/TO/JyYrsiKCRjVEZA3UB+96kON+DxOQ/NLE8PE5iUYgIXjFnCOlxEQMaSGVxjg4gxOnEycGz8bptuNjVx08LscIgrzH3umcn+KKtiBIyvzOO2O99aAdR8cF19oZalnCtvREUw79tCd5sow1g1UKM6kXqUx4T8wsi3sTjJ3yzDmmhenLXLpo8u45eG5y4Vvbk6kkC4LLtJMowkSQxmk4ggVJEG+7c6QpHT8vvW9X7/o7+3ELmiJi2mEzZJiz8cT6TBlanBk70cB5GGIGC1gRDdZ00yADLW1FL6gqhtvNXNG5S9gdSrk4M1qu7JAsmYshzDS4peoMrU/gT7qQdqYGZaYhxZmVbGJAm/CS/HloWyhRUlknQ9KYcExTwS80d3VNOxUZJpITYyspl0LbhArhpZCD9cRWEQuhYkNGMHToQ/2Cs6swJlb39CsllxdXX6IUKh/H5jbnSsPKjgmoaFQ1f8wRLR0UnGE/RcDEjj2jXG1WVTwUs8+zxfcrVO+vSsuOpVKxCfYZiQ0/aPKuxQbQ8lIz+DClxC8u+snlcJ7Yr1z1JPqUH0V+GDXbOwAib931Y4Imaq0NTIXPXY+N5L18GJ37SVWu+hwXff8l72Ds9XuwYIBaXPq6Shm4l+Vl/5QiOlV+uTk6YR9PxKsI9xNJny31ygK1e+nIRC1N97EGkFPI+jCpiHe5PCEy7oWqWSwRrpOvhFzcbTWMbm3ZJAOn1rUKpYIt/lDhW/5RHHteeWFN60qo98YJuoq1nK3uW5AabyspC1BcIEpOhft+SZAShYoLSvnmSfnYADUERP5jJn2h5XtsgCRuhYQqAvwTwn33+YWEKUI72HX5AtfSAZDe8F2DtPPm77afhl0EkthzuCQU0BWApgQIH9+KB0JhopMM7bJrdTRoleM2JAVNMyPF+wdoaz+XJpGoVAQ7WXUkcV7gT3oUZyi/ISIJAVKhgNp+4b4veCFhYVJw4locdSjZCp9cPUhLF9EZ3KKzURepMEtCDPP3VcWFx4UIiZIklIpFNfHpdEafIF2aRmOcrUmjohbT2WUllbmRvgfbythbQO3222fpDJoufaQPncYYuqoGtUEsCJZL6/3PR5b4syeSjZMQG/T2maGANlXT2v8S4AULWaUkCxfLyW8iW4kdka+nEMjxpL2NCwsYNBp+Q61PF43zyDg9Bm9+3NNySn78jMZUUkumqE4Gp7JmFOdP1vc8PpRrzj9+wPinCy8K1PiJ4aYbnTYpCCbDkBSbzhu2QJ1Gd82t8jI8TH51+OzvXoWbnXUOBkNW+0mWFwGcGOUVpU81/n3TOHb5oMt2FgYGjzau0Nif0Ss7Q3XB33hjjQHjHA5E5aOyIQc8CBrLdQSs3j92VG+3nNEjbkbdbBr9zm04ruvw37vh0QKOdeGIkckc80fX3KH/h7PT4BOjgCty8VZ5ux1MoO5Cf5naca2LAsEgehI+drX8o/0Nu+W0m6K/I9gGPd/dfx/EN/wN62AhsBWuAAAAAElFTkSuQmCC
|
||||
">
|
||||
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
</div>
|
||||
|
||||
## Overview
|
||||
# CLIP
|
||||
|
||||
The CLIP model was proposed in [Learning Transferable Visual Models From Natural Language Supervision](https://arxiv.org/abs/2103.00020) by Alec Radford, Jong Wook Kim, Chris Hallacy, Aditya Ramesh, Gabriel Goh,
|
||||
Sandhini Agarwal, Girish Sastry, Amanda Askell, Pamela Mishkin, Jack Clark, Gretchen Krueger, Ilya Sutskever. CLIP
|
||||
(Contrastive Language-Image Pre-Training) is a neural network trained on a variety of (image, text) pairs. It can be
|
||||
instructed in natural language to predict the most relevant text snippet, given an image, without directly optimizing
|
||||
for the task, similarly to the zero-shot capabilities of GPT-2 and 3.
|
||||
[CLIP](https://huggingface.co/papers/2103.00020) is a is a multimodal vision and language model motivated by overcoming the fixed number of object categories when training a computer vision model. CLIP learns about images directly from raw text by jointly training on 400M (image, text) pairs. Pretraining on this scale enables zero-shot transfer to downstream tasks. CLIP uses an image encoder and text encoder to get visual features and text features. Both features are projected to a latent space with the same number of dimensions and their dot product gives a similarity score.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
You can find all the original CLIP checkpoints under the [OpenAI](https://huggingface.co/openai?search_models=clip) organization.
|
||||
|
||||
*State-of-the-art computer vision systems are trained to predict a fixed set of predetermined object categories. This
|
||||
restricted form of supervision limits their generality and usability since additional labeled data is needed to specify
|
||||
any other visual concept. Learning directly from raw text about images is a promising alternative which leverages a
|
||||
much broader source of supervision. We demonstrate that the simple pre-training task of predicting which caption goes
|
||||
with which image is an efficient and scalable way to learn SOTA image representations from scratch on a dataset of 400
|
||||
million (image, text) pairs collected from the internet. After pre-training, natural language is used to reference
|
||||
learned visual concepts (or describe new ones) enabling zero-shot transfer of the model to downstream tasks. We study
|
||||
the performance of this approach by benchmarking on over 30 different existing computer vision datasets, spanning tasks
|
||||
such as OCR, action recognition in videos, geo-localization, and many types of fine-grained object classification. The
|
||||
model transfers non-trivially to most tasks and is often competitive with a fully supervised baseline without the need
|
||||
for any dataset specific training. For instance, we match the accuracy of the original ResNet-50 on ImageNet zero-shot
|
||||
without needing to use any of the 1.28 million training examples it was trained on. We release our code and pre-trained
|
||||
model weights at this https URL.*
|
||||
> [!TIP]
|
||||
> Click on the CLIP models in the right sidebar for more examples of how to apply CLIP to different image and language tasks.
|
||||
|
||||
This model was contributed by [valhalla](https://huggingface.co/valhalla). The original code can be found [here](https://github.com/openai/CLIP).
|
||||
The example below demonstrates how to calculate similarity scores between multiple text descriptions and an image with [`Pipeline`] or the [`AutoModel`] class.
|
||||
|
||||
## Usage tips and example
|
||||
<hfoptions id="usage">
|
||||
<hfoption id="Pipeline">
|
||||
|
||||
CLIP is a multi-modal vision and language model. It can be used for image-text similarity and for zero-shot image
|
||||
classification. CLIP uses a ViT like transformer to get visual features and a causal language model to get the text
|
||||
features. Both the text and visual features are then projected to a latent space with identical dimension. The dot
|
||||
product between the projected image and text features is then used as a similar score.
|
||||
```py
|
||||
import torch
|
||||
from transformers import pipeline
|
||||
|
||||
To feed images to the Transformer encoder, each image is split into a sequence of fixed-size non-overlapping patches,
|
||||
which are then linearly embedded. A [CLS] token is added to serve as representation of an entire image. The authors
|
||||
also add absolute position embeddings, and feed the resulting sequence of vectors to a standard Transformer encoder.
|
||||
The [`CLIPImageProcessor`] can be used to resize (or rescale) and normalize images for the model.
|
||||
|
||||
The [`CLIPTokenizer`] is used to encode the text. The [`CLIPProcessor`] wraps
|
||||
[`CLIPImageProcessor`] and [`CLIPTokenizer`] into a single instance to both
|
||||
encode the text and prepare the images. The following example shows how to get the image-text similarity scores using
|
||||
[`CLIPProcessor`] and [`CLIPModel`].
|
||||
|
||||
|
||||
```python
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
|
||||
>>> from transformers import CLIPProcessor, CLIPModel
|
||||
|
||||
>>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
||||
>>> processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
||||
|
||||
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
>>> inputs = processor(text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True)
|
||||
|
||||
>>> outputs = model(**inputs)
|
||||
>>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score
|
||||
>>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
|
||||
clip = pipeline(
|
||||
task="zero-shot-image-classification",
|
||||
model="openai/clip-vit-base-patch32",
|
||||
torch_dtype=torch.bfloat16,
|
||||
device=0
|
||||
)
|
||||
labels = ["a photo of a cat", "a photo of a dog", "a photo of a car"]
|
||||
clip("http://images.cocodataset.org/val2017/000000039769.jpg", candidate_labels=labels)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="AutoModel">
|
||||
|
||||
### Combining CLIP and Flash Attention 2
|
||||
```py
|
||||
import requests
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import AutoProcessor, AutoModel
|
||||
|
||||
First, make sure to install the latest version of Flash Attention 2.
|
||||
model = AutoModel.from_pretrained("openai/clip-vit-base-patch32", torch_dtype=torch.bfloat16, attn_implementation="sdpa")
|
||||
processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
||||
|
||||
```bash
|
||||
pip install -U flash-attn --no-build-isolation
|
||||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
image = Image.open(requests.get(url, stream=True).raw)
|
||||
labels = ["a photo of a cat", "a photo of a dog", "a photo of a car"]
|
||||
|
||||
inputs = processor(text=labels, images=image, return_tensors="pt", padding=True)
|
||||
|
||||
outputs = model(**inputs)
|
||||
logits_per_image = outputs.logits_per_image
|
||||
probs = logits_per_image.softmax(dim=1)
|
||||
most_likely_idx = probs.argmax(dim=1).item()
|
||||
most_likely_label = labels[most_likely_idx]
|
||||
print(f"Most likely label: {most_likely_label} with probability: {probs[0][most_likely_idx].item():.3f}")
|
||||
```
|
||||
|
||||
Make also sure that you have a hardware that is compatible with Flash-Attention 2. Read more about it in the official documentation of flash-attn repository. Make also sure to load your model in half-precision (e.g. `torch.float16`)
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
<Tip warning={true}>
|
||||
## Notes
|
||||
|
||||
For small batch sizes, you might notice a slowdown in your model when using flash attention. Refer to the section [Expected speedups with Flash Attention and SDPA](#Expected-speedups-with-Flash-Attention-and-SDPA) below and select an appropriate attention implementation.
|
||||
|
||||
</Tip>
|
||||
|
||||
To load and run a model using Flash Attention 2, refer to the snippet below:
|
||||
|
||||
```python
|
||||
>>> import torch
|
||||
>>> import requests
|
||||
>>> from PIL import Image
|
||||
|
||||
>>> from transformers import CLIPProcessor, CLIPModel
|
||||
|
||||
>>> device = "cuda"
|
||||
>>> torch_dtype = torch.float16
|
||||
|
||||
>>> model = CLIPModel.from_pretrained(
|
||||
... "openai/clip-vit-base-patch32",
|
||||
... attn_implementation="flash_attention_2",
|
||||
... device_map=device,
|
||||
... torch_dtype=torch_dtype,
|
||||
... )
|
||||
>>> processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
||||
|
||||
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
>>> inputs = processor(text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True)
|
||||
>>> inputs.to(device)
|
||||
|
||||
>>> with torch.no_grad():
|
||||
... with torch.autocast(device):
|
||||
... outputs = model(**inputs)
|
||||
|
||||
>>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score
|
||||
>>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
|
||||
>>> print(probs)
|
||||
tensor([[0.9946, 0.0052]], device='cuda:0', dtype=torch.float16)
|
||||
```
|
||||
|
||||
|
||||
### Using Scaled Dot Product Attention (SDPA)
|
||||
|
||||
PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function
|
||||
encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the
|
||||
[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
|
||||
or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention)
|
||||
page for more information.
|
||||
|
||||
SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set
|
||||
`attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used.
|
||||
|
||||
```python
|
||||
from transformers import CLIPModel
|
||||
|
||||
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32", torch_dtype=torch.float16, attn_implementation="sdpa")
|
||||
```
|
||||
|
||||
For the best speedups, we recommend loading the model in half-precision (e.g. `torch.float16` or `torch.bfloat16`).
|
||||
|
||||
### Expected speedups with Flash Attention and SDPA
|
||||
|
||||
On a local benchmark (NVIDIA A10G, PyTorch 2.3.1+cu121) with `float16`, we saw the following speedups during inference for `"openai/clip-vit-large-patch14"` checkpoint ([code](https://gist.github.com/qubvel/ac691a54e54f9fae8144275f866a7ff8)):
|
||||
|
||||
#### CLIPTextModel
|
||||
|
||||
| Num text labels | Eager (s/iter) | FA2 (s/iter) | FA2 speedup | SDPA (s/iter) | SDPA speedup |
|
||||
|------------------:|-----------------:|---------------:|--------------:|----------------:|---------------:|
|
||||
| 4 | 0.009 | 0.012 | 0.737 | 0.007 | 1.269 |
|
||||
| 16 | 0.009 | 0.014 | 0.659 | 0.008 | 1.187 |
|
||||
| 32 | 0.018 | 0.021 | 0.862 | 0.016 | 1.142 |
|
||||
| 64 | 0.034 | 0.034 | 1.001 | 0.03 | 1.163 |
|
||||
| 128 | 0.063 | 0.058 | 1.09 | 0.054 | 1.174 |
|
||||
|
||||

|
||||
|
||||
#### CLIPVisionModel
|
||||
|
||||
| Image batch size | Eager (s/iter) | FA2 (s/iter) | FA2 speedup | SDPA (s/iter) | SDPA speedup |
|
||||
|-------------------:|-----------------:|---------------:|--------------:|----------------:|---------------:|
|
||||
| 1 | 0.016 | 0.013 | 1.247 | 0.012 | 1.318 |
|
||||
| 4 | 0.025 | 0.021 | 1.198 | 0.021 | 1.202 |
|
||||
| 16 | 0.093 | 0.075 | 1.234 | 0.075 | 1.24 |
|
||||
| 32 | 0.181 | 0.147 | 1.237 | 0.146 | 1.241 |
|
||||
|
||||

|
||||
|
||||
#### CLIPModel
|
||||
|
||||
| Image batch size | Num text labels | Eager (s/iter) | FA2 (s/iter) | FA2 speedup | SDPA (s/iter) | SDPA speedup |
|
||||
|-------------------:|------------------:|-----------------:|---------------:|--------------:|----------------:|---------------:|
|
||||
| 1 | 4 | 0.025 | 0.026 | 0.954 | 0.02 | 1.217 |
|
||||
| 1 | 16 | 0.026 | 0.028 | 0.918 | 0.02 | 1.287 |
|
||||
| 1 | 64 | 0.042 | 0.046 | 0.906 | 0.036 | 1.167 |
|
||||
| 4 | 4 | 0.028 | 0.033 | 0.849 | 0.024 | 1.189 |
|
||||
| 4 | 16 | 0.034 | 0.035 | 0.955 | 0.029 | 1.169 |
|
||||
| 4 | 64 | 0.059 | 0.055 | 1.072 | 0.05 | 1.179 |
|
||||
| 16 | 4 | 0.096 | 0.088 | 1.091 | 0.078 | 1.234 |
|
||||
| 16 | 16 | 0.102 | 0.09 | 1.129 | 0.083 | 1.224 |
|
||||
| 16 | 64 | 0.127 | 0.11 | 1.157 | 0.105 | 1.218 |
|
||||
| 32 | 4 | 0.185 | 0.159 | 1.157 | 0.149 | 1.238 |
|
||||
| 32 | 16 | 0.19 | 0.162 | 1.177 | 0.154 | 1.233 |
|
||||
| 32 | 64 | 0.216 | 0.181 | 1.19 | 0.176 | 1.228 |
|
||||
|
||||
## Resources
|
||||
|
||||
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with CLIP.
|
||||
|
||||
- [Fine tuning CLIP with Remote Sensing (Satellite) images and captions](https://huggingface.co/blog/fine-tune-clip-rsicd), a blog post about how to fine-tune CLIP with [RSICD dataset](https://github.com/201528014227051/RSICD_optimal) and comparison of performance changes due to data augmentation.
|
||||
- This [example script](https://github.com/huggingface/transformers/tree/main/examples/pytorch/contrastive-image-text) shows how to train a CLIP-like vision-text dual encoder model using a pre-trained vision and text encoder using [COCO dataset](https://cocodataset.org/#home).
|
||||
|
||||
<PipelineTag pipeline="image-to-text"/>
|
||||
|
||||
- A [notebook](https://colab.research.google.com/drive/1tuoAC5F4sC7qid56Z0ap-stR3rwdk0ZV?usp=sharing) on how to use a pretrained CLIP for inference with beam search for image captioning. 🌎
|
||||
|
||||
**Image retrieval**
|
||||
|
||||
- A [notebook](https://colab.research.google.com/drive/1bLVwVKpAndpEDHqjzxVPr_9nGrSbuOQd?usp=sharing) on image retrieval using pretrained CLIP and computing MRR(Mean Reciprocal Rank) score. 🌎
|
||||
- A [notebook](https://colab.research.google.com/github/deep-diver/image_search_with_natural_language/blob/main/notebooks/Image_Search_CLIP.ipynb) on image retrieval and showing the similarity score. 🌎
|
||||
- A [notebook](https://colab.research.google.com/drive/1xO-wC_m_GNzgjIBQ4a4znvQkvDoZJvH4?usp=sharing) on how to map images and texts to the same vector space using Multilingual CLIP. 🌎
|
||||
- A [notebook](https://colab.research.google.com/github/vivien000/clip-demo/blob/master/clip.ipynb#scrollTo=uzdFhRGqiWkR) on how to run CLIP on semantic image search using [Unsplash](https://unsplash.com) and [TMDB](https://www.themoviedb.org/) datasets. 🌎
|
||||
|
||||
**Explainability**
|
||||
|
||||
- A [notebook](https://colab.research.google.com/github/hila-chefer/Transformer-MM-Explainability/blob/main/CLIP_explainability.ipynb) on how to visualize similarity between input token and image segment. 🌎
|
||||
|
||||
If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we will review it.
|
||||
The resource should ideally demonstrate something new instead of duplicating an existing resource.
|
||||
- Use [`CLIPImageProcessor`] to resize (or rescale) and normalizes images for the model.
|
||||
|
||||
## CLIPConfig
|
||||
|
||||
|
||||
@ -14,108 +14,154 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# CodeLlama
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="Flax" src="https://img.shields.io/badge/Flax-29a79b.svg?style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAC0AAAAtCAMAAAANxBKoAAAC7lBMVEUAAADg5vYHPVgAoJH+/v76+v39/f9JbLP///9+AIgAnY3///+mcqzt8fXy9fgkXa3Ax9709fr+///9/f8qXq49qp5AaLGMwrv8/P0eW60VWawxYq8yqJzG2dytt9Wyu9elzci519Lf3O3S2efY3OrY0+Xp7PT///////+dqNCexMc6Z7AGpJeGvbenstPZ5ejQ1OfJzOLa7ejh4+/r8fT29vpccbklWK8PVa0AS6ghW63O498vYa+lsdKz1NDRt9Kw1c672tbD3tnAxt7R6OHp5vDe7OrDyuDn6vLl6/EAQKak0MgATakkppo3ZK/Bz9y8w9yzu9jey97axdvHzeG21NHH4trTwthKZrVGZLSUSpuPQJiGAI+GAI8SWKydycLL4d7f2OTi1+S9xNzL0ePT6OLGzeEAo5U0qJw/aLEAo5JFa7JBabEAp5Y4qZ2QxLyKmsm3kL2xoMOehrRNb7RIbbOZgrGre68AUqwAqZqNN5aKJ5N/lMq+qsd8kMa4pcWzh7muhLMEV69juq2kbKqgUaOTR5uMMZWLLZSGAI5VAIdEAH+ovNDHuNCnxcy3qcaYx8K8msGplrx+wLahjbYdXrV6vbMvYK9DrZ8QrZ8tqJuFms+Sos6sw8ecy8RffsNVeMCvmb43aLltv7Q4Y7EZWK4QWa1gt6meZKUdr6GOAZVeA4xPAISyveLUwtivxtKTpNJ2jcqfvcltiMiwwcfAoMVxhL+Kx7xjdrqTe60tsaNQs6KaRKACrJ6UTZwkqpqTL5pkHY4AloSgsd2ptNXPvNOOncuxxsqFl8lmg8apt8FJcr9EbryGxLqlkrkrY7dRa7ZGZLQ5t6iXUZ6PPpgVpZeJCJFKAIGareTa0+KJod3H0deY2M+esM25usmYu8d2zsJOdcBVvrCLbqcAOaaHaKQAMaScWqKBXqCXMJ2RHpiLF5NmJZAdAHN2kta11dKu1M+DkcZLdb+Mcql3TppyRJdzQ5ZtNZNlIY+DF4+voCOQAAAAZ3RSTlMABAT+MEEJ/RH+/TP+Zlv+pUo6Ifz8+fco/fz6+evr39S9nJmOilQaF/7+/f38+smmoYp6b1T+/v7++vj189zU0tDJxsGzsrKSfv34+Pf27dDOysG9t6+n/vv6+vr59uzr1tG+tZ6Qg9Ym3QAABR5JREFUSMeNlVVUG1EQhpcuxEspXqS0SKEtxQp1d3d332STTRpIQhIISQgJhODu7lAoDoUCpe7u7u7+1puGpqnCPOyZvffbOXPm/PsP9JfQgyCC+tmTABTOcbxDz/heENS7/1F+9nhvkHePG0wNDLbGWwdXL+rbLWvpmZHXD8+gMfBjTh+aSe6Gnn7lwQIOTR0c8wfX3PWgv7avbdKwf/ZoBp1Gp/PvuvXW3vw5ib7emnTW4OR+3D4jB9vjNJ/7gNvfWWeH/TO/JyYrsiKCRjVEZA3UB+96kON+DxOQ/NLE8PE5iUYgIXjFnCOlxEQMaSGVxjg4gxOnEycGz8bptuNjVx08LscIgrzH3umcn+KKtiBIyvzOO2O99aAdR8cF19oZalnCtvREUw79tCd5sow1g1UKM6kXqUx4T8wsi3sTjJ3yzDmmhenLXLpo8u45eG5y4Vvbk6kkC4LLtJMowkSQxmk4ggVJEG+7c6QpHT8vvW9X7/o7+3ELmiJi2mEzZJiz8cT6TBlanBk70cB5GGIGC1gRDdZ00yADLW1FL6gqhtvNXNG5S9gdSrk4M1qu7JAsmYshzDS4peoMrU/gT7qQdqYGZaYhxZmVbGJAm/CS/HloWyhRUlknQ9KYcExTwS80d3VNOxUZJpITYyspl0LbhArhpZCD9cRWEQuhYkNGMHToQ/2Cs6swJlb39CsllxdXX6IUKh/H5jbnSsPKjgmoaFQ1f8wRLR0UnGE/RcDEjj2jXG1WVTwUs8+zxfcrVO+vSsuOpVKxCfYZiQ0/aPKuxQbQ8lIz+DClxC8u+snlcJ7Yr1z1JPqUH0V+GDXbOwAib931Y4Imaq0NTIXPXY+N5L18GJ37SVWu+hwXff8l72Ds9XuwYIBaXPq6Shm4l+Vl/5QiOlV+uTk6YR9PxKsI9xNJny31ygK1e+nIRC1N97EGkFPI+jCpiHe5PCEy7oWqWSwRrpOvhFzcbTWMbm3ZJAOn1rUKpYIt/lDhW/5RHHteeWFN60qo98YJuoq1nK3uW5AabyspC1BcIEpOhft+SZAShYoLSvnmSfnYADUERP5jJn2h5XtsgCRuhYQqAvwTwn33+YWEKUI72HX5AtfSAZDe8F2DtPPm77afhl0EkthzuCQU0BWApgQIH9+KB0JhopMM7bJrdTRoleM2JAVNMyPF+wdoaz+XJpGoVAQ7WXUkcV7gT3oUZyi/ISIJAVKhgNp+4b4veCFhYVJw4locdSjZCp9cPUhLF9EZ3KKzURepMEtCDPP3VcWFx4UIiZIklIpFNfHpdEafIF2aRmOcrUmjohbT2WUllbmRvgfbythbQO3222fpDJoufaQPncYYuqoGtUEsCJZL6/3PR5b4syeSjZMQG/T2maGANlXT2v8S4AULWaUkCxfLyW8iW4kdka+nEMjxpL2NCwsYNBp+Q61PF43zyDg9Bm9+3NNySn78jMZUUkumqE4Gp7JmFOdP1vc8PpRrzj9+wPinCy8K1PiJ4aYbnTYpCCbDkBSbzhu2QJ1Gd82t8jI8TH51+OzvXoWbnXUOBkNW+0mWFwGcGOUVpU81/n3TOHb5oMt2FgYGjzau0Nif0Ss7Q3XB33hjjQHjHA5E5aOyIQc8CBrLdQSs3j92VG+3nNEjbkbdbBr9zm04ruvw37vh0QKOdeGIkckc80fX3KH/h7PT4BOjgCty8VZ5ux1MoO5Cf5naca2LAsEgehI+drX8o/0Nu+W0m6K/I9gGPd/dfx/EN/wN62AhsBWuAAAAAElFTkSuQmCC
|
||||
">
|
||||
<div style="float: right;">
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="Flax" src="https://img.shields.io/badge/Flax-29a79b.svg?style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAC0AAAAtCAMAAAANxBKoAAAC7lBMVEUAAADg5vYHPVgAoJH+/v76+v39/f9JbLP///9+AIgAnY3///+mcqzt8fXy9fgkXa3Ax9709fr+///9/f8qXq49qp5AaLGMwrv8/P0eW60VWawxYq8yqJzG2dytt9Wyu9elzci519Lf3O3S2efY3OrY0+Xp7PT///////+dqNCexMc6Z7AGpJeGvbenstPZ5ejQ1OfJzOLa7ejh4+/r8fT29vpccbklWK8PVa0AS6ghW63O498vYa+lsdKz1NDRt9Kw1c672tbD3tnAxt7R6OHp5vDe7OrDyuDn6vLl6/EAQKak0MgATakkppo3ZK/Bz9y8w9yzu9jey97axdvHzeG21NHH4trTwthKZrVGZLSUSpuPQJiGAI+GAI8SWKydycLL4d7f2OTi1+S9xNzL0ePT6OLGzeEAo5U0qJw/aLEAo5JFa7JBabEAp5Y4qZ2QxLyKmsm3kL2xoMOehrRNb7RIbbOZgrGre68AUqwAqZqNN5aKJ5N/lMq+qsd8kMa4pcWzh7muhLMEV69juq2kbKqgUaOTR5uMMZWLLZSGAI5VAIdEAH+ovNDHuNCnxcy3qcaYx8K8msGplrx+wLahjbYdXrV6vbMvYK9DrZ8QrZ8tqJuFms+Sos6sw8ecy8RffsNVeMCvmb43aLltv7Q4Y7EZWK4QWa1gt6meZKUdr6GOAZVeA4xPAISyveLUwtivxtKTpNJ2jcqfvcltiMiwwcfAoMVxhL+Kx7xjdrqTe60tsaNQs6KaRKACrJ6UTZwkqpqTL5pkHY4AloSgsd2ptNXPvNOOncuxxsqFl8lmg8apt8FJcr9EbryGxLqlkrkrY7dRa7ZGZLQ5t6iXUZ6PPpgVpZeJCJFKAIGareTa0+KJod3H0deY2M+esM25usmYu8d2zsJOdcBVvrCLbqcAOaaHaKQAMaScWqKBXqCXMJ2RHpiLF5NmJZAdAHN2kta11dKu1M+DkcZLdb+Mcql3TppyRJdzQ5ZtNZNlIY+DF4+voCOQAAAAZ3RSTlMABAT+MEEJ/RH+/TP+Zlv+pUo6Ifz8+fco/fz6+evr39S9nJmOilQaF/7+/f38+smmoYp6b1T+/v7++vj189zU0tDJxsGzsrKSfv34+Pf27dDOysG9t6+n/vv6+vr59uzr1tG+tZ6Qg9Ym3QAABR5JREFUSMeNlVVUG1EQhpcuxEspXqS0SKEtxQp1d3d332STTRpIQhIISQgJhODu7lAoDoUCpe7u7u7+1puGpqnCPOyZvffbOXPm/PsP9JfQgyCC+tmTABTOcbxDz/heENS7/1F+9nhvkHePG0wNDLbGWwdXL+rbLWvpmZHXD8+gMfBjTh+aSe6Gnn7lwQIOTR0c8wfX3PWgv7avbdKwf/ZoBp1Gp/PvuvXW3vw5ib7emnTW4OR+3D4jB9vjNJ/7gNvfWWeH/TO/JyYrsiKCRjVEZA3UB+96kON+DxOQ/NLE8PE5iUYgIXjFnCOlxEQMaSGVxjg4gxOnEycGz8bptuNjVx08LscIgrzH3umcn+KKtiBIyvzOO2O99aAdR8cF19oZalnCtvREUw79tCd5sow1g1UKM6kXqUx4T8wsi3sTjJ3yzDmmhenLXLpo8u45eG5y4Vvbk6kkC4LLtJMowkSQxmk4ggVJEG+7c6QpHT8vvW9X7/o7+3ELmiJi2mEzZJiz8cT6TBlanBk70cB5GGIGC1gRDdZ00yADLW1FL6gqhtvNXNG5S9gdSrk4M1qu7JAsmYshzDS4peoMrU/gT7qQdqYGZaYhxZmVbGJAm/CS/HloWyhRUlknQ9KYcExTwS80d3VNOxUZJpITYyspl0LbhArhpZCD9cRWEQuhYkNGMHToQ/2Cs6swJlb39CsllxdXX6IUKh/H5jbnSsPKjgmoaFQ1f8wRLR0UnGE/RcDEjj2jXG1WVTwUs8+zxfcrVO+vSsuOpVKxCfYZiQ0/aPKuxQbQ8lIz+DClxC8u+snlcJ7Yr1z1JPqUH0V+GDXbOwAib931Y4Imaq0NTIXPXY+N5L18GJ37SVWu+hwXff8l72Ds9XuwYIBaXPq6Shm4l+Vl/5QiOlV+uTk6YR9PxKsI9xNJny31ygK1e+nIRC1N97EGkFPI+jCpiHe5PCEy7oWqWSwRrpOvhFzcbTWMbm3ZJAOn1rUKpYIt/lDhW/5RHHteeWFN60qo98YJuoq1nK3uW5AabyspC1BcIEpOhft+SZAShYoLSvnmSfnYADUERP5jJn2h5XtsgCRuhYQqAvwTwn33+YWEKUI72HX5AtfSAZDe8F2DtPPm77afhl0EkthzuCQU0BWApgQIH9+KB0JhopMM7bJrdTRoleM2JAVNMyPF+wdoaz+XJpGoVAQ7WXUkcV7gT3oUZyi/ISIJAVKhgNp+4b4veCFhYVJw4locdSjZCp9cPUhLF9EZ3KKzURepMEtCDPP3VcWFx4UIiZIklIpFNfHpdEafIF2aRmOcrUmjohbT2WUllbmRvgfbythbQO3222fpDJoufaQPncYYuqoGtUEsCJZL6/3PR5b4syeSjZMQG/T2maGANlXT2v8S4AULWaUkCxfLyW8iW4kdka+nEMjxpL2NCwsYNBp+Q61PF43zyDg9Bm9+3NNySn78jMZUUkumqE4Gp7JmFOdP1vc8PpRrzj9+wPinCy8K1PiJ4aYbnTYpCCbDkBSbzhu2QJ1Gd82t8jI8TH51+OzvXoWbnXUOBkNW+0mWFwGcGOUVpU81/n3TOHb5oMt2FgYGjzau0Nif0Ss7Q3XB33hjjQHjHA5E5aOyIQc8CBrLdQSs3j92VG+3nNEjbkbdbBr9zm04ruvw37vh0QKOdeGIkckc80fX3KH/h7PT4BOjgCty8VZ5ux1MoO5Cf5naca2LAsEgehI+drX8o/0Nu+W0m6K/I9gGPd/dfx/EN/wN62AhsBWuAAAAAElFTkSuQmCC
|
||||
">
|
||||
</div>
|
||||
</div>
|
||||
|
||||
## Overview
|
||||
# CodeLlama
|
||||
|
||||
The Code Llama model was proposed in [Code Llama: Open Foundation Models for Code](https://ai.meta.com/research/publications/code-llama-open-foundation-models-for-code/) by Baptiste Rozière, Jonas Gehring, Fabian Gloeckle, Sten Sootla, Itai Gat, Xiaoqing Ellen Tan, Yossi Adi, Jingyu Liu, Tal Remez, Jérémy Rapin, Artyom Kozhevnikov, Ivan Evtimov, Joanna Bitton, Manish Bhatt, Cristian Canton Ferrer, Aaron Grattafiori, Wenhan Xiong, Alexandre Défossez, Jade Copet, Faisal Azhar, Hugo Touvron, Louis Martin, Nicolas Usunier, Thomas Scialom, Gabriel Synnaeve.
|
||||
[Code Llama](https://huggingface.co/papers/2308.12950) is a specialized family of large language models based on [Llama 2](./llama2) for coding tasks. It comes in different flavors - general code, Python-specific, and instruction-following variant - all available in 7B, 13B, 34B, and 70B parameters. Code Llama models can generate, explain, and even fill in missing parts of your code (called "infilling"). It can also handle very long contexts with stable generation up to 100k tokens, even though it was trained on sequences of 16K tokens.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
You can find all the original Code Llama checkpoints under the [Code Llama](https://huggingface.co/collections/meta-llama/code-llama-family-661da32d0a9d678b6f55b933) collection.
|
||||
|
||||
*We release Code Llama, a family of large language models for code based on Llama 2 providing state-of-the-art performance among open models, infilling capabilities, support for large input contexts, and zero-shot instruction following ability for programming tasks. We provide multiple flavors to cover a wide range of applications: foundation models (Code Llama), Python specializations (Code Llama - Python), and instruction-following models (Code Llama - Instruct) with 7B, 13B and 34B parameters each. All models are trained on sequences of 16k tokens and show improvements on inputs with up to 100k tokens. 7B and 13B Code Llama and Code Llama - Instruct variants support infilling based on surrounding content. Code Llama reaches state-of-the-art performance among open models on several code benchmarks, with scores of up to 53% and 55% on HumanEval and MBPP, respectively. Notably, Code Llama - Python 7B outperforms Llama 2 70B on HumanEval and MBPP, and all our models outperform every other publicly available model on MultiPL-E. We release Code Llama under a permissive license that allows for both research and commercial use.*
|
||||
> [!TIP]
|
||||
> Click on the Code Llama models in the right sidebar for more examples of how to apply Code Llama to different coding tasks.
|
||||
|
||||
Check out all Code Llama model checkpoints [here](https://huggingface.co/models?search=code_llama) and the officially released ones in the [Meta Llama org](https://huggingface.co/meta-llama).
|
||||
The example below demonstrates how to generate code with [`Pipeline`], or the [`AutoModel`], and from the command line.
|
||||
|
||||
This model was contributed by [ArthurZucker](https://huggingface.co/ArthurZ). The original code of the authors can be found [here](https://github.com/facebookresearch/llama).
|
||||
<hfoptions id="usage">
|
||||
<hfoption id="Pipeline">
|
||||
|
||||
```py
|
||||
import torch
|
||||
from transformers import pipeline
|
||||
|
||||
## Usage tips and examples
|
||||
pipe = pipeline(
|
||||
"text-generation",
|
||||
model="meta-llama/CodeLlama-7b-hf",
|
||||
torch_dtype=torch.float16,
|
||||
device_map=0
|
||||
)
|
||||
|
||||
<Tip warning={true}>
|
||||
# basic code generation
|
||||
result = pipe("# Function to calculate the factorial of a number\ndef factorial(n):", max_new_tokens=256)
|
||||
print(result[0]['generated_text'])
|
||||
|
||||
The `Llama2` family models, on which Code Llama is based, were trained using `bfloat16`, but the original inference uses `float16`. Let's look at the different precisions:
|
||||
# infilling
|
||||
infill_result = pipe("def remove_non_ascii(s: str) -> str:\n \"\"\" <FILL_ME>\n return result", max_new_tokens=200)
|
||||
print(infill_result[0]['generated_text'])
|
||||
```
|
||||
|
||||
* `float32`: PyTorch convention on model initialization is to load models in `float32`, no matter with which `dtype` the model weights were stored. `transformers` also follows this convention for consistency with PyTorch. This will be picked by default. If you want the `AutoModel` API to load the checkpoints with the storage weights type, you must specify `torch_dtype="auto"`, e.g. `model = AutoModelForCausalLM.from_pretrained("path", torch_dtype = "auto")`.
|
||||
* `bfloat16`: Code Llama was trained with this precision, so we recommend using it for further training or fine-tuning.
|
||||
* `float16`: We recommend running inference using this precision, as it's usually faster than `bfloat16`, and evaluation metrics show no discernible degradation with respect to `bfloat16`. You can also run inference using `bfloat16`, and we recommend you check inference results with both `float16` and `bfloat16` after fine-tuning.
|
||||
</hfoption>
|
||||
<hfoption id="AutoModel">
|
||||
|
||||
As mentioned above, the `dtype` of the storage weights is mostly irrelevant unless you are using `torch_dtype="auto"` when initializing a model using. The reason is that the model will first be downloaded (using the `dtype` of the checkpoints online) and then will be casted to the default `dtype` of `torch` (becomes `torch.float32`). If there is a specified `torch_dtype`, it will be used instead.
|
||||
```py
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
</Tip>
|
||||
tokenizer = AutoTokenizer.from_pretrained("meta-llama/CodeLlama-7b-hf")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"meta-llama/CodeLlama-7b-hf",
|
||||
torch_dtype=torch.float16,
|
||||
device_map="auto",
|
||||
attn_implementation="sdpa"
|
||||
)
|
||||
|
||||
# basic code generation
|
||||
prompt = "# Function to calculate the factorial of a number\ndef factorial(n):"
|
||||
input_ids = tokenizer(prompt, return_tensors="pt").to("cuda")
|
||||
|
||||
Tips:
|
||||
- The infilling task is supported out of the box. You should be using the `tokenizer.fill_token` where you want your input to be filled.
|
||||
- The model conversion script is the same as for the `Llama2` family:
|
||||
output = model.generate(
|
||||
**input_ids,
|
||||
max_new_tokens=256,
|
||||
cache_implementation="static"
|
||||
)
|
||||
print(tokenizer.decode(output[0], skip_special_tokens=True))
|
||||
|
||||
Here is a sample usage:
|
||||
# infilling
|
||||
infill_prompt = "def remove_non_ascii(s: str) -> str:\n \"\"\" <FILL_ME>\n return result"
|
||||
input_ids = tokenizer(infill_prompt, return_tensors="pt").to(model.device)
|
||||
|
||||
filled_output = model.generate(**input_ids, max_new_tokens=200)
|
||||
filled_text = tokenizer.decode(filled_output[0], skip_special_tokens=True)
|
||||
print(filled_text)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="transformers-cli">
|
||||
|
||||
```bash
|
||||
python src/transformers/models/llama/convert_llama_weights_to_hf.py \
|
||||
--input_dir /path/to/downloaded/llama/weights --model_size 7B --output_dir /output/path
|
||||
echo -e "# Function to calculate the factorial of a number\ndef factorial(n):" | transformers-cli run --task text-generation --model meta-llama/CodeLlama-7b-hf --device 0
|
||||
```
|
||||
|
||||
Note that executing the script requires enough CPU RAM to host the whole model in float16 precision (even if the biggest versions
|
||||
come in several checkpoints they each contain a part of each weight of the model, so we need to load them all in RAM).
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
After conversion, the model and tokenizer can be loaded via:
|
||||
Quantization reduces the memory burden of large models by representing the weights in a lower precision. Refer to the [Quantization](../quantization/overview) overview for more available quantization backends.
|
||||
|
||||
```python
|
||||
>>> from transformers import LlamaForCausalLM, CodeLlamaTokenizer
|
||||
The example below uses [bitsandbytes](../quantization/bitsandbytes) to only quantize the weights to 4-bits.
|
||||
|
||||
>>> tokenizer = CodeLlamaTokenizer.from_pretrained("meta-llama/CodeLlama-7b-hf")
|
||||
>>> model = LlamaForCausalLM.from_pretrained("meta-llama/CodeLlama-7b-hf")
|
||||
>>> PROMPT = '''def remove_non_ascii(s: str) -> str:
|
||||
... """ <FILL_ME>
|
||||
... return result
|
||||
... '''
|
||||
>>> input_ids = tokenizer(PROMPT, return_tensors="pt")["input_ids"]
|
||||
>>> generated_ids = model.generate(input_ids, max_new_tokens=128)
|
||||
```py
|
||||
# pip install bitsandbytes
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, CodeLlamaTokenizer, BitsAndBytesConfig
|
||||
|
||||
>>> filling = tokenizer.batch_decode(generated_ids[:, input_ids.shape[1]:], skip_special_tokens = True)[0]
|
||||
>>> print(PROMPT.replace("<FILL_ME>", filling))
|
||||
def remove_non_ascii(s: str) -> str:
|
||||
""" Remove non-ASCII characters from a string.
|
||||
<BLANKLINE>
|
||||
Args:
|
||||
s: The string to remove non-ASCII characters from.
|
||||
<BLANKLINE>
|
||||
Returns:
|
||||
The string with non-ASCII characters removed.
|
||||
"""
|
||||
result = ""
|
||||
for c in s:
|
||||
if ord(c) < 128:
|
||||
result += c
|
||||
return result
|
||||
<BLANKLINE>
|
||||
bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True)
|
||||
tokenizer = CodeLlamaTokenizer.from_pretrained("meta-llama/CodeLlama-34b-hf")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"meta-llama/CodeLlama-34b-hf",
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
quantization_config=bnb_config
|
||||
)
|
||||
|
||||
prompt = "# Write a Python function to check if a string is a palindrome\ndef is_palindrome(s):"
|
||||
input_ids = tokenizer(prompt, return_tensors="pt").to("cuda")
|
||||
|
||||
output = model.generate(**input_ids, max_new_tokens=200, cache_implementation="static")
|
||||
print(tokenizer.decode(output[0], skip_special_tokens=True))
|
||||
```
|
||||
|
||||
If you only want the infilled part:
|
||||
```python
|
||||
>>> from transformers import pipeline
|
||||
>>> import torch
|
||||
Use the [AttentionMaskVisualizer](https://github.com/huggingface/transformers/blob/beb9b5b02246b9b7ee81ddf938f93f44cfeaad19/src/transformers/utils/attention_visualizer.py#L139) to better understand what tokens the model can and cannot attend to.
|
||||
|
||||
>>> generator = pipeline("text-generation",model="meta-llama/CodeLlama-7b-hf",torch_dtype=torch.float16, device_map="auto")
|
||||
>>> generator('def remove_non_ascii(s: str) -> str:\n """ <FILL_ME>\n return result', max_new_tokens = 128)
|
||||
[{'generated_text': 'def remove_non_ascii(s: str) -> str:\n """ <FILL_ME>\n return resultRemove non-ASCII characters from a string. """\n result = ""\n for c in s:\n if ord(c) < 128:\n result += c'}]
|
||||
```py
|
||||
from transformers.utils.attention_visualizer import AttentionMaskVisualizer
|
||||
|
||||
visualizer = AttentionMaskVisualizer("meta-llama/CodeLlama-7b-hf")
|
||||
visualizer("""def func(a, b):
|
||||
return a + b""")
|
||||
```
|
||||
|
||||
Under the hood, the tokenizer [automatically splits by `<FILL_ME>`](https://huggingface.co/docs/transformers/main/model_doc/code_llama#transformers.CodeLlamaTokenizer.fill_token) to create a formatted input string that follows [the original training pattern](https://github.com/facebookresearch/codellama/blob/cb51c14ec761370ba2e2bc351374a79265d0465e/llama/generation.py#L402). This is more robust than preparing the pattern yourself: it avoids pitfalls, such as token glueing, that are very hard to debug. To see how much CPU and GPU memory you need for this model or others, try [this calculator](https://huggingface.co/spaces/hf-accelerate/model-memory-usage) which can help determine that value.
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/codellama-attn-mask.png"/>
|
||||
</div>
|
||||
|
||||
The LLaMA tokenizer is a BPE model based on [sentencepiece](https://github.com/google/sentencepiece). One quirk of sentencepiece is that when decoding a sequence, if the first token is the start of the word (e.g. "Banana"), the tokenizer does not prepend the prefix space to the string.
|
||||
|
||||
<Tip>
|
||||
|
||||
Code Llama has the same architecture as the `Llama2` models, refer to [Llama2's documentation page](llama2) for the API reference.
|
||||
Find Code Llama tokenizer reference below.
|
||||
</Tip>
|
||||
## Notes
|
||||
|
||||
- Infilling is only available in the 7B and 13B base models, and not in the Python, Instruct, 34B, or 70B models.
|
||||
- Use the `<FILL_ME>` token where you want your input to be filled. The tokenizer splits this token to create a formatted input string that follows the [original training pattern](https://github.com/facebookresearch/codellama/blob/cb51c14ec761370ba2e2bc351374a79265d0465e/llama/generation.py#L402). This is more robust than preparing the pattern yourself.
|
||||
```py
|
||||
from transformers import LlamaForCausalLM, CodeLlamaTokenizer
|
||||
|
||||
tokenizer = CodeLlamaTokenizer.from_pretrained("meta-llama/CodeLlama-7b-hf")
|
||||
model = LlamaForCausalLM.from_pretrained("meta-llama/CodeLlama-7b-hf")
|
||||
PROMPT = '''def remove_non_ascii(s: str) -> str:
|
||||
""" <FILL_ME>
|
||||
return result
|
||||
'''
|
||||
input_ids = tokenizer(PROMPT, return_tensors="pt")["input_ids"]
|
||||
generated_ids = model.generate(input_ids, max_new_tokens=128)
|
||||
|
||||
filling = tokenizer.batch_decode(generated_ids[:, input_ids.shape[1]:], skip_special_tokens = True)[0]
|
||||
print(PROMPT.replace("<FILL_ME>", filling))
|
||||
```
|
||||
- Use `bfloat16` for further training or fine-tuning and `float16` for inference.
|
||||
- The `BOS` character is not used for infilling when encoding the prefix or suffix, but only at the beginning of each prompt.
|
||||
- The tokenizer is a byte-pair encoding model based on [SentencePiece](https://github.com/google/sentencepiece). During decoding, if the first token is the start of the word (for example, “Banana”), the tokenizer doesn’t prepend the prefix space to the string.
|
||||
|
||||
## CodeLlamaTokenizer
|
||||
|
||||
|
||||
@ -1,124 +1,115 @@
|
||||
# Cohere
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<div style="float: right;">
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
</div>
|
||||
|
||||
## Overview
|
||||
|
||||
The Cohere Command-R model was proposed in the blogpost [Command-R: Retrieval Augmented Generation at Production Scale](https://txt.cohere.com/command-r/) by the Cohere Team.
|
||||
# Cohere
|
||||
|
||||
The abstract from the paper is the following:
|
||||
Cohere Command-R is a 35B parameter multilingual large language model designed for long context tasks like retrieval-augmented generation (RAG) and calling external APIs and tools. The model is specifically trained for grounded generation and supports both single-step and multi-step tool use. It supports a context length of 128K tokens.
|
||||
|
||||
*Command-R is a scalable generative model targeting RAG and Tool Use to enable production-scale AI for enterprise. Today, we are introducing Command-R, a new LLM aimed at large-scale production workloads. Command-R targets the emerging “scalable” category of models that balance high efficiency with strong accuracy, enabling companies to move beyond proof of concept, and into production.*
|
||||
You can find all the original Command-R checkpoints under the [Command Models](https://huggingface.co/collections/CohereForAI/command-models-67652b401665205e17b192ad) collection.
|
||||
|
||||
*Command-R is a generative model optimized for long context tasks such as retrieval augmented generation (RAG) and using external APIs and tools. It is designed to work in concert with our industry-leading Embed and Rerank models to provide best-in-class integration for RAG applications and excel at enterprise use cases. As a model built for companies to implement at scale, Command-R boasts:
|
||||
- Strong accuracy on RAG and Tool Use
|
||||
- Low latency, and high throughput
|
||||
- Longer 128k context and lower pricing
|
||||
- Strong capabilities across 10 key languages
|
||||
- Model weights available on HuggingFace for research and evaluation
|
||||
|
||||
Checkout model checkpoints [here](https://huggingface.co/CohereForAI/c4ai-command-r-v01).
|
||||
This model was contributed by [Saurabh Dash](https://huggingface.co/saurabhdash) and [Ahmet Üstün](https://huggingface.co/ahmetustun). The code of the implementation in Hugging Face is based on GPT-NeoX [here](https://github.com/EleutherAI/gpt-neox).
|
||||
> [!TIP]
|
||||
> Click on the Cohere models in the right sidebar for more examples of how to apply Cohere to different language tasks.
|
||||
|
||||
## Usage tips
|
||||
The example below demonstrates how to generate text with [`Pipeline`] or the [`AutoModel`], and from the command line.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
The checkpoints uploaded on the Hub use `torch_dtype = 'float16'`, which will be
|
||||
used by the `AutoModel` API to cast the checkpoints from `torch.float32` to `torch.float16`.
|
||||
|
||||
The `dtype` of the online weights is mostly irrelevant unless you are using `torch_dtype="auto"` when initializing a model using `model = AutoModelForCausalLM.from_pretrained("path", torch_dtype = "auto")`. The reason is that the model will first be downloaded ( using the `dtype` of the checkpoints online), then it will be casted to the default `dtype` of `torch` (becomes `torch.float32`), and finally, if there is a `torch_dtype` provided in the config, it will be used.
|
||||
|
||||
Training the model in `float16` is not recommended and is known to produce `nan`; as such, the model should be trained in `bfloat16`.
|
||||
|
||||
</Tip>
|
||||
The model and tokenizer can be loaded via:
|
||||
<hfoptions id="usage">
|
||||
<hfoption id="Pipeline">
|
||||
|
||||
```python
|
||||
# pip install transformers
|
||||
import torch
|
||||
from transformers import pipeline
|
||||
|
||||
pipeline = pipeline(
|
||||
task="text-generation",
|
||||
model="CohereForAI/c4ai-command-r-v01",
|
||||
torch_dtype=torch.float16,
|
||||
device=0
|
||||
)
|
||||
pipeline("Plants create energy through a process known as")
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="AutoModel">
|
||||
|
||||
```python
|
||||
import torch
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
|
||||
model_id = "CohereForAI/c4ai-command-r-v01"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id)
|
||||
tokenizer = AutoTokenizer.from_pretrained("CohereForAI/c4ai-command-r-v01")
|
||||
model = AutoModelForCausalLM.from_pretrained("CohereForAI/c4ai-command-r-v01", torch_dtype=torch.float16, device_map="auto", attn_implementation="sdpa")
|
||||
|
||||
# Format message with the command-r chat template
|
||||
messages = [{"role": "user", "content": "Hello, how are you?"}]
|
||||
input_ids = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt")
|
||||
## <BOS_TOKEN><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Hello, how are you?<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>
|
||||
|
||||
gen_tokens = model.generate(
|
||||
# format message with the Command-R chat template
|
||||
messages = [{"role": "user", "content": "How do plants make energy?"}]
|
||||
input_ids = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt").to("cuda")
|
||||
output = model.generate(
|
||||
input_ids,
|
||||
max_new_tokens=100,
|
||||
do_sample=True,
|
||||
temperature=0.3,
|
||||
)
|
||||
|
||||
gen_text = tokenizer.decode(gen_tokens[0])
|
||||
print(gen_text)
|
||||
cache_implementation="static",
|
||||
)
|
||||
print(tokenizer.decode(output[0], skip_special_tokens=True))
|
||||
```
|
||||
|
||||
- When using Flash Attention 2 via `attn_implementation="flash_attention_2"`, don't pass `torch_dtype` to the `from_pretrained` class method and use Automatic Mixed-Precision training. When using `Trainer`, it is simply specifying either `fp16` or `bf16` to `True`. Otherwise, make sure you are using `torch.autocast`. This is required because the Flash Attention only support `fp16` and `bf16` data type.
|
||||
</hfoption>
|
||||
<hfoption id="transformers-cli">
|
||||
|
||||
|
||||
## Resources
|
||||
|
||||
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with Command-R. If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we'll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource.
|
||||
|
||||
|
||||
<PipelineTag pipeline="text-generation"/>
|
||||
|
||||
Loading FP16 model
|
||||
```python
|
||||
# pip install transformers
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
|
||||
model_id = "CohereForAI/c4ai-command-r-v01"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id)
|
||||
|
||||
# Format message with the command-r chat template
|
||||
messages = [{"role": "user", "content": "Hello, how are you?"}]
|
||||
input_ids = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt")
|
||||
## <BOS_TOKEN><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Hello, how are you?<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>
|
||||
|
||||
gen_tokens = model.generate(
|
||||
input_ids,
|
||||
max_new_tokens=100,
|
||||
do_sample=True,
|
||||
temperature=0.3,
|
||||
)
|
||||
|
||||
gen_text = tokenizer.decode(gen_tokens[0])
|
||||
print(gen_text)
|
||||
```bash
|
||||
# pip install -U flash-attn --no-build-isolation
|
||||
transformers-cli chat --model_name_or_path CohereForAI/c4ai-command-r-v01 --torch_dtype auto --attn_implementation flash_attention_2
|
||||
```
|
||||
|
||||
Loading bitsnbytes 4bit quantized model
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
Quantization reduces the memory burden of large models by representing the weights in a lower precision. Refer to the [Quantization](../quantization/overview) overview for more available quantization backends.
|
||||
|
||||
The example below uses [bitsandbytes](../quantization/bitsandbytes) to quantize the weights to 4-bits.
|
||||
|
||||
```python
|
||||
# pip install transformers bitsandbytes accelerate
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
|
||||
import torch
|
||||
from transformers import BitsAndBytesConfig, AutoTokenizer, AutoModelForCausalLM
|
||||
|
||||
bnb_config = BitsAndBytesConfig(load_in_4bit=True)
|
||||
tokenizer = AutoTokenizer.from_pretrained("CohereForAI/c4ai-command-r-v01")
|
||||
model = AutoModelForCausalLM.from_pretrained("CohereForAI/c4ai-command-r-v01", torch_dtype=torch.float16, device_map="auto", quantization_config=bnb_config, attn_implementation="sdpa")
|
||||
|
||||
model_id = "CohereForAI/c4ai-command-r-v01"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config)
|
||||
|
||||
gen_tokens = model.generate(
|
||||
# format message with the Command-R chat template
|
||||
messages = [{"role": "user", "content": "How do plants make energy?"}]
|
||||
input_ids = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt").to("cuda")
|
||||
output = model.generate(
|
||||
input_ids,
|
||||
max_new_tokens=100,
|
||||
do_sample=True,
|
||||
temperature=0.3,
|
||||
)
|
||||
|
||||
gen_text = tokenizer.decode(gen_tokens[0])
|
||||
print(gen_text)
|
||||
cache_implementation="static",
|
||||
)
|
||||
print(tokenizer.decode(output[0], skip_special_tokens=True))
|
||||
```
|
||||
|
||||
Use the [AttentionMaskVisualizer](https://github.com/huggingface/transformers/blob/beb9b5b02246b9b7ee81ddf938f93f44cfeaad19/src/transformers/utils/attention_visualizer.py#L139) to better understand what tokens the model can and cannot attend to.
|
||||
|
||||
```py
|
||||
from transformers.utils.attention_visualizer import AttentionMaskVisualizer
|
||||
|
||||
visualizer = AttentionMaskVisualizer("CohereForAI/c4ai-command-r-v01")
|
||||
visualizer("Plants create energy through a process known as")
|
||||
```
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/cohere-attn-mask.png"/>
|
||||
</div>
|
||||
|
||||
|
||||
## Notes
|
||||
- Don’t use the torch_dtype parameter in [`~AutoModel.from_pretrained`] if you’re using FlashAttention-2 because it only supports fp16 or bf16. You should use [Automatic Mixed Precision](https://pytorch.org/tutorials/recipes/recipes/amp_recipe.html), set fp16 or bf16 to True if using [`Trainer`], or use [torch.autocast](https://pytorch.org/docs/stable/amp.html#torch.autocast).
|
||||
|
||||
## CohereConfig
|
||||
|
||||
@ -143,5 +134,3 @@ print(gen_text)
|
||||
|
||||
[[autodoc]] CohereForCausalLM
|
||||
- forward
|
||||
|
||||
|
||||
|
||||
184
docs/source/en/model_doc/deepseek_v3.md
Normal file
184
docs/source/en/model_doc/deepseek_v3.md
Normal file
@ -0,0 +1,184 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# DeepSeek-V3
|
||||
|
||||
## Overview
|
||||
|
||||
The DeepSeek-V3 model was proposed in [DeepSeek-V3 Technical Report](https://arxiv.org/abs/2412.19437) by DeepSeek-AI Team.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
We present DeepSeek-V3, a strong Mixture-of-Experts (MoE) language model with 671B total parameters with 37B activated for each token. To achieve efficient inference and cost-effective training, DeepSeek-V3 adopts Multi-head Latent Attention (MLA) and DeepSeekMoE architectures, which were thoroughly validated in DeepSeek-V2. Furthermore, DeepSeek-V3 pioneers an auxiliary-loss-free strategy for load balancing and sets a multi-token prediction training objective for stronger performance. We pre-train DeepSeek-V3 on 14.8 trillion diverse and high-quality tokens, followed by Supervised Fine-Tuning and Reinforcement Learning stages to fully harness its capabilities. Comprehensive evaluations reveal that DeepSeek-V3 outperforms other open-source models and achieves performance comparable to leading closed-source models. Despite its excellent performance, DeepSeek-V3 requires only 2.788M H800 GPU hours for its full training. In addition, its training process is remarkably stable. Throughout the entire training process, we did not experience any irrecoverable loss spikes or perform any rollbacks. The model checkpoints are available at https://github.com/deepseek-ai/DeepSeek-V3.
|
||||
|
||||
## Limitations and call for contribution!
|
||||
|
||||
We are super happy to make this code community-powered, and would love to see how you can best optimize the following:
|
||||
|
||||
- current implementation uses the "naive" attention compution (so not really MLA)
|
||||
- current implementation loops through the experts. This should be replaced. Pointers to use `get_packed_weights` from `intetrations/tensor_parallel`.
|
||||
- current implementation uses the eleuther formula for ROPE, using the orginal one would be more efficient! (should still follow our API)
|
||||
- static cache is not supported (this should be just a generation config issue / config shape issues)
|
||||
|
||||
### Usage tips
|
||||
The model uses Multi-head Latent Attention (MLA) and DeepSeekMoE architectures for efficient inference and cost-effective training. It employs an auxiliary-loss-free strategy for load balancing and multi-token prediction training objective. The model can be used for various language tasks after being pre-trained on 14.8 trillion tokens and going through Supervised Fine-Tuning and Reinforcement Learning stages.
|
||||
|
||||
You can run the model in `FP8` automatically, using 2 nodes of 8 H100 should be more than enough!
|
||||
|
||||
```python
|
||||
# `run_deepseek_v1.py`
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
import torch
|
||||
torch.manual_seed(30)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("deepseek-r1")
|
||||
|
||||
chat = [
|
||||
{"role": "user", "content": "Hello, how are you?"},
|
||||
{"role": "assistant", "content": "I'm doing great. How can I help you today?"},
|
||||
{"role": "user", "content": "I'd like to show off how chat templating works!"},
|
||||
]
|
||||
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("deepseek-r1", device_map="auto", torch_dtype=torch.bfloat16)
|
||||
inputs = tokenizer.apply_chat_template(chat, tokenize=True, add_generation_prompt=True, return_tensors="pt").to(model.device)
|
||||
import time
|
||||
start = time.time()
|
||||
outputs = model.generate(inputs, max_new_tokens=50)
|
||||
print(tokenizer.batch_decode(outputs))
|
||||
print(time.time()-start)
|
||||
```
|
||||
This generated:
|
||||
|
||||
``````
|
||||
<|Assistant|><think>
|
||||
Okay, the user wants to demonstrate how chat templating works. Let me break down what that means. Chat templating is about structuring the conversation data, especially for models that need specific input formats. Maybe they're referring to something like how messages are formatted with roles (user, assistant, system) in APIs like OpenAI.
|
||||
|
||||
First, I should explain what chat templating is. It's the process of formatting conversation data into a structured format that the model can understand. This usually includes roles and content. For example, user messages, assistant responses, and system messages each have their own role tags.
|
||||
|
||||
They might want an example. Let me think of a simple conversation. The user says "Hello, how are you?" and the assistant responds "I'm doing great. How can I help you today?" Then the user follows up with wanting to show off chat templating. So the example should include the history and the new message.
|
||||
|
||||
In some frameworks, like Hugging Face's Transformers, chat templates are applied using Jinja2 templates. The template might look something like combining system messages, then looping through user and assistant messages with appropriate tags. For instance, using {% for message in messages %} and assigning roles like <|user|>, <|assistant|>, etc.
|
||||
|
||||
I should structure the example with the messages array, showing each role and content. Then apply a hypothetical template to convert that into a formatted string the model uses. Also, mention that different models have different templating requirements, like using special tokens or varying role labels.
|
||||
|
||||
Wait, the user mentioned "chat templating" in the context of showing off. Maybe they want a practical example they can present. So providing a code snippet or a structured data example would be helpful. Let me outline a typical messages array and then the templated output.
|
||||
|
||||
Also, it's important to note that proper templating ensures the model knows the conversation flow, which is crucial for generating coherent responses. Maybe include a note about why it's important, like maintaining context and role-specific processing.
|
||||
|
||||
Let me check if there are any common mistakes or things to avoid. For example, not closing tags properly, or mismatching roles. But maybe that's too detailed unless the user asks. Focus on the positive example first.
|
||||
|
||||
Putting it all together, the response should have an example messages array, the applied template, and the final formatted string. Maybe use angle brackets or special tokens as placeholders. Also, mention that this helps in training or fine-tuning models with structured data.
|
||||
|
||||
I think that's a solid approach. Let me structure it step by step to make it clear.
|
||||
</think>
|
||||
|
||||
Chat templating is a way to structure conversation data (e.g., user/assistant interactions) into a format that language models understand. This is especially important for models trained to handle multi-turn dialogues, where the input must explicitly separate roles (user, assistant, system, etc.) and messages. Let’s break this down with an example!
|
||||
|
||||
---
|
||||
|
||||
### **Step 1: Raw Conversation History**
|
||||
Suppose we have this conversation:
|
||||
- **User**: "Hello, how are you?"
|
||||
- **Assistant**: "I'm doing great. How can I help you today?"
|
||||
- **User**: "I'd like to show off how chat templating works!"
|
||||
|
||||
---
|
||||
|
||||
### **Step 2: Structured Messages**
|
||||
In frameworks like Hugging Face Transformers or OpenAI, conversations are often formatted as a list of dictionaries with `role` and `content`:
|
||||
```python
|
||||
messages = [
|
||||
{"role": "user", "content": "Hello, how are you?"},
|
||||
{"role": "assistant", "content": "I'm doing great. How can I help you today?"},
|
||||
{"role": "user", "content": "I'd like to show off how chat templating works!"},
|
||||
]
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### **Step 3: Apply a Chat Template**
|
||||
A **chat template** converts this structured data into a single string formatted for the model. For example, using a Jinja-style template (common in Hugging Face):
|
||||
|
||||
```jinja
|
||||
{% for message in messages %}
|
||||
{% if message['role'] == 'user' %}
|
||||
<|user|>{{ message['content'] }}<|end|>
|
||||
{% elif message['role'] == 'assistant' %}
|
||||
<|assistant|>{{ message['content'] }}<|end|>
|
||||
{% endif %}
|
||||
{% endfor %}
|
||||
<|assistant|>
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### **Step 4: Final Templated Output**
|
||||
Applying the template to our `messages` list would produce:
|
||||
```text
|
||||
<|user|>Hello, how are you?<|end|>
|
||||
<|assistant|>I'm doing great. How can I help you today?<|end|>
|
||||
<|user|>I'd like to show off how chat templating works!<|end|>
|
||||
<|assistant|>
|
||||
```
|
||||
|
||||
This tells the model:
|
||||
1. The conversation history (user/assistant turns).
|
||||
2. The model’s turn to generate a response (`<|assistant|>` at the end).
|
||||
|
||||
---
|
||||
|
||||
### **Key Notes**:
|
||||
- **Role Separation**: Tags like `<|user|>` and `<|assistant|>` help the model distinguish speakers.
|
||||
- **Special Tokens**: Models often use unique tokens (e.g., `<|end|>`) to mark message boundaries.
|
||||
- **Flexibility**: Templates vary by model (e.g., OpenAI uses `{"role": "user", "content": "..."}` instead of tags).
|
||||
|
||||
---
|
||||
|
||||
### **Why This Matters**:
|
||||
- **Consistency**: Ensures the model understands dialogue structure.
|
||||
- **Context Preservation**: Maintains the flow of multi-turn conversations.
|
||||
- **Alignment**: Matches the format the model was trained on for better performance.
|
||||
|
||||
Want to dive deeper or see a specific framework’s implementation (e.g., OpenAI, Llama, Mistral)? Let me know! 😊<|end▁of▁sentence|>
|
||||
``````
|
||||
|
||||
Use the following to run it
|
||||
```bash
|
||||
torchrun --nproc_per_node=8 --nnodes=2 --node_rank=0|1 --rdzv-id an_id --rdzv-backend c10d --rdzv-endpoint master_addr:master_port run_deepseek_r1.py
|
||||
```
|
||||
|
||||
If you have:
|
||||
```bash
|
||||
[rank0]: ncclInternalError: Internal check failed.
|
||||
[rank0]: Last error:
|
||||
[rank0]: Bootstrap : no socket interface found
|
||||
```
|
||||
error, it means NCCL was probably not loaded.
|
||||
|
||||
|
||||
## DeepseekV3Config
|
||||
|
||||
[[autodoc]] DeepseekV3Config
|
||||
|
||||
## DeepseekV3Model
|
||||
|
||||
[[autodoc]] DeepseekV3Model
|
||||
- forward
|
||||
|
||||
## DeepseekV3ForCausalLM
|
||||
|
||||
[[autodoc]] DeepseekV3ForCausalLM
|
||||
- forward
|
||||
@ -14,101 +14,69 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# Depth Anything
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<div style="float: right;">
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
</div>
|
||||
|
||||
## Overview
|
||||
# Depth Anything
|
||||
|
||||
The Depth Anything model was proposed in [Depth Anything: Unleashing the Power of Large-Scale Unlabeled Data](https://arxiv.org/abs/2401.10891) by Lihe Yang, Bingyi Kang, Zilong Huang, Xiaogang Xu, Jiashi Feng, Hengshuang Zhao. Depth Anything is based on the [DPT](dpt) architecture, trained on ~62 million images, obtaining state-of-the-art results for both relative and absolute depth estimation.
|
||||
[Depth Anything](https://huggingface.co/papers/2401.10891) is designed to be a foundation model for monocular depth estimation (MDE). It is jointly trained on labeled and ~62M unlabeled images to enhance the dataset. It uses a pretrained [DINOv2](./dinov2) model as an image encoder to inherit its existing rich semantic priors, and [DPT](./dpt) as the decoder. A teacher model is trained on unlabeled images to create pseudo-labels. The student model is trained on a combination of the pseudo-labels and labeled images. To improve the student model's performance, strong perturbations are added to the unlabeled images to challenge the student model to learn more visual knowledge from the image.
|
||||
|
||||
<Tip>
|
||||
You can find all the original Depth Anything checkpoints under the [Depth Anything](https://huggingface.co/collections/LiheYoung/depth-anything-release-65b317de04eec72abf6b55aa) collection.
|
||||
|
||||
[Depth Anything V2](depth_anything_v2) was released in June 2024. It uses the same architecture as Depth Anything and therefore it is compatible with all code examples and existing workflows. However, it leverages synthetic data and a larger capacity teacher model to achieve much finer and robust depth predictions.
|
||||
> [!TIP]
|
||||
> Click on the Depth Anything models in the right sidebar for more examples of how to apply Depth Anything to different vision tasks.
|
||||
|
||||
</Tip>
|
||||
The example below demonstrates how to obtain a depth map with [`Pipeline`] or the [`AutoModel`] class.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
<hfoptions id="usage">
|
||||
<hfoption id="Pipeline">
|
||||
|
||||
*This work presents Depth Anything, a highly practical solution for robust monocular depth estimation. Without pursuing novel technical modules, we aim to build a simple yet powerful foundation model dealing with any images under any circumstances. To this end, we scale up the dataset by designing a data engine to collect and automatically annotate large-scale unlabeled data (~62M), which significantly enlarges the data coverage and thus is able to reduce the generalization error. We investigate two simple yet effective strategies that make data scaling-up promising. First, a more challenging optimization target is created by leveraging data augmentation tools. It compels the model to actively seek extra visual knowledge and acquire robust representations. Second, an auxiliary supervision is developed to enforce the model to inherit rich semantic priors from pre-trained encoders. We evaluate its zero-shot capabilities extensively, including six public datasets and randomly captured photos. It demonstrates impressive generalization ability. Further, through fine-tuning it with metric depth information from NYUv2 and KITTI, new SOTAs are set. Our better depth model also results in a better depth-conditioned ControlNet.*
|
||||
```py
|
||||
import torch
|
||||
from transformers import pipeline
|
||||
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/depth_anything_overview.jpg"
|
||||
alt="drawing" width="600"/>
|
||||
|
||||
<small> Depth Anything overview. Taken from the <a href="https://arxiv.org/abs/2401.10891">original paper</a>.</small>
|
||||
|
||||
This model was contributed by [nielsr](https://huggingface.co/nielsr).
|
||||
The original code can be found [here](https://github.com/LiheYoung/Depth-Anything).
|
||||
|
||||
## Usage example
|
||||
|
||||
There are 2 main ways to use Depth Anything: either using the pipeline API, which abstracts away all the complexity for you, or by using the `DepthAnythingForDepthEstimation` class yourself.
|
||||
|
||||
### Pipeline API
|
||||
|
||||
The pipeline allows to use the model in a few lines of code:
|
||||
|
||||
```python
|
||||
>>> from transformers import pipeline
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
|
||||
>>> # load pipe
|
||||
>>> pipe = pipeline(task="depth-estimation", model="LiheYoung/depth-anything-small-hf")
|
||||
|
||||
>>> # load image
|
||||
>>> url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
>>> # inference
|
||||
>>> depth = pipe(image)["depth"]
|
||||
pipe = pipeline(task="depth-estimation", model="LiheYoung/depth-anything-base-hf", torch_dtype=torch.bfloat16, device=0)
|
||||
pipe("http://images.cocodataset.org/val2017/000000039769.jpg")["depth"]
|
||||
```
|
||||
|
||||
### Using the model yourself
|
||||
</hfoption>
|
||||
<hfoption id="AutoModel">
|
||||
|
||||
If you want to do the pre- and postprocessing yourself, here's how to do that:
|
||||
```py
|
||||
import torch
|
||||
import requests
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from transformers import AutoImageProcessor, AutoModelForDepthEstimation
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoImageProcessor, AutoModelForDepthEstimation
|
||||
>>> import torch
|
||||
>>> import numpy as np
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
image_processor = AutoImageProcessor.from_pretrained("LiheYoung/depth-anything-base-hf")
|
||||
model = AutoModelForDepthEstimation.from_pretrained("LiheYoung/depth-anything-base-hf", torch_dtype=torch.bfloat16)
|
||||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
image = Image.open(requests.get(url, stream=True).raw)
|
||||
inputs = image_processor(images=image, return_tensors="pt")
|
||||
|
||||
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
|
||||
>>> image_processor = AutoImageProcessor.from_pretrained("LiheYoung/depth-anything-small-hf")
|
||||
>>> model = AutoModelForDepthEstimation.from_pretrained("LiheYoung/depth-anything-small-hf")
|
||||
|
||||
>>> # prepare image for the model
|
||||
>>> inputs = image_processor(images=image, return_tensors="pt")
|
||||
|
||||
>>> with torch.no_grad():
|
||||
... outputs = model(**inputs)
|
||||
|
||||
>>> # interpolate to original size and visualize the prediction
|
||||
>>> post_processed_output = image_processor.post_process_depth_estimation(
|
||||
... outputs,
|
||||
... target_sizes=[(image.height, image.width)],
|
||||
... )
|
||||
|
||||
>>> predicted_depth = post_processed_output[0]["predicted_depth"]
|
||||
>>> depth = (predicted_depth - predicted_depth.min()) / (predicted_depth.max() - predicted_depth.min())
|
||||
>>> depth = depth.detach().cpu().numpy() * 255
|
||||
>>> depth = Image.fromarray(depth.astype("uint8"))
|
||||
post_processed_output = image_processor.post_process_depth_estimation(
|
||||
outputs,
|
||||
target_sizes=[(image.height, image.width)],
|
||||
)
|
||||
predicted_depth = post_processed_output[0]["predicted_depth"]
|
||||
depth = (predicted_depth - predicted_depth.min()) / (predicted_depth.max() - predicted_depth.min())
|
||||
depth = depth.detach().cpu().numpy() * 255
|
||||
Image.fromarray(depth.astype("uint8"))
|
||||
```
|
||||
|
||||
## Resources
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with Depth Anything.
|
||||
## Notes
|
||||
|
||||
- [Monocular depth estimation task guide](../tasks/monocular_depth_estimation)
|
||||
- A notebook showcasing inference with [`DepthAnythingForDepthEstimation`] can be found [here](https://github.com/NielsRogge/Transformers-Tutorials/blob/master/Depth%20Anything/Predicting_depth_in_an_image_with_Depth_Anything.ipynb). 🌎
|
||||
|
||||
If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we'll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource.
|
||||
- [DepthAnythingV2](./depth_anything_v2), released in June 2024, uses the same architecture as Depth Anything and is compatible with all code examples and existing workflows. It uses synthetic data and a larger capacity teacher model to achieve much finer and robust depth predictions.
|
||||
|
||||
## DepthAnythingConfig
|
||||
|
||||
|
||||
@ -10,71 +10,134 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# DINOv2
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="Flax" src="https://img.shields.io/badge/Flax-29a79b.svg?style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAC0AAAAtCAMAAAANxBKoAAAC7lBMVEUAAADg5vYHPVgAoJH+/v76+v39/f9JbLP///9+AIgAnY3///+mcqzt8fXy9fgkXa3Ax9709fr+///9/f8qXq49qp5AaLGMwrv8/P0eW60VWawxYq8yqJzG2dytt9Wyu9elzci519Lf3O3S2efY3OrY0+Xp7PT///////+dqNCexMc6Z7AGpJeGvbenstPZ5ejQ1OfJzOLa7ejh4+/r8fT29vpccbklWK8PVa0AS6ghW63O498vYa+lsdKz1NDRt9Kw1c672tbD3tnAxt7R6OHp5vDe7OrDyuDn6vLl6/EAQKak0MgATakkppo3ZK/Bz9y8w9yzu9jey97axdvHzeG21NHH4trTwthKZrVGZLSUSpuPQJiGAI+GAI8SWKydycLL4d7f2OTi1+S9xNzL0ePT6OLGzeEAo5U0qJw/aLEAo5JFa7JBabEAp5Y4qZ2QxLyKmsm3kL2xoMOehrRNb7RIbbOZgrGre68AUqwAqZqNN5aKJ5N/lMq+qsd8kMa4pcWzh7muhLMEV69juq2kbKqgUaOTR5uMMZWLLZSGAI5VAIdEAH+ovNDHuNCnxcy3qcaYx8K8msGplrx+wLahjbYdXrV6vbMvYK9DrZ8QrZ8tqJuFms+Sos6sw8ecy8RffsNVeMCvmb43aLltv7Q4Y7EZWK4QWa1gt6meZKUdr6GOAZVeA4xPAISyveLUwtivxtKTpNJ2jcqfvcltiMiwwcfAoMVxhL+Kx7xjdrqTe60tsaNQs6KaRKACrJ6UTZwkqpqTL5pkHY4AloSgsd2ptNXPvNOOncuxxsqFl8lmg8apt8FJcr9EbryGxLqlkrkrY7dRa7ZGZLQ5t6iXUZ6PPpgVpZeJCJFKAIGareTa0+KJod3H0deY2M+esM25usmYu8d2zsJOdcBVvrCLbqcAOaaHaKQAMaScWqKBXqCXMJ2RHpiLF5NmJZAdAHN2kta11dKu1M+DkcZLdb+Mcql3TppyRJdzQ5ZtNZNlIY+DF4+voCOQAAAAZ3RSTlMABAT+MEEJ/RH+/TP+Zlv+pUo6Ifz8+fco/fz6+evr39S9nJmOilQaF/7+/f38+smmoYp6b1T+/v7++vj189zU0tDJxsGzsrKSfv34+Pf27dDOysG9t6+n/vv6+vr59uzr1tG+tZ6Qg9Ym3QAABR5JREFUSMeNlVVUG1EQhpcuxEspXqS0SKEtxQp1d3d332STTRpIQhIISQgJhODu7lAoDoUCpe7u7u7+1puGpqnCPOyZvffbOXPm/PsP9JfQgyCC+tmTABTOcbxDz/heENS7/1F+9nhvkHePG0wNDLbGWwdXL+rbLWvpmZHXD8+gMfBjTh+aSe6Gnn7lwQIOTR0c8wfX3PWgv7avbdKwf/ZoBp1Gp/PvuvXW3vw5ib7emnTW4OR+3D4jB9vjNJ/7gNvfWWeH/TO/JyYrsiKCRjVEZA3UB+96kON+DxOQ/NLE8PE5iUYgIXjFnCOlxEQMaSGVxjg4gxOnEycGz8bptuNjVx08LscIgrzH3umcn+KKtiBIyvzOO2O99aAdR8cF19oZalnCtvREUw79tCd5sow1g1UKM6kXqUx4T8wsi3sTjJ3yzDmmhenLXLpo8u45eG5y4Vvbk6kkC4LLtJMowkSQxmk4ggVJEG+7c6QpHT8vvW9X7/o7+3ELmiJi2mEzZJiz8cT6TBlanBk70cB5GGIGC1gRDdZ00yADLW1FL6gqhtvNXNG5S9gdSrk4M1qu7JAsmYshzDS4peoMrU/gT7qQdqYGZaYhxZmVbGJAm/CS/HloWyhRUlknQ9KYcExTwS80d3VNOxUZJpITYyspl0LbhArhpZCD9cRWEQuhYkNGMHToQ/2Cs6swJlb39CsllxdXX6IUKh/H5jbnSsPKjgmoaFQ1f8wRLR0UnGE/RcDEjj2jXG1WVTwUs8+zxfcrVO+vSsuOpVKxCfYZiQ0/aPKuxQbQ8lIz+DClxC8u+snlcJ7Yr1z1JPqUH0V+GDXbOwAib931Y4Imaq0NTIXPXY+N5L18GJ37SVWu+hwXff8l72Ds9XuwYIBaXPq6Shm4l+Vl/5QiOlV+uTk6YR9PxKsI9xNJny31ygK1e+nIRC1N97EGkFPI+jCpiHe5PCEy7oWqWSwRrpOvhFzcbTWMbm3ZJAOn1rUKpYIt/lDhW/5RHHteeWFN60qo98YJuoq1nK3uW5AabyspC1BcIEpOhft+SZAShYoLSvnmSfnYADUERP5jJn2h5XtsgCRuhYQqAvwTwn33+YWEKUI72HX5AtfSAZDe8F2DtPPm77afhl0EkthzuCQU0BWApgQIH9+KB0JhopMM7bJrdTRoleM2JAVNMyPF+wdoaz+XJpGoVAQ7WXUkcV7gT3oUZyi/ISIJAVKhgNp+4b4veCFhYVJw4locdSjZCp9cPUhLF9EZ3KKzURepMEtCDPP3VcWFx4UIiZIklIpFNfHpdEafIF2aRmOcrUmjohbT2WUllbmRvgfbythbQO3222fpDJoufaQPncYYuqoGtUEsCJZL6/3PR5b4syeSjZMQG/T2maGANlXT2v8S4AULWaUkCxfLyW8iW4kdka+nEMjxpL2NCwsYNBp+Q61PF43zyDg9Bm9+3NNySn78jMZUUkumqE4Gp7JmFOdP1vc8PpRrzj9+wPinCy8K1PiJ4aYbnTYpCCbDkBSbzhu2QJ1Gd82t8jI8TH51+OzvXoWbnXUOBkNW+0mWFwGcGOUVpU81/n3TOHb5oMt2FgYGjzau0Nif0Ss7Q3XB33hjjQHjHA5E5aOyIQc8CBrLdQSs3j92VG+3nNEjbkbdbBr9zm04ruvw37vh0QKOdeGIkckc80fX3KH/h7PT4BOjgCty8VZ5ux1MoO5Cf5naca2LAsEgehI+drX8o/0Nu+W0m6K/I9gGPd/dfx/EN/wN62AhsBWuAAAAAElFTkSuQmCC
|
||||
">
|
||||
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<div style="float: right;">
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="Flax" src="https://img.shields.io/badge/Flax-29a79b.svg?style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAC0AAAAtCAMAAAANxBKoAAAC7lBMVEUAAADg5vYHPVgAoJH+/v76+v39/f9JbLP///9+AIgAnY3///+mcqzt8fXy9fgkXa3Ax9709fr+///9/f8qXq49qp5AaLGMwrv8/P0eW60VWawxYq8yqJzG2dytt9Wyu9elzci519Lf3O3S2efY3OrY0+Xp7PT///////+dqNCexMc6Z7AGpJeGvbenstPZ5ejQ1OfJzOLa7ejh4+/r8fT29vpccbklWK8PVa0AS6ghW63O498vYa+lsdKz1NDRt9Kw1c672tbD3tnAxt7R6OHp5vDe7OrDyuDn6vLl6/EAQKak0MgATakkppo3ZK/Bz9y8w9yzu9jey97axdvHzeG21NHH4trTwthKZrVGZLSUSpuPQJiGAI+GAI8SWKydycLL4d7f2OTi1+S9xNzL0ePT6OLGzeEAo5U0qJw/aLEAo5JFa7JBabEAp5Y4qZ2QxLyKmsm3kL2xoMOehrRNb7RIbbOZgrGre68AUqwAqZqNN5aKJ5N/lMq+qsd8kMa4pcWzh7muhLMEV69juq2kbKqgUaOTR5uMMZWLLZSGAI5VAIdEAH+ovNDHuNCnxcy3qcaYx8K8msGplrx+wLahjbYdXrV6vbMvYK9DrZ8QrZ8tqJuFms+Sos6sw8ecy8RffsNVeMCvmb43aLltv7Q4Y7EZWK4QWa1gt6meZKUdr6GOAZVeA4xPAISyveLUwtivxtKTpNJ2jcqfvcltiMiwwcfAoMVxhL+Kx7xjdrqTe60tsaNQs6KaRKACrJ6UTZwkqpqTL5pkHY4AloSgsd2ptNXPvNOOncuxxsqFl8lmg8apt8FJcr9EbryGxLqlkrkrY7dRa7ZGZLQ5t6iXUZ6PPpgVpZeJCJFKAIGareTa0+KJod3H0deY2M+esM25usmYu8d2zsJOdcBVvrCLbqcAOaaHaKQAMaScWqKBXqCXMJ2RHpiLF5NmJZAdAHN2kta11dKu1M+DkcZLdb+Mcql3TppyRJdzQ5ZtNZNlIY+DF4+voCOQAAAAZ3RSTlMABAT+MEEJ/RH+/TP+Zlv+pUo6Ifz8+fco/fz6+evr39S9nJmOilQaF/7+/f38+smmoYp6b1T+/v7++vj189zU0tDJxsGzsrKSfv34+Pf27dDOysG9t6+n/vv6+vr59uzr1tG+tZ6Qg9Ym3QAABR5JREFUSMeNlVVUG1EQhpcuxEspXqS0SKEtxQp1d3d332STTRpIQhIISQgJhODu7lAoDoUCpe7u7u7+1puGpqnCPOyZvffbOXPm/PsP9JfQgyCC+tmTABTOcbxDz/heENS7/1F+9nhvkHePG0wNDLbGWwdXL+rbLWvpmZHXD8+gMfBjTh+aSe6Gnn7lwQIOTR0c8wfX3PWgv7avbdKwf/ZoBp1Gp/PvuvXW3vw5ib7emnTW4OR+3D4jB9vjNJ/7gNvfWWeH/TO/JyYrsiKCRjVEZA3UB+96kON+DxOQ/NLE8PE5iUYgIXjFnCOlxEQMaSGVxjg4gxOnEycGz8bptuNjVx08LscIgrzH3umcn+KKtiBIyvzOO2O99aAdR8cF19oZalnCtvREUw79tCd5sow1g1UKM6kXqUx4T8wsi3sTjJ3yzDmmhenLXLpo8u45eG5y4Vvbk6kkC4LLtJMowkSQxmk4ggVJEG+7c6QpHT8vvW9X7/o7+3ELmiJi2mEzZJiz8cT6TBlanBk70cB5GGIGC1gRDdZ00yADLW1FL6gqhtvNXNG5S9gdSrk4M1qu7JAsmYshzDS4peoMrU/gT7qQdqYGZaYhxZmVbGJAm/CS/HloWyhRUlknQ9KYcExTwS80d3VNOxUZJpITYyspl0LbhArhpZCD9cRWEQuhYkNGMHToQ/2Cs6swJlb39CsllxdXX6IUKh/H5jbnSsPKjgmoaFQ1f8wRLR0UnGE/RcDEjj2jXG1WVTwUs8+zxfcrVO+vSsuOpVKxCfYZiQ0/aPKuxQbQ8lIz+DClxC8u+snlcJ7Yr1z1JPqUH0V+GDXbOwAib931Y4Imaq0NTIXPXY+N5L18GJ37SVWu+hwXff8l72Ds9XuwYIBaXPq6Shm4l+Vl/5QiOlV+uTk6YR9PxKsI9xNJny31ygK1e+nIRC1N97EGkFPI+jCpiHe5PCEy7oWqWSwRrpOvhFzcbTWMbm3ZJAOn1rUKpYIt/lDhW/5RHHteeWFN60qo98YJuoq1nK3uW5AabyspC1BcIEpOhft+SZAShYoLSvnmSfnYADUERP5jJn2h5XtsgCRuhYQqAvwTwn33+YWEKUI72HX5AtfSAZDe8F2DtPPm77afhl0EkthzuCQU0BWApgQIH9+KB0JhopMM7bJrdTRoleM2JAVNMyPF+wdoaz+XJpGoVAQ7WXUkcV7gT3oUZyi/ISIJAVKhgNp+4b4veCFhYVJw4locdSjZCp9cPUhLF9EZ3KKzURepMEtCDPP3VcWFx4UIiZIklIpFNfHpdEafIF2aRmOcrUmjohbT2WUllbmRvgfbythbQO3222fpDJoufaQPncYYuqoGtUEsCJZL6/3PR5b4syeSjZMQG/T2maGANlXT2v8S4AULWaUkCxfLyW8iW4kdka+nEMjxpL2NCwsYNBp+Q61PF43zyDg9Bm9+3NNySn78jMZUUkumqE4Gp7JmFOdP1vc8PpRrzj9+wPinCy8K1PiJ4aYbnTYpCCbDkBSbzhu2QJ1Gd82t8jI8TH51+OzvXoWbnXUOBkNW+0mWFwGcGOUVpU81/n3TOHb5oMt2FgYGjzau0Nif0Ss7Q3XB33hjjQHjHA5E5aOyIQc8CBrLdQSs3j92VG+3nNEjbkbdbBr9zm04ruvw37vh0QKOdeGIkckc80fX3KH/h7PT4BOjgCty8VZ5ux1MoO5Cf5naca2LAsEgehI+drX8o/0Nu+W0m6K/I9gGPd/dfx/EN/wN62AhsBWuAAAAAElFTkSuQmCC">
|
||||
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
</div>
|
||||
|
||||
## Overview
|
||||
|
||||
The DINOv2 model was proposed in [DINOv2: Learning Robust Visual Features without Supervision](https://arxiv.org/abs/2304.07193) by
|
||||
Maxime Oquab, Timothée Darcet, Théo Moutakanni, Huy Vo, Marc Szafraniec, Vasil Khalidov, Pierre Fernandez, Daniel Haziza, Francisco Massa, Alaaeldin El-Nouby, Mahmoud Assran, Nicolas Ballas, Wojciech Galuba, Russell Howes, Po-Yao Huang, Shang-Wen Li, Ishan Misra, Michael Rabbat, Vasu Sharma, Gabriel Synnaeve, Hu Xu, Hervé Jegou, Julien Mairal, Patrick Labatut, Armand Joulin, Piotr Bojanowski.
|
||||
DINOv2 is an upgrade of [DINO](https://arxiv.org/abs/2104.14294), a self-supervised method applied on [Vision Transformers](vit). This method enables all-purpose visual features, i.e., features that work across image distributions and tasks without finetuning.
|
||||
# DINOv2
|
||||
|
||||
The abstract from the paper is the following:
|
||||
[DINOv2](https://huggingface.co/papers/2304.07193) is a vision foundation model that uses [ViT](./vit) as a feature extractor for multiple downstream tasks like image classification and depth estimation. It focuses on stabilizing and accelerating training through techniques like a faster memory-efficient attention, sequence packing, improved stochastic depth, Fully Sharded Data Parallel (FSDP), and model distillation.
|
||||
|
||||
*The recent breakthroughs in natural language processing for model pretraining on large quantities of data have opened the way for similar foundation models in computer vision. These models could greatly simplify the use of images in any system by producing all-purpose visual features, i.e., features that work across image distributions and tasks without finetuning. This work shows that existing pretraining methods, especially self-supervised methods, can produce such features if trained on enough curated data from diverse sources. We revisit existing approaches and combine different techniques to scale our pretraining in terms of data and model size. Most of the technical contributions aim at accelerating and stabilizing the training at scale. In terms of data, we propose an automatic pipeline to build a dedicated, diverse, and curated image dataset instead of uncurated data, as typically done in the self-supervised literature. In terms of models, we train a ViT model (Dosovitskiy et al., 2020) with 1B parameters and distill it into a series of smaller models that surpass the best available all-purpose features, OpenCLIP (Ilharco et al., 2021) on most of the benchmarks at image and pixel levels.*
|
||||
You can find all the original DINOv2 checkpoints under the [Dinov2](https://huggingface.co/collections/facebook/dinov2-6526c98554b3d2576e071ce3) collection.
|
||||
|
||||
This model was contributed by [nielsr](https://huggingface.co/nielsr).
|
||||
The original code can be found [here](https://github.com/facebookresearch/dinov2).
|
||||
> [!TIP]
|
||||
> Click on the DINOv2 models in the right sidebar for more examples of how to apply DINOv2 to different vision tasks.
|
||||
|
||||
## Usage tips
|
||||
The example below demonstrates how to obtain an image embedding with [`Pipeline`] or the [`AutoModel`] class.
|
||||
|
||||
The model can be traced using `torch.jit.trace` which leverages JIT compilation to optimize the model making it faster to run. Note this still produces some mis-matched elements and the difference between the original model and the traced model is of the order of 1e-4.
|
||||
<hfoptions id="usage">
|
||||
<hfoption id="Pipeline">
|
||||
|
||||
```python
|
||||
```py
|
||||
import torch
|
||||
from transformers import AutoImageProcessor, AutoModel
|
||||
from PIL import Image
|
||||
from transformers import pipeline
|
||||
|
||||
pipe = pipeline(
|
||||
task="image-classification",
|
||||
model="facebook/dinov2-small-imagenet1k-1-layer",
|
||||
torch_dtype=torch.float16,
|
||||
device=0
|
||||
)
|
||||
|
||||
pipe("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg")
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="AutoModel">
|
||||
|
||||
```py
|
||||
import requests
|
||||
from transformers import AutoImageProcessor, AutoModelForImageClassification
|
||||
from PIL import Image
|
||||
|
||||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
processor = AutoImageProcessor.from_pretrained("facebook/dinov2-small-imagenet1k-1-layer")
|
||||
model = AutoModelForImageClassification.from_pretrained(
|
||||
"facebook/dinov2-small-imagenet1k-1-layer",
|
||||
torch_dtype=torch.float16,
|
||||
device_map="auto",
|
||||
attn_implementation="sdpa"
|
||||
)
|
||||
|
||||
inputs = processor(images=image, return_tensors="pt")
|
||||
logits = model(**inputs).logits
|
||||
predicted_class_idx = logits.argmax(-1).item()
|
||||
print("Predicted class:", model.config.id2label[predicted_class_idx])
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
Quantization reduces the memory burden of large models by representing the weights in a lower precision. Refer to the [Quantization](../quantization/overview) overview for more available quantization backends.
|
||||
|
||||
The example below uses [torchao](../quantization/torchao) to only quantize the weights to int4.
|
||||
|
||||
```py
|
||||
# pip install torchao
|
||||
import requests
|
||||
from transformers import TorchAoConfig, AutoImageProcessor, AutoModelForImageClassification
|
||||
from torchao.quantization import Int4WeightOnlyConfig
|
||||
from PIL import Image
|
||||
|
||||
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
|
||||
image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
processor = AutoImageProcessor.from_pretrained('facebook/dinov2-base')
|
||||
model = AutoModel.from_pretrained('facebook/dinov2-base')
|
||||
processor = AutoImageProcessor.from_pretrained('facebook/dinov2-giant-imagenet1k-1-layer')
|
||||
|
||||
quant_config = Int4WeightOnlyConfig(group_size=128)
|
||||
quantization_config = TorchAoConfig(quant_type=quant_config)
|
||||
|
||||
model = AutoModelForImageClassification.from_pretrained(
|
||||
'facebook/dinov2-giant-imagenet1k-1-layer',
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
quantization_config=quantization_config
|
||||
)
|
||||
|
||||
inputs = processor(images=image, return_tensors="pt")
|
||||
outputs = model(**inputs)
|
||||
last_hidden_states = outputs[0]
|
||||
|
||||
# We have to force return_dict=False for tracing
|
||||
model.config.return_dict = False
|
||||
|
||||
with torch.no_grad():
|
||||
traced_model = torch.jit.trace(model, [inputs.pixel_values])
|
||||
traced_outputs = traced_model(inputs.pixel_values)
|
||||
|
||||
print((last_hidden_states - traced_outputs[0]).abs().max())
|
||||
logits = outputs.logits
|
||||
predicted_class_idx = logits.argmax(-1).item()
|
||||
print("Predicted class:", model.config.id2label[predicted_class_idx])
|
||||
```
|
||||
|
||||
## Resources
|
||||
## Notes
|
||||
|
||||
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with DINOv2.
|
||||
- Use [torch.jit.trace](https://pytorch.org/docs/stable/generated/torch.jit.trace.html) to speedup inference. However, it will produce some mismatched elements. The difference between the original and traced model is 1e-4.
|
||||
|
||||
- Demo notebooks for DINOv2 can be found [here](https://github.com/NielsRogge/Transformers-Tutorials/tree/master/DINOv2). 🌎
|
||||
```py
|
||||
import torch
|
||||
from transformers import AutoImageProcessor, AutoModel
|
||||
from PIL import Image
|
||||
import requests
|
||||
|
||||
<PipelineTag pipeline="image-classification"/>
|
||||
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
|
||||
image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
- [`Dinov2ForImageClassification`] is supported by this [example script](https://github.com/huggingface/transformers/tree/main/examples/pytorch/image-classification) and [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/image_classification.ipynb).
|
||||
- See also: [Image classification task guide](../tasks/image_classification)
|
||||
processor = AutoImageProcessor.from_pretrained('facebook/dinov2-base')
|
||||
model = AutoModel.from_pretrained('facebook/dinov2-base')
|
||||
|
||||
If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we'll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource.
|
||||
inputs = processor(images=image, return_tensors="pt")
|
||||
outputs = model(**inputs)
|
||||
last_hidden_states = outputs[0]
|
||||
|
||||
# We have to force return_dict=False for tracing
|
||||
model.config.return_dict = False
|
||||
|
||||
with torch.no_grad():
|
||||
traced_model = torch.jit.trace(model, [inputs.pixel_values])
|
||||
traced_outputs = traced_model(inputs.pixel_values)
|
||||
|
||||
print((last_hidden_states - traced_outputs[0]).abs().max())
|
||||
```
|
||||
|
||||
## Dinov2Config
|
||||
|
||||
|
||||
@ -14,199 +14,91 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# DistilBERT
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="TensorFlow" src="https://img.shields.io/badge/TensorFlow-FF6F00?style=flat&logo=tensorflow&logoColor=white">
|
||||
<img alt="Flax" src="https://img.shields.io/badge/Flax-29a79b.svg?style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAC0AAAAtCAMAAAANxBKoAAAC7lBMVEUAAADg5vYHPVgAoJH+/v76+v39/f9JbLP///9+AIgAnY3///+mcqzt8fXy9fgkXa3Ax9709fr+///9/f8qXq49qp5AaLGMwrv8/P0eW60VWawxYq8yqJzG2dytt9Wyu9elzci519Lf3O3S2efY3OrY0+Xp7PT///////+dqNCexMc6Z7AGpJeGvbenstPZ5ejQ1OfJzOLa7ejh4+/r8fT29vpccbklWK8PVa0AS6ghW63O498vYa+lsdKz1NDRt9Kw1c672tbD3tnAxt7R6OHp5vDe7OrDyuDn6vLl6/EAQKak0MgATakkppo3ZK/Bz9y8w9yzu9jey97axdvHzeG21NHH4trTwthKZrVGZLSUSpuPQJiGAI+GAI8SWKydycLL4d7f2OTi1+S9xNzL0ePT6OLGzeEAo5U0qJw/aLEAo5JFa7JBabEAp5Y4qZ2QxLyKmsm3kL2xoMOehrRNb7RIbbOZgrGre68AUqwAqZqNN5aKJ5N/lMq+qsd8kMa4pcWzh7muhLMEV69juq2kbKqgUaOTR5uMMZWLLZSGAI5VAIdEAH+ovNDHuNCnxcy3qcaYx8K8msGplrx+wLahjbYdXrV6vbMvYK9DrZ8QrZ8tqJuFms+Sos6sw8ecy8RffsNVeMCvmb43aLltv7Q4Y7EZWK4QWa1gt6meZKUdr6GOAZVeA4xPAISyveLUwtivxtKTpNJ2jcqfvcltiMiwwcfAoMVxhL+Kx7xjdrqTe60tsaNQs6KaRKACrJ6UTZwkqpqTL5pkHY4AloSgsd2ptNXPvNOOncuxxsqFl8lmg8apt8FJcr9EbryGxLqlkrkrY7dRa7ZGZLQ5t6iXUZ6PPpgVpZeJCJFKAIGareTa0+KJod3H0deY2M+esM25usmYu8d2zsJOdcBVvrCLbqcAOaaHaKQAMaScWqKBXqCXMJ2RHpiLF5NmJZAdAHN2kta11dKu1M+DkcZLdb+Mcql3TppyRJdzQ5ZtNZNlIY+DF4+voCOQAAAAZ3RSTlMABAT+MEEJ/RH+/TP+Zlv+pUo6Ifz8+fco/fz6+evr39S9nJmOilQaF/7+/f38+smmoYp6b1T+/v7++vj189zU0tDJxsGzsrKSfv34+Pf27dDOysG9t6+n/vv6+vr59uzr1tG+tZ6Qg9Ym3QAABR5JREFUSMeNlVVUG1EQhpcuxEspXqS0SKEtxQp1d3d332STTRpIQhIISQgJhODu7lAoDoUCpe7u7u7+1puGpqnCPOyZvffbOXPm/PsP9JfQgyCC+tmTABTOcbxDz/heENS7/1F+9nhvkHePG0wNDLbGWwdXL+rbLWvpmZHXD8+gMfBjTh+aSe6Gnn7lwQIOTR0c8wfX3PWgv7avbdKwf/ZoBp1Gp/PvuvXW3vw5ib7emnTW4OR+3D4jB9vjNJ/7gNvfWWeH/TO/JyYrsiKCRjVEZA3UB+96kON+DxOQ/NLE8PE5iUYgIXjFnCOlxEQMaSGVxjg4gxOnEycGz8bptuNjVx08LscIgrzH3umcn+KKtiBIyvzOO2O99aAdR8cF19oZalnCtvREUw79tCd5sow1g1UKM6kXqUx4T8wsi3sTjJ3yzDmmhenLXLpo8u45eG5y4Vvbk6kkC4LLtJMowkSQxmk4ggVJEG+7c6QpHT8vvW9X7/o7+3ELmiJi2mEzZJiz8cT6TBlanBk70cB5GGIGC1gRDdZ00yADLW1FL6gqhtvNXNG5S9gdSrk4M1qu7JAsmYshzDS4peoMrU/gT7qQdqYGZaYhxZmVbGJAm/CS/HloWyhRUlknQ9KYcExTwS80d3VNOxUZJpITYyspl0LbhArhpZCD9cRWEQuhYkNGMHToQ/2Cs6swJlb39CsllxdXX6IUKh/H5jbnSsPKjgmoaFQ1f8wRLR0UnGE/RcDEjj2jXG1WVTwUs8+zxfcrVO+vSsuOpVKxCfYZiQ0/aPKuxQbQ8lIz+DClxC8u+snlcJ7Yr1z1JPqUH0V+GDXbOwAib931Y4Imaq0NTIXPXY+N5L18GJ37SVWu+hwXff8l72Ds9XuwYIBaXPq6Shm4l+Vl/5QiOlV+uTk6YR9PxKsI9xNJny31ygK1e+nIRC1N97EGkFPI+jCpiHe5PCEy7oWqWSwRrpOvhFzcbTWMbm3ZJAOn1rUKpYIt/lDhW/5RHHteeWFN60qo98YJuoq1nK3uW5AabyspC1BcIEpOhft+SZAShYoLSvnmSfnYADUERP5jJn2h5XtsgCRuhYQqAvwTwn33+YWEKUI72HX5AtfSAZDe8F2DtPPm77afhl0EkthzuCQU0BWApgQIH9+KB0JhopMM7bJrdTRoleM2JAVNMyPF+wdoaz+XJpGoVAQ7WXUkcV7gT3oUZyi/ISIJAVKhgNp+4b4veCFhYVJw4locdSjZCp9cPUhLF9EZ3KKzURepMEtCDPP3VcWFx4UIiZIklIpFNfHpdEafIF2aRmOcrUmjohbT2WUllbmRvgfbythbQO3222fpDJoufaQPncYYuqoGtUEsCJZL6/3PR5b4syeSjZMQG/T2maGANlXT2v8S4AULWaUkCxfLyW8iW4kdka+nEMjxpL2NCwsYNBp+Q61PF43zyDg9Bm9+3NNySn78jMZUUkumqE4Gp7JmFOdP1vc8PpRrzj9+wPinCy8K1PiJ4aYbnTYpCCbDkBSbzhu2QJ1Gd82t8jI8TH51+OzvXoWbnXUOBkNW+0mWFwGcGOUVpU81/n3TOHb5oMt2FgYGjzau0Nif0Ss7Q3XB33hjjQHjHA5E5aOyIQc8CBrLdQSs3j92VG+3nNEjbkbdbBr9zm04ruvw37vh0QKOdeGIkckc80fX3KH/h7PT4BOjgCty8VZ5ux1MoO5Cf5naca2LAsEgehI+drX8o/0Nu+W0m6K/I9gGPd/dfx/EN/wN62AhsBWuAAAAAElFTkSuQmCC
|
||||
">
|
||||
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<div style="float: right;">
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="TensorFlow" src="https://img.shields.io/badge/TensorFlow-FF6F00?style=flat&logo=tensorflow&logoColor=white">
|
||||
<img alt="Flax" src="https://img.shields.io/badge/Flax-29a79b.svg?style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAC0AAAAtCAMAAAANxBKoAAAC7lBMVEUAAADg5vYHPVgAoJH+/v76+v39/f9JbLP///9+AIgAnY3///+mcqzt8fXy9fgkXa3Ax9709fr+///9/f8qXq49qp5AaLGMwrv8/P0eW60VWawxYq8yqJzG2dytt9Wyu9elzci519Lf3O3S2efY3OrY0+Xp7PT///////+dqNCexMc6Z7AGpJeGvbenstPZ5ejQ1OfJzOLa7ejh4+/r8fT29vpccbklWK8PVa0AS6ghW63O498vYa+lsdKz1NDRt9Kw1c672tbD3tnAxt7R6OHp5vDe7OrDyuDn6vLl6/EAQKak0MgATakkppo3ZK/Bz9y8w9yzu9jey97axdvHzeG21NHH4trTwthKZrVGZLSUSpuPQJiGAI+GAI8SWKydycLL4d7f2OTi1+S9xNzL0ePT6OLGzeEAo5U0qJw/aLEAo5JFa7JBabEAp5Y4qZ2QxLyKmsm3kL2xoMOehrRNb7RIbbOZgrGre68AUqwAqZqNN5aKJ5N/lMq+qsd8kMa4pcWzh7muhLMEV69juq2kbKqgUaOTR5uMMZWLLZSGAI5VAIdEAH+ovNDHuNCnxcy3qcaYx8K8msGplrx+wLahjbYdXrV6vbMvYK9DrZ8QrZ8tqJuFms+Sos6sw8ecy8RffsNVeMCvmb43aLltv7Q4Y7EZWK4QWa1gt6meZKUdr6GOAZVeA4xPAISyveLUwtivxtKTpNJ2jcqfvcltiMiwwcfAoMVxhL+Kx7xjdrqTe60tsaNQs6KaRKACrJ6UTZwkqpqTL5pkHY4AloSgsd2ptNXPvNOOncuxxsqFl8lmg8apt8FJcr9EbryGxLqlkrkrY7dRa7ZGZLQ5t6iXUZ6PPpgVpZeJCJFKAIGareTa0+KJod3H0deY2M+esM25usmYu8d2zsJOdcBVvrCLbqcAOaaHaKQAMaScWqKBXqCXMJ2RHpiLF5NmJZAdAHN2kta11dKu1M+DkcZLdb+Mcql3TppyRJdzQ5ZtNZNlIY+DF4+voCOQAAAAZ3RSTlMABAT+MEEJ/RH+/TP+Zlv+pUo6Ifz8+fco/fz6+evr39S9nJmOilQaF/7+/f38+smmoYp6b1T+/v7++vj189zU0tDJxsGzsrKSfv34+Pf27dDOysG9t6+n/vv6+vr59uzr1tG+tZ6Qg9Ym3QAABR5JREFUSMeNlVVUG1EQhpcuxEspXqS0SKEtxQp1d3d332STTRpIQhIISQgJhODu7lAoDoUCpe7u7u7+1puGpqnCPOyZvffbOXPm/PsP9JfQgyCC+tmTABTOcbxDz/heENS7/1F+9nhvkHePG0wNDLbGWwdXL+rbLWvpmZHXD8+gMfBjTh+aSe6Gnn7lwQIOTR0c8wfX3PWgv7avbdKwf/ZoBp1Gp/PvuvXW3vw5ib7emnTW4OR+3D4jB9vjNJ/7gNvfWWeH/TO/JyYrsiKCRjVEZA3UB+96kON+DxOQ/NLE8PE5iUYgIXjFnCOlxEQMaSGVxjg4gxOnEycGz8bptuNjVx08LscIgrzH3umcn+KKtiBIyvzOO2O99aAdR8cF19oZalnCtvREUw79tCd5sow1g1UKM6kXqUx4T8wsi3sTjJ3yzDmmhenLXLpo8u45eG5y4Vvbk6kkC4LLtJMowkSQxmk4ggVJEG+7c6QpHT8vvW9X7/o7+3ELmiJi2mEzZJiz8cT6TBlanBk70cB5GGIGC1gRDdZ00yADLW1FL6gqhtvNXNG5S9gdSrk4M1qu7JAsmYshzDS4peoMrU/gT7qQdqYGZaYhxZmVbGJAm/CS/HloWyhRUlknQ9KYcExTwS80d3VNOxUZJpITYyspl0LbhArhpZCD9cRWEQuhYkNGMHToQ/2Cs6swJlb39CsllxdXX6IUKh/H5jbnSsPKjgmoaFQ1f8wRLR0UnGE/RcDEjj2jXG1WVTwUs8+zxfcrVO+vSsuOpVKxCfYZiQ0/aPKuxQbQ8lIz+DClxC8u+snlcJ7Yr1z1JPqUH0V+GDXbOwAib931Y4Imaq0NTIXPXY+N5L18GJ37SVWu+hwXff8l72Ds9XuwYIBaXPq6Shm4l+Vl/5QiOlV+uTk6YR9PxKsI9xNJny31ygK1e+nIRC1N97EGkFPI+jCpiHe5PCEy7oWqWSwRrpOvhFzcbTWMbm3ZJAOn1rUKpYIt/lDhW/5RHHteeWFN60qo98YJuoq1nK3uW5AabyspC1BcIEpOhft+SZAShYoLSvnmSfnYADUERP5jJn2h5XtsgCRuhYQqAvwTwn33+YWEKUI72HX5AtfSAZDe8F2DtPPm77afhl0EkthzuCQU0BWApgQIH9+KB0JhopMM7bJrdTRoleM2JAVNMyPF+wdoaz+XJpGoVAQ7WXUkcV7gT3oUZyi/ISIJAVKhgNp+4b4veCFhYVJw4locdSjZCp9cPUhLF9EZ3KKzURepMEtCDPP3VcWFx4UIiZIklIpFNfHpdEafIF2aRmOcrUmjohbT2WUllbmRvgfbythbQO3222fpDJoufaQPncYYuqoGtUEsCJZL6/3PR5b4syeSjZMQG/T2maGANlXT2v8S4AULWaUkCxfLyW8iW4kdka+nEMjxpL2NCwsYNBp+Q61PF43zyDg9Bm9+3NNySn78jMZUUkumqE4Gp7JmFOdP1vc8PpRrzj9+wPinCy8K1PiJ4aYbnTYpCCbDkBSbzhu2QJ1Gd82t8jI8TH51+OzvXoWbnXUOBkNW+0mWFwGcGOUVpU81/n3TOHb5oMt2FgYGjzau0Nif0Ss7Q3XB33hjjQHjHA5E5aOyIQc8CBrLdQSs3j92VG+3nNEjbkbdbBr9zm04ruvw37vh0QKOdeGIkckc80fX3KH/h7PT4BOjgCty8VZ5ux1MoO5Cf5naca2LAsEgehI+drX8o/0Nu+W0m6K/I9gGPd/dfx/EN/wN62AhsBWuAAAAAElFTkSuQmCC">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
|
||||
</div>
|
||||
</div>
|
||||
|
||||
## Overview
|
||||
# DistilBERT
|
||||
|
||||
The DistilBERT model was proposed in the blog post [Smaller, faster, cheaper, lighter: Introducing DistilBERT, a
|
||||
distilled version of BERT](https://medium.com/huggingface/distilbert-8cf3380435b5), and the paper [DistilBERT, a
|
||||
distilled version of BERT: smaller, faster, cheaper and lighter](https://arxiv.org/abs/1910.01108). DistilBERT is a
|
||||
small, fast, cheap and light Transformer model trained by distilling BERT base. It has 40% less parameters than
|
||||
*google-bert/bert-base-uncased*, runs 60% faster while preserving over 95% of BERT's performances as measured on the GLUE language
|
||||
understanding benchmark.
|
||||
[DistilBERT](https://huggingface.co/papers/1910.01108) is pretrained by knowledge distillation to create a smaller model with faster inference and requires less compute to train. Through a triple loss objective during pretraining, language modeling loss, distillation loss, cosine-distance loss, DistilBERT demonstrates similar performance to a larger transformer language model.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
You can find all the original DistilBERT checkpoints under the [DistilBERT](https://huggingface.co/distilbert) organization.
|
||||
|
||||
*As Transfer Learning from large-scale pre-trained models becomes more prevalent in Natural Language Processing (NLP),
|
||||
operating these large models in on-the-edge and/or under constrained computational training or inference budgets
|
||||
remains challenging. In this work, we propose a method to pre-train a smaller general-purpose language representation
|
||||
model, called DistilBERT, which can then be fine-tuned with good performances on a wide range of tasks like its larger
|
||||
counterparts. While most prior work investigated the use of distillation for building task-specific models, we leverage
|
||||
knowledge distillation during the pretraining phase and show that it is possible to reduce the size of a BERT model by
|
||||
40%, while retaining 97% of its language understanding capabilities and being 60% faster. To leverage the inductive
|
||||
biases learned by larger models during pretraining, we introduce a triple loss combining language modeling,
|
||||
distillation and cosine-distance losses. Our smaller, faster and lighter model is cheaper to pre-train and we
|
||||
demonstrate its capabilities for on-device computations in a proof-of-concept experiment and a comparative on-device
|
||||
study.*
|
||||
> [!TIP]
|
||||
> Click on the DistilBERT models in the right sidebar for more examples of how to apply DistilBERT to different language tasks.
|
||||
|
||||
This model was contributed by [victorsanh](https://huggingface.co/victorsanh). This model jax version was
|
||||
contributed by [kamalkraj](https://huggingface.co/kamalkraj). The original code can be found [here](https://github.com/huggingface/transformers-research-projects/tree/main/distillation).
|
||||
The example below demonstrates how to classify text with [`Pipeline`], [`AutoModel`], and from the command line.
|
||||
|
||||
## Usage tips
|
||||
<hfoptions id="usage">
|
||||
|
||||
<hfoption id="Pipeline">
|
||||
|
||||
```py
|
||||
from transformers import pipeline
|
||||
|
||||
classifier = pipeline(
|
||||
task="text-classification",
|
||||
model="distilbert-base-uncased-finetuned-sst-2-english",
|
||||
torch_dtype=torch.float16,
|
||||
device=0
|
||||
)
|
||||
|
||||
result = classifier("I love using Hugging Face Transformers!")
|
||||
print(result)
|
||||
# Output: [{'label': 'POSITIVE', 'score': 0.9998}]
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
|
||||
<hfoption id="AutoModel">
|
||||
|
||||
```py
|
||||
import torch
|
||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"distilbert/distilbert-base-uncased-finetuned-sst-2-english",
|
||||
)
|
||||
model = AutoModelForSequenceClassification.from_pretrained(
|
||||
"distilbert/distilbert-base-uncased-finetuned-sst-2-english",
|
||||
torch_dtype=torch.float16,
|
||||
device_map="auto",
|
||||
attn_implementation="sdpa"
|
||||
)
|
||||
inputs = tokenizer("I love using Hugging Face Transformers!", return_tensors="pt").to("cuda")
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
|
||||
predicted_class_id = torch.argmax(outputs.logits, dim=-1).item()
|
||||
predicted_label = model.config.id2label[predicted_class_id]
|
||||
print(f"Predicted label: {predicted_label}")
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
|
||||
<hfoption id="transformers-cli">
|
||||
|
||||
```bash
|
||||
echo -e "I love using Hugging Face Transformers!" | transformers-cli run --task text-classification --model distilbert-base-uncased-finetuned-sst-2-english
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
|
||||
</hfoptions>
|
||||
|
||||
## Notes
|
||||
|
||||
- DistilBERT doesn't have `token_type_ids`, you don't need to indicate which token belongs to which segment. Just
|
||||
separate your segments with the separation token `tokenizer.sep_token` (or `[SEP]`).
|
||||
- DistilBERT doesn't have options to select the input positions (`position_ids` input). This could be added if
|
||||
necessary though, just let us know if you need this option.
|
||||
- Same as BERT but smaller. Trained by distillation of the pretrained BERT model, meaning it’s been trained to predict the same probabilities as the larger model. The actual objective is a combination of:
|
||||
|
||||
* finding the same probabilities as the teacher model
|
||||
* predicting the masked tokens correctly (but no next-sentence objective)
|
||||
* a cosine similarity between the hidden states of the student and the teacher model
|
||||
|
||||
### Using Scaled Dot Product Attention (SDPA)
|
||||
|
||||
PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function
|
||||
encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the
|
||||
[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
|
||||
or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention)
|
||||
page for more information.
|
||||
|
||||
SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set
|
||||
`attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used.
|
||||
|
||||
```
|
||||
from transformers import DistilBertModel
|
||||
model = DistilBertModel.from_pretrained("distilbert-base-uncased", torch_dtype=torch.float16, attn_implementation="sdpa")
|
||||
```
|
||||
|
||||
For the best speedups, we recommend loading the model in half-precision (e.g. `torch.float16` or `torch.bfloat16`).
|
||||
|
||||
On a local benchmark (NVIDIA GeForce RTX 2060-8GB, PyTorch 2.3.1, OS Ubuntu 20.04) with `float16` and the `distilbert-base-uncased` model with
|
||||
a MaskedLM head, we saw the following speedups during training and inference.
|
||||
|
||||
#### Training
|
||||
|
||||
| num_training_steps | batch_size | seq_len | is cuda | Time per batch (eager - s) | Time per batch (sdpa - s) | Speedup (%) | Eager peak mem (MB) | sdpa peak mem (MB) | Mem saving (%) |
|
||||
|--------------------|------------|---------|---------|----------------------------|---------------------------|-------------|---------------------|--------------------|----------------|
|
||||
| 100 | 1 | 128 | False | 0.010 | 0.008 | 28.870 | 397.038 | 399.629 | -0.649 |
|
||||
| 100 | 1 | 256 | False | 0.011 | 0.009 | 20.681 | 412.505 | 412.606 | -0.025 |
|
||||
| 100 | 2 | 128 | False | 0.011 | 0.009 | 23.741 | 412.213 | 412.606 | -0.095 |
|
||||
| 100 | 2 | 256 | False | 0.015 | 0.013 | 16.502 | 427.491 | 425.787 | 0.400 |
|
||||
| 100 | 4 | 128 | False | 0.015 | 0.013 | 13.828 | 427.491 | 425.787 | 0.400 |
|
||||
| 100 | 4 | 256 | False | 0.025 | 0.022 | 12.882 | 594.156 | 502.745 | 18.182 |
|
||||
| 100 | 8 | 128 | False | 0.023 | 0.022 | 8.010 | 545.922 | 502.745 | 8.588 |
|
||||
| 100 | 8 | 256 | False | 0.046 | 0.041 | 12.763 | 983.450 | 798.480 | 23.165 |
|
||||
|
||||
#### Inference
|
||||
|
||||
| num_batches | batch_size | seq_len | is cuda | is half | use mask | Per token latency eager (ms) | Per token latency SDPA (ms) | Speedup (%) | Mem eager (MB) | Mem BT (MB) | Mem saved (%) |
|
||||
|-------------|------------|---------|---------|---------|----------|-----------------------------|-----------------------------|-------------|----------------|--------------|---------------|
|
||||
| 50 | 2 | 64 | True | True | True | 0.032 | 0.025 | 28.192 | 154.532 | 155.531 | -0.642 |
|
||||
| 50 | 2 | 128 | True | True | True | 0.033 | 0.025 | 32.636 | 157.286 | 157.482 | -0.125 |
|
||||
| 50 | 4 | 64 | True | True | True | 0.032 | 0.026 | 24.783 | 157.023 | 157.449 | -0.271 |
|
||||
| 50 | 4 | 128 | True | True | True | 0.034 | 0.028 | 19.299 | 162.794 | 162.269 | 0.323 |
|
||||
| 50 | 8 | 64 | True | True | True | 0.035 | 0.028 | 25.105 | 160.958 | 162.204 | -0.768 |
|
||||
| 50 | 8 | 128 | True | True | True | 0.052 | 0.046 | 12.375 | 173.155 | 171.844 | 0.763 |
|
||||
| 50 | 16 | 64 | True | True | True | 0.051 | 0.045 | 12.882 | 172.106 | 171.713 | 0.229 |
|
||||
| 50 | 16 | 128 | True | True | True | 0.096 | 0.081 | 18.524 | 191.257 | 191.517 | -0.136 |
|
||||
|
||||
|
||||
## Resources
|
||||
|
||||
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with DistilBERT. If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we'll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource.
|
||||
|
||||
<PipelineTag pipeline="text-classification"/>
|
||||
|
||||
- A blog post on [Getting Started with Sentiment Analysis using Python](https://huggingface.co/blog/sentiment-analysis-python) with DistilBERT.
|
||||
- A blog post on how to [train DistilBERT with Blurr for sequence classification](https://huggingface.co/blog/fastai).
|
||||
- A blog post on how to use [Ray to tune DistilBERT hyperparameters](https://huggingface.co/blog/ray-tune).
|
||||
- A blog post on how to [train DistilBERT with Hugging Face and Amazon SageMaker](https://huggingface.co/blog/the-partnership-amazon-sagemaker-and-hugging-face).
|
||||
- A notebook on how to [finetune DistilBERT for multi-label classification](https://colab.research.google.com/github/DhavalTaunk08/Transformers_scripts/blob/master/Transformers_multilabel_distilbert.ipynb). 🌎
|
||||
- A notebook on how to [finetune DistilBERT for multiclass classification with PyTorch](https://colab.research.google.com/github/abhimishra91/transformers-tutorials/blob/master/transformers_multiclass_classification.ipynb). 🌎
|
||||
- A notebook on how to [finetune DistilBERT for text classification in TensorFlow](https://colab.research.google.com/github/peterbayerle/huggingface_notebook/blob/main/distilbert_tf.ipynb). 🌎
|
||||
- [`DistilBertForSequenceClassification`] is supported by this [example script](https://github.com/huggingface/transformers/tree/main/examples/pytorch/text-classification) and [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/text_classification.ipynb).
|
||||
- [`TFDistilBertForSequenceClassification`] is supported by this [example script](https://github.com/huggingface/transformers/tree/main/examples/tensorflow/text-classification) and [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/text_classification-tf.ipynb).
|
||||
- [`FlaxDistilBertForSequenceClassification`] is supported by this [example script](https://github.com/huggingface/transformers/tree/main/examples/flax/text-classification) and [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/text_classification_flax.ipynb).
|
||||
- [Text classification task guide](../tasks/sequence_classification)
|
||||
|
||||
|
||||
<PipelineTag pipeline="token-classification"/>
|
||||
|
||||
- [`DistilBertForTokenClassification`] is supported by this [example script](https://github.com/huggingface/transformers/tree/main/examples/pytorch/token-classification) and [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/token_classification.ipynb).
|
||||
- [`TFDistilBertForTokenClassification`] is supported by this [example script](https://github.com/huggingface/transformers/tree/main/examples/tensorflow/token-classification) and [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/token_classification-tf.ipynb).
|
||||
- [`FlaxDistilBertForTokenClassification`] is supported by this [example script](https://github.com/huggingface/transformers/tree/main/examples/flax/token-classification).
|
||||
- [Token classification](https://huggingface.co/course/chapter7/2?fw=pt) chapter of the 🤗 Hugging Face Course.
|
||||
- [Token classification task guide](../tasks/token_classification)
|
||||
|
||||
|
||||
<PipelineTag pipeline="fill-mask"/>
|
||||
|
||||
- [`DistilBertForMaskedLM`] is supported by this [example script](https://github.com/huggingface/transformers/tree/main/examples/pytorch/language-modeling#robertabertdistilbert-and-masked-language-modeling) and [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/language_modeling.ipynb).
|
||||
- [`TFDistilBertForMaskedLM`] is supported by this [example script](https://github.com/huggingface/transformers/tree/main/examples/tensorflow/language-modeling#run_mlmpy) and [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/language_modeling-tf.ipynb).
|
||||
- [`FlaxDistilBertForMaskedLM`] is supported by this [example script](https://github.com/huggingface/transformers/tree/main/examples/flax/language-modeling#masked-language-modeling) and [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/masked_language_modeling_flax.ipynb).
|
||||
- [Masked language modeling](https://huggingface.co/course/chapter7/3?fw=pt) chapter of the 🤗 Hugging Face Course.
|
||||
- [Masked language modeling task guide](../tasks/masked_language_modeling)
|
||||
|
||||
<PipelineTag pipeline="question-answering"/>
|
||||
|
||||
- [`DistilBertForQuestionAnswering`] is supported by this [example script](https://github.com/huggingface/transformers/tree/main/examples/pytorch/question-answering) and [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/question_answering.ipynb).
|
||||
- [`TFDistilBertForQuestionAnswering`] is supported by this [example script](https://github.com/huggingface/transformers/tree/main/examples/tensorflow/question-answering) and [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/question_answering-tf.ipynb).
|
||||
- [`FlaxDistilBertForQuestionAnswering`] is supported by this [example script](https://github.com/huggingface/transformers/tree/main/examples/flax/question-answering).
|
||||
- [Question answering](https://huggingface.co/course/chapter7/7?fw=pt) chapter of the 🤗 Hugging Face Course.
|
||||
- [Question answering task guide](../tasks/question_answering)
|
||||
|
||||
**Multiple choice**
|
||||
- [`DistilBertForMultipleChoice`] is supported by this [example script](https://github.com/huggingface/transformers/tree/main/examples/pytorch/multiple-choice) and [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/multiple_choice.ipynb).
|
||||
- [`TFDistilBertForMultipleChoice`] is supported by this [example script](https://github.com/huggingface/transformers/tree/main/examples/tensorflow/multiple-choice) and [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/multiple_choice-tf.ipynb).
|
||||
- [Multiple choice task guide](../tasks/multiple_choice)
|
||||
|
||||
⚗️ Optimization
|
||||
|
||||
- A blog post on how to [quantize DistilBERT with 🤗 Optimum and Intel](https://huggingface.co/blog/intel).
|
||||
- A blog post on how [Optimizing Transformers for GPUs with 🤗 Optimum](https://www.philschmid.de/optimizing-transformers-with-optimum-gpu).
|
||||
- A blog post on [Optimizing Transformers with Hugging Face Optimum](https://www.philschmid.de/optimizing-transformers-with-optimum).
|
||||
|
||||
⚡️ Inference
|
||||
|
||||
- A blog post on how to [Accelerate BERT inference with Hugging Face Transformers and AWS Inferentia](https://huggingface.co/blog/bert-inferentia-sagemaker) with DistilBERT.
|
||||
- A blog post on [Serverless Inference with Hugging Face's Transformers, DistilBERT and Amazon SageMaker](https://www.philschmid.de/sagemaker-serverless-huggingface-distilbert).
|
||||
|
||||
🚀 Deploy
|
||||
|
||||
- A blog post on how to [deploy DistilBERT on Google Cloud](https://huggingface.co/blog/how-to-deploy-a-pipeline-to-google-clouds).
|
||||
- A blog post on how to [deploy DistilBERT with Amazon SageMaker](https://huggingface.co/blog/deploy-hugging-face-models-easily-with-amazon-sagemaker).
|
||||
- A blog post on how to [Deploy BERT with Hugging Face Transformers, Amazon SageMaker and Terraform module](https://www.philschmid.de/terraform-huggingface-amazon-sagemaker).
|
||||
|
||||
|
||||
## Combining DistilBERT and Flash Attention 2
|
||||
|
||||
First, make sure to install the latest version of Flash Attention 2 to include the sliding window attention feature.
|
||||
|
||||
```bash
|
||||
pip install -U flash-attn --no-build-isolation
|
||||
```
|
||||
|
||||
Make also sure that you have a hardware that is compatible with Flash-Attention 2. Read more about it in the official documentation of flash-attn repository. Make also sure to load your model in half-precision (e.g. `torch.float16`)
|
||||
|
||||
To load and run a model using Flash Attention 2, refer to the snippet below:
|
||||
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from transformers import AutoTokenizer, AutoModel
|
||||
|
||||
>>> device = "cuda" # the device to load the model onto
|
||||
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained('distilbert/distilbert-base-uncased')
|
||||
>>> model = AutoModel.from_pretrained("distilbert/distilbert-base-uncased", torch_dtype=torch.float16, attn_implementation="flash_attention_2")
|
||||
|
||||
>>> text = "Replace me by any text you'd like."
|
||||
|
||||
>>> encoded_input = tokenizer(text, return_tensors='pt').to(device)
|
||||
>>> model.to(device)
|
||||
|
||||
>>> output = model(**encoded_input)
|
||||
```
|
||||
|
||||
|
||||
## DistilBertConfig
|
||||
|
||||
|
||||
@ -13,180 +13,191 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
specific language governing permissions and limitations under the License. -->
|
||||
|
||||
<div style="float: right;">
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
</div>
|
||||
|
||||
# Donut
|
||||
|
||||
## Overview
|
||||
[Donut (Document Understanding Transformer)](https://huggingface.co/papers2111.15664) is a visual document understanding model that doesn't require an Optical Character Recognition (OCR) engine. Unlike traditional approaches that extract text using OCR before processing, Donut employs an end-to-end Transformer-based architecture to directly analyze document images. This eliminates OCR-related inefficiencies making it more accurate and adaptable to diverse languages and formats.
|
||||
|
||||
The Donut model was proposed in [OCR-free Document Understanding Transformer](https://arxiv.org/abs/2111.15664) by
|
||||
Geewook Kim, Teakgyu Hong, Moonbin Yim, Jeongyeon Nam, Jinyoung Park, Jinyeong Yim, Wonseok Hwang, Sangdoo Yun, Dongyoon Han, Seunghyun Park.
|
||||
Donut consists of an image Transformer encoder and an autoregressive text Transformer decoder to perform document understanding
|
||||
tasks such as document image classification, form understanding and visual question answering.
|
||||
Donut features vision encoder ([Swin](./swin)) and a text decoder ([BART](./bart)). Swin converts document images into embeddings and BART processes them into meaningful text sequences.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
You can find all the original Donut checkpoints under the [Naver Clova Information Extraction](https://huggingface.co/naver-clova-ix) organization.
|
||||
|
||||
*Understanding document images (e.g., invoices) is a core but challenging task since it requires complex functions such as reading text and a holistic understanding of the document. Current Visual Document Understanding (VDU) methods outsource the task of reading text to off-the-shelf Optical Character Recognition (OCR) engines and focus on the understanding task with the OCR outputs. Although such OCR-based approaches have shown promising performance, they suffer from 1) high computational costs for using OCR; 2) inflexibility of OCR models on languages or types of document; 3) OCR error propagation to the subsequent process. To address these issues, in this paper, we introduce a novel OCR-free VDU model named Donut, which stands for Document understanding transformer. As the first step in OCR-free VDU research, we propose a simple architecture (i.e., Transformer) with a pre-training objective (i.e., cross-entropy loss). Donut is conceptually simple yet effective. Through extensive experiments and analyses, we show a simple OCR-free VDU model, Donut, achieves state-of-the-art performances on various VDU tasks in terms of both speed and accuracy. In addition, we offer a synthetic data generator that helps the model pre-training to be flexible in various languages and domains.*
|
||||
> [!TIP]
|
||||
> Click on the Donut models in the right sidebar for more examples of how to apply Donut to different language and vision tasks.
|
||||
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/donut_architecture.jpg"
|
||||
alt="drawing" width="600"/>
|
||||
The examples below demonstrate how to perform document understanding tasks using Donut with [`Pipeline`] and [`AutoModel`]
|
||||
|
||||
<small> Donut high-level overview. Taken from the <a href="https://arxiv.org/abs/2111.15664">original paper</a>. </small>
|
||||
|
||||
This model was contributed by [nielsr](https://huggingface.co/nielsr). The original code can be found
|
||||
[here](https://github.com/clovaai/donut).
|
||||
|
||||
## Usage tips
|
||||
|
||||
- The quickest way to get started with Donut is by checking the [tutorial
|
||||
notebooks](https://github.com/NielsRogge/Transformers-Tutorials/tree/master/Donut), which show how to use the model
|
||||
at inference time as well as fine-tuning on custom data.
|
||||
- Donut is always used within the [VisionEncoderDecoder](vision-encoder-decoder) framework.
|
||||
|
||||
## Inference examples
|
||||
|
||||
Donut's [`VisionEncoderDecoder`] model accepts images as input and makes use of
|
||||
[`~generation.GenerationMixin.generate`] to autoregressively generate text given the input image.
|
||||
|
||||
The [`DonutImageProcessor`] class is responsible for preprocessing the input image and
|
||||
[`XLMRobertaTokenizer`/`XLMRobertaTokenizerFast`] decodes the generated target tokens to the target string. The
|
||||
[`DonutProcessor`] wraps [`DonutImageProcessor`] and [`XLMRobertaTokenizer`/`XLMRobertaTokenizerFast`]
|
||||
into a single instance to both extract the input features and decode the predicted token ids.
|
||||
|
||||
- Step-by-step Document Image Classification
|
||||
<hfoptions id="usage">
|
||||
<hfoption id="Pipeline">
|
||||
|
||||
```py
|
||||
>>> import re
|
||||
# pip install datasets
|
||||
import torch
|
||||
from transformers import pipeline
|
||||
from PIL import Image
|
||||
|
||||
>>> from transformers import DonutProcessor, VisionEncoderDecoderModel
|
||||
>>> from datasets import load_dataset
|
||||
>>> import torch
|
||||
pipeline = pipeline(
|
||||
task="document-question-answering",
|
||||
model="naver-clova-ix/donut-base-finetuned-docvqa",
|
||||
device=0,
|
||||
torch_dtype=torch.float16
|
||||
)
|
||||
dataset = load_dataset("hf-internal-testing/example-documents", split="test")
|
||||
image = dataset[0]["image"]
|
||||
|
||||
>>> processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-rvlcdip")
|
||||
>>> model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-rvlcdip")
|
||||
|
||||
>>> device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
>>> model.to(device) # doctest: +IGNORE_RESULT
|
||||
|
||||
>>> # load document image
|
||||
>>> dataset = load_dataset("hf-internal-testing/example-documents", split="test")
|
||||
>>> image = dataset[1]["image"]
|
||||
|
||||
>>> # prepare decoder inputs
|
||||
>>> task_prompt = "<s_rvlcdip>"
|
||||
>>> decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
|
||||
|
||||
>>> pixel_values = processor(image, return_tensors="pt").pixel_values
|
||||
|
||||
>>> outputs = model.generate(
|
||||
... pixel_values.to(device),
|
||||
... decoder_input_ids=decoder_input_ids.to(device),
|
||||
... max_length=model.decoder.config.max_position_embeddings,
|
||||
... pad_token_id=processor.tokenizer.pad_token_id,
|
||||
... eos_token_id=processor.tokenizer.eos_token_id,
|
||||
... use_cache=True,
|
||||
... bad_words_ids=[[processor.tokenizer.unk_token_id]],
|
||||
... return_dict_in_generate=True,
|
||||
... )
|
||||
|
||||
>>> sequence = processor.batch_decode(outputs.sequences)[0]
|
||||
>>> sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
|
||||
>>> sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
|
||||
>>> print(processor.token2json(sequence))
|
||||
{'class': 'advertisement'}
|
||||
pipeline(image=image, question="What time is the coffee break?")
|
||||
```
|
||||
|
||||
- Step-by-step Document Parsing
|
||||
</hfoption>
|
||||
<hfoption id="AutoModel">
|
||||
|
||||
```py
|
||||
>>> import re
|
||||
# pip install datasets
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoProcessor, AutoModelForVision2Seq
|
||||
|
||||
>>> from transformers import DonutProcessor, VisionEncoderDecoderModel
|
||||
>>> from datasets import load_dataset
|
||||
>>> import torch
|
||||
processor = AutoProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-docvqa")
|
||||
model = AutoModelForVision2Seq.from_pretrained("naver-clova-ix/donut-base-finetuned-docvqa")
|
||||
|
||||
>>> processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-cord-v2")
|
||||
>>> model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-cord-v2")
|
||||
dataset = load_dataset("hf-internal-testing/example-documents", split="test")
|
||||
image = dataset[0]["image"]
|
||||
question = "What time is the coffee break?"
|
||||
task_prompt = f"<s_docvqa><s_question>{question}</s_question><s_answer>"
|
||||
inputs = processor(image, task_prompt, return_tensors="pt")
|
||||
|
||||
>>> device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
>>> model.to(device) # doctest: +IGNORE_RESULT
|
||||
|
||||
>>> # load document image
|
||||
>>> dataset = load_dataset("hf-internal-testing/example-documents", split="test")
|
||||
>>> image = dataset[2]["image"]
|
||||
|
||||
>>> # prepare decoder inputs
|
||||
>>> task_prompt = "<s_cord-v2>"
|
||||
>>> decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
|
||||
|
||||
>>> pixel_values = processor(image, return_tensors="pt").pixel_values
|
||||
|
||||
>>> outputs = model.generate(
|
||||
... pixel_values.to(device),
|
||||
... decoder_input_ids=decoder_input_ids.to(device),
|
||||
... max_length=model.decoder.config.max_position_embeddings,
|
||||
... pad_token_id=processor.tokenizer.pad_token_id,
|
||||
... eos_token_id=processor.tokenizer.eos_token_id,
|
||||
... use_cache=True,
|
||||
... bad_words_ids=[[processor.tokenizer.unk_token_id]],
|
||||
... return_dict_in_generate=True,
|
||||
... )
|
||||
|
||||
>>> sequence = processor.batch_decode(outputs.sequences)[0]
|
||||
>>> sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
|
||||
>>> sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
|
||||
>>> print(processor.token2json(sequence))
|
||||
{'menu': {'nm': 'CINNAMON SUGAR', 'unitprice': '17,000', 'cnt': '1 x', 'price': '17,000'}, 'sub_total': {'subtotal_price': '17,000'}, 'total': {'total_price': '17,000', 'cashprice': '20,000', 'changeprice': '3,000'}}
|
||||
outputs = model.generate(
|
||||
input_ids=inputs.input_ids,
|
||||
pixel_values=inputs.pixel_values,
|
||||
max_length=512
|
||||
)
|
||||
answer = processor.decode(outputs[0], skip_special_tokens=True)
|
||||
print(answer)
|
||||
```
|
||||
|
||||
- Step-by-step Document Visual Question Answering (DocVQA)
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
Quantization reduces the memory burden of large models by representing the weights in a lower precision. Refer to the [Quantization](../quantization/overview) overview for more available quantization backends.
|
||||
|
||||
The example below uses [torchao](../quantization/torchao) to only quantize the weights to int4.
|
||||
|
||||
```py
|
||||
>>> import re
|
||||
# pip install datasets torchao
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from transformers import TorchAoConfig, AutoProcessor, AutoModelForVision2Seq
|
||||
|
||||
>>> from transformers import DonutProcessor, VisionEncoderDecoderModel
|
||||
>>> from datasets import load_dataset
|
||||
>>> import torch
|
||||
quantization_config = TorchAoConfig("int4_weight_only", group_size=128)
|
||||
processor = AutoProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-docvqa")
|
||||
model = AutoModelForVision2Seq.from_pretrained("naver-clova-ix/donut-base-finetuned-docvqa", quantization_config=quantization_config)
|
||||
|
||||
>>> processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-docvqa")
|
||||
>>> model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-docvqa")
|
||||
dataset = load_dataset("hf-internal-testing/example-documents", split="test")
|
||||
image = dataset[0]["image"]
|
||||
question = "What time is the coffee break?"
|
||||
task_prompt = f"<s_docvqa><s_question>{question}</s_question><s_answer>"
|
||||
inputs = processor(image, task_prompt, return_tensors="pt")
|
||||
|
||||
>>> device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
>>> model.to(device) # doctest: +IGNORE_RESULT
|
||||
|
||||
>>> # load document image from the DocVQA dataset
|
||||
>>> dataset = load_dataset("hf-internal-testing/example-documents", split="test")
|
||||
>>> image = dataset[0]["image"]
|
||||
|
||||
>>> # prepare decoder inputs
|
||||
>>> task_prompt = "<s_docvqa><s_question>{user_input}</s_question><s_answer>"
|
||||
>>> question = "When is the coffee break?"
|
||||
>>> prompt = task_prompt.replace("{user_input}", question)
|
||||
>>> decoder_input_ids = processor.tokenizer(prompt, add_special_tokens=False, return_tensors="pt").input_ids
|
||||
|
||||
>>> pixel_values = processor(image, return_tensors="pt").pixel_values
|
||||
|
||||
>>> outputs = model.generate(
|
||||
... pixel_values.to(device),
|
||||
... decoder_input_ids=decoder_input_ids.to(device),
|
||||
... max_length=model.decoder.config.max_position_embeddings,
|
||||
... pad_token_id=processor.tokenizer.pad_token_id,
|
||||
... eos_token_id=processor.tokenizer.eos_token_id,
|
||||
... use_cache=True,
|
||||
... bad_words_ids=[[processor.tokenizer.unk_token_id]],
|
||||
... return_dict_in_generate=True,
|
||||
... )
|
||||
|
||||
>>> sequence = processor.batch_decode(outputs.sequences)[0]
|
||||
>>> sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
|
||||
>>> sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
|
||||
>>> print(processor.token2json(sequence))
|
||||
{'question': 'When is the coffee break?', 'answer': '11-14 to 11:39 a.m.'}
|
||||
outputs = model.generate(
|
||||
input_ids=inputs.input_ids,
|
||||
pixel_values=inputs.pixel_values,
|
||||
max_length=512
|
||||
)
|
||||
answer = processor.decode(outputs[0], skip_special_tokens=True)
|
||||
print(answer)
|
||||
```
|
||||
|
||||
See the [model hub](https://huggingface.co/models?filter=donut) to look for Donut checkpoints.
|
||||
## Notes
|
||||
|
||||
## Training
|
||||
- Use Donut for document image classification as shown below.
|
||||
|
||||
We refer to the [tutorial notebooks](https://github.com/NielsRogge/Transformers-Tutorials/tree/master/Donut).
|
||||
```py
|
||||
>>> import re
|
||||
>>> from transformers import DonutProcessor, VisionEncoderDecoderModel
|
||||
>>> from datasets import load_dataset
|
||||
>>> import torch
|
||||
|
||||
>>> processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-rvlcdip")
|
||||
>>> model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-rvlcdip")
|
||||
|
||||
>>> device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
>>> model.to(device) # doctest: +IGNORE_RESULT
|
||||
|
||||
>>> # load document image
|
||||
>>> dataset = load_dataset("hf-internal-testing/example-documents", split="test")
|
||||
>>> image = dataset[1]["image"]
|
||||
|
||||
>>> # prepare decoder inputs
|
||||
>>> task_prompt = "<s_rvlcdip>"
|
||||
>>> decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
|
||||
|
||||
>>> pixel_values = processor(image, return_tensors="pt").pixel_values
|
||||
|
||||
>>> outputs = model.generate(
|
||||
... pixel_values.to(device),
|
||||
... decoder_input_ids=decoder_input_ids.to(device),
|
||||
... max_length=model.decoder.config.max_position_embeddings,
|
||||
... pad_token_id=processor.tokenizer.pad_token_id,
|
||||
... eos_token_id=processor.tokenizer.eos_token_id,
|
||||
... use_cache=True,
|
||||
... bad_words_ids=[[processor.tokenizer.unk_token_id]],
|
||||
... return_dict_in_generate=True,
|
||||
... )
|
||||
|
||||
>>> sequence = processor.batch_decode(outputs.sequences)[0]
|
||||
>>> sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
|
||||
>>> sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
|
||||
>>> print(processor.token2json(sequence))
|
||||
{'class': 'advertisement'}
|
||||
```
|
||||
|
||||
- Use Donut for document parsing as shown below.
|
||||
|
||||
```py
|
||||
>>> import re
|
||||
>>> from transformers import DonutProcessor, VisionEncoderDecoderModel
|
||||
>>> from datasets import load_dataset
|
||||
>>> import torch
|
||||
|
||||
>>> processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-cord-v2")
|
||||
>>> model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-cord-v2")
|
||||
|
||||
>>> device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
>>> model.to(device) # doctest: +IGNORE_RESULT
|
||||
|
||||
>>> # load document image
|
||||
>>> dataset = load_dataset("hf-internal-testing/example-documents", split="test")
|
||||
>>> image = dataset[2]["image"]
|
||||
|
||||
>>> # prepare decoder inputs
|
||||
>>> task_prompt = "<s_cord-v2>"
|
||||
>>> decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
|
||||
|
||||
>>> pixel_values = processor(image, return_tensors="pt").pixel_values
|
||||
|
||||
>>> outputs = model.generate(
|
||||
... pixel_values.to(device),
|
||||
... decoder_input_ids=decoder_input_ids.to(device),
|
||||
... max_length=model.decoder.config.max_position_embeddings,
|
||||
... pad_token_id=processor.tokenizer.pad_token_id,
|
||||
... eos_token_id=processor.tokenizer.eos_token_id,
|
||||
... use_cache=True,
|
||||
... bad_words_ids=[[processor.tokenizer.unk_token_id]],
|
||||
... return_dict_in_generate=True,
|
||||
... )
|
||||
|
||||
>>> sequence = processor.batch_decode(outputs.sequences)[0]
|
||||
>>> sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
|
||||
>>> sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
|
||||
>>> print(processor.token2json(sequence))
|
||||
{'menu': {'nm': 'CINNAMON SUGAR', 'unitprice': '17,000', 'cnt': '1 x', 'price': '17,000'}, 'sub_total': {'subtotal_price': '17,000'}, 'total':
|
||||
{'total_price': '17,000', 'cashprice': '20,000', 'changeprice': '3,000'}}
|
||||
```
|
||||
|
||||
## DonutSwinConfig
|
||||
|
||||
@ -215,3 +226,8 @@ We refer to the [tutorial notebooks](https://github.com/NielsRogge/Transformers-
|
||||
|
||||
[[autodoc]] DonutSwinModel
|
||||
- forward
|
||||
|
||||
## DonutSwinForImageClassification
|
||||
|
||||
[[autodoc]] transformers.DonutSwinForImageClassification
|
||||
- forward
|
||||
@ -14,66 +14,95 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# ELECTRA
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="TensorFlow" src="https://img.shields.io/badge/TensorFlow-FF6F00?style=flat&logo=tensorflow&logoColor=white">
|
||||
<img alt="Flax" src="https://img.shields.io/badge/Flax-29a79b.svg?style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAC0AAAAtCAMAAAANxBKoAAAC7lBMVEUAAADg5vYHPVgAoJH+/v76+v39/f9JbLP///9+AIgAnY3///+mcqzt8fXy9fgkXa3Ax9709fr+///9/f8qXq49qp5AaLGMwrv8/P0eW60VWawxYq8yqJzG2dytt9Wyu9elzci519Lf3O3S2efY3OrY0+Xp7PT///////+dqNCexMc6Z7AGpJeGvbenstPZ5ejQ1OfJzOLa7ejh4+/r8fT29vpccbklWK8PVa0AS6ghW63O498vYa+lsdKz1NDRt9Kw1c672tbD3tnAxt7R6OHp5vDe7OrDyuDn6vLl6/EAQKak0MgATakkppo3ZK/Bz9y8w9yzu9jey97axdvHzeG21NHH4trTwthKZrVGZLSUSpuPQJiGAI+GAI8SWKydycLL4d7f2OTi1+S9xNzL0ePT6OLGzeEAo5U0qJw/aLEAo5JFa7JBabEAp5Y4qZ2QxLyKmsm3kL2xoMOehrRNb7RIbbOZgrGre68AUqwAqZqNN5aKJ5N/lMq+qsd8kMa4pcWzh7muhLMEV69juq2kbKqgUaOTR5uMMZWLLZSGAI5VAIdEAH+ovNDHuNCnxcy3qcaYx8K8msGplrx+wLahjbYdXrV6vbMvYK9DrZ8QrZ8tqJuFms+Sos6sw8ecy8RffsNVeMCvmb43aLltv7Q4Y7EZWK4QWa1gt6meZKUdr6GOAZVeA4xPAISyveLUwtivxtKTpNJ2jcqfvcltiMiwwcfAoMVxhL+Kx7xjdrqTe60tsaNQs6KaRKACrJ6UTZwkqpqTL5pkHY4AloSgsd2ptNXPvNOOncuxxsqFl8lmg8apt8FJcr9EbryGxLqlkrkrY7dRa7ZGZLQ5t6iXUZ6PPpgVpZeJCJFKAIGareTa0+KJod3H0deY2M+esM25usmYu8d2zsJOdcBVvrCLbqcAOaaHaKQAMaScWqKBXqCXMJ2RHpiLF5NmJZAdAHN2kta11dKu1M+DkcZLdb+Mcql3TppyRJdzQ5ZtNZNlIY+DF4+voCOQAAAAZ3RSTlMABAT+MEEJ/RH+/TP+Zlv+pUo6Ifz8+fco/fz6+evr39S9nJmOilQaF/7+/f38+smmoYp6b1T+/v7++vj189zU0tDJxsGzsrKSfv34+Pf27dDOysG9t6+n/vv6+vr59uzr1tG+tZ6Qg9Ym3QAABR5JREFUSMeNlVVUG1EQhpcuxEspXqS0SKEtxQp1d3d332STTRpIQhIISQgJhODu7lAoDoUCpe7u7u7+1puGpqnCPOyZvffbOXPm/PsP9JfQgyCC+tmTABTOcbxDz/heENS7/1F+9nhvkHePG0wNDLbGWwdXL+rbLWvpmZHXD8+gMfBjTh+aSe6Gnn7lwQIOTR0c8wfX3PWgv7avbdKwf/ZoBp1Gp/PvuvXW3vw5ib7emnTW4OR+3D4jB9vjNJ/7gNvfWWeH/TO/JyYrsiKCRjVEZA3UB+96kON+DxOQ/NLE8PE5iUYgIXjFnCOlxEQMaSGVxjg4gxOnEycGz8bptuNjVx08LscIgrzH3umcn+KKtiBIyvzOO2O99aAdR8cF19oZalnCtvREUw79tCd5sow1g1UKM6kXqUx4T8wsi3sTjJ3yzDmmhenLXLpo8u45eG5y4Vvbk6kkC4LLtJMowkSQxmk4ggVJEG+7c6QpHT8vvW9X7/o7+3ELmiJi2mEzZJiz8cT6TBlanBk70cB5GGIGC1gRDdZ00yADLW1FL6gqhtvNXNG5S9gdSrk4M1qu7JAsmYshzDS4peoMrU/gT7qQdqYGZaYhxZmVbGJAm/CS/HloWyhRUlknQ9KYcExTwS80d3VNOxUZJpITYyspl0LbhArhpZCD9cRWEQuhYkNGMHToQ/2Cs6swJlb39CsllxdXX6IUKh/H5jbnSsPKjgmoaFQ1f8wRLR0UnGE/RcDEjj2jXG1WVTwUs8+zxfcrVO+vSsuOpVKxCfYZiQ0/aPKuxQbQ8lIz+DClxC8u+snlcJ7Yr1z1JPqUH0V+GDXbOwAib931Y4Imaq0NTIXPXY+N5L18GJ37SVWu+hwXff8l72Ds9XuwYIBaXPq6Shm4l+Vl/5QiOlV+uTk6YR9PxKsI9xNJny31ygK1e+nIRC1N97EGkFPI+jCpiHe5PCEy7oWqWSwRrpOvhFzcbTWMbm3ZJAOn1rUKpYIt/lDhW/5RHHteeWFN60qo98YJuoq1nK3uW5AabyspC1BcIEpOhft+SZAShYoLSvnmSfnYADUERP5jJn2h5XtsgCRuhYQqAvwTwn33+YWEKUI72HX5AtfSAZDe8F2DtPPm77afhl0EkthzuCQU0BWApgQIH9+KB0JhopMM7bJrdTRoleM2JAVNMyPF+wdoaz+XJpGoVAQ7WXUkcV7gT3oUZyi/ISIJAVKhgNp+4b4veCFhYVJw4locdSjZCp9cPUhLF9EZ3KKzURepMEtCDPP3VcWFx4UIiZIklIpFNfHpdEafIF2aRmOcrUmjohbT2WUllbmRvgfbythbQO3222fpDJoufaQPncYYuqoGtUEsCJZL6/3PR5b4syeSjZMQG/T2maGANlXT2v8S4AULWaUkCxfLyW8iW4kdka+nEMjxpL2NCwsYNBp+Q61PF43zyDg9Bm9+3NNySn78jMZUUkumqE4Gp7JmFOdP1vc8PpRrzj9+wPinCy8K1PiJ4aYbnTYpCCbDkBSbzhu2QJ1Gd82t8jI8TH51+OzvXoWbnXUOBkNW+0mWFwGcGOUVpU81/n3TOHb5oMt2FgYGjzau0Nif0Ss7Q3XB33hjjQHjHA5E5aOyIQc8CBrLdQSs3j92VG+3nNEjbkbdbBr9zm04ruvw37vh0QKOdeGIkckc80fX3KH/h7PT4BOjgCty8VZ5ux1MoO5Cf5naca2LAsEgehI+drX8o/0Nu+W0m6K/I9gGPd/dfx/EN/wN62AhsBWuAAAAAElFTkSuQmCC
|
||||
<div style="float: right;">
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="TensorFlow" src="https://img.shields.io/badge/TensorFlow-FF6F00?style=flat&logo=tensorflow&logoColor=white">
|
||||
<img alt="Flax" src="https://img.shields.io/badge/Flax-29a79b.svg?style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAC0AAAAtCAMAAAANxBKoAAAC7lBMVEUAAADg5vYHPVgAoJH+/v76+v39/f9JbLP///9+AIgAnY3///+mcqzt8fXy9fgkXa3Ax9709fr+///9/f8qXq49qp5AaLGMwrv8/P0eW60VWawxYq8yqJzG2dytt9Wyu9elzci519Lf3O3S2efY3OrY0+Xp7PT///////+dqNCexMc6Z7AGpJeGvbenstPZ5ejQ1OfJzOLa7ejh4+/r8fT29vpccbklWK8PVa0AS6ghW63O498vYa+lsdKz1NDRt9Kw1c672tbD3tnAxt7R6OHp5vDe7OrDyuDn6vLl6/EAQKak0MgATakkppo3ZK/Bz9y8w9yzu9jey97axdvHzeG21NHH4trTwthKZrVGZLSUSpuPQJiGAI+GAI8SWKydycLL4d7f2OTi1+S9xNzL0ePT6OLGzeEAo5U0qJw/aLEAo5JFa7JBabEAp5Y4qZ2QxLyKmsm3kL2xoMOehrRNb7RIbbOZgrGre68AUqwAqZqNN5aKJ5N/lMq+qsd8kMa4pcWzh7muhLMEV69juq2kbKqgUaOTR5uMMZWLLZSGAI5VAIdEAH+ovNDHuNCnxcy3qcaYx8K8msGplrx+wLahjbYdXrV6vbMvYK9DrZ8QrZ8tqJuFms+Sos6sw8ecy8RffsNVeMCvmb43aLltv7Q4Y7EZWK4QWa1gt6meZKUdr6GOAZVeA4xPAISyveLUwtivxtKTpNJ2jcqfvcltiMiwwcfAoMVxhL+Kx7xjdrqTe60tsaNQs6KaRKACrJ6UTZwkqpqTL5pkHY4AloSgsd2ptNXPvNOOncuxxsqFl8lmg8apt8FJcr9EbryGxLqlkrkrY7dRa7ZGZLQ5t6iXUZ6PPpgVpZeJCJFKAIGareTa0+KJod3H0deY2M+esM25usmYu8d2zsJOdcBVvrCLbqcAOaaHaKQAMaScWqKBXqCXMJ2RHpiLF5NmJZAdAHN2kta11dKu1M+DkcZLdb+Mcql3TppyRJdzQ5ZtNZNlIY+DF4+voCOQAAAAZ3RSTlMABAT+MEEJ/RH+/TP+Zlv+pUo6Ifz8+fco/fz6+evr39S9nJmOilQaF/7+/f38+smmoYp6b1T+/v7++vj189zU0tDJxsGzsrKSfv34+Pf27dDOysG9t6+n/vv6+vr59uzr1tG+tZ6Qg9Ym3QAABR5JREFUSMeNlVVUG1EQhpcuxEspXqS0SKEtxQp1d3d332STTRpIQhIISQgJhODu7lAoDoUCpe7u7u7+1puGpqnCPOyZvffbOXPm/PsP9JfQgyCC+tmTABTOcbxDz/heENS7/1F+9nhvkHePG0wNDLbGWwdXL+rbLWvpmZHXD8+gMfBjTh+aSe6Gnn7lwQIOTR0c8wfX3PWgv7avbdKwf/ZoBp1Gp/PvuvXW3vw5ib7emnTW4OR+3D4jB9vjNJ/7gNvfWWeH/TO/JyYrsiKCRjVEZA3UB+96kON+DxOQ/NLE8PE5iUYgIXjFnCOlxEQMaSGVxjg4gxOnEycGz8bptuNjVx08LscIgrzH3umcn+KKtiBIyvzOO2O99aAdR8cF19oZalnCtvREUw79tCd5sow1g1UKM6kXqUx4T8wsi3sTjJ3yzDmmhenLXLpo8u45eG5y4Vvbk6kkC4LLtJMowkSQxmk4ggVJEG+7c6QpHT8vvW9X7/o7+3ELmiJi2mEzZJiz8cT6TBlanBk70cB5GGIGC1gRDdZ00yADLW1FL6gqhtvNXNG5S9gdSrk4M1qu7JAsmYshzDS4peoMrU/gT7qQdqYGZaYhxZmVbGJAm/CS/HloWyhRUlknQ9KYcExTwS80d3VNOxUZJpITYyspl0LbhArhpZCD9cRWEQuhYkNGMHToQ/2Cs6swJlb39CsllxdXX6IUKh/H5jbnSsPKjgmoaFQ1f8wRLR0UnGE/RcDEjj2jXG1WVTwUs8+zxfcrVO+vSsuOpVKxCfYZiQ0/aPKuxQbQ8lIz+DClxC8u+snlcJ7Yr1z1JPqUH0V+GDXbOwAib931Y4Imaq0NTIXPXY+N5L18GJ37SVWu+hwXff8l72Ds9XuwYIBaXPq6Shm4l+Vl/5QiOlV+uTk6YR9PxKsI9xNJny31ygK1e+nIRC1N97EGkFPI+jCpiHe5PCEy7oWqWSwRrpOvhFzcbTWMbm3ZJAOn1rUKpYIt/lDhW/5RHHteeWFN60qo98YJuoq1nK3uW5AabyspC1BcIEpOhft+SZAShYoLSvnmSfnYADUERP5jJn2h5XtsgCRuhYQqAvwTwn33+YWEKUI72HX5AtfSAZDe8F2DtPPm77afhl0EkthzuCQU0BWApgQIH9+KB0JhopMM7bJrdTRoleM2JAVNMyPF+wdoaz+XJpGoVAQ7WXUkcV7gT3oUZyi/ISIJAVKhgNp+4b4veCFhYVJw4locdSjZCp9cPUhLF9EZ3KKzURepMEtCDPP3VcWFx4UIiZIklIpFNfHpdEafIF2aRmOcrUmjohbT2WUllbmRvgfbythbQO3222fpDJoufaQPncYYuqoGtUEsCJZL6/3PR5b4syeSjZMQG/T2maGANlXT2v8S4AULWaUkCxfLyW8iW4kdka+nEMjxpL2NCwsYNBp+Q61PF43zyDg9Bm9+3NNySn78jMZUUkumqE4Gp7JmFOdP1vc8PpRrzj9+wPinCy8K1PiJ4aYbnTYpCCbDkBSbzhu2QJ1Gd82t8jI8TH51+OzvXoWbnXUOBkNW+0mWFwGcGOUVpU81/n3TOHb5oMt2FgYGjzau0Nif0Ss7Q3XB33hjjQHjHA5E5aOyIQc8CBrLdQSs3j92VG+3nNEjbkbdbBr9zm04ruvw37vh0QKOdeGIkckc80fX3KH/h7PT4BOjgCty8VZ5ux1MoO5Cf5naca2LAsEgehI+drX8o/0Nu+W0m6K/I9gGPd/dfx/EN/wN62AhsBWuAAAAAElFTkSuQmCC
|
||||
">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
</div>
|
||||
|
||||
## Overview
|
||||
# ELECTRA
|
||||
|
||||
The ELECTRA model was proposed in the paper [ELECTRA: Pre-training Text Encoders as Discriminators Rather Than
|
||||
Generators](https://openreview.net/pdf?id=r1xMH1BtvB). ELECTRA is a new pretraining approach which trains two
|
||||
transformer models: the generator and the discriminator. The generator's role is to replace tokens in a sequence, and
|
||||
is therefore trained as a masked language model. The discriminator, which is the model we're interested in, tries to
|
||||
identify which tokens were replaced by the generator in the sequence.
|
||||
[ELECTRA](https://huggingface.co/papers/2003.10555) modifies the pretraining objective of traditional masked language models like BERT. Instead of just masking tokens and asking the model to predict them, ELECTRA trains two models, a generator and a discriminator. The generator replaces some tokens with plausible alternatives and the discriminator (the model you'll actually use) learns to detect which tokens are original and which were replaced. This training approach is very efficient and scales to larger models while using considerably less compute.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
This approach is super efficient because ELECTRA learns from every single token in the input, not just the masked ones. That's why even the small ELECTRA models can match or outperform much larger models while using way less computing resources.
|
||||
|
||||
*Masked language modeling (MLM) pretraining methods such as BERT corrupt the input by replacing some tokens with [MASK]
|
||||
and then train a model to reconstruct the original tokens. While they produce good results when transferred to
|
||||
downstream NLP tasks, they generally require large amounts of compute to be effective. As an alternative, we propose a
|
||||
more sample-efficient pretraining task called replaced token detection. Instead of masking the input, our approach
|
||||
corrupts it by replacing some tokens with plausible alternatives sampled from a small generator network. Then, instead
|
||||
of training a model that predicts the original identities of the corrupted tokens, we train a discriminative model that
|
||||
predicts whether each token in the corrupted input was replaced by a generator sample or not. Thorough experiments
|
||||
demonstrate this new pretraining task is more efficient than MLM because the task is defined over all input tokens
|
||||
rather than just the small subset that was masked out. As a result, the contextual representations learned by our
|
||||
approach substantially outperform the ones learned by BERT given the same model size, data, and compute. The gains are
|
||||
particularly strong for small models; for example, we train a model on one GPU for 4 days that outperforms GPT (trained
|
||||
using 30x more compute) on the GLUE natural language understanding benchmark. Our approach also works well at scale,
|
||||
where it performs comparably to RoBERTa and XLNet while using less than 1/4 of their compute and outperforms them when
|
||||
using the same amount of compute.*
|
||||
You can find all the original ELECTRA checkpoints under the [ELECTRA](https://huggingface.co/collections/google/electra-release-64ff6e8b18830fabea30a1ab) release.
|
||||
|
||||
This model was contributed by [lysandre](https://huggingface.co/lysandre). The original code can be found [here](https://github.com/google-research/electra).
|
||||
> [!TIP]
|
||||
> Click on the right sidebar for more examples of how to use ELECTRA for different language tasks like sequence classification, token classification, and question answering.
|
||||
|
||||
## Usage tips
|
||||
The example below demonstrates how to classify text with [`Pipeline`] or the [`AutoModel`] class.
|
||||
|
||||
- ELECTRA is the pretraining approach, therefore there is nearly no changes done to the underlying model: BERT. The
|
||||
only change is the separation of the embedding size and the hidden size: the embedding size is generally smaller,
|
||||
while the hidden size is larger. An additional projection layer (linear) is used to project the embeddings from their
|
||||
embedding size to the hidden size. In the case where the embedding size is the same as the hidden size, no projection
|
||||
layer is used.
|
||||
- ELECTRA is a transformer model pretrained with the use of another (small) masked language model. The inputs are corrupted by that language model, which takes an input text that is randomly masked and outputs a text in which ELECTRA has to predict which token is an original and which one has been replaced. Like for GAN training, the small language model is trained for a few steps (but with the original texts as objective, not to fool the ELECTRA model like in a traditional GAN setting) then the ELECTRA model is trained for a few steps.
|
||||
- The ELECTRA checkpoints saved using [Google Research's implementation](https://github.com/google-research/electra)
|
||||
contain both the generator and discriminator. The conversion script requires the user to name which model to export
|
||||
into the correct architecture. Once converted to the HuggingFace format, these checkpoints may be loaded into all
|
||||
available ELECTRA models, however. This means that the discriminator may be loaded in the
|
||||
[`ElectraForMaskedLM`] model, and the generator may be loaded in the
|
||||
[`ElectraForPreTraining`] model (the classification head will be randomly initialized as it
|
||||
doesn't exist in the generator).
|
||||
<hfoptions id="usage">
|
||||
<hfoption id="Pipeline">
|
||||
|
||||
## Resources
|
||||
```py
|
||||
import torch
|
||||
from transformers import pipeline
|
||||
|
||||
- [Text classification task guide](../tasks/sequence_classification)
|
||||
- [Token classification task guide](../tasks/token_classification)
|
||||
- [Question answering task guide](../tasks/question_answering)
|
||||
- [Causal language modeling task guide](../tasks/language_modeling)
|
||||
- [Masked language modeling task guide](../tasks/masked_language_modeling)
|
||||
- [Multiple choice task guide](../tasks/multiple_choice)
|
||||
classifier = pipeline(
|
||||
task="text-classification",
|
||||
model="bhadresh-savani/electra-base-emotion",
|
||||
torch_dtype=torch.float16,
|
||||
device=0
|
||||
)
|
||||
classifier("This restaurant has amazing food!")
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="AutoModel">
|
||||
|
||||
```py
|
||||
import torch
|
||||
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"bhadresh-savani/electra-base-emotion",
|
||||
)
|
||||
model = AutoModelForSequenceClassification.from_pretrained(
|
||||
"bhadresh-savani/electra-base-emotion",
|
||||
torch_dtype=torch.float16
|
||||
)
|
||||
inputs = tokenizer("ELECTRA is more efficient than BERT", return_tensors="pt")
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
logits = outputs.logits
|
||||
predicted_class_id = logits.argmax(dim=-1).item()
|
||||
predicted_label = model.config.id2label[predicted_class_id]
|
||||
print(f"Predicted label: {predicted_label}")
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="transformers-cli">
|
||||
|
||||
```bash
|
||||
echo -e "This restaurant has amazing food." | transformers-cli run --task text-classification --model bhadresh-savani/electra-base-emotion --device 0
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
## Notes
|
||||
|
||||
- ELECTRA consists of two transformer models, a generator (G) and a discriminator (D). For most downstream tasks, use the discriminator model (as indicated by `*-discriminator` in the name) rather than the generator.
|
||||
- ELECTRA comes in three sizes: small (14M parameters), base (110M parameters), and large (335M parameters).
|
||||
- ELECTRA can use a smaller embedding size than the hidden size for efficiency. When `embedding_size` is smaller than `hidden_size` in the configuration, a projection layer connects them.
|
||||
- When using batched inputs with padding, make sure to use attention masks to prevent the model from attending to padding tokens.
|
||||
|
||||
```py
|
||||
# Example of properly handling padding with attention masks
|
||||
inputs = tokenizer(["Short text", "This is a much longer text that needs padding"],
|
||||
padding=True,
|
||||
return_tensors="pt")
|
||||
outputs = model(**inputs) # automatically uses the attention_mask
|
||||
```
|
||||
|
||||
- When using the discriminator for a downstream task, you can load it into any of the ELECTRA model classes ([`ElectraForSequenceClassification`], [`ElectraForTokenClassification`], etc.).
|
||||
|
||||
## ElectraConfig
|
||||
|
||||
|
||||
@ -14,48 +14,113 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# Falcon
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<div style="float: right;">
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
</div>
|
||||
|
||||
## Overview
|
||||
# Falcon
|
||||
|
||||
Falcon is a class of causal decoder-only models built by [TII](https://www.tii.ae/). The largest Falcon checkpoints
|
||||
have been trained on >=1T tokens of text, with a particular emphasis on the [RefinedWeb](https://arxiv.org/abs/2306.01116)
|
||||
corpus. They are made available under the Apache 2.0 license.
|
||||
[Falcon](https://huggingface.co/papers/2311.16867) is a family of large language models, available in 7B, 40B, and 180B parameters, as pretrained and instruction tuned variants. This model focuses on scaling pretraining over three categories, performance, data, and hardware. Falcon uses multigroup attention to significantly reduce inference memory requirements and rotary positional embeddings (RoPE). These models are pretrained on [RefinedWeb](https://huggingface.co/datasets/tiiuae/falcon-refinedweb), a high-quality and deduplicated 5T token dataset.
|
||||
|
||||
You can find all the original Falcon checkpoints under the [Falcon](https://huggingface.co/collections/tiiuae/falcon-64fb432660017eeec9837b5a) collection.
|
||||
|
||||
Falcon's architecture is modern and optimized for inference, with multi-query attention and support for efficient
|
||||
attention variants like `FlashAttention`. Both 'base' models trained only as causal language models as well as
|
||||
'instruct' models that have received further fine-tuning are available.
|
||||
> [!TIP]
|
||||
> Click on the Falcon models in the right sidebar for more examples of how to apply Falcon to different language tasks.
|
||||
|
||||
The example below demonstrates how to generate text with [`Pipeline`], [`AutoModel`], and from the command line.
|
||||
|
||||
Falcon models are (as of 2023) some of the largest and most powerful open-source language models,
|
||||
and consistently rank highly in the [OpenLLM leaderboard](https://huggingface.co/spaces/HuggingFaceH4/open_llm_leaderboard).
|
||||
<hfoptions id="usage">
|
||||
<hfoption id="Pipeline">
|
||||
|
||||
## Converting custom checkpoints
|
||||
```py
|
||||
import torch
|
||||
from transformers import pipeline
|
||||
|
||||
<Tip>
|
||||
pipeline = pipeline(
|
||||
task="text-generation",
|
||||
model="tiiuae/falcon-7b-instruct",
|
||||
torch_dtype=torch.bfloat16,
|
||||
device=0
|
||||
)
|
||||
pipeline(
|
||||
"Write a short poem about coding",
|
||||
max_length=100,
|
||||
do_sample=True,
|
||||
temperature=0.7
|
||||
)
|
||||
```
|
||||
|
||||
Falcon models were initially added to the Hugging Face Hub as custom code checkpoints. However, Falcon is now fully
|
||||
supported in the Transformers library. If you fine-tuned a model from a custom code checkpoint, we recommend converting
|
||||
your checkpoint to the new in-library format, as this should give significant improvements to stability and
|
||||
performance, especially for generation, as well as removing the need to use `trust_remote_code=True`!
|
||||
</hfoption>
|
||||
<hfoption id="AutoModel">
|
||||
|
||||
</Tip>
|
||||
```py
|
||||
import torch
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
|
||||
You can convert custom code checkpoints to full Transformers checkpoints using the `convert_custom_code_checkpoint.py`
|
||||
script located in the
|
||||
[Falcon model directory](https://github.com/huggingface/transformers/tree/main/src/transformers/models/falcon)
|
||||
of the Transformers library. To use this script, simply call it with
|
||||
`python convert_custom_code_checkpoint.py --checkpoint_dir my_model`. This will convert your checkpoint in-place, and
|
||||
you can immediately load it from the directory afterwards with e.g. `from_pretrained()`. If your model hasn't been
|
||||
uploaded to the Hub, we recommend making a backup before attempting the conversion, just in case!
|
||||
tokenizer = AutoTokenizer.from_pretrained("tiiuae/falcon-7b-instruct")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"tiiuae/falcon-7b-instruct",
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
attn_implementation="sdpa",
|
||||
)
|
||||
|
||||
input_ids = tokenizer("Write a short poem about coding", return_tensors="pt").to("cuda")
|
||||
|
||||
output = model.generate(**input_ids)
|
||||
print(tokenizer.decode(output[0], skip_special_tokens=True))
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="transformers-cli">
|
||||
|
||||
```bash
|
||||
# pip install -U flash-attn --no-build-isolation
|
||||
transformers-cli chat --model_name_or_path tiiuae/falcon-7b-instruct --torch_dtype auto --attn_implementation flash_attention_2 --device 0
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
Quantization reduces the memory burden of large models by representing the weights in a lower precision. Refer to the [Quantization](../quantization/overview) overview for more available quantization backends.
|
||||
|
||||
The example below uses [bitsandbytes](../quantization/bitsandbytes) to only quantize the weights to 4-bits.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
|
||||
|
||||
quantization_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_compute_dtype=torch.bfloat16,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
bnb_4bit_use_double_quant=True,
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("tiiuae/falcon-7b")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"tiiuae/falcon-7b",
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
quantization_config=quantization_config,
|
||||
)
|
||||
|
||||
inputs = tokenizer("In quantum physics, entanglement means", return_tensors="pt").to("cuda")
|
||||
outputs = model.generate(**inputs, max_new_tokens=100)
|
||||
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
||||
```
|
||||
|
||||
## Notes
|
||||
|
||||
- If you're upgrading from an older custom code checkpoint, remember to convert it to the official Transformers format for better stability and performance using the conversion script located in the [Falcon model directory](https://github.com/huggingface/transformers/tree/main/src/transformers/models/falcon).
|
||||
|
||||
```bash
|
||||
python convert_custom_code_checkpoint.py --checkpoint_dir my_model
|
||||
```
|
||||
|
||||
## FalconConfig
|
||||
|
||||
@ -85,6 +150,4 @@ uploaded to the Hub, we recommend making a backup before attempting the conversi
|
||||
## FalconForQuestionAnswering
|
||||
|
||||
[[autodoc]] FalconForQuestionAnswering
|
||||
- forward
|
||||
|
||||
|
||||
- forward
|
||||
@ -14,95 +14,100 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# FalconMamba
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<div style="float: right;">
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
</div>
|
||||
|
||||
## Overview
|
||||
# FalconMamba
|
||||
|
||||
The FalconMamba model was proposed by TII UAE (Technology Innovation Institute) in their release.
|
||||
[FalconMamba](https://huggingface.co/papers/2410.05355) is a 7B large language model, available as pretrained and instruction-tuned variants, based on the [Mamba](./mamba). This model implements a pure Mamba design that focuses on computational efficiency while maintaining strong performance. FalconMamba is significantly faster at inference and requires substantially less memory for long sequence generation. The models are pretrained on a diverse 5.8T token dataset including [RefinedWeb](https://huggingface.co/datasets/tiiuae/falcon-refinedweb), technical content, code, and mathematical data.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
You can find the official FalconMamba checkpoints in the [FalconMamba 7B](https://huggingface.co/collections/tiiuae/falconmamba-7b-66b9a580324dd1598b0f6d4a) collection.
|
||||
|
||||
*We present FalconMamba, a new base large language model based on the novel Mamba architecture. FalconMamba is trained on 5.8 trillion tokens with carefully selected data mixtures. As a pure Mamba-based model, FalconMamba surpasses leading open-weight models based on Transformers, such as Mistral 7B, Llama3 8B, and Falcon2 11B. It is on par with Gemma 7B and outperforms models with different architecture designs, such as RecurrentGemma 9B. Currently, FalconMamba is the best-performing Mamba model in the literature at this scale, surpassing both existing Mamba and hybrid Mamba-Transformer models.
|
||||
Due to its architecture, FalconMamba is significantly faster at inference and requires substantially less memory for long sequence generation. Despite recent studies suggesting that hybrid Mamba-Transformer models outperform pure architecture designs, we argue and demonstrate that the pure Mamba design can achieve similar, even superior results compared to the hybrid design. We make the weights of our implementation of FalconMamba publicly available under a permissive license.*
|
||||
> [!TIP]
|
||||
> Click on the FalconMamba models in the right sidebar for more examples of how to apply FalconMamba to different language tasks.
|
||||
|
||||
Tips:
|
||||
The examples below demonstrate how to generate text with [`Pipeline`], [`AutoModel`], and from the command line.
|
||||
|
||||
- FalconMamba is mostly based on Mamba architecture, the same [tips and best practices](./mamba) would be relevant here.
|
||||
<hfoptions id="usage">
|
||||
<hfoption id="Pipeline">
|
||||
|
||||
The model has been trained on approximtely 6T tokens consisting a mixture of many data sources such as RefineWeb, Cosmopedia and Math data.
|
||||
|
||||
For more details about the training procedure and the architecture, have a look at [the technical paper of FalconMamba]() (coming soon).
|
||||
|
||||
# Usage
|
||||
|
||||
Below we demonstrate how to use the model:
|
||||
|
||||
```python
|
||||
from transformers import FalconMambaForCausalLM, AutoTokenizer
|
||||
```py
|
||||
import torch
|
||||
from transformers import pipeline
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("tiiuae/falcon-mamba-7b")
|
||||
model = FalconMambaForCausalLM.from_pretrained("tiiuae/falcon-mamba-7b")
|
||||
|
||||
input_ids = tokenizer("Hey how are you doing?", return_tensors= "pt")["input_ids"]
|
||||
|
||||
out = model.generate(input_ids, max_new_tokens=10)
|
||||
print(tokenizer.batch_decode(out))
|
||||
pipeline = pipeline(
|
||||
"text-generation",
|
||||
model="tiiuae/falcon-mamba-7b-instruct",
|
||||
torch_dtype=torch.bfloat16,
|
||||
device=0
|
||||
)
|
||||
pipeline(
|
||||
"Explain the difference between transformers and SSMs",
|
||||
max_length=100,
|
||||
do_sample=True,
|
||||
temperature=0.7
|
||||
)
|
||||
```
|
||||
|
||||
The architecture is also compatible with `torch.compile` for faster generation:
|
||||
</hfoption>
|
||||
<hfoption id="AutoModel">
|
||||
|
||||
```python
|
||||
from transformers import FalconMambaForCausalLM, AutoTokenizer
|
||||
import torch
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("tiiuae/falcon-mamba-7b")
|
||||
model = FalconMambaForCausalLM.from_pretrained("tiiuae/falcon-mamba-7b", torch_dtype=torch.bfloat16).to(0)
|
||||
model = torch.compile(model)
|
||||
|
||||
input_ids = tokenizer("Hey how are you doing?", return_tensors= "pt")["input_ids"]
|
||||
|
||||
out = model.generate(input_ids, max_new_tokens=10)
|
||||
print(tokenizer.batch_decode(out))
|
||||
```
|
||||
|
||||
If you have access to a GPU that is compatible with `bitsandbytes`, you can also quantize the model in 4-bit precision:
|
||||
|
||||
```python
|
||||
from transformers import FalconMambaForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
||||
import torch
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("tiiuae/falcon-mamba-7b")
|
||||
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
|
||||
model = FalconMambaForCausalLM.from_pretrained("tiiuae/falcon-mamba-7b", quantization_config=quantization_config)
|
||||
|
||||
input_ids = tokenizer("Hey how are you doing?", return_tensors= "pt")["input_ids"]
|
||||
|
||||
out = model.generate(input_ids, max_new_tokens=10)
|
||||
print(tokenizer.batch_decode(out))
|
||||
```
|
||||
|
||||
You can also play with the instruction fine-tuned model:
|
||||
|
||||
```python
|
||||
from transformers import FalconMambaForCausalLM, AutoTokenizer
|
||||
```py
|
||||
import torch
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("tiiuae/falcon-mamba-7b-instruct")
|
||||
model = FalconMambaForCausalLM.from_pretrained("tiiuae/falcon-mamba-7b-instruct")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"tiiuae/falcon-mamba-7b-instruct",
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto"
|
||||
)
|
||||
|
||||
# We use the tokenizer's chat template to format each message - see https://huggingface.co/docs/transformers/main/en/chat_templating
|
||||
messages = [
|
||||
{"role": "user", "content": "How many helicopters can a human eat in one sitting?"},
|
||||
]
|
||||
input_ids = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True).input_ids
|
||||
input_ids = tokenizer("Explain the difference between transformers and SSMs", return_tensors="pt").to("cuda")
|
||||
|
||||
outputs = model.generate(input_ids)
|
||||
print(tokenizer.decode(outputs[0]))
|
||||
output = model.generate(**input_ids, max_new_tokens=100, cache_implementation="static")
|
||||
print(tokenizer.decode(output[0], skip_special_tokens=True))
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="transformers-cli">
|
||||
|
||||
```bash
|
||||
transformers-cli chat --model_name_or_path tiiuae/falcon-mamba-7b-instruct --torch_dtype auto --device 0
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
Quantization reduces the memory burden of large models by representing the weights in a lower precision. Refer to the [Quantization](../quantization/overview) overview for more available quantization backends.
|
||||
|
||||
The example below uses [bitsandbytes](../quantization/bitsandbytes) to quantize the weights to 4-bits.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from transformers import AutoTokenizer, FalconMambaForCausalLM, BitsAndBytesConfig
|
||||
|
||||
quantization_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_compute_dtype=torch.bfloat16,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
bnb_4bit_use_double_quant=True,
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("tiiuae/falcon-mamba-7b")
|
||||
model = FalconMambaForCausalLM.from_pretrained(
|
||||
"tiiuae/falcon-mamba-7b",
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
quantization_config=quantization_config,
|
||||
)
|
||||
|
||||
inputs = tokenizer("Explain the concept of state space models in simple terms", return_tensors="pt").to("cuda")
|
||||
outputs = model.generate(**inputs, max_new_tokens=100)
|
||||
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
||||
```
|
||||
|
||||
## FalconMambaConfig
|
||||
|
||||
@ -14,36 +14,133 @@ specific language governing permissions and limitations under the License.
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
<div style="float: right;">
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="TensorFlow" src="https://img.shields.io/badge/TensorFlow-FF6F00?style=flat&logo=tensorflow&logoColor=white">
|
||||
<img alt="Flax" src="https://img.shields.io/badge/Flax-29a79b.svg?style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAC0AAAAtCAMAAAANxBKoAAAC7lBMVEUAAADg5vYHPVgAoJH+/v76+v39/f9JbLP///9+AIgAnY3///+mcqzt8fXy9fgkXa3Ax9709fr+///9/f8qXq49qp5AaLGMwrv8/P0eW60VWawxYq8yqJzG2dytt9Wyu9elzci519Lf3O3S2efY3OrY0+Xp7PT///////+dqNCexMc6Z7AGpJeGvbenstPZ5ejQ1OfJzOLa7ejh4+/r8fT29vpccbklWK8PVa0AS6ghW63O498vYa+lsdKz1NDRt9Kw1c672tbD3tnAxt7R6OHp5vDe7OrDyuDn6vLl6/EAQKak0MgATakkppo3ZK/Bz9y8w9yzu9jey97axdvHzeG21NHH4trTwthKZrVGZLSUSpuPQJiGAI+GAI8SWKydycLL4d7f2OTi1+S9xNzL0ePT6OLGzeEAo5U0qJw/aLEAo5JFa7JBabEAp5Y4qZ2QxLyKmsm3kL2xoMOehrRNb7RIbbOZgrGre68AUqwAqZqNN5aKJ5N/lMq+qsd8kMa4pcWzh7muhLMEV69juq2kbKqgUaOTR5uMMZWLLZSGAI5VAIdEAH+ovNDHuNCnxcy3qcaYx8K8msGplrx+wLahjbYdXrV6vbMvYK9DrZ8QrZ8tqJuFms+Sos6sw8ecy8RffsNVeMCvmb43aLltv7Q4Y7EZWK4QWa1gt6meZKUdr6GOAZVeA4xPAISyveLUwtivxtKTpNJ2jcqfvcltiMiwwcfAoMVxhL+Kx7xjdrqTe60tsaNQs6KaRKACrJ6UTZwkqpqTL5pkHY4AloSgsd2ptNXPvNOOncuxxsqFl8lmg8apt8FJcr9EbryGxLqlkrkrY7dRa7ZGZLQ5t6iXUZ6PPpgVpZeJCJFKAIGareTa0+KJod3H0deY2M+esM25usmYu8d2zsJOdcBVvrCLbqcAOaaHaKQAMaScWqKBXqCXMJ2RHpiLF5NmJZAdAHN2kta11dKu1M+DkcZLdb+Mcql3TppyRJdzQ5ZtNZNlIY+DF4+voCOQAAAAZ3RSTlMABAT+MEEJ/RH+/TP+Zlv+pUo6Ifz8+fco/fz6+evr39S9nJmOilQaF/7+/f38+smmoYp6b1T+/v7++vj189zU0tDJxsGzsrKSfv34+Pf27dDOysG9t6+n/vv6+vr59uzr1tG+tZ6Qg9Ym3QAABR5JREFUSMeNlVVUG1EQhpcuxEspXqS0SKEtxQp1d3d332STTRpIQhIISQgJhODu7lAoDoUCpe7u7u7+1puGpqnCPOyZvffbOXPm/PsP9JfQgyCC+tmTABTOcbxDz/heENS7/1F+9nhvkHePG0wNDLbGWwdXL+rbLWvpmZHXD8+gMfBjTh+aSe6Gnn7lwQIOTR0c8wfX3PWgv7avbdKwf/ZoBp1Gp/PvuvXW3vw5ib7emnTW4OR+3D4jB9vjNJ/7gNvfWWeH/TO/JyYrsiKCRjVEZA3UB+96kON+DxOQ/NLE8PE5iUYgIXjFnCOlxEQMaSGVxjg4gxOnEycGz8bptuNjVx08LscIgrzH3umcn+KKtiBIyvzOO2O99aAdR8cF19oZalnCtvREUw79tCd5sow1g1UKM6kXqUx4T8wsi3sTjJ3yzDmmhenLXLpo8u45eG5y4Vvbk6kkC4LLtJMowkSQxmk4ggVJEG+7c6QpHT8vvW9X7/o7+3ELmiJi2mEzZJiz8cT6TBlanBk70cB5GGIGC1gRDdZ00yADLW1FL6gqhtvNXNG5S9gdSrk4M1qu7JAsmYshzDS4peoMrU/gT7qQdqYGZaYhxZmVbGJAm/CS/HloWyhRUlknQ9KYcExTwS80d3VNOxUZJpITYyspl0LbhArhpZCD9cRWEQuhYkNGMHToQ/2Cs6swJlb39CsllxdXX6IUKh/H5jbnSsPKjgmoaFQ1f8wRLR0UnGE/RcDEjj2jXG1WVTwUs8+zxfcrVO+vSsuOpVKxCfYZiQ0/aPKuxQbQ8lIz+DClxC8u+snlcJ7Yr1z1JPqUH0V+GDXbOwAib931Y4Imaq0NTIXPXY+N5L18GJ37SVWu+hwXff8l72Ds9XuwYIBaXPq6Shm4l+Vl/5QiOlV+uTk6YR9PxKsI9xNJny31ygK1e+nIRC1N97EGkFPI+jCpiHe5PCEy7oWqWSwRrpOvhFzcbTWMbm3ZJAOn1rUKpYIt/lDhW/5RHHteeWFN60qo98YJuoq1nK3uW5AabyspC1BcIEpOhft+SZAShYoLSvnmSfnYADUERP5jJn2h5XtsgCRuhYQqAvwTwn33+YWEKUI72HX5AtfSAZDe8F2DtPPm77afhl0EkthzuCQU0BWApgQIH9+KB0JhopMM7bJrdTRoleM2JAVNMyPF+wdoaz+XJpGoVAQ7WXUkcV7gT3oUZyi/ISIJAVKhgNp+4b4veCFhYVJw4locdSjZCp9cPUhLF9EZ3KKzURepMEtCDPP3VcWFx4UIiZIklIpFNfHpdEafIF2aRmOcrUmjohbT2WUllbmRvgfbythbQO3222fpDJoufaQPncYYuqoGtUEsCJZL6/3PR5b4syeSjZMQG/T2maGANlXT2v8S4AULWaUkCxfLyW8iW4kdka+nEMjxpL2NCwsYNBp+Q61PF43zyDg9Bm9+3NNySn78jMZUUkumqE4Gp7JmFOdP1vc8PpRrzj9+wPinCy8K1PiJ4aYbnTYpCCbDkBSbzhu2QJ1Gd82t8jI8TH51+OzvXoWbnXUOBkNW+0mWFwGcGOUVpU81/n3TOHb5oMt2FgYGjzau0Nif0Ss7Q3XB33hjjQHjHA5E5aOyIQc8CBrLdQSs3j92VG+3nNEjbkbdbBr9zm04ruvw37vh0QKOdeGIkckc80fX3KH/h7PT4BOjgCty8VZ5ux1MoO5Cf5naca2LAsEgehI+drX8o/0Nu+W0m6K/I9gGPd/dfx/EN/wN62AhsBWuAAAAAElFTkSuQmCC
|
||||
">
|
||||
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
</div>
|
||||
|
||||
# Gemma2
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
[Gemma 2](https://huggingface.co/papers/2408.00118) is a family of language models with pretrained and instruction-tuned variants, available in 2B, 9B, 27B parameters. The architecture is similar to the previous Gemma, except it features interleaved local attention (4096 tokens) and global attention (8192 tokens) and grouped-query attention (GQA) to increase inference performance.
|
||||
|
||||
The 2B and 9B models are trained with knowledge distillation, and the instruction-tuned variant was post-trained with supervised fine-tuning and reinforcement learning.
|
||||
|
||||
You can find all the original Gemma 2 checkpoints under the [Gemma 2](https://huggingface.co/collections/google/gemma-2-release-667d6600fd5220e7b967f315) collection.
|
||||
|
||||
> [!TIP]
|
||||
> Click on the Gemma 2 models in the right sidebar for more examples of how to apply Gemma to different language tasks.
|
||||
|
||||
The example below demonstrates how to chat with the model with [`Pipeline`] or the [`AutoModel`] class, and from the command line.
|
||||
|
||||
<hfoptions id="usage">
|
||||
<hfoption id="Pipeline">
|
||||
|
||||
|
||||
```python
|
||||
import torch
|
||||
from transformers import pipeline
|
||||
|
||||
pipe = pipeline(
|
||||
task="text-generation",
|
||||
model="google/gemma-2-9b",
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
pipe("Explain quantum computing simply. ", max_new_tokens=50)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="AutoModel">
|
||||
|
||||
```python
|
||||
import torch
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"google/gemma-2-9b",
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
attn_implementation="sdpa"
|
||||
)
|
||||
|
||||
input_text = "Explain quantum computing simply."
|
||||
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
|
||||
|
||||
outputs = model.generate(**input_ids, max_new_tokens=32, cache_implementation="static")
|
||||
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
||||
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="transformers-cli">
|
||||
|
||||
```
|
||||
echo -e "Explain quantum computing simply." | transformers-cli run --task text-generation --model google/gemma-2-2b --device 0
|
||||
```
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
Quantization reduces the memory burden of large models by representing the weights in a lower precision. Refer to the [Quantization](../quantization/overview) overview for more available quantization backends.
|
||||
|
||||
The example below uses [bitsandbytes](../quantization/bitsandbytes) to only quantize the weights to int4.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
|
||||
|
||||
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
|
||||
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-27b")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"google/gemma-2-27b",
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
attn_implementation="sdpa"
|
||||
)
|
||||
|
||||
input_text = "Explain quantum computing simply."
|
||||
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
|
||||
|
||||
outputs = model.generate(**input_ids, max_new_tokens=32, cache_implementation="static")
|
||||
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
||||
```
|
||||
|
||||
Use the [AttentionMaskVisualizer](https://github.com/huggingface/transformers/blob/beb9b5b02246b9b7ee81ddf938f93f44cfeaad19/src/transformers/utils/attention_visualizer.py#L139) to better understand what tokens the model can and cannot attend to.
|
||||
|
||||
|
||||
```python
|
||||
from transformers.utils.attention_visualizer import AttentionMaskVisualizer
|
||||
visualizer = AttentionMaskVisualizer("google/gemma-2b")
|
||||
visualizer("You are an assistant. Make sure you print me")
|
||||
```
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/gemma-2-attn-mask.png"/>
|
||||
</div>
|
||||
|
||||
## Overview
|
||||
## Notes
|
||||
|
||||
The Gemma2 model was proposed in [Gemma2: Open Models Based on Gemini Technology and Research](https://blog.google/technology/developers/google-gemma-2/) by Gemma2 Team, Google.
|
||||
Two Gemma2 models are released, with parameters sizes of 9 billion (9B) and 27 billion (27B).
|
||||
- Use a [`HybridCache`] instance to enable caching in Gemma 2. Gemma 2 doesn't support kv-caching strategies like [`DynamicCache`] or tuples of tensors because it uses sliding window attention every second layer.
|
||||
|
||||
The abstract from the blog post is the following:
|
||||
```python
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, HybridCache
|
||||
|
||||
*Now we’re officially releasing Gemma 2 to researchers and developers globally. Available in both 9 billion (9B) and 27 billion (27B) parameter sizes, Gemma 2 is higher-performing and more efficient at inference than the first generation, with significant safety advancements built in. In fact, at 27B, it offers competitive alternatives to models more than twice its size, delivering the kind of performance that was only possible with proprietary models as recently as December.*
|
||||
|
||||
Tips:
|
||||
|
||||
- The original checkpoints can be converted using the conversion script `src/transformers/models/Gemma2/convert_Gemma2_weights_to_hf.py`
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
- Gemma2 uses sliding window attention every second layer, which makes it unsuitable for typical kv caching with [`~DynamicCache`] or tuples of tensors. To enable caching in Gemma2 forward call, you must initialize a [`~HybridCache`] instance and pass it as `past_key_values` to the forward call. Note, that you also have to prepare `cache_position` if the `past_key_values` already contains previous keys and values.
|
||||
|
||||
</Tip>
|
||||
|
||||
This model was contributed by [Arthur Zucker](https://huggingface.co/ArthurZ), [Pedro Cuenca](https://huggingface.co/pcuenq) and [Tom Arsen]().
|
||||
model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b")
|
||||
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b")
|
||||
|
||||
inputs = tokenizer(text="My name is Gemma", return_tensors="pt")
|
||||
max_generated_length = inputs.input_ids.shape[1] + 10
|
||||
past_key_values = HybridCache(config=model.config, max_batch_size=1,
|
||||
max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
|
||||
outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
|
||||
```
|
||||
|
||||
## Gemma2Config
|
||||
|
||||
|
||||
45
docs/source/en/model_doc/glm4.md
Normal file
45
docs/source/en/model_doc/glm4.md
Normal file
@ -0,0 +1,45 @@
|
||||
<!--Copyright 2025 The GLM & ZhipuAI team and The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# Glm4
|
||||
|
||||
## Overview
|
||||
|
||||
To be released with the official model launch.
|
||||
|
||||
## Glm4Config
|
||||
|
||||
[[autodoc]] Glm4Config
|
||||
|
||||
## Glm4Model
|
||||
|
||||
[[autodoc]] Glm4Model
|
||||
- forward
|
||||
|
||||
## Glm4ForCausalLM
|
||||
|
||||
[[autodoc]] Glm4ForCausalLM
|
||||
- forward
|
||||
|
||||
## Glm4ForSequenceClassification
|
||||
|
||||
[[autodoc]] Glm4ForSequenceClassification
|
||||
- forward
|
||||
|
||||
## Glm4ForTokenClassification
|
||||
|
||||
[[autodoc]] Glm4ForTokenClassification
|
||||
- forward
|
||||
@ -14,197 +14,97 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# OpenAI GPT2
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<a href="https://huggingface.co/models?filter=gpt2">
|
||||
<img alt="Models" src="https://img.shields.io/badge/All_model_pages-gpt2-blueviolet">
|
||||
</a>
|
||||
<a href="https://huggingface.co/spaces/docs-demos/gpt2">
|
||||
<img alt="Spaces" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue">
|
||||
</a>
|
||||
<div style="float: right;">
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="TensorFlow" src="https://img.shields.io/badge/TensorFlow-FF6F00?style=flat&logo=tensorflow&logoColor=white">
|
||||
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
</div>
|
||||
|
||||
## Overview
|
||||
|
||||
OpenAI GPT-2 model was proposed in [Language Models are Unsupervised Multitask Learners](https://cdn.openai.com/better-language-models/language_models_are_unsupervised_multitask_learners.pdf) by Alec
|
||||
Radford, Jeffrey Wu, Rewon Child, David Luan, Dario Amodei and Ilya Sutskever from [OpenAI](https://huggingface.co/openai). It's a causal (unidirectional)
|
||||
transformer pretrained using language modeling on a very large corpus of ~40 GB of text data.
|
||||
# GPT-2
|
||||
|
||||
The abstract from the paper is the following:
|
||||
[GPT-2](https://cdn.openai.com/better-language-models/language_models_are_unsupervised_multitask_learners.pdf) is a scaled up version of GPT, a causal transformer language model, with 10x more parameters and training data. The model was pretrained on a 40GB dataset to predict the next word in a sequence based on all the previous words. This approach enabled the model to perform many downstream tasks in a zero-shot setting.
|
||||
|
||||
*GPT-2 is a large transformer-based language model with 1.5 billion parameters, trained on a dataset[1] of 8 million
|
||||
web pages. GPT-2 is trained with a simple objective: predict the next word, given all of the previous words within some
|
||||
text. The diversity of the dataset causes this simple goal to contain naturally occurring demonstrations of many tasks
|
||||
across diverse domains. GPT-2 is a direct scale-up of GPT, with more than 10X the parameters and trained on more than
|
||||
10X the amount of data.*
|
||||
The model architecture uses a unidirectional (causal) attention mechanism where each token can only attend to previous tokens, making it particularly effective for text generation tasks.
|
||||
|
||||
[Write With Transformer](https://transformer.huggingface.co/doc/gpt2-large) is a webapp created and hosted by
|
||||
Hugging Face showcasing the generative capabilities of several models. GPT-2 is one of them and is available in five
|
||||
different sizes: small, medium, large, xl and a distilled version of the small checkpoint: *distilgpt-2*.
|
||||
You can find all the original GPT-2 checkpoints under the [OpenAI community](https://huggingface.co/openai-community?search_models=gpt) organization.
|
||||
|
||||
This model was contributed by [thomwolf](https://huggingface.co/thomwolf). The original code can be found [here](https://openai.com/blog/better-language-models/).
|
||||
> [!TIP]
|
||||
> Click on the GPT-2 models in the right sidebar for more examples of how to apply GPT-2 to different language tasks.
|
||||
|
||||
## Usage tips
|
||||
The example below demonstrates how to generate text with [`Pipeline`] or the [`AutoModel`], and from the command line.
|
||||
|
||||
- GPT-2 is a model with absolute position embeddings so it's usually advised to pad the inputs on the right rather than
|
||||
the left.
|
||||
- GPT-2 was trained with a causal language modeling (CLM) objective and is therefore powerful at predicting the next
|
||||
token in a sequence. Leveraging this feature allows GPT-2 to generate syntactically coherent text as it can be
|
||||
observed in the *run_generation.py* example script.
|
||||
- The model can take the *past_key_values* (for PyTorch) or *past* (for TF) as input, which is the previously computed
|
||||
key/value attention pairs. Using this (*past_key_values* or *past*) value prevents the model from re-computing
|
||||
pre-computed values in the context of text generation. For PyTorch, see *past_key_values* argument of the
|
||||
[`GPT2Model.forward`] method, or for TF the *past* argument of the
|
||||
[`TFGPT2Model.call`] method for more information on its usage.
|
||||
- Enabling the *scale_attn_by_inverse_layer_idx* and *reorder_and_upcast_attn* flags will apply the training stability
|
||||
improvements from [Mistral](https://github.com/stanford-crfm/mistral/) (for PyTorch only).
|
||||
<hfoptions id="usage">
|
||||
<hfoption id="Pipeline">
|
||||
|
||||
## Usage example
|
||||
```py
|
||||
import torch
|
||||
from transformers import pipeline
|
||||
|
||||
The `generate()` method can be used to generate text using GPT2 model.
|
||||
pipeline = pipeline(task="text-generation", model="openai-community/gpt2", torch_dtype=torch.float16, device=0)
|
||||
pipeline("Hello, I'm a language model")
|
||||
```
|
||||
</hfoption>
|
||||
<hfoption id="AutoModel">
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
```py
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
>>> model = AutoModelForCausalLM.from_pretrained("gpt2")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
||||
model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2", torch_dtype=torch.float16, device_map="auto", attn_implementation="sdpa")
|
||||
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
|
||||
|
||||
>>> prompt = "GPT2 is a model developed by OpenAI."
|
||||
input_ids = tokenzier("Hello, I'm a language model". return_tensors="pt").to("cuda")
|
||||
|
||||
>>> input_ids = tokenizer(prompt, return_tensors="pt").input_ids
|
||||
|
||||
>>> gen_tokens = model.generate(
|
||||
... input_ids,
|
||||
... do_sample=True,
|
||||
... temperature=0.9,
|
||||
... max_length=100,
|
||||
... )
|
||||
>>> gen_text = tokenizer.batch_decode(gen_tokens)[0]
|
||||
output = model.generate(**input_ids, cache_implementation="static")
|
||||
print(tokenizer.decode(output[0], skip_special_tokens=True))
|
||||
```
|
||||
|
||||
## Using Flash Attention 2
|
||||
|
||||
Flash Attention 2 is a faster, optimized version of the attention scores computation which relies on `cuda` kernels.
|
||||
|
||||
### Installation
|
||||
|
||||
First, check whether your hardware is compatible with Flash Attention 2. The latest list of compatible hardware can be found in the [official documentation](https://github.com/Dao-AILab/flash-attention#installation-and-features). If your hardware is not compatible with Flash Attention 2, you can still benefit from attention kernel optimisations through Better Transformer support covered [above](https://huggingface.co/docs/transformers/main/en/model_doc/bark#using-better-transformer).
|
||||
|
||||
Next, [install](https://github.com/Dao-AILab/flash-attention#installation-and-features) the latest version of Flash Attention 2:
|
||||
</hfoption>
|
||||
<hfoption id="transformers-cli">
|
||||
|
||||
```bash
|
||||
pip install -U flash-attn --no-build-isolation
|
||||
echo -e "Hello, I'm a language model" | transformers-cli run --task text-generation --model openai-community/gpt2 --device 0
|
||||
```
|
||||
|
||||
### Usage
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
To load a model using Flash Attention 2, we can pass the argument `attn_implementation="flash_attention_2"` to [`.from_pretrained`](https://huggingface.co/docs/transformers/main/en/main_classes/model#transformers.PreTrainedModel.from_pretrained). We'll also load the model in half-precision (e.g. `torch.float16`), since it results in almost no degradation to audio quality but significantly lower memory usage and faster inference:
|
||||
Quantization reduces the memory burden of large models by representing the weights in a lower precision. Refer to the [Quantization](../quantization/overview) overview for more available quantization backends.
|
||||
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
>>> device = "cuda" # the device to load the model onto
|
||||
The example below uses [bitsandbytes](../quantization/bitsandbytes) to only quantize the weights to 4-bits.
|
||||
|
||||
>>> model = AutoModelForCausalLM.from_pretrained("gpt2", torch_dtype=torch.float16, attn_implementation="flash_attention_2")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
||||
```py
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, pipeline
|
||||
|
||||
>>> prompt = "def hello_world():"
|
||||
quantization_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
bnb_4bit_compute_dtype="float16",
|
||||
bnb_4bit_use_double_quant=True
|
||||
)
|
||||
|
||||
>>> model_inputs = tokenizer([prompt], return_tensors="pt").to(device)
|
||||
>>> model.to(device)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"openai-community/gpt2-xl",
|
||||
quantization_config=quantization_config,
|
||||
device_map="auto"
|
||||
)
|
||||
|
||||
>>> generated_ids = model.generate(**model_inputs, max_new_tokens=100, do_sample=True)
|
||||
>>> tokenizer.batch_decode(generated_ids)[0]
|
||||
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2-xl")
|
||||
inputs = tokenizer("Once upon a time, there was a magical forest", return_tensors="pt").to("cuda")
|
||||
outputs = model.generate(**inputs, max_new_tokens=100)
|
||||
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
||||
```
|
||||
|
||||
## Notes
|
||||
|
||||
### Expected speedups
|
||||
|
||||
Below is an expected speedup diagram that compares pure inference time between the native implementation in transformers using `gpt2` checkpoint and the Flash Attention 2 version of the model using a sequence length of 512.
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/EduardoPacheco/documentation-images/resolve/main/gpt2_flash_attention_2_speedup.jpg">
|
||||
</div>
|
||||
|
||||
|
||||
## Using Scaled Dot Product Attention (SDPA)
|
||||
PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function
|
||||
encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the
|
||||
[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
|
||||
or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention)
|
||||
page for more information.
|
||||
|
||||
SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set
|
||||
`attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used.
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM
|
||||
model = AutoModelForCausalLM.from_pretrained("gpt2", torch_dtype=torch.float16, attn_implementation="sdpa")
|
||||
...
|
||||
```
|
||||
|
||||
For the best speedups, we recommend loading the model in half-precision (e.g. `torch.float16` or `torch.bfloat16`).
|
||||
|
||||
On a local benchmark (rtx3080ti-16GB, PyTorch 2.2.1, OS Ubuntu 22.04) using `float16` with
|
||||
[gpt2-large](https://huggingface.co/openai-community/gpt2-large), we saw the
|
||||
following speedups during training and inference.
|
||||
|
||||
### Training
|
||||
| Batch size | Seq len | Time per batch (Eager - s) | Time per batch (SDPA - s) | Speedup (%) | Eager peak mem (MB) | SDPA peak mem (MB) | Mem saving (%) |
|
||||
|-----------:|--------:|----------------------------:|--------------------------:|------------:|--------------------:|-------------------:|------------------:|
|
||||
| 1 | 128 | 0.039 | 0.032 | 23.042 | 3482.32 | 3494.62 | -0.352 |
|
||||
| 1 | 256 | 0.073 | 0.059 | 25.15 | 3546.66 | 3552.6 | -0.167 |
|
||||
| 1 | 512 | 0.155 | 0.118 | 30.96 | 4230.1 | 3665.59 | 15.4 |
|
||||
| 1 | 1024 | 0.316 | 0.209 | 50.839 | 8682.26 | 4881.09 | 77.875 |
|
||||
| 2 | 128 | 0.07 | 0.06 | 15.324 | 3557.8 | 3545.91 | 0.335 |
|
||||
| 2 | 256 | 0.143 | 0.122 | 16.53 | 3901.5 | 3657.68 | 6.666 |
|
||||
| 2 | 512 | 0.267 | 0.213 | 25.626 | 7062.21 | 4876.47 | 44.822 |
|
||||
| 2 | 1024 | OOM | 0.404 | / | OOM | 8096.35 | SDPA does not OOM |
|
||||
| 4 | 128 | 0.134 | 0.128 | 4.412 | 3675.79 | 3648.72 | 0.742 |
|
||||
| 4 | 256 | 0.243 | 0.217 | 12.292 | 6129.76 | 4871.12 | 25.839 |
|
||||
| 4 | 512 | 0.494 | 0.406 | 21.687 | 12466.6 | 8102.64 | 53.858 |
|
||||
| 4 | 1024 | OOM | 0.795 | / | OOM | 14568.2 | SDPA does not OOM |
|
||||
|
||||
### Inference
|
||||
| Batch size | Seq len | Per token latency Eager (ms) | Per token latency SDPA (ms) | Speedup (%) | Mem Eager (MB) | Mem SDPA (MB) | Mem saved (%) |
|
||||
|-----------:|--------:|-----------------------------:|----------------------------:|------------:|---------------:|--------------:|--------------:|
|
||||
| 1 | 128 | 7.991 | 6.968 | 14.681 | 1685.2 | 1701.32 | -0.947 |
|
||||
| 1 | 256 | 8.462 | 7.199 | 17.536 | 1745.49 | 1770.78 | -1.428 |
|
||||
| 1 | 512 | 8.68 | 7.853 | 10.529 | 1907.69 | 1921.29 | -0.708 |
|
||||
| 1 | 768 | 9.101 | 8.365 | 8.791 | 2032.93 | 2068.12 | -1.701 |
|
||||
| 2 | 128 | 9.169 | 9.001 | 1.861 | 1803.84 | 1811.4 | -0.418 |
|
||||
| 2 | 256 | 9.907 | 9.78 | 1.294 | 1907.72 | 1921.44 | -0.714 |
|
||||
| 2 | 512 | 11.519 | 11.644 | -1.071 | 2176.86 | 2197.75 | -0.951 |
|
||||
| 2 | 768 | 13.022 | 13.407 | -2.873 | 2464.3 | 2491.06 | -1.074 |
|
||||
| 4 | 128 | 10.097 | 9.831 | 2.709 | 1942.25 | 1985.13 | -2.16 |
|
||||
| 4 | 256 | 11.599 | 11.398 | 1.764 | 2177.28 | 2197.86 | -0.937 |
|
||||
| 4 | 512 | 14.653 | 14.45 | 1.411 | 2753.16 | 2772.57 | -0.7 |
|
||||
| 4 | 768 | 17.846 | 17.617 | 1.299 | 3327.04 | 3343.97 | -0.506 |
|
||||
|
||||
|
||||
|
||||
|
||||
## Resources
|
||||
|
||||
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with GPT2. If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we'll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource.
|
||||
|
||||
<PipelineTag pipeline="text-generation"/>
|
||||
|
||||
- A blog on how to [Finetune a non-English GPT-2 Model with Hugging Face](https://www.philschmid.de/fine-tune-a-non-english-gpt-2-model-with-huggingface).
|
||||
- A blog on [How to generate text: using different decoding methods for language generation with Transformers](https://huggingface.co/blog/how-to-generate) with GPT-2.
|
||||
- A blog on [Training CodeParrot 🦜 from Scratch](https://huggingface.co/blog/codeparrot), a large GPT-2 model.
|
||||
- A blog on [Faster Text Generation with TensorFlow and XLA](https://huggingface.co/blog/tf-xla-generate) with GPT-2.
|
||||
- A blog on [How to train a Language Model with Megatron-LM](https://huggingface.co/blog/megatron-training) with a GPT-2 model.
|
||||
- A notebook on how to [finetune GPT2 to generate lyrics in the style of your favorite artist](https://colab.research.google.com/github/AlekseyKorshuk/huggingartists/blob/master/huggingartists-demo.ipynb). 🌎
|
||||
- A notebook on how to [finetune GPT2 to generate tweets in the style of your favorite Twitter user](https://colab.research.google.com/github/borisdayma/huggingtweets/blob/master/huggingtweets-demo.ipynb). 🌎
|
||||
- [Causal language modeling](https://huggingface.co/course/en/chapter7/6?fw=pt#training-a-causal-language-model-from-scratch) chapter of the 🤗 Hugging Face Course.
|
||||
- [`GPT2LMHeadModel`] is supported by this [causal language modeling example script](https://github.com/huggingface/transformers/tree/main/examples/pytorch/language-modeling#gpt-2gpt-and-causal-language-modeling), [text generation example script](https://github.com/huggingface/transformers/tree/main/examples/pytorch/text-generation), and [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/language_modeling.ipynb).
|
||||
- [`TFGPT2LMHeadModel`] is supported by this [causal language modeling example script](https://github.com/huggingface/transformers/tree/main/examples/tensorflow/language-modeling#run_clmpy) and [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/language_modeling-tf.ipynb).
|
||||
- [`FlaxGPT2LMHeadModel`] is supported by this [causal language modeling example script](https://github.com/huggingface/transformers/tree/main/examples/flax/language-modeling#causal-language-modeling) and [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/causal_language_modeling_flax.ipynb).
|
||||
- [Text classification task guide](../tasks/sequence_classification)
|
||||
- [Token classification task guide](../tasks/token_classification)
|
||||
- [Causal language modeling task guide](../tasks/language_modeling)
|
||||
- Pad inputs on the right because GPT-2 uses absolute position embeddings.
|
||||
- GPT-2 can reuse previously computed key-value attention pairs. Access this feature with the [past_key_values](https://huggingface.co/docs/transformers//en/model_doc/gpt2#transformers.GPT2Model.forward.past_key_values) parameter in [`GPT2Model.forward`].
|
||||
- Enable the [scale_attn_by_inverse_layer_idx](https://huggingface.co/docs/transformers/en/model_doc/gpt2#transformers.GPT2Config.scale_attn_by_inverse_layer_idx) and [reorder_and_upcast_attn](https://huggingface.co/docs/transformers/en/model_doc/gpt2#transformers.GPT2Config.reorder_and_upcast_attn) parameters to apply the training stability improvements from [Mistral](./mistral).
|
||||
|
||||
## GPT2Config
|
||||
|
||||
|
||||
@ -14,96 +14,126 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# Jamba
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<div style="float: right;">
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
</div>
|
||||
|
||||
## Overview
|
||||
# Jamba
|
||||
|
||||
Jamba is a state-of-the-art, hybrid SSM-Transformer LLM. It is the first production-scale Mamba implementation, which opens up interesting research and application opportunities. While this initial experimentation shows encouraging gains, we expect these to be further enhanced with future optimizations and explorations.
|
||||
[Jamba](https://huggingface.co/papers/2403.19887) is a hybrid Transformer-Mamba mixture-of-experts (MoE) language model ranging from 52B to 398B total parameters. This model aims to combine the advantages of both model families, the performance of transformer models and the efficiency and longer context (256K tokens) of state space models (SSMs) like Mamba.
|
||||
|
||||
For full details of this model please read the [release blog post](https://www.ai21.com/blog/announcing-jamba).
|
||||
Jamba's architecture features a blocks-and-layers approach that allows Jamba to successfully integrate Transformer and Mamba architectures altogether. Each Jamba block contains either an attention or a Mamba layer, followed by a multi-layer perceptron (MLP), producing an overall ratio of one Transformer layer out of every eight total layers. MoE layers are mixed in to increase model capacity.
|
||||
|
||||
### Model Details
|
||||
You can find all the original Jamba checkpoints under the [AI21](https://huggingface.co/ai21labs) organization.
|
||||
|
||||
Jamba is a pretrained, mixture-of-experts (MoE) generative text model, with 12B active parameters and an overall of 52B parameters across all experts. It supports a 256K context length, and can fit up to 140K tokens on a single 80GB GPU.
|
||||
> [!TIP]
|
||||
> Click on the Jamba models in the right sidebar for more examples of how to apply Jamba to different language tasks.
|
||||
|
||||
As depicted in the diagram below, Jamba's architecture features a blocks-and-layers approach that allows Jamba to successfully integrate Transformer and Mamba architectures altogether. Each Jamba block contains either an attention or a Mamba layer, followed by a multi-layer perceptron (MLP), producing an overall ratio of one Transformer layer out of every eight total layers.
|
||||
The example below demonstrates how to generate text with [`Pipeline`], [`AutoModel`], and from the command line.
|
||||
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/jamba_architecture.png"
|
||||
alt="drawing" width="600"/>
|
||||
<hfoptions id="usage">
|
||||
<hfoption id="Pipeline">
|
||||
|
||||
## Usage
|
||||
```py
|
||||
# install optimized Mamba implementations
|
||||
# !pip install mamba-ssm causal-conv1d>=1.2.0
|
||||
import torch
|
||||
from transformers import pipeline
|
||||
|
||||
### Prerequisites
|
||||
|
||||
Jamba requires you use `transformers` version 4.39.0 or higher:
|
||||
```bash
|
||||
pip install transformers>=4.39.0
|
||||
pipeline = pipeline(
|
||||
task="text-generation",
|
||||
model="ai21labs/AI21-Jamba-Mini-1.6",
|
||||
torch_dtype=torch.float16,
|
||||
device=0
|
||||
)
|
||||
pipeline("Plants create energy through a process known as")
|
||||
```
|
||||
|
||||
In order to run optimized Mamba implementations, you first need to install `mamba-ssm` and `causal-conv1d`:
|
||||
```bash
|
||||
pip install mamba-ssm causal-conv1d>=1.2.0
|
||||
```
|
||||
You also have to have the model on a CUDA device.
|
||||
</hfoption>
|
||||
<hfoption id="AutoModel">
|
||||
|
||||
You can run the model not using the optimized Mamba kernels, but it is **not** recommended as it will result in significantly lower latencies. In order to do that, you'll need to specify `use_mamba_kernels=False` when loading the model.
|
||||
|
||||
### Run the model
|
||||
```python
|
||||
```py
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1")
|
||||
tokenizer = AutoTokenizer.from_pretrained("ai21labs/Jamba-v0.1")
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"ai21labs/AI21-Jamba-Large-1.6",
|
||||
)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"ai21labs/AI21-Jamba-Large-1.6",
|
||||
torch_dtype=torch.float16,
|
||||
device_map="auto",
|
||||
attn_implementation="sdpa"
|
||||
)
|
||||
input_ids = tokenizer("Plants create energy through a process known as", return_tensors="pt").to("cuda")
|
||||
|
||||
input_ids = tokenizer("In the recent Super Bowl LVIII,", return_tensors='pt').to(model.device)["input_ids"]
|
||||
output = model.generate(**input_ids, cache_implementation="static")
|
||||
print(tokenizer.decode(output[0], skip_special_tokens=True))
|
||||
```
|
||||
</hfoption>
|
||||
<hfoption id="transformers-cli">
|
||||
|
||||
```bash
|
||||
echo -e "Plants create energy through a process known as" | transformers-cli run --task text-generation --model ai21labs/AI21-Jamba-Mini-1.6 --device 0
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
Quantization reduces the memory burden of large models by representing the weights in a lower precision. Refer to the [Quantization](../quantization/overview) overview for more available quantization backends.
|
||||
|
||||
The example below uses [bitsandbytes](../quantization/bitsandbytes) to only quantize the weights to 8-bits.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
|
||||
|
||||
quantization_config = BitsAndBytesConfig(load_in_8bit=True,
|
||||
llm_int8_skip_modules=["mamba"])
|
||||
|
||||
# a device map to distribute the model evenly across 8 GPUs
|
||||
device_map = {'model.embed_tokens': 0, 'model.layers.0': 0, 'model.layers.1': 0, 'model.layers.2': 0, 'model.layers.3': 0, 'model.layers.4': 0, 'model.layers.5': 0, 'model.layers.6': 0, 'model.layers.7': 0, 'model.layers.8': 0, 'model.layers.9': 1, 'model.layers.10': 1, 'model.layers.11': 1, 'model.layers.12': 1, 'model.layers.13': 1, 'model.layers.14': 1, 'model.layers.15': 1, 'model.layers.16': 1, 'model.layers.17': 1, 'model.layers.18': 2, 'model.layers.19': 2, 'model.layers.20': 2, 'model.layers.21': 2, 'model.layers.22': 2, 'model.layers.23': 2, 'model.layers.24': 2, 'model.layers.25': 2, 'model.layers.26': 2, 'model.layers.27': 3, 'model.layers.28': 3, 'model.layers.29': 3, 'model.layers.30': 3, 'model.layers.31': 3, 'model.layers.32': 3, 'model.layers.33': 3, 'model.layers.34': 3, 'model.layers.35': 3, 'model.layers.36': 4, 'model.layers.37': 4, 'model.layers.38': 4, 'model.layers.39': 4, 'model.layers.40': 4, 'model.layers.41': 4, 'model.layers.42': 4, 'model.layers.43': 4, 'model.layers.44': 4, 'model.layers.45': 5, 'model.layers.46': 5, 'model.layers.47': 5, 'model.layers.48': 5, 'model.layers.49': 5, 'model.layers.50': 5, 'model.layers.51': 5, 'model.layers.52': 5, 'model.layers.53': 5, 'model.layers.54': 6, 'model.layers.55': 6, 'model.layers.56': 6, 'model.layers.57': 6, 'model.layers.58': 6, 'model.layers.59': 6, 'model.layers.60': 6, 'model.layers.61': 6, 'model.layers.62': 6, 'model.layers.63': 7, 'model.layers.64': 7, 'model.layers.65': 7, 'model.layers.66': 7, 'model.layers.67': 7, 'model.layers.68': 7, 'model.layers.69': 7, 'model.layers.70': 7, 'model.layers.71': 7, 'model.final_layernorm': 7, 'lm_head': 7}
|
||||
model = AutoModelForCausalLM.from_pretrained("ai21labs/AI21-Jamba-Large-1.6",
|
||||
torch_dtype=torch.bfloat16,
|
||||
attn_implementation="flash_attention_2",
|
||||
quantization_config=quantization_config,
|
||||
device_map=device_map)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("ai21labs/AI21-Jamba-Large-1.6")
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "You are an ancient oracle who speaks in cryptic but wise phrases, always hinting at deeper meanings."},
|
||||
{"role": "user", "content": "Hello!"},
|
||||
]
|
||||
|
||||
input_ids = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors='pt').to(model.device)
|
||||
|
||||
outputs = model.generate(input_ids, max_new_tokens=216)
|
||||
|
||||
print(tokenizer.batch_decode(outputs))
|
||||
# ["<|startoftext|>In the recent Super Bowl LVIII, the Kansas City Chiefs emerged victorious, defeating the San Francisco 49ers in a thrilling overtime showdown. The game was a nail-biter, with both teams showcasing their skills and determination.\n\nThe Chiefs, led by their star quarterback Patrick Mahomes, displayed their offensive prowess, while the 49ers, led by their strong defense, put up a tough fight. The game went into overtime, with the Chiefs ultimately securing the win with a touchdown.\n\nThe victory marked the Chiefs' second Super Bowl win in four years, solidifying their status as one of the top teams in the NFL. The game was a testament to the skill and talent of both teams, and a thrilling end to the NFL season.\n\nThe Super Bowl is not just about the game itself, but also about the halftime show and the commercials. This year's halftime show featured a star-studded lineup, including Usher, Alicia Keys, and Lil Jon. The show was a spectacle of music and dance, with the performers delivering an energetic and entertaining performance.\n"]
|
||||
# Decode the output
|
||||
conversation = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
||||
|
||||
# Split the conversation to get only the assistant's response
|
||||
assistant_response = conversation.split(messages[-1]['content'])[1].strip()
|
||||
print(assistant_response)
|
||||
# Output: Seek and you shall find. The path is winding, but the journey is enlightening. What wisdom do you seek from the ancient echoes?
|
||||
```
|
||||
|
||||
<details>
|
||||
<summary><strong>Loading the model in half precision</strong></summary>
|
||||
## Notes
|
||||
|
||||
The published checkpoint is saved in BF16. In order to load it into RAM in BF16/FP16, you need to specify `torch_dtype`:
|
||||
- Don't quantize the Mamba blocks to prevent model performance degradation.
|
||||
- It is not recommended to use Mamba without the optimized Mamba kernels as it results in significantly lower latencies. If you still want to use Mamba without the kernels, then set `use_mamba_kernels=False` in [`~AutoModel.from_pretrained`].
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM
|
||||
import torch
|
||||
model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1", torch_dtype=torch.bfloat16)
|
||||
# you can also use torch_dtype=torch.float16
|
||||
```
|
||||
|
||||
When using half precision, you can enable the [FlashAttention2](https://github.com/Dao-AILab/flash-attention) implementation of the Attention blocks. In order to use it, you also need the model on a CUDA device. Since in this precision the model is to big to fit on a single 80GB GPU, you'll also need to parallelize it using [accelerate](https://huggingface.co/docs/accelerate/index):
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM
|
||||
import torch
|
||||
model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1",
|
||||
torch_dtype=torch.bfloat16,
|
||||
attn_implementation="flash_attention_2",
|
||||
device_map="auto")
|
||||
```
|
||||
|
||||
</details>
|
||||
<details><summary><strong>Load the model in 8-bit</strong></summary>
|
||||
|
||||
**Using 8-bit precision, it is possible to fit up to 140K sequence lengths on a single 80GB GPU.** You can easily quantize the model to 8-bit using [bitsandbytes](https://huggingface.co/docs/bitsandbytes/index). In order to not degrade model quality, we recommend to exclude the Mamba blocks from the quantization:
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
|
||||
quantization_config = BitsAndBytesConfig(load_in_8bit=True, llm_int8_skip_modules=["mamba"])
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"ai21labs/Jamba-v0.1", torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2", quantization_config=quantization_config
|
||||
)
|
||||
```
|
||||
</details>
|
||||
```py
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM
|
||||
model = AutoModelForCausalLM.from_pretrained("ai21labs/AI21-Jamba-1.5-Large",
|
||||
use_mamba_kernels=False)
|
||||
```
|
||||
|
||||
## JambaConfig
|
||||
|
||||
|
||||
442
docs/source/en/model_doc/llama4.md
Normal file
442
docs/source/en/model_doc/llama4.md
Normal file
@ -0,0 +1,442 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# Llama4
|
||||
|
||||
|
||||
<div style="float: right;">
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
|
||||
</div>
|
||||
</div>
|
||||
|
||||
Llama 4, developed by Meta, introduces a new auto-regressive Mixture-of-Experts (MoE) architecture.
|
||||
This generation includes two models:
|
||||
- The highly capable Llama 4 Maverick with 17B active parameters out of ~400B total, with 128 experts.
|
||||
- The efficient Llama 4 Scout also has 17B active parameters out of ~109B total, using just 16 experts.
|
||||
-
|
||||
Both models leverage early fusion for native multimodality, enabling them to process text and image inputs.
|
||||
Maverick and Scout are both trained on up to 40 trillion tokens on data encompassing 200 languages
|
||||
(with specific fine-tuning support for 12 languages including Arabic, Spanish, German, and Hindi).
|
||||
|
||||
For deployment, Llama 4 Scout is designed for accessibility, fitting on a single server-grade GPU via
|
||||
on-the-fly 4-bit or 8-bitint4 quantization, while Maverick is available in BF16 and FP8 formats.
|
||||
These models are released under the custom Llama 4 Community License Agreement, available on the model repositories.
|
||||
|
||||
You can find all the original Llama checkpoints under the [meta-llama](https://huggingface.co/meta-llama) organization.
|
||||
|
||||
> [!TIP]
|
||||
> The Llama 4 family of models comes in two flavors: 109B, and 402B parameters. Both of these flavors are extremely
|
||||
> large and won't fit on your run-of-the-mill device. See below for some examples to reduce the memory usage of the
|
||||
> model.
|
||||
>
|
||||
> For the download to be faster and more resilient, we recommend installing the `hf_xet` dependency as followed:
|
||||
> `pip install transformers[hf_xet]`
|
||||
|
||||
The examples below demonstrates how to generate with [`Pipeline`] or the [`AutoModel`]. We additionally add an example
|
||||
showcasing how to toggle the right attributes to enable very long-context generations, as some flavors of Llama 4
|
||||
have context lengths going up to 10 million tokens.
|
||||
|
||||
|
||||
<hfoptions id="usage">
|
||||
<hfoption id="Pipeline">
|
||||
|
||||
```py
|
||||
from transformers import pipeline
|
||||
import torch
|
||||
|
||||
model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "what is the recipe of mayonnaise?"},
|
||||
]
|
||||
|
||||
pipe = pipeline(
|
||||
"text-generation",
|
||||
model=model_id,
|
||||
device_map="auto",
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
output = pipe(messages, do_sample=False, max_new_tokens=200)
|
||||
print(output[0]["generated_text"][-1]["content"])
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="AutoModel - Text only">
|
||||
|
||||
```py
|
||||
from transformers import AutoTokenizer, Llama4ForConditionalGeneration
|
||||
import torch
|
||||
|
||||
model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "Who are you?"},
|
||||
]
|
||||
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt", return_dict=True)
|
||||
|
||||
model = Llama4ForConditionalGeneration.from_pretrained(
|
||||
model_id,
|
||||
device_map="auto",
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
outputs = model.generate(**inputs.to(model.device), max_new_tokens=100)
|
||||
outputs = tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:])
|
||||
print(outputs[0])
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="AutoModel - Multimodal">
|
||||
|
||||
```py
|
||||
from transformers import AutoProcessor, Llama4ForConditionalGeneration
|
||||
import torch
|
||||
|
||||
model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
|
||||
|
||||
processor = AutoProcessor.from_pretrained(model_id)
|
||||
model = Llama4ForConditionalGeneration.from_pretrained(
|
||||
model_id,
|
||||
device_map="auto",
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg"
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image", "url": img_url},
|
||||
{"type": "text", "text": "Describe this image in two sentences."},
|
||||
]
|
||||
},
|
||||
]
|
||||
|
||||
inputs = processor.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
).to(model.device)
|
||||
|
||||
outputs = model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=256,
|
||||
)
|
||||
|
||||
response = processor.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:])[0]
|
||||
print(response)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="AutoModel - Multimodal with multiple images">
|
||||
|
||||
```py
|
||||
from transformers import AutoProcessor, Llama4ForConditionalGeneration
|
||||
import torch
|
||||
|
||||
model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
|
||||
|
||||
processor = AutoProcessor.from_pretrained(model_id)
|
||||
model = Llama4ForConditionalGeneration.from_pretrained(
|
||||
model_id,
|
||||
device_map="auto",
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
url1 = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg"
|
||||
url2 = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png"
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image", "url": url1},
|
||||
{"type": "image", "url": url2},
|
||||
{"type": "text", "text": "Can you describe how these two images are similar, and how they differ?"},
|
||||
]
|
||||
},
|
||||
]
|
||||
|
||||
inputs = processor.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
).to(model.device)
|
||||
|
||||
outputs = model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=256,
|
||||
)
|
||||
|
||||
response = processor.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:])[0]
|
||||
print(response)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="AutoModel - Long context">
|
||||
|
||||
Beware: the example below uses both `device_map="auto"` and flex-attention.
|
||||
Please use `torchrun` to run this example in tensor-parallel mode.
|
||||
|
||||
We will work to enable running with `device_map="auto"` and flex-attention without
|
||||
tensor-parallel in the future.
|
||||
|
||||
```py
|
||||
from transformers import Llama4ForConditionalGeneration, AutoTokenizer
|
||||
import torch
|
||||
import time
|
||||
|
||||
file = "very_long_context_prompt.txt"
|
||||
model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
|
||||
|
||||
with open(file, "r") as f:
|
||||
very_long_text = "\n".join(f.readlines())
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
model = Llama4ForConditionalGeneration.from_pretrained(
|
||||
model_id,
|
||||
device_map="auto",
|
||||
attn_implementation="flex_attention",
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": f"Look at the following texts: [{very_long_text}]\n\n\n\nWhat are the books, and who wrote them? Make me a nice list."},
|
||||
]
|
||||
input_ids = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt")
|
||||
|
||||
torch.cuda.synchronize()
|
||||
start = time.time()
|
||||
out = model.generate(
|
||||
input_ids.to(model.device),
|
||||
prefill_chunk_size=2048*8,
|
||||
max_new_tokens=300,
|
||||
cache_implementation="hybrid",
|
||||
)
|
||||
print(time.time()-start)
|
||||
print(tokenizer.batch_decode(out[:, input_ids.shape[-1]:]))
|
||||
print(f"{torch.cuda.max_memory_allocated(model.device) / 1024**3:.2f} GiB")
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
## Efficiency; how to get the best out of llama 4
|
||||
|
||||
### The Attention methods
|
||||
|
||||
Updating the default attention function can significantly improve compute performance as well as memory usage. Refer to the [Attention Interface](../attention_interface) overview for an in-depth explanation of our interface.
|
||||
|
||||
As of release, the Llama 4 model supports the following attention methods: `eager`, `flex_attention`, `sdpa`. We recommend using `flex_attention` for best results.
|
||||
Switching attention mechanism is done at the model initialization step:
|
||||
|
||||
|
||||
<hfoptions id="Attention">
|
||||
<hfoption id="Flex Attention">
|
||||
|
||||
Setting Flex Attention ensures the best results with the very long context the model can handle.
|
||||
|
||||
> [!TIP] Beware: the example below uses both `device_map="auto"` and flex-attention.
|
||||
> Please use `torchrun` to run this example in tensor-parallel mode.
|
||||
>
|
||||
> We will work to enable running with `device_map="auto"` and flex-attention without
|
||||
> tensor-parallel in the future.
|
||||
|
||||
```py
|
||||
from transformers import Llama4ForConditionalGeneration
|
||||
import torch
|
||||
|
||||
model = Llama4ForConditionalGeneration.from_pretrained(
|
||||
model_id,
|
||||
attn_implementation="flex_attention",
|
||||
device_map="auto",
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
```
|
||||
</hfoption>
|
||||
<hfoption id="SDPA">
|
||||
The `sdpa` attention method is generally more compute-efficient than the `eager` method.
|
||||
|
||||
```py
|
||||
from transformers import Llama4ForConditionalGeneration
|
||||
import torch
|
||||
|
||||
model = Llama4ForConditionalGeneration.from_pretrained(
|
||||
model_id,
|
||||
attn_implementation="sdpa",
|
||||
device_map="auto",
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
```
|
||||
</hfoption>
|
||||
<hfoption id="Eager">
|
||||
The `eager` attention method is set by default, so no need for anything different when loading the model:
|
||||
|
||||
```py
|
||||
from transformers import Llama4ForConditionalGeneration
|
||||
import torch
|
||||
|
||||
model = Llama4ForConditionalGeneration.from_pretrained(
|
||||
model_id,
|
||||
device_map="auto",
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
```
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
|
||||
### Quantization
|
||||
|
||||
Quantization reduces the memory burden of large models by representing the weights in a lower precision. Refer to the [Quantization](../quantization/overview) overview for available quantization backends.
|
||||
At time of release, both FBGEMM and LLM-Compressor are supported; more quantization methods will be supported in the days that follow the release.
|
||||
|
||||
See below for examples using both:
|
||||
|
||||
|
||||
|
||||
Here is an example loading an BF16 model in FP8 using the FBGEMM approach:
|
||||
|
||||
<hfoptions id="Quantization">
|
||||
<hfoption id="FBGEMM">
|
||||
|
||||
```python
|
||||
from transformers import AutoTokenizer, Llama4ForConditionalGeneration, FbgemmFp8Config
|
||||
import torch
|
||||
|
||||
model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "Who are you?"},
|
||||
]
|
||||
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt", return_dict=True)
|
||||
|
||||
model = Llama4ForConditionalGeneration.from_pretrained(
|
||||
model_id,
|
||||
device_map="auto",
|
||||
torch_dtype=torch.bfloat16,
|
||||
quantization_config=FbgemmFp8Config()
|
||||
)
|
||||
|
||||
outputs = model.generate(**inputs.to(model.device), max_new_tokens=100)
|
||||
outputs = tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:])
|
||||
print(outputs[0])
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="LLM-Compressor">
|
||||
|
||||
To use the LLM-Compressor technique, we recommend leveraging the pre-quantized FP8 checkpoint available with the release:
|
||||
|
||||
```python
|
||||
from transformers import AutoTokenizer, Llama4ForConditionalGeneration
|
||||
import torch
|
||||
|
||||
model_id = "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8"
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "Who are you?"},
|
||||
]
|
||||
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt", return_dict=True)
|
||||
|
||||
model = Llama4ForConditionalGeneration.from_pretrained(
|
||||
model_id,
|
||||
tp_plan="auto",
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
outputs = model.generate(**inputs.to(model.device), max_new_tokens=100)
|
||||
outputs = tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:])
|
||||
print(outputs[0])
|
||||
```
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
### Offloading
|
||||
|
||||
Enabling CPU-offloading means that components of the model might be moved to CPU instead of GPU in case the GPU-memory available isn't sufficient to load the entire model.
|
||||
At inference, different components will be loaded/unloaded from/to the GPU on the fly. This ensures that the model can be loaded on smaller machines as long as the CPU-memory is sufficient.
|
||||
However, this also slows down inference as it adds communication overhead.
|
||||
|
||||
In order to enable CPU-offloading, you simply need to specify the `device_map` to `auto` at model load:
|
||||
|
||||
```py
|
||||
from transformers import Llama4ForConditionalGeneration
|
||||
import torch
|
||||
|
||||
model = Llama4ForConditionalGeneration.from_pretrained(
|
||||
model_id,
|
||||
device_map="auto",
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
```
|
||||
|
||||
## Llama4Config
|
||||
|
||||
[[autodoc]] Llama4Config
|
||||
|
||||
## Llama4TextConfig
|
||||
|
||||
[[autodoc]] Llama4TextConfig
|
||||
|
||||
## Llama4VisionConfig
|
||||
|
||||
[[autodoc]] Llama4VisionConfig
|
||||
|
||||
## Llama4Processor
|
||||
|
||||
[[autodoc]] Llama4Processor
|
||||
|
||||
## Llama4ImageProcessorFast
|
||||
|
||||
[[autodoc]] Llama4ImageProcessorFast
|
||||
|
||||
## Llama4ForConditionalGeneration
|
||||
|
||||
[[autodoc]] Llama4ForConditionalGeneration
|
||||
- forward
|
||||
|
||||
## Llama4ForCausalLM
|
||||
|
||||
[[autodoc]] Llama4ForCausalLM
|
||||
- forward
|
||||
|
||||
## Llama4TextModel
|
||||
|
||||
[[autodoc]] Llama4TextModel
|
||||
- forward
|
||||
|
||||
## Llama4ForCausalLM
|
||||
|
||||
[[autodoc]] Llama4ForCausalLM
|
||||
- forward
|
||||
|
||||
## Llama4VisionModel
|
||||
|
||||
[[autodoc]] Llama4VisionModel
|
||||
- forward
|
||||
@ -14,74 +14,55 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# Mistral
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="TensorFlow" src="https://img.shields.io/badge/TensorFlow-FF6F00?style=flat&logo=tensorflow&logoColor=white">
|
||||
<img alt="Flax" src="https://img.shields.io/badge/Flax-29a79b.svg?style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAC0AAAAtCAMAAAANxBKoAAAC7lBMVEUAAADg5vYHPVgAoJH+/v76+v39/f9JbLP///9+AIgAnY3///+mcqzt8fXy9fgkXa3Ax9709fr+///9/f8qXq49qp5AaLGMwrv8/P0eW60VWawxYq8yqJzG2dytt9Wyu9elzci519Lf3O3S2efY3OrY0+Xp7PT///////+dqNCexMc6Z7AGpJeGvbenstPZ5ejQ1OfJzOLa7ejh4+/r8fT29vpccbklWK8PVa0AS6ghW63O498vYa+lsdKz1NDRt9Kw1c672tbD3tnAxt7R6OHp5vDe7OrDyuDn6vLl6/EAQKak0MgATakkppo3ZK/Bz9y8w9yzu9jey97axdvHzeG21NHH4trTwthKZrVGZLSUSpuPQJiGAI+GAI8SWKydycLL4d7f2OTi1+S9xNzL0ePT6OLGzeEAo5U0qJw/aLEAo5JFa7JBabEAp5Y4qZ2QxLyKmsm3kL2xoMOehrRNb7RIbbOZgrGre68AUqwAqZqNN5aKJ5N/lMq+qsd8kMa4pcWzh7muhLMEV69juq2kbKqgUaOTR5uMMZWLLZSGAI5VAIdEAH+ovNDHuNCnxcy3qcaYx8K8msGplrx+wLahjbYdXrV6vbMvYK9DrZ8QrZ8tqJuFms+Sos6sw8ecy8RffsNVeMCvmb43aLltv7Q4Y7EZWK4QWa1gt6meZKUdr6GOAZVeA4xPAISyveLUwtivxtKTpNJ2jcqfvcltiMiwwcfAoMVxhL+Kx7xjdrqTe60tsaNQs6KaRKACrJ6UTZwkqpqTL5pkHY4AloSgsd2ptNXPvNOOncuxxsqFl8lmg8apt8FJcr9EbryGxLqlkrkrY7dRa7ZGZLQ5t6iXUZ6PPpgVpZeJCJFKAIGareTa0+KJod3H0deY2M+esM25usmYu8d2zsJOdcBVvrCLbqcAOaaHaKQAMaScWqKBXqCXMJ2RHpiLF5NmJZAdAHN2kta11dKu1M+DkcZLdb+Mcql3TppyRJdzQ5ZtNZNlIY+DF4+voCOQAAAAZ3RSTlMABAT+MEEJ/RH+/TP+Zlv+pUo6Ifz8+fco/fz6+evr39S9nJmOilQaF/7+/f38+smmoYp6b1T+/v7++vj189zU0tDJxsGzsrKSfv34+Pf27dDOysG9t6+n/vv6+vr59uzr1tG+tZ6Qg9Ym3QAABR5JREFUSMeNlVVUG1EQhpcuxEspXqS0SKEtxQp1d3d332STTRpIQhIISQgJhODu7lAoDoUCpe7u7u7+1puGpqnCPOyZvffbOXPm/PsP9JfQgyCC+tmTABTOcbxDz/heENS7/1F+9nhvkHePG0wNDLbGWwdXL+rbLWvpmZHXD8+gMfBjTh+aSe6Gnn7lwQIOTR0c8wfX3PWgv7avbdKwf/ZoBp1Gp/PvuvXW3vw5ib7emnTW4OR+3D4jB9vjNJ/7gNvfWWeH/TO/JyYrsiKCRjVEZA3UB+96kON+DxOQ/NLE8PE5iUYgIXjFnCOlxEQMaSGVxjg4gxOnEycGz8bptuNjVx08LscIgrzH3umcn+KKtiBIyvzOO2O99aAdR8cF19oZalnCtvREUw79tCd5sow1g1UKM6kXqUx4T8wsi3sTjJ3yzDmmhenLXLpo8u45eG5y4Vvbk6kkC4LLtJMowkSQxmk4ggVJEG+7c6QpHT8vvW9X7/o7+3ELmiJi2mEzZJiz8cT6TBlanBk70cB5GGIGC1gRDdZ00yADLW1FL6gqhtvNXNG5S9gdSrk4M1qu7JAsmYshzDS4peoMrU/gT7qQdqYGZaYhxZmVbGJAm/CS/HloWyhRUlknQ9KYcExTwS80d3VNOxUZJpITYyspl0LbhArhpZCD9cRWEQuhYkNGMHToQ/2Cs6swJlb39CsllxdXX6IUKh/H5jbnSsPKjgmoaFQ1f8wRLR0UnGE/RcDEjj2jXG1WVTwUs8+zxfcrVO+vSsuOpVKxCfYZiQ0/aPKuxQbQ8lIz+DClxC8u+snlcJ7Yr1z1JPqUH0V+GDXbOwAib931Y4Imaq0NTIXPXY+N5L18GJ37SVWu+hwXff8l72Ds9XuwYIBaXPq6Shm4l+Vl/5QiOlV+uTk6YR9PxKsI9xNJny31ygK1e+nIRC1N97EGkFPI+jCpiHe5PCEy7oWqWSwRrpOvhFzcbTWMbm3ZJAOn1rUKpYIt/lDhW/5RHHteeWFN60qo98YJuoq1nK3uW5AabyspC1BcIEpOhft+SZAShYoLSvnmSfnYADUERP5jJn2h5XtsgCRuhYQqAvwTwn33+YWEKUI72HX5AtfSAZDe8F2DtPPm77afhl0EkthzuCQU0BWApgQIH9+KB0JhopMM7bJrdTRoleM2JAVNMyPF+wdoaz+XJpGoVAQ7WXUkcV7gT3oUZyi/ISIJAVKhgNp+4b4veCFhYVJw4locdSjZCp9cPUhLF9EZ3KKzURepMEtCDPP3VcWFx4UIiZIklIpFNfHpdEafIF2aRmOcrUmjohbT2WUllbmRvgfbythbQO3222fpDJoufaQPncYYuqoGtUEsCJZL6/3PR5b4syeSjZMQG/T2maGANlXT2v8S4AULWaUkCxfLyW8iW4kdka+nEMjxpL2NCwsYNBp+Q61PF43zyDg9Bm9+3NNySn78jMZUUkumqE4Gp7JmFOdP1vc8PpRrzj9+wPinCy8K1PiJ4aYbnTYpCCbDkBSbzhu2QJ1Gd82t8jI8TH51+OzvXoWbnXUOBkNW+0mWFwGcGOUVpU81/n3TOHb5oMt2FgYGjzau0Nif0Ss7Q3XB33hjjQHjHA5E5aOyIQc8CBrLdQSs3j92VG+3nNEjbkbdbBr9zm04ruvw37vh0QKOdeGIkckc80fX3KH/h7PT4BOjgCty8VZ5ux1MoO5Cf5naca2LAsEgehI+drX8o/0Nu+W0m6K/I9gGPd/dfx/EN/wN62AhsBWuAAAAAElFTkSuQmCC
|
||||
">
|
||||
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<div style="float: right;">
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="TensorFlow" src="https://img.shields.io/badge/TensorFlow-FF6F00?style=flat&logo=tensorflow&logoColor=white">
|
||||
<img alt="Flax" src="https://img.shields.io/badge/Flax-29a79b.svg?style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAC0AAAAtCAMAAAANxBKoAAAC7lBMVEUAAADg5vYHPVgAoJH+/v76+v39/f9JbLP///9+AIgAnY3///+mcqzt8fXy9fgkXa3Ax9709fr+///9/f8qXq49qp5AaLGMwrv8/P0eW60VWawxYq8yqJzG2dytt9Wyu9elzci519Lf3O3S2efY3OrY0+Xp7PT///////+dqNCexMc6Z7AGpJeGvbenstPZ5ejQ1OfJzOLa7ejh4+/r8fT29vpccbklWK8PVa0AS6ghW63O498vYa+lsdKz1NDRt9Kw1c672tbD3tnAxt7R6OHp5vDe7OrDyuDn6vLl6/EAQKak0MgATakkppo3ZK/Bz9y8w9yzu9jey97axdvHzeG21NHH4trTwthKZrVGZLSUSpuPQJiGAI+GAI8SWKydycLL4d7f2OTi1+S9xNzL0ePT6OLGzeEAo5U0qJw/aLEAo5JFa7JBabEAp5Y4qZ2QxLyKmsm3kL2xoMOehrRNb7RIbbOZgrGre68AUqwAqZqNN5aKJ5N/lMq+qsd8kMa4pcWzh7muhLMEV69juq2kbKqgUaOTR5uMMZWLLZSGAI5VAIdEAH+ovNDHuNCnxcy3qcaYx8K8msGplrx+wLahjbYdXrV6vbMvYK9DrZ8QrZ8tqJuFms+Sos6sw8ecy8RffsNVeMCvmb43aLltv7Q4Y7EZWK4QWa1gt6meZKUdr6GOAZVeA4xPAISyveLUwtivxtKTpNJ2jcqfvcltiMiwwcfAoMVxhL+Kx7xjdrqTe60tsaNQs6KaRKACrJ6UTZwkqpqTL5pkHY4AloSgsd2ptNXPvNOOncuxxsqFl8lmg8apt8FJcr9EbryGxLqlkrkrY7dRa7ZGZLQ5t6iXUZ6PPpgVpZeJCJFKAIGareTa0+KJod3H0deY2M+esM25usmYu8d2zsJOdcBVvrCLbqcAOaaHaKQAMaScWqKBXqCXMJ2RHpiLF5NmJZAdAHN2kta11dKu1M+DkcZLdb+Mcql3TppyRJdzQ5ZtNZNlIY+DF4+voCOQAAAAZ3RSTlMABAT+MEEJ/RH+/TP+Zlv+pUo6Ifz8+fco/fz6+evr39S9nJmOilQaF/7+/f38+smmoYp6b1T+/v7++vj189zU0tDJxsGzsrKSfv34+Pf27dDOysG9t6+n/vv6+vr59uzr1tG+tZ6Qg9Ym3QAABR5JREFUSMeNlVVUG1EQhpcuxEspXqS0SKEtxQp1d3d332STTRpIQhIISQgJhODu7lAoDoUCpe7u7u7+1puGpqnCPOyZvffbOXPm/PsP9JfQgyCC+tmTABTOcbxDz/heENS7/1F+9nhvkHePG0wNDLbGWwdXL+rbLWvpmZHXD8+gMfBjTh+aSe6Gnn7lwQIOTR0c8wfX3PWgv7avbdKwf/ZoBp1Gp/PvuvXW3vw5ib7emnTW4OR+3D4jB9vjNJ/7gNvfWWeH/TO/JyYrsiKCRjVEZA3UB+96kON+DxOQ/NLE8PE5iUYgIXjFnCOlxEQMaSGVxjg4gxOnEycGz8bptuNjVx08LscIgrzH3umcn+KKtiBIyvzOO2O99aAdR8cF19oZalnCtvREUw79tCd5sow1g1UKM6kXqUx4T8wsi3sTjJ3yzDmmhenLXLpo8u45eG5y4Vvbk6kkC4LLtJMowkSQxmk4ggVJEG+7c6QpHT8vvW9X7/o7+3ELmiJi2mEzZJiz8cT6TBlanBk70cB5GGIGC1gRDdZ00yADLW1FL6gqhtvNXNG5S9gdSrk4M1qu7JAsmYshzDS4peoMrU/gT7qQdqYGZaYhxZmVbGJAm/CS/HloWyhRUlknQ9KYcExTwS80d3VNOxUZJpITYyspl0LbhArhpZCD9cRWEQuhYkNGMHToQ/2Cs6swJlb39CsllxdXX6IUKh/H5jbnSsPKjgmoaFQ1f8wRLR0UnGE/RcDEjj2jXG1WVTwUs8+zxfcrVO+vSsuOpVKxCfYZiQ0/aPKuxQbQ8lIz+DClxC8u+snlcJ7Yr1z1JPqUH0V+GDXbOwAib931Y4Imaq0NTIXPXY+N5L18GJ37SVWu+hwXff8l72Ds9XuwYIBaXPq6Shm4l+Vl/5QiOlV+uTk6YR9PxKsI9xNJny31ygK1e+nIRC1N97EGkFPI+jCpiHe5PCEy7oWqWSwRrpOvhFzcbTWMbm3ZJAOn1rUKpYIt/lDhW/5RHHteeWFN60qo98YJuoq1nK3uW5AabyspC1BcIEpOhft+SZAShYoLSvnmSfnYADUERP5jJn2h5XtsgCRuhYQqAvwTwn33+YWEKUI72HX5AtfSAZDe8F2DtPPm77afhl0EkthzuCQU0BWApgQIH9+KB0JhopMM7bJrdTRoleM2JAVNMyPF+wdoaz+XJpGoVAQ7WXUkcV7gT3oUZyi/ISIJAVKhgNp+4b4veCFhYVJw4locdSjZCp9cPUhLF9EZ3KKzURepMEtCDPP3VcWFx4UIiZIklIpFNfHpdEafIF2aRmOcrUmjohbT2WUllbmRvgfbythbQO3222fpDJoufaQPncYYuqoGtUEsCJZL6/3PR5b4syeSjZMQG/T2maGANlXT2v8S4AULWaUkCxfLyW8iW4kdka+nEMjxpL2NCwsYNBp+Q61PF43zyDg9Bm9+3NNySn78jMZUUkumqE4Gp7JmFOdP1vc8PpRrzj9+wPinCy8K1PiJ4aYbnTYpCCbDkBSbzhu2QJ1Gd82t8jI8TH51+OzvXoWbnXUOBkNW+0mWFwGcGOUVpU81/n3TOHb5oMt2FgYGjzau0Nif0Ss7Q3XB33hjjQHjHA5E5aOyIQc8CBrLdQSs3j92VG+3nNEjbkbdbBr9zm04ruvw37vh0QKOdeGIkckc80fX3KH/h7PT4BOjgCty8VZ5ux1MoO5Cf5naca2LAsEgehI+drX8o/0Nu+W0m6K/I9gGPd/dfx/EN/wN62AhsBWuAAAAAElFTkSuQmCC
|
||||
">
|
||||
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
</div>
|
||||
|
||||
## Overview
|
||||
# Mistral
|
||||
|
||||
Mistral was introduced in the [this blogpost](https://mistral.ai/news/announcing-mistral-7b/) by Albert Jiang, Alexandre Sablayrolles, Arthur Mensch, Chris Bamford, Devendra Singh Chaplot, Diego de las Casas, Florian Bressand, Gianna Lengyel, Guillaume Lample, Lélio Renard Lavaud, Lucile Saulnier, Marie-Anne Lachaux, Pierre Stock, Teven Le Scao, Thibaut Lavril, Thomas Wang, Timothée Lacroix, William El Sayed.
|
||||
[Mistral](https://huggingface.co/papers/2310.06825) is a 7B parameter language model, available as a pretrained and instruction-tuned variant, focused on balancing
|
||||
the scaling costs of large models with performance and efficient inference. This model uses sliding window attention (SWA) trained with a 8K context length and a fixed cache size to handle longer sequences more effectively. Grouped-query attention (GQA) speeds up inference and reduces memory requirements. Mistral also features a byte-fallback BPE tokenizer to improve token handling and efficiency by ensuring characters are never mapped to out-of-vocabulary tokens.
|
||||
|
||||
The introduction of the blog post says:
|
||||
You can find all the original Mistral checkpoints under the [Mistral AI_](https://huggingface.co/mistralai) organization.
|
||||
|
||||
*Mistral AI team is proud to release Mistral 7B, the most powerful language model for its size to date.*
|
||||
> [!TIP]
|
||||
> Click on the Mistral models in the right sidebar for more examples of how to apply Mistral to different language tasks.
|
||||
|
||||
Mistral-7B is the first large language model (LLM) released by [mistral.ai](https://mistral.ai/).
|
||||
The example below demonstrates how to chat with [`Pipeline`] or the [`AutoModel`], and from the command line.
|
||||
|
||||
### Architectural details
|
||||
|
||||
Mistral-7B is a decoder-only Transformer with the following architectural choices:
|
||||
|
||||
- Sliding Window Attention - Trained with 8k context length and fixed cache size, with a theoretical attention span of 128K tokens
|
||||
- GQA (Grouped Query Attention) - allowing faster inference and lower cache size.
|
||||
- Byte-fallback BPE tokenizer - ensures that characters are never mapped to out of vocabulary tokens.
|
||||
|
||||
For more details refer to the [release blog post](https://mistral.ai/news/announcing-mistral-7b/).
|
||||
|
||||
### License
|
||||
|
||||
`Mistral-7B` is released under the Apache 2.0 license.
|
||||
|
||||
## Usage tips
|
||||
|
||||
The Mistral team has released 3 checkpoints:
|
||||
|
||||
- a base model, [Mistral-7B-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1), which has been pre-trained to predict the next token on internet-scale data.
|
||||
- an instruction tuned model, [Mistral-7B-Instruct-v0.1](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1), which is the base model optimized for chat purposes using supervised fine-tuning (SFT) and direct preference optimization (DPO).
|
||||
- an improved instruction tuned model, [Mistral-7B-Instruct-v0.2](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2), which improves upon v1.
|
||||
|
||||
The base model can be used as follows:
|
||||
<hfoptions id="usage">
|
||||
<hfoption id="Pipeline">
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
>>> import torch
|
||||
>>> from transformers import pipeline
|
||||
|
||||
>>> model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", device_map="auto")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
|
||||
>>> messages = [
|
||||
... {"role": "user", "content": "What is your favourite condiment?"},
|
||||
... {"role": "assistant", "content": "Well, I'm quite partial to a good squeeze of fresh lemon juice. It adds just the right amount of zesty flavour to whatever I'm cooking up in the kitchen!"},
|
||||
... {"role": "user", "content": "Do you have mayonnaise recipes?"}
|
||||
... ]
|
||||
|
||||
>>> prompt = "My favourite condiment is"
|
||||
|
||||
>>> model_inputs = tokenizer([prompt], return_tensors="pt").to("cuda")
|
||||
>>> model.to(device)
|
||||
|
||||
>>> generated_ids = model.generate(**model_inputs, max_new_tokens=100, do_sample=True)
|
||||
>>> tokenizer.batch_decode(generated_ids)[0]
|
||||
"My favourite condiment is to ..."
|
||||
>>> chatbot = pipeline("text-generation", model="mistralai/Mistral-7B-Instruct-v0.3", torch_dtype=torch.bfloat16, device=0)
|
||||
>>> chatbot(messages)
|
||||
```
|
||||
|
||||
The instruction tuned model can be used as follows:
|
||||
</hfoption>
|
||||
<hfoption id="AutoModel">
|
||||
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
>>> model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2", device_map="auto")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
|
||||
>>> model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3", torch_dtype=torch.bfloat16, attn_implementation="sdpa", device_map="auto")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3")
|
||||
|
||||
>>> messages = [
|
||||
... {"role": "user", "content": "What is your favourite condiment?"},
|
||||
@ -96,59 +77,20 @@ The instruction tuned model can be used as follows:
|
||||
"Mayonnaise can be made as follows: (...)"
|
||||
```
|
||||
|
||||
As can be seen, the instruction-tuned model requires a [chat template](../chat_templating) to be applied to make sure the inputs are prepared in the right format.
|
||||
|
||||
## Speeding up Mistral by using Flash Attention
|
||||
|
||||
The code snippets above showcase inference without any optimization tricks. However, one can drastically speed up the model by leveraging [Flash Attention](../perf_train_gpu_one#flash-attention-2), which is a faster implementation of the attention mechanism used inside the model.
|
||||
|
||||
First, make sure to install the latest version of Flash Attention 2 to include the sliding window attention feature.
|
||||
|
||||
```bash
|
||||
pip install -U flash-attn --no-build-isolation
|
||||
```
|
||||
|
||||
Make also sure that you have a hardware that is compatible with Flash-Attention 2. Read more about it in the official documentation of the [flash attention repository](https://github.com/Dao-AILab/flash-attention). Make also sure to load your model in half-precision (e.g. `torch.float16`)
|
||||
|
||||
To load and run a model using Flash Attention-2, refer to the snippet below:
|
||||
</hfoption>
|
||||
<hfoption id="transformers-cli">
|
||||
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
>>> model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", torch_dtype=torch.float16, attn_implementation="flash_attention_2", device_map="auto")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
|
||||
|
||||
>>> prompt = "My favourite condiment is"
|
||||
|
||||
>>> model_inputs = tokenizer([prompt], return_tensors="pt").to("cuda")
|
||||
>>> model.to(device)
|
||||
|
||||
>>> generated_ids = model.generate(**model_inputs, max_new_tokens=100, do_sample=True)
|
||||
>>> tokenizer.batch_decode(generated_ids)[0]
|
||||
"My favourite condiment is to (...)"
|
||||
echo -e "My favorite condiment is" | transformers-cli chat --model_name_or_path mistralai/Mistral-7B-v0.3 --torch_dtype auto --device 0 --attn_implementation flash_attention_2
|
||||
```
|
||||
|
||||
### Expected speedups
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
Below is a expected speedup diagram that compares pure inference time between the native implementation in transformers using `mistralai/Mistral-7B-v0.1` checkpoint and the Flash Attention 2 version of the model.
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/ybelkada/documentation-images/resolve/main/mistral-7b-inference-large-seqlen.png">
|
||||
</div>
|
||||
Quantization reduces the memory burden of large models by representing the weights in a lower precision. Refer to the [Quantization](../quantization/overview) overview for more available quantization backends.
|
||||
|
||||
### Sliding window Attention
|
||||
|
||||
The current implementation supports the sliding window attention mechanism and memory efficient cache management.
|
||||
To enable sliding window attention, just make sure to have a `flash-attn` version that is compatible with sliding window attention (`>=2.3.0`).
|
||||
|
||||
The Flash Attention-2 model uses also a more memory efficient cache slicing mechanism - as recommended per the official implementation of Mistral model that use rolling cache mechanism we keep the cache size fixed (`self.config.sliding_window`), support batched generation only for `padding_side="left"` and use the absolute position of the current token to compute the positional embedding.
|
||||
|
||||
## Shrinking down Mistral using quantization
|
||||
|
||||
As the Mistral model has 7 billion parameters, that would require about 14GB of GPU RAM in half precision (float16), since each parameter is stored in 2 bytes. However, one can shrink down the size of the model using [quantization](../quantization.md). If the model is quantized to 4 bits (or half a byte per parameter),that requires only about 3.5GB of RAM.
|
||||
|
||||
Quantizing a model is as simple as passing a `quantization_config` to the model. Below, we'll leverage the BitsAndyBytes quantization (but refer to [this page](../quantization.md) for other quantization methods):
|
||||
The example below uses [bitsandbytes](../quantization/bitsandbytes) to only quantize the weights to 4-bits.
|
||||
|
||||
```python
|
||||
>>> import torch
|
||||
@ -161,8 +103,8 @@ Quantizing a model is as simple as passing a `quantization_config` to the model.
|
||||
... bnb_4bit_compute_dtype="torch.float16",
|
||||
... )
|
||||
|
||||
>>> model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2", quantization_config=True, device_map="auto")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
|
||||
>>> model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3", quantization_config=True, torch_dtype=torch.bfloat16, device_map="auto")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3")
|
||||
|
||||
>>> prompt = "My favourite condiment is"
|
||||
|
||||
@ -179,19 +121,18 @@ Quantizing a model is as simple as passing a `quantization_config` to the model.
|
||||
"The expected output"
|
||||
```
|
||||
|
||||
This model was contributed by [Younes Belkada](https://huggingface.co/ybelkada) and [Arthur Zucker](https://huggingface.co/ArthurZ) .
|
||||
The original code can be found [here](https://github.com/mistralai/mistral-src).
|
||||
Use the [AttentionMaskVisualizer](https://github.com/huggingface/transformers/blob/beb9b5b02246b9b7ee81ddf938f93f44cfeaad19/src/transformers/utils/attention_visualizer.py#L139) to better understand what tokens the model can and cannot attend to.
|
||||
|
||||
## Resources
|
||||
```py
|
||||
>>> from transformers.utils.attention_visualizer import AttentionMaskVisualizer
|
||||
|
||||
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with Mistral. If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we'll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource.
|
||||
>>> visualizer = AttentionMaskVisualizer("mistralai/Mistral-7B-Instruct-v0.3")
|
||||
>>> visualizer("Do you have mayonnaise recipes?")
|
||||
```
|
||||
|
||||
<PipelineTag pipeline="text-generation"/>
|
||||
|
||||
- A demo notebook to perform supervised fine-tuning (SFT) of Mistral-7B can be found [here](https://github.com/NielsRogge/Transformers-Tutorials/blob/master/Mistral/Supervised_fine_tuning_(SFT)_of_an_LLM_using_Hugging_Face_tooling.ipynb). 🌎
|
||||
- A [blog post](https://www.philschmid.de/fine-tune-llms-in-2024-with-trl) on how to fine-tune LLMs in 2024 using Hugging Face tooling. 🌎
|
||||
- The [Alignment Handbook](https://github.com/huggingface/alignment-handbook) by Hugging Face includes scripts and recipes to perform supervised fine-tuning (SFT) and direct preference optimization with Mistral-7B. This includes scripts for full fine-tuning, QLoRa on a single GPU as well as multi-GPU fine-tuning.
|
||||
- [Causal language modeling task guide](../tasks/language_modeling)
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/mistral-attn-mask.png"/>
|
||||
</div>
|
||||
|
||||
## MistralConfig
|
||||
|
||||
@ -245,4 +186,4 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h
|
||||
## TFMistralForSequenceClassification
|
||||
|
||||
[[autodoc]] TFMistralForSequenceClassification
|
||||
- call
|
||||
- call
|
||||
|
||||
@ -14,52 +14,81 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# MobileBERT
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="TensorFlow" src="https://img.shields.io/badge/TensorFlow-FF6F00?style=flat&logo=tensorflow&logoColor=white">
|
||||
<div style="float: right;">
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="TensorFlow" src="https://img.shields.io/badge/TensorFlow-FF6F00?style=flat&logo=tensorflow&logoColor=white">
|
||||
</div>
|
||||
</div>
|
||||
|
||||
## Overview
|
||||
# MobileBERT
|
||||
|
||||
The MobileBERT model was proposed in [MobileBERT: a Compact Task-Agnostic BERT for Resource-Limited Devices](https://arxiv.org/abs/2004.02984) by Zhiqing Sun, Hongkun Yu, Xiaodan Song, Renjie Liu, Yiming Yang, and Denny
|
||||
Zhou. It's a bidirectional transformer based on the BERT model, which is compressed and accelerated using several
|
||||
approaches.
|
||||
[MobileBERT](https://huggingface.co/papers/2004.02984) is a lightweight and efficient variant of BERT, specifically designed for resource-limited devices such as mobile phones. It retains BERT's architecture but significantly reduces model size and inference latency while maintaining strong performance on NLP tasks. MobileBERT achieves this through a bottleneck structure and carefully balanced self-attention and feedforward networks. The model is trained by knowledge transfer from a large BERT model with an inverted bottleneck structure.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
You can find the original MobileBERT checkpoint under the [Google](https://huggingface.co/google/mobilebert-uncased) organization.
|
||||
> [!TIP]
|
||||
> Click on the MobileBERT models in the right sidebar for more examples of how to apply MobileBERT to different language tasks.
|
||||
|
||||
*Natural Language Processing (NLP) has recently achieved great success by using huge pre-trained models with hundreds
|
||||
of millions of parameters. However, these models suffer from heavy model sizes and high latency such that they cannot
|
||||
be deployed to resource-limited mobile devices. In this paper, we propose MobileBERT for compressing and accelerating
|
||||
the popular BERT model. Like the original BERT, MobileBERT is task-agnostic, that is, it can be generically applied to
|
||||
various downstream NLP tasks via simple fine-tuning. Basically, MobileBERT is a thin version of BERT_LARGE, while
|
||||
equipped with bottleneck structures and a carefully designed balance between self-attentions and feed-forward networks.
|
||||
To train MobileBERT, we first train a specially designed teacher model, an inverted-bottleneck incorporated BERT_LARGE
|
||||
model. Then, we conduct knowledge transfer from this teacher to MobileBERT. Empirical studies show that MobileBERT is
|
||||
4.3x smaller and 5.5x faster than BERT_BASE while achieving competitive results on well-known benchmarks. On the
|
||||
natural language inference tasks of GLUE, MobileBERT achieves a GLUEscore o 77.7 (0.6 lower than BERT_BASE), and 62 ms
|
||||
latency on a Pixel 4 phone. On the SQuAD v1.1/v2.0 question answering task, MobileBERT achieves a dev F1 score of
|
||||
90.0/79.2 (1.5/2.1 higher than BERT_BASE).*
|
||||
The example below demonstrates how to predict the `[MASK]` token with [`Pipeline`], [`AutoModel`], and from the command line.
|
||||
|
||||
This model was contributed by [vshampor](https://huggingface.co/vshampor). The original code can be found [here](https://github.com/google-research/google-research/tree/master/mobilebert).
|
||||
<hfoptions id="usage">
|
||||
<hfoption id="Pipeline">
|
||||
|
||||
## Usage tips
|
||||
```py
|
||||
import torch
|
||||
from transformers import pipeline
|
||||
|
||||
- MobileBERT is a model with absolute position embeddings so it's usually advised to pad the inputs on the right rather
|
||||
than the left.
|
||||
- MobileBERT is similar to BERT and therefore relies on the masked language modeling (MLM) objective. It is therefore
|
||||
efficient at predicting masked tokens and at NLU in general, but is not optimal for text generation. Models trained
|
||||
with a causal language modeling (CLM) objective are better in that regard.
|
||||
pipeline = pipeline(
|
||||
task="fill-mask",
|
||||
model="google/mobilebert-uncased",
|
||||
torch_dtype=torch.float16,
|
||||
device=0
|
||||
)
|
||||
pipeline("The capital of France is [MASK].")
|
||||
```
|
||||
</hfoption>
|
||||
<hfoption id="AutoModel">
|
||||
|
||||
```py
|
||||
import torch
|
||||
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"google/mobilebert-uncased",
|
||||
)
|
||||
model = AutoModelForMaskedLM.from_pretrained(
|
||||
"google/mobilebert-uncased",
|
||||
torch_dtype=torch.float16,
|
||||
device_map="auto",
|
||||
)
|
||||
inputs = tokenizer("The capital of France is [MASK].", return_tensors="pt").to("cuda")
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
predictions = outputs.logits
|
||||
|
||||
masked_index = torch.where(inputs['input_ids'] == tokenizer.mask_token_id)[1]
|
||||
predicted_token_id = predictions[0, masked_index].argmax(dim=-1)
|
||||
predicted_token = tokenizer.decode(predicted_token_id)
|
||||
|
||||
print(f"The predicted token is: {predicted_token}")
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="transformers-cli">
|
||||
|
||||
```bash
|
||||
echo -e "The capital of France is [MASK]." | transformers-cli run --task fill-mask --model google/mobilebert-uncased --device 0
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
|
||||
## Resources
|
||||
## Notes
|
||||
|
||||
- [Text classification task guide](../tasks/sequence_classification)
|
||||
- [Token classification task guide](../tasks/token_classification)
|
||||
- [Question answering task guide](../tasks/question_answering)
|
||||
- [Masked language modeling task guide](../tasks/masked_language_modeling)
|
||||
- [Multiple choice task guide](../tasks/multiple_choice)
|
||||
- Inputs should be padded on the right because BERT uses absolute position embeddings.
|
||||
|
||||
## MobileBertConfig
|
||||
|
||||
|
||||
@ -14,55 +14,79 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# ModernBERT
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<div style="float: right;">
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
</div>
|
||||
|
||||
## Overview
|
||||
# ModernBERT
|
||||
|
||||
The ModernBERT model was proposed in [Smarter, Better, Faster, Longer: A Modern Bidirectional Encoder for Fast, Memory Efficient, and Long Context Finetuning and Inference](https://arxiv.org/abs/2412.13663) by Benjamin Warner, Antoine Chaffin, Benjamin Clavié, Orion Weller, Oskar Hallström, Said Taghadouini, Alexis Galalgher, Raja Bisas, Faisal Ladhak, Tom Aarsen, Nathan Cooper, Grifin Adams, Jeremy Howard and Iacopo Poli.
|
||||
[ModernBERT](https://huggingface.co/papers/2412.13663) is a modernized version of [`BERT`] trained on 2T tokens. It brings many improvements to the original architecture such as rotary positional embeddings to support sequences of up to 8192 tokens, unpadding to avoid wasting compute on padding tokens, GeGLU layers, and alternating attention.
|
||||
|
||||
It is a refresh of the traditional encoder architecture, as used in previous models such as [BERT](https://huggingface.co/docs/transformers/en/model_doc/bert) and [RoBERTa](https://huggingface.co/docs/transformers/en/model_doc/roberta).
|
||||
You can find all the original ModernBERT checkpoints under the [ModernBERT](https://huggingface.co/collections/answerdotai/modernbert-67627ad707a4acbf33c41deb) collection.
|
||||
|
||||
It builds on BERT and implements many modern architectural improvements which have been developed since its original release, such as:
|
||||
- [Rotary Positional Embeddings](https://huggingface.co/blog/designing-positional-encoding) to support sequences of up to 8192 tokens.
|
||||
- [Unpadding](https://arxiv.org/abs/2208.08124) to ensure no compute is wasted on padding tokens, speeding up processing time for batches with mixed-length sequences.
|
||||
- [GeGLU](https://arxiv.org/abs/2002.05202) Replacing the original MLP layers with GeGLU layers, shown to improve performance.
|
||||
- [Alternating Attention](https://arxiv.org/abs/2004.05150v2) where most attention layers employ a sliding window of 128 tokens, with Global Attention only used every 3 layers.
|
||||
- [Flash Attention](https://github.com/Dao-AILab/flash-attention) to speed up processing.
|
||||
- A model designed following recent [The Case for Co-Designing Model Architectures with Hardware](https://arxiv.org/abs/2401.14489), ensuring maximum efficiency across inference GPUs.
|
||||
- Modern training data scales (2 trillion tokens) and mixtures (including code ande math data)
|
||||
> [!TIP]
|
||||
> Click on the ModernBERT models in the right sidebar for more examples of how to apply ModernBERT to different language tasks.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
The example below demonstrates how to predict the `[MASK]` token with [`Pipeline`], [`AutoModel`], and from the command line.
|
||||
|
||||
*Encoder-only transformer models such as BERT offer a great performance-size tradeoff for retrieval and classification tasks with respect to larger decoder-only models. Despite being the workhorse of numerous production pipelines, there have been limited Pareto improvements to BERT since its release. In this paper, we introduce ModernBERT, bringing modern model optimizations to encoder-only models and representing a major Pareto improvement over older encoders. Trained on 2 trillion tokens with a native 8192 sequence length, ModernBERT models exhibit state-of-the-art results on a large pool of evaluations encompassing diverse classification tasks and both single and multi-vector retrieval on different domains (including code). In addition to strong downstream performance, ModernBERT is also the most speed and memory efficient encoder and is designed for inference on common GPUs.*
|
||||
<hfoptions id="usage">
|
||||
<hfoption id="Pipeline">
|
||||
|
||||
The original code can be found [here](https://github.com/answerdotai/modernbert).
|
||||
```py
|
||||
import torch
|
||||
from transformers import pipeline
|
||||
|
||||
## Resources
|
||||
pipeline = pipeline(
|
||||
task="fill-mask",
|
||||
model="answerdotai/ModernBERT-base",
|
||||
torch_dtype=torch.float16,
|
||||
device=0
|
||||
)
|
||||
pipeline("Plants create [MASK] through a process known as photosynthesis.")
|
||||
```
|
||||
|
||||
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with ModernBert.
|
||||
</hfoption>
|
||||
<hfoption id="AutoModel">
|
||||
|
||||
<PipelineTag pipeline="text-classification"/>
|
||||
```py
|
||||
import torch
|
||||
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
||||
|
||||
- A notebook on how to [finetune for General Language Understanding Evaluation (GLUE) with Transformers](https://github.com/AnswerDotAI/ModernBERT/blob/main/examples/finetune_modernbert_on_glue.ipynb), also available as a Google Colab [notebook](https://colab.research.google.com/github/AnswerDotAI/ModernBERT/blob/main/examples/finetune_modernbert_on_glue.ipynb). 🌎
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"answerdotai/ModernBERT-base",
|
||||
)
|
||||
model = AutoModelForMaskedLM.from_pretrained(
|
||||
"answerdotai/ModernBERT-base",
|
||||
torch_dtype=torch.float16,
|
||||
device_map="auto",
|
||||
attn_implementation="sdpa"
|
||||
)
|
||||
inputs = tokenizer("Plants create [MASK] through a process known as photosynthesis.", return_tensors="pt").to("cuda")
|
||||
|
||||
<PipelineTag pipeline="sentence-similarity"/>
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
predictions = outputs.logits
|
||||
|
||||
- A script on how to [finetune for text similarity or information retrieval with Sentence Transformers](https://github.com/AnswerDotAI/ModernBERT/blob/main/examples/train_st.py). 🌎
|
||||
- A script on how to [finetune for information retrieval with PyLate](https://github.com/AnswerDotAI/ModernBERT/blob/main/examples/train_pylate.py). 🌎
|
||||
masked_index = torch.where(inputs['input_ids'] == tokenizer.mask_token_id)[1]
|
||||
predicted_token_id = predictions[0, masked_index].argmax(dim=-1)
|
||||
predicted_token = tokenizer.decode(predicted_token_id)
|
||||
|
||||
<PipelineTag pipeline="fill-mask"/>
|
||||
print(f"The predicted token is: {predicted_token}")
|
||||
```
|
||||
|
||||
- [Masked language modeling task guide](../tasks/masked_language_modeling)
|
||||
</hfoption>
|
||||
<hfoption id="transformers-cli">
|
||||
|
||||
<PipelineTag pipeline="question-answering"/>
|
||||
```bash
|
||||
echo -e "Plants create [MASK] through a process known as photosynthesis." | transformers-cli run --task fill-mask --model answerdotai/ModernBERT-base --device 0
|
||||
```
|
||||
|
||||
- [`ModernBertForQuestionAnswering`] is supported by this [example script](https://github.com/huggingface/transformers/tree/main/examples/pytorch/question-answering) and [colab notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/question_answering.ipynb).
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
## ModernBertConfig
|
||||
|
||||
|
||||
@ -14,154 +14,123 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# OpenAI GPT
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="TensorFlow" src="https://img.shields.io/badge/TensorFlow-FF6F00?style=flat&logo=tensorflow&logoColor=white">
|
||||
<img alt="Flax" src="https://img.shields.io/badge/Flax-29a79b.svg?style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAC0AAAAtCAMAAAANxBKoAAAC7lBMVEUAAADg5vYHPVgAoJH+/v76+v39/f9JbLP///9+AIgAnY3///+mcqzt8fXy9fgkXa3Ax9709fr+///9/f8qXq49qp5AaLGMwrv8/P0eW60VWawxYq8yqJzG2dytt9Wyu9elzci519Lf3O3S2efY3OrY0+Xp7PT///////+dqNCexMc6Z7AGpJeGvbenstPZ5ejQ1OfJzOLa7ejh4+/r8fT29vpccbklWK8PVa0AS6ghW63O498vYa+lsdKz1NDRt9Kw1c672tbD3tnAxt7R6OHp5vDe7OrDyuDn6vLl6/EAQKak0MgATakkppo3ZK/Bz9y8w9yzu9jey97axdvHzeG21NHH4trTwthKZrVGZLSUSpuPQJiGAI+GAI8SWKydycLL4d7f2OTi1+S9xNzL0ePT6OLGzeEAo5U0qJw/aLEAo5JFa7JBabEAp5Y4qZ2QxLyKmsm3kL2xoMOehrRNb7RIbbOZgrGre68AUqwAqZqNN5aKJ5N/lMq+qsd8kMa4pcWzh7muhLMEV69juq2kbKqgUaOTR5uMMZWLLZSGAI5VAIdEAH+ovNDHuNCnxcy3qcaYx8K8msGplrx+wLahjbYdXrV6vbMvYK9DrZ8QrZ8tqJuFms+Sos6sw8ecy8RffsNVeMCvmb43aLltv7Q4Y7EZWK4QWa1gt6meZKUdr6GOAZVeA4xPAISyveLUwtivxtKTpNJ2jcqfvcltiMiwwcfAoMVxhL+Kx7xjdrqTe60tsaNQs6KaRKACrJ6UTZwkqpqTL5pkHY4AloSgsd2ptNXPvNOOncuxxsqFl8lmg8apt8FJcr9EbryGxLqlkrkrY7dRa7ZGZLQ5t6iXUZ6PPpgVpZeJCJFKAIGareTa0+KJod3H0deY2M+esM25usmYu8d2zsJOdcBVvrCLbqcAOaaHaKQAMaScWqKBXqCXMJ2RHpiLF5NmJZAdAHN2kta11dKu1M+DkcZLdb+Mcql3TppyRJdzQ5ZtNZNlIY+DF4+voCOQAAAAZ3RSTlMABAT+MEEJ/RH+/TP+Zlv+pUo6Ifz8+fco/fz6+evr39S9nJmOilQaF/7+/f38+smmoYp6b1T+/v7++vj189zU0tDJxsGzsrKSfv34+Pf27dDOysG9t6+n/vv6+vr59uzr1tG+tZ6Qg9Ym3QAABR5JREFUSMeNlVVUG1EQhpcuxEspXqS0SKEtxQp1d3d332STTRpIQhIISQgJhODu7lAoDoUCpe7u7u7+1puGpqnCPOyZvffbOXPm/PsP9JfQgyCC+tmTABTOcbxDz/heENS7/1F+9nhvkHePG0wNDLbGWwdXL+rbLWvpmZHXD8+gMfBjTh+aSe6Gnn7lwQIOTR0c8wfX3PWgv7avbdKwf/ZoBp1Gp/PvuvXW3vw5ib7emnTW4OR+3D4jB9vjNJ/7gNvfWWeH/TO/JyYrsiKCRjVEZA3UB+96kON+DxOQ/NLE8PE5iUYgIXjFnCOlxEQMaSGVxjg4gxOnEycGz8bptuNjVx08LscIgrzH3umcn+KKtiBIyvzOO2O99aAdR8cF19oZalnCtvREUw79tCd5sow1g1UKM6kXqUx4T8wsi3sTjJ3yzDmmhenLXLpo8u45eG5y4Vvbk6kkC4LLtJMowkSQxmk4ggVJEG+7c6QpHT8vvW9X7/o7+3ELmiJi2mEzZJiz8cT6TBlanBk70cB5GGIGC1gRDdZ00yADLW1FL6gqhtvNXNG5S9gdSrk4M1qu7JAsmYshzDS4peoMrU/gT7qQdqYGZaYhxZmVbGJAm/CS/HloWyhRUlknQ9KYcExTwS80d3VNOxUZJpITYyspl0LbhArhpZCD9cRWEQuhYkNGMHToQ/2Cs6swJlb39CsllxdXX6IUKh/H5jbnSsPKjgmoaFQ1f8wRLR0UnGE/RcDEjj2jXG1WVTwUs8+zxfcrVO+vSsuOpVKxCfYZiQ0/aPKuxQbQ8lIz+DClxC8u+snlcJ7Yr1z1JPqUH0V+GDXbOwAib931Y4Imaq0NTIXPXY+N5L18GJ37SVWu+hwXff8l72Ds9XuwYIBaXPq6Shm4l+Vl/5QiOlV+uTk6YR9PxKsI9xNJny31ygK1e+nIRC1N97EGkFPI+jCpiHe5PCEy7oWqWSwRrpOvhFzcbTWMbm3ZJAOn1rUKpYIt/lDhW/5RHHteeWFN60qo98YJuoq1nK3uW5AabyspC1BcIEpOhft+SZAShYoLSvnmSfnYADUERP5jJn2h5XtsgCRuhYQqAvwTwn33+YWEKUI72HX5AtfSAZDe8F2DtPPm77afhl0EkthzuCQU0BWApgQIH9+KB0JhopMM7bJrdTRoleM2JAVNMyPF+wdoaz+XJpGoVAQ7WXUkcV7gT3oUZyi/ISIJAVKhgNp+4b4veCFhYVJw4locdSjZCp9cPUhLF9EZ3KKzURepMEtCDPP3VcWFx4UIiZIklIpFNfHpdEafIF2aRmOcrUmjohbT2WUllbmRvgfbythbQO3222fpDJoufaQPncYYuqoGtUEsCJZL6/3PR5b4syeSjZMQG/T2maGANlXT2v8S4AULWaUkCxfLyW8iW4kdka+nEMjxpL2NCwsYNBp+Q61PF43zyDg9Bm9+3NNySn78jMZUUkumqE4Gp7JmFOdP1vc8PpRrzj9+wPinCy8K1PiJ4aYbnTYpCCbDkBSbzhu2QJ1Gd82t8jI8TH51+OzvXoWbnXUOBkNW+0mWFwGcGOUVpU81/n3TOHb5oMt2FgYGjzau0Nif0Ss7Q3XB33hjjQHjHA5E5aOyIQc8CBrLdQSs3j92VG+3nNEjbkbdbBr9zm04ruvw37vh0QKOdeGIkckc80fX3KH/h7PT4BOjgCty8VZ5ux1MoO5Cf5naca2LAsEgehI+drX8o/0Nu+W0m6K/I9gGPd/dfx/EN/wN62AhsBWuAAAAAElFTkSuQmCC
|
||||
">
|
||||
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<div style="float: right;">
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="TensorFlow" src="https://img.shields.io/badge/TensorFlow-FF6F00?style=flat&logo=tensorflow&logoColor=white">
|
||||
<img alt="Flax" src="https://img.shields.io/badge/Flax-29a79b.svg?style=flat&logo=data:image/png;base64,...">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
|
||||
</div>
|
||||
</div>
|
||||
|
||||
## Overview
|
||||
|
||||
OpenAI GPT model was proposed in [Improving Language Understanding by Generative Pre-Training](https://s3-us-west-2.amazonaws.com/openai-assets/research-covers/language-unsupervised/language_understanding_paper.pdf)
|
||||
by Alec Radford, Karthik Narasimhan, Tim Salimans and Ilya Sutskever. It's a causal (unidirectional) transformer
|
||||
pre-trained using language modeling on a large corpus with long range dependencies, the Toronto Book Corpus.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
*Natural language understanding comprises a wide range of diverse tasks such as textual entailment, question answering,
|
||||
semantic similarity assessment, and document classification. Although large unlabeled text corpora are abundant,
|
||||
labeled data for learning these specific tasks is scarce, making it challenging for discriminatively trained models to
|
||||
perform adequately. We demonstrate that large gains on these tasks can be realized by generative pretraining of a
|
||||
language model on a diverse corpus of unlabeled text, followed by discriminative fine-tuning on each specific task. In
|
||||
contrast to previous approaches, we make use of task-aware input transformations during fine-tuning to achieve
|
||||
effective transfer while requiring minimal changes to the model architecture. We demonstrate the effectiveness of our
|
||||
approach on a wide range of benchmarks for natural language understanding. Our general task-agnostic model outperforms
|
||||
discriminatively trained models that use architectures specifically crafted for each task, significantly improving upon
|
||||
the state of the art in 9 out of the 12 tasks studied.*
|
||||
|
||||
[Write With Transformer](https://transformer.huggingface.co/doc/gpt) is a webapp created and hosted by Hugging Face
|
||||
showcasing the generative capabilities of several models. GPT is one of them.
|
||||
|
||||
This model was contributed by [thomwolf](https://huggingface.co/thomwolf). The original code can be found [here](https://github.com/openai/finetune-transformer-lm).
|
||||
|
||||
## Usage tips
|
||||
|
||||
- GPT is a model with absolute position embeddings so it's usually advised to pad the inputs on the right rather than
|
||||
the left.
|
||||
- GPT was trained with a causal language modeling (CLM) objective and is therefore powerful at predicting the next
|
||||
token in a sequence. Leveraging this feature allows GPT-2 to generate syntactically coherent text as it can be
|
||||
observed in the *run_generation.py* example script.
|
||||
|
||||
|
||||
Note:
|
||||
# GPT
|
||||
|
||||
If you want to reproduce the original tokenization process of the *OpenAI GPT* paper, you will need to install `ftfy`
|
||||
and `SpaCy`:
|
||||
[GPT (Generative Pre-trained Transformer)](https://cdn.openai.com/research-covers/language-unsupervised/language_understanding_paper.pdf) focuses on effectively learning text representations and transferring them to tasks. This model trains the Transformer decoder to predict the next word, and then fine-tuned on labeled data.
|
||||
|
||||
```bash
|
||||
pip install spacy ftfy==4.4.3
|
||||
python -m spacy download en
|
||||
GPT can generate high-quality text, making it well-suited for a variety of natural language understanding tasks such as textual entailment, question answering, semantic similarity, and document classification.
|
||||
|
||||
You can find all the original GPT checkpoints under the [OpenAI community](https://huggingface.co/openai-community/openai-gpt) organization.
|
||||
|
||||
> [!TIP]
|
||||
> Click on the GPT models in the right sidebar for more examples of how to apply GPT to different language tasks.
|
||||
|
||||
The example below demonstrates how to generate text with [`Pipeline`], [`AutoModel`], and from the command line.
|
||||
|
||||
|
||||
|
||||
<hfoptions id="usage">
|
||||
<hfoption id="Pipeline">
|
||||
|
||||
|
||||
```python
|
||||
import torch
|
||||
from transformers import pipeline
|
||||
|
||||
generator = pipeline(task="text-generation", model="openai-community/gpt", torch_dtype=torch.float16, device=0)
|
||||
output = generator("The future of AI is", max_length=50, do_sample=True)
|
||||
print(output[0]["generated_text"])
|
||||
```
|
||||
|
||||
If you don't install `ftfy` and `SpaCy`, the [`OpenAIGPTTokenizer`] will default to tokenize
|
||||
using BERT's `BasicTokenizer` followed by Byte-Pair Encoding (which should be fine for most usage, don't worry).
|
||||
</hfoption>
|
||||
<hfoption id="AutoModel">
|
||||
|
||||
## Resources
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with OpenAI GPT. If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we'll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource.
|
||||
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt")
|
||||
model = AutoModelForCausalLM.from_pretrained("openai-community/openai-gpt", torch_dtype=torch.float16)
|
||||
|
||||
<PipelineTag pipeline="text-classification"/>
|
||||
inputs = tokenizer("The future of AI is", return_tensors="pt")
|
||||
outputs = model.generate(**inputs, max_length=50)
|
||||
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
||||
```
|
||||
|
||||
- A blog post on [outperforming OpenAI GPT-3 with SetFit for text-classification](https://www.philschmid.de/getting-started-setfit).
|
||||
- See also: [Text classification task guide](../tasks/sequence_classification)
|
||||
</hfoption>
|
||||
<hfoption id="transformers-cli">
|
||||
|
||||
<PipelineTag pipeline="text-generation"/>
|
||||
```bash
|
||||
echo -e "The future of AI is" | transformers-cli run --task text-generation --model openai-community/openai-gpt --device 0
|
||||
|
||||
- A blog on how to [Finetune a non-English GPT-2 Model with Hugging Face](https://www.philschmid.de/fine-tune-a-non-english-gpt-2-model-with-huggingface).
|
||||
- A blog on [How to generate text: using different decoding methods for language generation with Transformers](https://huggingface.co/blog/how-to-generate) with GPT-2.
|
||||
- A blog on [Training CodeParrot 🦜 from Scratch](https://huggingface.co/blog/codeparrot), a large GPT-2 model.
|
||||
- A blog on [Faster Text Generation with TensorFlow and XLA](https://huggingface.co/blog/tf-xla-generate) with GPT-2.
|
||||
- A blog on [How to train a Language Model with Megatron-LM](https://huggingface.co/blog/megatron-training) with a GPT-2 model.
|
||||
- A notebook on how to [finetune GPT2 to generate lyrics in the style of your favorite artist](https://colab.research.google.com/github/AlekseyKorshuk/huggingartists/blob/master/huggingartists-demo.ipynb). 🌎
|
||||
- A notebook on how to [finetune GPT2 to generate tweets in the style of your favorite Twitter user](https://colab.research.google.com/github/borisdayma/huggingtweets/blob/master/huggingtweets-demo.ipynb). 🌎
|
||||
- [Causal language modeling](https://huggingface.co/course/en/chapter7/6?fw=pt#training-a-causal-language-model-from-scratch) chapter of the 🤗 Hugging Face Course.
|
||||
- [`OpenAIGPTLMHeadModel`] is supported by this [causal language modeling example script](https://github.com/huggingface/transformers/tree/main/examples/pytorch/language-modeling#gpt-2gpt-and-causal-language-modeling), [text generation example script](https://github.com/huggingface/transformers/blob/main/examples/pytorch/text-generation/run_generation.py) and [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/language_modeling.ipynb).
|
||||
- [`TFOpenAIGPTLMHeadModel`] is supported by this [causal language modeling example script](https://github.com/huggingface/transformers/tree/main/examples/tensorflow/language-modeling#run_clmpy) and [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/language_modeling-tf.ipynb).
|
||||
- See also: [Causal language modeling task guide](../tasks/language_modeling)
|
||||
```
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
<PipelineTag pipeline="token-classification"/>
|
||||
## Notes
|
||||
|
||||
- A course material on [Byte-Pair Encoding tokenization](https://huggingface.co/course/en/chapter6/5).
|
||||
- Inputs should be padded on the right because GPT uses absolute position embeddings.
|
||||
|
||||
## OpenAIGPTConfig
|
||||
|
||||
[[autodoc]] OpenAIGPTConfig
|
||||
|
||||
## OpenAIGPTModel
|
||||
|
||||
[[autodoc]] OpenAIGPTModel
|
||||
- forward
|
||||
|
||||
## OpenAIGPTLMHeadModel
|
||||
|
||||
[[autodoc]] OpenAIGPTLMHeadModel
|
||||
- forward
|
||||
|
||||
## OpenAIGPTDoubleHeadsModel
|
||||
|
||||
[[autodoc]] OpenAIGPTDoubleHeadsModel
|
||||
- forward
|
||||
|
||||
## OpenAIGPTForSequenceClassification
|
||||
|
||||
[[autodoc]] OpenAIGPTForSequenceClassification
|
||||
- forward
|
||||
|
||||
## OpenAIGPTTokenizer
|
||||
|
||||
[[autodoc]] OpenAIGPTTokenizer
|
||||
- save_vocabulary
|
||||
|
||||
## OpenAIGPTTokenizerFast
|
||||
|
||||
[[autodoc]] OpenAIGPTTokenizerFast
|
||||
|
||||
## OpenAI specific outputs
|
||||
|
||||
[[autodoc]] models.openai.modeling_openai.OpenAIGPTDoubleHeadsModelOutput
|
||||
|
||||
[[autodoc]] models.openai.modeling_tf_openai.TFOpenAIGPTDoubleHeadsModelOutput
|
||||
|
||||
<frameworkcontent>
|
||||
<pt>
|
||||
|
||||
## OpenAIGPTModel
|
||||
|
||||
[[autodoc]] OpenAIGPTModel
|
||||
- forward
|
||||
|
||||
## OpenAIGPTLMHeadModel
|
||||
|
||||
[[autodoc]] OpenAIGPTLMHeadModel
|
||||
- forward
|
||||
|
||||
## OpenAIGPTDoubleHeadsModel
|
||||
|
||||
[[autodoc]] OpenAIGPTDoubleHeadsModel
|
||||
- forward
|
||||
|
||||
## OpenAIGPTForSequenceClassification
|
||||
|
||||
[[autodoc]] OpenAIGPTForSequenceClassification
|
||||
- forward
|
||||
|
||||
</pt>
|
||||
<tf>
|
||||
|
||||
## TFOpenAIGPTModel
|
||||
|
||||
[[autodoc]] TFOpenAIGPTModel
|
||||
- call
|
||||
- call
|
||||
|
||||
## TFOpenAIGPTLMHeadModel
|
||||
|
||||
[[autodoc]] TFOpenAIGPTLMHeadModel
|
||||
- call
|
||||
- call
|
||||
|
||||
## TFOpenAIGPTDoubleHeadsModel
|
||||
|
||||
[[autodoc]] TFOpenAIGPTDoubleHeadsModel
|
||||
- call
|
||||
- call
|
||||
|
||||
## TFOpenAIGPTForSequenceClassification
|
||||
|
||||
[[autodoc]] TFOpenAIGPTForSequenceClassification
|
||||
- call
|
||||
|
||||
</tf>
|
||||
</frameworkcontent>
|
||||
- call
|
||||
|
||||
@ -30,14 +30,8 @@ found [here](https://github.com/huggingface/transformers/blob/main/src/transform
|
||||
In the following, we demonstrate how to use it for inference depending on the input modalities (text, image, audio).
|
||||
|
||||
```python
|
||||
import requests
|
||||
import torch
|
||||
import os
|
||||
import io
|
||||
from PIL import Image
|
||||
import soundfile as sf
|
||||
from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
|
||||
from urllib.request import urlopen
|
||||
|
||||
|
||||
# Define model path
|
||||
@ -52,21 +46,25 @@ model = AutoModelForCausalLM.from_pretrained(model_path, device_map=device, tor
|
||||
model.load_adapter(model_path, adapter_name="speech", device_map=device, adapter_kwargs={"subfolder": 'speech-lora'})
|
||||
model.load_adapter(model_path, adapter_name="vision", device_map=device, adapter_kwargs={"subfolder": 'vision-lora'})
|
||||
|
||||
# Define prompt structure
|
||||
user_prompt = '<|user|>'
|
||||
assistant_prompt = '<|assistant|>'
|
||||
prompt_suffix = '<|end|>'
|
||||
# Part : Image Processing
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"},
|
||||
{"type": "text", "text": "What is shown in this image?"},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
# Part 1: Image Processing
|
||||
model.set_adapter("vision") # if loaded, activate the vision adapter
|
||||
print("\n--- IMAGE PROCESSING ---")
|
||||
image_url = 'https://www.ilankelman.org/stopsigns/australia.jpg'
|
||||
prompt = f'{user_prompt}<|image_1|>What is shown in this image?{prompt_suffix}{assistant_prompt}'
|
||||
print(f'>>> Prompt\n{prompt}')
|
||||
|
||||
# Download and open image
|
||||
image = Image.open(requests.get(image_url, stream=True).raw)
|
||||
inputs = processor(text=prompt, images=image, return_tensors='pt').to(device)
|
||||
inputs = processor.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
).to(device, torch.float16)
|
||||
|
||||
# Generate response
|
||||
generate_ids = model.generate(
|
||||
@ -80,19 +78,28 @@ response = processor.batch_decode(
|
||||
)[0]
|
||||
print(f'>>> Response\n{response}')
|
||||
|
||||
|
||||
# Part 2: Audio Processing
|
||||
model.set_adapter("speech") # if loaded, activate the speech adapter
|
||||
print("\n--- AUDIO PROCESSING ---")
|
||||
audio_url = "https://upload.wikimedia.org/wikipedia/commons/b/b0/Barbara_Sahakian_BBC_Radio4_The_Life_Scientific_29_May_2012_b01j5j24.flac"
|
||||
speech_prompt = "Transcribe the audio to text, and then translate the audio to French. Use <sep> as a separator between the original transcript and the translation."
|
||||
prompt = f'{user_prompt}<|audio_1|>{speech_prompt}{prompt_suffix}{assistant_prompt}'
|
||||
print(f'>>> Prompt\n{prompt}')
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "audio", "url": audio_url},
|
||||
{"type": "text", "text": "Transcribe the audio to text, and then translate the audio to French. Use <sep> as a separator between the origina transcript and the translation."},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
# Downlowd and open audio file
|
||||
audio, sample_rate = sf.read(io.BytesIO(urlopen(audio_url).read()))
|
||||
|
||||
# Process with the model
|
||||
inputs = processor(text=prompt, audios=audio, sample_rate=sample_rate, return_tensors='pt').to(device)
|
||||
inputs = processor.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
sample_rate=sample_rate,
|
||||
).to(device, torch.float16)
|
||||
|
||||
generate_ids = model.generate(
|
||||
**inputs,
|
||||
|
||||
@ -14,51 +14,137 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# Qwen2
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<div style="float: right;">
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
</div>
|
||||
|
||||
## Overview
|
||||
# Qwen2
|
||||
|
||||
Qwen2 is the new model series of large language models from the Qwen team. Previously, we released the Qwen series, including Qwen2-0.5B, Qwen2-1.5B, Qwen2-7B, Qwen2-57B-A14B, Qwen2-72B, Qwen2-Audio, etc.
|
||||
[Qwen2](https://huggingface.co/papers/2407.10671) is a family of large language models (pretrained, instruction-tuned and mixture-of-experts) available in sizes from 0.5B to 72B parameters. The models are built on the Transformer architecture featuring enhancements like group query attention (GQA), rotary positional embeddings (RoPE), a mix of sliding window and full attention, and dual chunk attention with YARN for training stability. Qwen2 models support multiple languages and context lengths up to 131,072 tokens.
|
||||
|
||||
### Model Details
|
||||
You can find all the official Qwen2 checkpoints under the [Qwen2](https://huggingface.co/collections/Qwen/qwen2-6659360b33528ced941e557f) collection.
|
||||
|
||||
Qwen2 is a language model series including decoder language models of different model sizes. For each size, we release the base language model and the aligned chat model. It is based on the Transformer architecture with SwiGLU activation, attention QKV bias, group query attention, mixture of sliding window attention and full attention, etc. Additionally, we have an improved tokenizer adaptive to multiple natural languages and codes.
|
||||
> [!TIP]
|
||||
> Click on the Qwen2 models in the right sidebar for more examples of how to apply Qwen2 to different language tasks.
|
||||
|
||||
The example below demonstrates how to generate text with [`Pipeline`], [`AutoModel`], and from the command line using the instruction-tuned models.
|
||||
|
||||
## Usage tips
|
||||
|
||||
`Qwen2-7B` and `Qwen2-7B-Instruct` can be found on the [Huggingface Hub](https://huggingface.co/Qwen)
|
||||
|
||||
In the following, we demonstrate how to use `Qwen2-7B-Instruct` for the inference. Note that we have used the ChatML format for dialog, in this demo we show how to leverage `apply_chat_template` for this purpose.
|
||||
<hfoptions id="usage">
|
||||
<hfoption id="Pipeline">
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
>>> device = "cuda" # the device to load the model onto
|
||||
import torch
|
||||
from transformers import pipeline
|
||||
|
||||
>>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-7B-Instruct", device_map="auto")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-7B-Instruct")
|
||||
pipe = pipeline(
|
||||
task="text-generation",
|
||||
model="Qwen/Qwen2-1.5B-Instruct",
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map=0
|
||||
)
|
||||
|
||||
>>> prompt = "Give me a short introduction to large language model."
|
||||
|
||||
>>> messages = [{"role": "user", "content": prompt}]
|
||||
|
||||
>>> text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||
|
||||
>>> model_inputs = tokenizer([text], return_tensors="pt").to(device)
|
||||
|
||||
>>> generated_ids = model.generate(model_inputs.input_ids, max_new_tokens=512, do_sample=True)
|
||||
|
||||
>>> generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)]
|
||||
|
||||
>>> response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Tell me about the Qwen2 model family."},
|
||||
]
|
||||
outputs = pipe(messages, max_new_tokens=256, do_sample=True, temperature=0.7, top_k=50, top_p=0.95)
|
||||
print(outputs[0]["generated_text"][-1]['content'])
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="AutoModel">
|
||||
|
||||
```python
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"Qwen/Qwen2-1.5B-Instruct",
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
attn_implementation="sdpa"
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-1.5B-Instruct")
|
||||
|
||||
prompt = "Give me a short introduction to large language models."
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
text = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True
|
||||
)
|
||||
model_inputs = tokenizer([text], return_tensors="pt").to("cuda")
|
||||
|
||||
generated_ids = model.generate(
|
||||
model_inputs.input_ids,
|
||||
cache_implementation="static",
|
||||
max_new_tokens=512,
|
||||
do_sample=True,
|
||||
temperature=0.7,
|
||||
top_k=50,
|
||||
top_p=0.95
|
||||
)
|
||||
generated_ids = [
|
||||
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
|
||||
]
|
||||
|
||||
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||
print(response)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="transformers-cli">
|
||||
|
||||
```bash
|
||||
# pip install -U flash-attn --no-build-isolation
|
||||
transformers-cli chat --model_name_or_path Qwen/Qwen2-7B-Instruct --torch_dtype auto --attn_implementation flash_attention_2 --device 0
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
Quantization reduces the memory burden of large models by representing the weights in a lower precision. Refer to the [Quantization](../quantization/overview) overview for more available quantization backends.
|
||||
|
||||
The example below uses [bitsandbytes](../quantization/bitsandbytes) to quantize the weights to 4-bits.
|
||||
|
||||
```python
|
||||
# pip install -U flash-attn --no-build-isolation
|
||||
import torch
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
|
||||
|
||||
quantization_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_compute_dtype=torch.bfloat16,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
bnb_4bit_use_double_quant=True,
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-7B")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"Qwen/Qwen2-7B",
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
quantization_config=quantization_config,
|
||||
attn_implementation="flash_attention_2"
|
||||
)
|
||||
|
||||
inputs = tokenizer("The Qwen2 model family is", return_tensors="pt").to("cuda")
|
||||
outputs = model.generate(**inputs, max_new_tokens=100)
|
||||
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
||||
```
|
||||
|
||||
|
||||
## Notes
|
||||
|
||||
- Ensure your Transformers library version is up-to-date. Qwen2 requires Transformers>=4.37.0 for full support.
|
||||
|
||||
## Qwen2Config
|
||||
|
||||
[[autodoc]] Qwen2Config
|
||||
|
||||
@ -14,45 +14,74 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# Qwen2.5-VL
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<div style="float: right;">
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white"> </div>
|
||||
</div>
|
||||
|
||||
## Overview
|
||||
# Qwen2.5-VL
|
||||
|
||||
The [Qwen2.5-VL](https://qwenlm.github.io/blog/qwen2_5-vl/) model is an update to [Qwen2-VL](https://arxiv.org/abs/2409.12191) from Qwen team, Alibaba Group.
|
||||
[Qwen2.5-VL](https://huggingface.co/papers/2502.13923) is a multimodal vision-language model, available in 3B, 7B, and 72B parameters, pretrained on 4.1T tokens. The model introduces window attention in the ViT encoder to accelerate training and inference, dynamic FPS sampling on the spatial and temporal dimensions for better video understanding across different sampling rates, and an upgraded MRoPE (multi-resolutional rotary positional encoding) mechanism to better capture and learn temporal dynamics.
|
||||
|
||||
The abstract from this update is the following:
|
||||
|
||||
*Qwen2.5-VL marks a major step forward from Qwen2-VL, built upon the latest Qwen2.5 LLM. We've accelerated training and testing through the strategic implementation of window attention within the ViT. The ViT architecture itself has been refined with SwiGLU and RMSNorm, aligning it more closely with the LLM's structure. A key innovation is the expansion of native dynamic resolution to encompass the temporal dimension, in addition to spatial aspects. Furthermore, we've upgraded MRoPE, incorporating absolute time alignment on the time axis to allow the model to effectively capture temporal dynamics, regardless of frame rate, leading to superior video understanding.*
|
||||
You can find all the original Qwen2.5-VL checkpoints under the [Qwen2.5-VL](https://huggingface.co/collections/Qwen/qwen25-vl-6795ffac22b334a837c0f9a5) collection.
|
||||
|
||||
## Usage example
|
||||
> [!TIP]
|
||||
> Click on the Qwen2.5-VL models in the right sidebar for more examples of how to apply Qwen2.5-VL to different vision and language tasks.
|
||||
|
||||
### Single Media inference
|
||||
The example below demonstrates how to generate text based on an image with [`Pipeline`] or the [`AutoModel`] class.
|
||||
|
||||
The model can accept both images and videos as input. Here's an example code for inference.
|
||||
|
||||
```python
|
||||
<hfoptions id="usage">
|
||||
<hfoption id="Pipeline">
|
||||
|
||||
```py
|
||||
import torch
|
||||
from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor
|
||||
from transformers import pipeline
|
||||
pipe = pipeline(
|
||||
task="image-text-to-text",
|
||||
model="Qwen/Qwen2.5-VL-7B-Instruct",
|
||||
device=0,
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
"url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg",
|
||||
},
|
||||
{ "type": "text", "text": "Describe this image."},
|
||||
]
|
||||
}
|
||||
]
|
||||
pipe(text=messages,max_new_tokens=20, return_full_text=False)
|
||||
|
||||
# Load the model in half-precision on the available device(s)
|
||||
model = Qwen2_5_VLForConditionalGeneration.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct", device_map="auto")
|
||||
```
|
||||
</hfoption>
|
||||
|
||||
<hfoption id="AutoModel">
|
||||
|
||||
```py
|
||||
import torch
|
||||
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
|
||||
|
||||
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||
"Qwen/Qwen2.5-VL-7B-Instruct",
|
||||
torch_dtype=torch.float16,
|
||||
device_map="auto",
|
||||
attn_implementation="sdpa"
|
||||
)
|
||||
processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
|
||||
|
||||
|
||||
conversation = [
|
||||
messages = [
|
||||
{
|
||||
"role":"user",
|
||||
"content":[
|
||||
{
|
||||
"type":"image",
|
||||
"url": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg"
|
||||
"url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
|
||||
},
|
||||
{
|
||||
"type":"text",
|
||||
@ -60,212 +89,145 @@ conversation = [
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
]
|
||||
|
||||
inputs = processor.apply_chat_template(
|
||||
conversation,
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt"
|
||||
).to(model.device)
|
||||
).to("cuda")
|
||||
|
||||
|
||||
# Inference: Generation of the output
|
||||
output_ids = model.generate(**inputs, max_new_tokens=128)
|
||||
generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, output_ids)]
|
||||
output_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
||||
print(output_text)
|
||||
|
||||
# Video
|
||||
conversation = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "video", "path": "/path/to/video.mp4"},
|
||||
{"type": "text", "text": "What happened in the video?"},
|
||||
],
|
||||
}
|
||||
generated_ids = model.generate(**inputs, max_new_tokens=128)
|
||||
generated_ids_trimmed = [
|
||||
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
||||
]
|
||||
|
||||
inputs = processor.apply_chat_template(
|
||||
conversation,
|
||||
video_fps=1,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt"
|
||||
).to(model.device)
|
||||
|
||||
# Inference: Generation of the output
|
||||
output_ids = model.generate(**inputs, max_new_tokens=128)
|
||||
generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, output_ids)]
|
||||
output_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
||||
print(output_text)
|
||||
```
|
||||
|
||||
### Batch Mixed Media Inference
|
||||
|
||||
The model can batch inputs composed of mixed samples of various types such as images, videos, and text. Here is an example.
|
||||
|
||||
```python
|
||||
# Conversation for the first image
|
||||
conversation1 = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image", "path": "/path/to/image1.jpg"},
|
||||
{"type": "text", "text": "Describe this image."}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
# Conversation with two images
|
||||
conversation2 = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image", "path": "/path/to/image2.jpg"},
|
||||
{"type": "image", "path": "/path/to/image3.jpg"},
|
||||
{"type": "text", "text": "What is written in the pictures?"}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
# Conversation with pure text
|
||||
conversation3 = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "who are you?"
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
# Conversation with mixed midia
|
||||
conversation4 = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image", "path": "/path/to/image3.jpg"},
|
||||
{"type": "image", "path": "/path/to/image4.jpg"},
|
||||
{"type": "video", "path": "/path/to/video.jpg"},
|
||||
{"type": "text", "text": "What are the common elements in these medias?"},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
conversations = [conversation1, conversation2, conversation3, conversation4]
|
||||
# Preparation for batch inference
|
||||
ipnuts = processor.apply_chat_template(
|
||||
conversations,
|
||||
video_fps=1,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt"
|
||||
).to(model.device)
|
||||
|
||||
|
||||
# Batch Inference
|
||||
output_ids = model.generate(**inputs, max_new_tokens=128)
|
||||
generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, output_ids)]
|
||||
output_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
||||
print(output_text)
|
||||
```
|
||||
|
||||
### Usage Tips
|
||||
|
||||
#### Image Resolution trade-off
|
||||
|
||||
The model supports a wide range of resolution inputs. By default, it uses the native resolution for input, but higher resolutions can enhance performance at the cost of more computation. Users can set the minimum and maximum number of pixels to achieve an optimal configuration for their needs.
|
||||
|
||||
```python
|
||||
min_pixels = 224*224
|
||||
max_pixels = 2048*2048
|
||||
processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels)
|
||||
```
|
||||
|
||||
In case of limited GPU RAM, one can reduce the resolution as follows:
|
||||
|
||||
```python
|
||||
min_pixels = 256*28*28
|
||||
max_pixels = 1024*28*28
|
||||
processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels)
|
||||
```
|
||||
This ensures each image gets encoded using a number between 256-1024 tokens. The 28 comes from the fact that the model uses a patch size of 14 and a temporal patch size of 2 (14 x 2 = 28).
|
||||
|
||||
#### Multiple Image Inputs
|
||||
|
||||
By default, images and video content are directly included in the conversation. When handling multiple images, it's helpful to add labels to the images and videos for better reference. Users can control this behavior with the following settings:
|
||||
|
||||
```python
|
||||
conversation = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image"},
|
||||
{"type": "text", "text": "Hello, how are you?"}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "I'm doing well, thank you for asking. How can I assist you today?"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "Can you describe these images and video?"},
|
||||
{"type": "image"},
|
||||
{"type": "image"},
|
||||
{"type": "video"},
|
||||
{"type": "text", "text": "These are from my vacation."}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "I'd be happy to describe the images and video for you. Could you please provide more context about your vacation?"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "It was a trip to the mountains. Can you see the details in the images and video?"
|
||||
}
|
||||
]
|
||||
|
||||
# default:
|
||||
prompt_without_id = processor.apply_chat_template(conversation, add_generation_prompt=True)
|
||||
# Excepted output: '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Hello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing well, thank you for asking. How can I assist you today?<|im_end|>\n<|im_start|>user\nCan you describe these images and video?<|vision_start|><|image_pad|><|vision_end|><|vision_start|><|image_pad|><|vision_end|><|vision_start|><|video_pad|><|vision_end|>These are from my vacation.<|im_end|>\n<|im_start|>assistant\nI'd be happy to describe the images and video for you. Could you please provide more context about your vacation?<|im_end|>\n<|im_start|>user\nIt was a trip to the mountains. Can you see the details in the images and video?<|im_end|>\n<|im_start|>assistant\n'
|
||||
|
||||
|
||||
# add ids
|
||||
prompt_with_id = processor.apply_chat_template(conversation, add_generation_prompt=True, add_vision_id=True)
|
||||
# Excepted output: '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nPicture 1: <|vision_start|><|image_pad|><|vision_end|>Hello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing well, thank you for asking. How can I assist you today?<|im_end|>\n<|im_start|>user\nCan you describe these images and video?Picture 2: <|vision_start|><|image_pad|><|vision_end|>Picture 3: <|vision_start|><|image_pad|><|vision_end|>Video 1: <|vision_start|><|video_pad|><|vision_end|>These are from my vacation.<|im_end|>\n<|im_start|>assistant\nI'd be happy to describe the images and video for you. Could you please provide more context about your vacation?<|im_end|>\n<|im_start|>user\nIt was a trip to the mountains. Can you see the details in the images and video?<|im_end|>\n<|im_start|>assistant\n'
|
||||
|
||||
```
|
||||
|
||||
#### Flash-Attention 2 to speed up generation
|
||||
|
||||
First, make sure to install the latest version of Flash Attention 2:
|
||||
|
||||
```bash
|
||||
pip install -U flash-attn --no-build-isolation
|
||||
```
|
||||
|
||||
Also, you should have hardware that is compatible with FlashAttention 2. Read more about it in the official documentation of the [flash attention repository](https://github.com/Dao-AILab/flash-attention). FlashAttention-2 can only be used when a model is loaded in `torch.float16` or `torch.bfloat16`.
|
||||
|
||||
To load and run a model using FlashAttention-2, add `attn_implementation="flash_attention_2"` when loading the model:
|
||||
|
||||
```python
|
||||
from transformers import Qwen2_5_VLForConditionalGeneration
|
||||
|
||||
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||
"Qwen/Qwen2.5-VL-7B-Instruct",
|
||||
torch_dtype=torch.bfloat16,
|
||||
attn_implementation="flash_attention_2",
|
||||
output_text = processor.batch_decode(
|
||||
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||||
)
|
||||
print(output_text)
|
||||
```
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
Quantization reduces the memory burden of large models by representing the weights in a lower precision. Refer to the [Quantization](../quantization/overview) overview for more available quantization backends.
|
||||
|
||||
The example below uses [torchao](../quantization/torchao) to only quantize the weights to int4.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from transformers import TorchAoConfig, Gemma3ForConditionalGeneration, AutoProcessor
|
||||
|
||||
quantization_config = TorchAoConfig("int4_weight_only", group_size=128)
|
||||
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||
"Qwen/Qwen2.5-VL-7B-Instruct",
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
quantization_config=quantization_config
|
||||
)
|
||||
|
||||
```
|
||||
### Notes
|
||||
|
||||
- Use Qwen2.5-VL for video inputs by setting `"type": "video"` as shown below.
|
||||
```python
|
||||
conversation = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "video", "path": "/path/to/video.mp4"},
|
||||
{"type": "text", "text": "What happened in the video?"},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
inputs = processor.apply_chat_template(
|
||||
conversation,
|
||||
video_fps=1,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt"
|
||||
).to(model.device)
|
||||
|
||||
# Inference: Generation of the output
|
||||
output_ids = model.generate(**inputs, max_new_tokens=128)
|
||||
generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, output_ids)]
|
||||
output_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
||||
print(output_text)
|
||||
```
|
||||
- Use Qwen2.5-VL for a mixed batch of inputs (images, videos, text). Add labels when handling multiple images or videos for better reference
|
||||
as show below.
|
||||
```python
|
||||
import torch
|
||||
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
|
||||
|
||||
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||
"Qwen/Qwen2.5-VL-7B-Instruct",
|
||||
torch_dtype=torch.float16,
|
||||
device_map="auto",
|
||||
attn_implementation="sdpa"
|
||||
)
|
||||
processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
|
||||
conversation = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image"},
|
||||
{"type": "text", "text": "Hello, how are you?"}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "I'm doing well, thank you for asking. How can I assist you today?"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "Can you describe these images and video?"},
|
||||
{"type": "image"},
|
||||
{"type": "image"},
|
||||
{"type": "video"},
|
||||
{"type": "text", "text": "These are from my vacation."}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "I'd be happy to describe the images and video for you. Could you please provide more context about your vacation?"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "It was a trip to the mountains. Can you see the details in the images and video?"
|
||||
}
|
||||
]
|
||||
|
||||
# default:
|
||||
prompt_without_id = processor.apply_chat_template(conversation, add_generation_prompt=True)
|
||||
# Excepted output: '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Hello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing well, thank you for asking. How can I assist you today?<|im_end|>\n<|im_start|>user\nCan you describe these images and video?<|vision_start|><|image_pad|><|vision_end|><|vision_start|><|image_pad|><|vision_end|><|vision_start|><|video_pad|><|vision_end|>These are from my vacation.<|im_end|>\n<|im_start|>assistant\nI'd be happy to describe the images and video for you. Could you please provide more context about your vacation?<|im_end|>\n<|im_start|>user\nIt was a trip to the mountains. Can you see the details in the images and video?<|im_end|>\n<|im_start|>assistant\n'
|
||||
|
||||
|
||||
# add ids
|
||||
prompt_with_id = processor.apply_chat_template(conversation, add_generation_prompt=True, add_vision_id=True)
|
||||
# Excepted output: '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nPicture 1: <|vision_start|><|image_pad|><|vision_end|>Hello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing well, thank you for asking. How can I assist you today?<|im_end|>\n<|im_start|>user\nCan you describe these images and video?Picture 2: <|vision_start|><|image_pad|><|vision_end|>Picture 3: <|vision_start|><|image_pad|><|vision_end|>Video 1: <|vision_start|><|video_pad|><|vision_end|>These are from my vacation.<|im_end|>\n<|im_start|>assistant\nI'd be happy to describe the images and video for you. Could you please provide more context about your vacation?<|im_end|>\n<|im_start|>user\nIt was a trip to the mountains. Can you see the details in the images and video?<|im_end|>\n<|im_start|>assistant\n'
|
||||
```
|
||||
|
||||
- Use the `min_pixels` and `max_pixels` parameters in [`AutoProcessor`] to set the resolution.
|
||||
|
||||
```python
|
||||
min_pixels = 224*224
|
||||
max_pixels = 2048*2048
|
||||
processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels)
|
||||
```
|
||||
|
||||
Higher resolution can require more compute whereas reducing the resolution can save memory as follows:
|
||||
|
||||
```python
|
||||
min_pixels = 256*28*28
|
||||
max_pixels = 1024*28*28
|
||||
processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels)
|
||||
```
|
||||
## Qwen2_5_VLConfig
|
||||
|
||||
[[autodoc]] Qwen2_5_VLConfig
|
||||
|
||||
59
docs/source/en/model_doc/qwen3.md
Normal file
59
docs/source/en/model_doc/qwen3.md
Normal file
@ -0,0 +1,59 @@
|
||||
<!--Copyright 2024 The Qwen Team and The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# Qwen3
|
||||
|
||||
## Overview
|
||||
|
||||
To be released with the official model launch.
|
||||
|
||||
### Model Details
|
||||
|
||||
To be released with the official model launch.
|
||||
|
||||
|
||||
## Usage tips
|
||||
|
||||
To be released with the official model launch.
|
||||
|
||||
## Qwen3Config
|
||||
|
||||
[[autodoc]] Qwen3Config
|
||||
|
||||
## Qwen3Model
|
||||
|
||||
[[autodoc]] Qwen3Model
|
||||
- forward
|
||||
|
||||
## Qwen3ForCausalLM
|
||||
|
||||
[[autodoc]] Qwen3ForCausalLM
|
||||
- forward
|
||||
|
||||
## Qwen3ForSequenceClassification
|
||||
|
||||
[[autodoc]] Qwen3ForSequenceClassification
|
||||
- forward
|
||||
|
||||
## Qwen3ForTokenClassification
|
||||
|
||||
[[autodoc]] Qwen3ForTokenClassification
|
||||
- forward
|
||||
|
||||
## Qwen3ForQuestionAnswering
|
||||
|
||||
[[autodoc]] Qwen3ForQuestionAnswering
|
||||
- forward
|
||||
58
docs/source/en/model_doc/qwen3_moe.md
Normal file
58
docs/source/en/model_doc/qwen3_moe.md
Normal file
@ -0,0 +1,58 @@
|
||||
<!--Copyright 2024 The Qwen Team and The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# Qwen3MoE
|
||||
|
||||
## Overview
|
||||
|
||||
To be released with the official model launch.
|
||||
|
||||
### Model Details
|
||||
|
||||
To be released with the official model launch.
|
||||
|
||||
## Usage tips
|
||||
|
||||
To be released with the official model launch.
|
||||
|
||||
## Qwen3MoeConfig
|
||||
|
||||
[[autodoc]] Qwen3MoeConfig
|
||||
|
||||
## Qwen3MoeModel
|
||||
|
||||
[[autodoc]] Qwen3MoeModel
|
||||
- forward
|
||||
|
||||
## Qwen3MoeForCausalLM
|
||||
|
||||
[[autodoc]] Qwen3MoeForCausalLM
|
||||
- forward
|
||||
|
||||
## Qwen3MoeForSequenceClassification
|
||||
|
||||
[[autodoc]] Qwen3MoeForSequenceClassification
|
||||
- forward
|
||||
|
||||
## Qwen3MoeForTokenClassification
|
||||
|
||||
[[autodoc]] Qwen3MoeForTokenClassification
|
||||
- forward
|
||||
|
||||
## Qwen3MoeForQuestionAnswering
|
||||
|
||||
[[autodoc]] Qwen3MoeForQuestionAnswering
|
||||
- forward
|
||||
@ -149,12 +149,24 @@ alt="drawing" width="900"/>
|
||||
[[autodoc]] SamImageProcessor
|
||||
|
||||
|
||||
## SamVisionModel
|
||||
|
||||
[[autodoc]] SamVisionModel
|
||||
- forward
|
||||
|
||||
|
||||
## SamModel
|
||||
|
||||
[[autodoc]] SamModel
|
||||
- forward
|
||||
|
||||
|
||||
## TFSamVisionModel
|
||||
|
||||
[[autodoc]] TFSamVisionModel
|
||||
- call
|
||||
|
||||
|
||||
## TFSamModel
|
||||
|
||||
[[autodoc]] TFSamModel
|
||||
|
||||
@ -19,7 +19,7 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
## Overview
|
||||
|
||||
The ShieldGemma 2 model was proposed in a forthcoming technical report by Google. ShieldGemma 2 is built on [Gemma 3](https://ai.google.dev/gemma/docs/core/model_card_3), is a 4 billion (4B) parameter model that checks the safety of both synthetic and natural images against key categories to help you build robust datasets and models. With this addition to the Gemma family of models, researchers and developers can now easily minimize the risk of harmful content in their models across key areas of harm as defined below:
|
||||
The ShieldGemma 2 model was proposed in a [technical report](https://arxiv.org/abs/2504.01081) by Google. ShieldGemma 2, built on [Gemma 3](https://ai.google.dev/gemma/docs/core/model_card_3), is a 4 billion (4B) parameter model that checks the safety of both synthetic and natural images against key categories to help you build robust datasets and models. With this addition to the Gemma family of models, researchers and developers can now easily minimize the risk of harmful content in their models across key areas of harm as defined below:
|
||||
|
||||
- No Sexually Explicit content: The image shall not contain content that depicts explicit or graphic sexual acts (e.g., pornography, erotic nudity, depictions of rape or sexual assault).
|
||||
- No Dangerous Content: The image shall not contain content that facilitates or encourages activities that could cause real-world harm (e.g., building firearms and explosive devices, promotion of terrorism, instructions for suicide).
|
||||
|
||||
@ -14,349 +14,104 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# T5
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="TensorFlow" src="https://img.shields.io/badge/TensorFlow-FF6F00?style=flat&logo=tensorflow&logoColor=white">
|
||||
<img alt="Flax" src="https://img.shields.io/badge/Flax-29a79b.svg?style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAC0AAAAtCAMAAAANxBKoAAAC7lBMVEUAAADg5vYHPVgAoJH+/v76+v39/f9JbLP///9+AIgAnY3///+mcqzt8fXy9fgkXa3Ax9709fr+///9/f8qXq49qp5AaLGMwrv8/P0eW60VWawxYq8yqJzG2dytt9Wyu9elzci519Lf3O3S2efY3OrY0+Xp7PT///////+dqNCexMc6Z7AGpJeGvbenstPZ5ejQ1OfJzOLa7ejh4+/r8fT29vpccbklWK8PVa0AS6ghW63O498vYa+lsdKz1NDRt9Kw1c672tbD3tnAxt7R6OHp5vDe7OrDyuDn6vLl6/EAQKak0MgATakkppo3ZK/Bz9y8w9yzu9jey97axdvHzeG21NHH4trTwthKZrVGZLSUSpuPQJiGAI+GAI8SWKydycLL4d7f2OTi1+S9xNzL0ePT6OLGzeEAo5U0qJw/aLEAo5JFa7JBabEAp5Y4qZ2QxLyKmsm3kL2xoMOehrRNb7RIbbOZgrGre68AUqwAqZqNN5aKJ5N/lMq+qsd8kMa4pcWzh7muhLMEV69juq2kbKqgUaOTR5uMMZWLLZSGAI5VAIdEAH+ovNDHuNCnxcy3qcaYx8K8msGplrx+wLahjbYdXrV6vbMvYK9DrZ8QrZ8tqJuFms+Sos6sw8ecy8RffsNVeMCvmb43aLltv7Q4Y7EZWK4QWa1gt6meZKUdr6GOAZVeA4xPAISyveLUwtivxtKTpNJ2jcqfvcltiMiwwcfAoMVxhL+Kx7xjdrqTe60tsaNQs6KaRKACrJ6UTZwkqpqTL5pkHY4AloSgsd2ptNXPvNOOncuxxsqFl8lmg8apt8FJcr9EbryGxLqlkrkrY7dRa7ZGZLQ5t6iXUZ6PPpgVpZeJCJFKAIGareTa0+KJod3H0deY2M+esM25usmYu8d2zsJOdcBVvrCLbqcAOaaHaKQAMaScWqKBXqCXMJ2RHpiLF5NmJZAdAHN2kta11dKu1M+DkcZLdb+Mcql3TppyRJdzQ5ZtNZNlIY+DF4+voCOQAAAAZ3RSTlMABAT+MEEJ/RH+/TP+Zlv+pUo6Ifz8+fco/fz6+evr39S9nJmOilQaF/7+/f38+smmoYp6b1T+/v7++vj189zU0tDJxsGzsrKSfv34+Pf27dDOysG9t6+n/vv6+vr59uzr1tG+tZ6Qg9Ym3QAABR5JREFUSMeNlVVUG1EQhpcuxEspXqS0SKEtxQp1d3d332STTRpIQhIISQgJhODu7lAoDoUCpe7u7u7+1puGpqnCPOyZvffbOXPm/PsP9JfQgyCC+tmTABTOcbxDz/heENS7/1F+9nhvkHePG0wNDLbGWwdXL+rbLWvpmZHXD8+gMfBjTh+aSe6Gnn7lwQIOTR0c8wfX3PWgv7avbdKwf/ZoBp1Gp/PvuvXW3vw5ib7emnTW4OR+3D4jB9vjNJ/7gNvfWWeH/TO/JyYrsiKCRjVEZA3UB+96kON+DxOQ/NLE8PE5iUYgIXjFnCOlxEQMaSGVxjg4gxOnEycGz8bptuNjVx08LscIgrzH3umcn+KKtiBIyvzOO2O99aAdR8cF19oZalnCtvREUw79tCd5sow1g1UKM6kXqUx4T8wsi3sTjJ3yzDmmhenLXLpo8u45eG5y4Vvbk6kkC4LLtJMowkSQxmk4ggVJEG+7c6QpHT8vvW9X7/o7+3ELmiJi2mEzZJiz8cT6TBlanBk70cB5GGIGC1gRDdZ00yADLW1FL6gqhtvNXNG5S9gdSrk4M1qu7JAsmYshzDS4peoMrU/gT7qQdqYGZaYhxZmVbGJAm/CS/HloWyhRUlknQ9KYcExTwS80d3VNOxUZJpITYyspl0LbhArhpZCD9cRWEQuhYkNGMHToQ/2Cs6swJlb39CsllxdXX6IUKh/H5jbnSsPKjgmoaFQ1f8wRLR0UnGE/RcDEjj2jXG1WVTwUs8+zxfcrVO+vSsuOpVKxCfYZiQ0/aPKuxQbQ8lIz+DClxC8u+snlcJ7Yr1z1JPqUH0V+GDXbOwAib931Y4Imaq0NTIXPXY+N5L18GJ37SVWu+hwXff8l72Ds9XuwYIBaXPq6Shm4l+Vl/5QiOlV+uTk6YR9PxKsI9xNJny31ygK1e+nIRC1N97EGkFPI+jCpiHe5PCEy7oWqWSwRrpOvhFzcbTWMbm3ZJAOn1rUKpYIt/lDhW/5RHHteeWFN60qo98YJuoq1nK3uW5AabyspC1BcIEpOhft+SZAShYoLSvnmSfnYADUERP5jJn2h5XtsgCRuhYQqAvwTwn33+YWEKUI72HX5AtfSAZDe8F2DtPPm77afhl0EkthzuCQU0BWApgQIH9+KB0JhopMM7bJrdTRoleM2JAVNMyPF+wdoaz+XJpGoVAQ7WXUkcV7gT3oUZyi/ISIJAVKhgNp+4b4veCFhYVJw4locdSjZCp9cPUhLF9EZ3KKzURepMEtCDPP3VcWFx4UIiZIklIpFNfHpdEafIF2aRmOcrUmjohbT2WUllbmRvgfbythbQO3222fpDJoufaQPncYYuqoGtUEsCJZL6/3PR5b4syeSjZMQG/T2maGANlXT2v8S4AULWaUkCxfLyW8iW4kdka+nEMjxpL2NCwsYNBp+Q61PF43zyDg9Bm9+3NNySn78jMZUUkumqE4Gp7JmFOdP1vc8PpRrzj9+wPinCy8K1PiJ4aYbnTYpCCbDkBSbzhu2QJ1Gd82t8jI8TH51+OzvXoWbnXUOBkNW+0mWFwGcGOUVpU81/n3TOHb5oMt2FgYGjzau0Nif0Ss7Q3XB33hjjQHjHA5E5aOyIQc8CBrLdQSs3j92VG+3nNEjbkbdbBr9zm04ruvw37vh0QKOdeGIkckc80fX3KH/h7PT4BOjgCty8VZ5ux1MoO5Cf5naca2LAsEgehI+drX8o/0Nu+W0m6K/I9gGPd/dfx/EN/wN62AhsBWuAAAAAElFTkSuQmCC
|
||||
">
|
||||
<div style="float: right;">
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="TensorFlow" src="https://img.shields.io/badge/TensorFlow-FF6F00?style=flat&logo=tensorflow&logoColor=white">
|
||||
<img alt="Flax" src="https://img.shields.io/badge/Flax-29a79b.svg?style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAC0AAAAtCAMAAAANxBKoAAAC7lBMVEUAAADg5vYHPVgAoJH+/v76+v39/f9JbLP///9+AIgAnY3///+mcqzt8fXy9fgkXa3Ax9709fr+///9/f8qXq49qp5AaLGMwrv8/P0eW60VWawxYq8yqJzG2dytt9Wyu9elzci519Lf3O3S2efY3OrY0+Xp7PT///////+dqNCexMc6Z7AGpJeGvbenstPZ5ejQ1OfJzOLa7ejh4+/r8fT29vpccbklWK8PVa0AS6ghW63O498vYa+lsdKz1NDRt9Kw1c672tbD3tnAxt7R6OHp5vDe7OrDyuDn6vLl6/EAQKak0MgATakkppo3ZK/Bz9y8w9yzu9jey97axdvHzeG21NHH4trTwthKZrVGZLSUSpuPQJiGAI+GAI8SWKydycLL4d7f2OTi1+S9xNzL0ePT6OLGzeEAo5U0qJw/aLEAo5JFa7JBabEAp5Y4qZ2QxLyKmsm3kL2xoMOehrRNb7RIbbOZgrGre68AUqwAqZqNN5aKJ5N/lMq+qsd8kMa4pcWzh7muhLMEV69juq2kbKqgUaOTR5uMMZWLLZSGAI5VAIdEAH+ovNDHuNCnxcy3qcaYx8K8msGplrx+wLahjbYdXrV6vbMvYK9DrZ8QrZ8tqJuFms+Sos6sw8ecy8RffsNVeMCvmb43aLltv7Q4Y7EZWK4QWa1gt6meZKUdr6GOAZVeA4xPAISyveLUwtivxtKTpNJ2jcqfvcltiMiwwcfAoMVxhL+Kx7xjdrqTe60tsaNQs6KaRKACrJ6UTZwkqpqTL5pkHY4AloSgsd2ptNXPvNOOncuxxsqFl8lmg8apt8FJcr9EbryGxLqlkrkrY7dRa7ZGZLQ5t6iXUZ6PPpgVpZeJCJFKAIGareTa0+KJod3H0deY2M+esM25usmYu8d2zsJOdcBVvrCLbqcAOaaHaKQAMaScWqKBXqCXMJ2RHpiLF5NmJZAdAHN2kta11dKu1M+DkcZLdb+Mcql3TppyRJdzQ5ZtNZNlIY+DF4+voCOQAAAAZ3RSTlMABAT+MEEJ/RH+/TP+Zlv+pUo6Ifz8+fco/fz6+evr39S9nJmOilQaF/7+/f38+smmoYp6b1T+/v7++vj189zU0tDJxsGzsrKSfv34+Pf27dDOysG9t6+n/vv6+vr59uzr1tG+tZ6Qg9Ym3QAABR5JREFUSMeNlVVUG1EQhpcuxEspXqS0SKEtxQp1d3d332STTRpIQhIISQgJhODu7lAoDoUCpe7u7u7+1puGpqnCPOyZvffbOXPm/PsP9JfQgyCC+tmTABTOcbxDz/heENS7/1F+9nhvkHePG0wNDLbGWwdXL+rbLWvpmZHXD8+gMfBjTh+aSe6Gnn7lwQIOTR0c8wfX3PWgv7avbdKwf/ZoBp1Gp/PvuvXW3vw5ib7emnTW4OR+3D4jB9vjNJ/7gNvfWWeH/TO/JyYrsiKCRjVEZA3UB+96kON+DxOQ/NLE8PE5iUYgIXjFnCOlxEQMaSGVxjg4gxOnEycGz8bptuNjVx08LscIgrzH3umcn+KKtiBIyvzOO2O99aAdR8cF19oZalnCtvREUw79tCd5sow1g1UKM6kXqUx4T8wsi3sTjJ3yzDmmhenLXLpo8u45eG5y4Vvbk6kkC4LLtJMowkSQxmk4ggVJEG+7c6QpHT8vvW9X7/o7+3ELmiJi2mEzZJiz8cT6TBlanBk70cB5GGIGC1gRDdZ00yADLW1FL6gqhtvNXNG5S9gdSrk4M1qu7JAsmYshzDS4peoMrU/gT7qQdqYGZaYhxZmVbGJAm/CS/HloWyhRUlknQ9KYcExTwS80d3VNOxUZJpITYyspl0LbhArhpZCD9cRWEQuhYkNGMHToQ/2Cs6swJlb39CsllxdXX6IUKh/H5jbnSsPKjgmoaFQ1f8wRLR0UnGE/RcDEjj2jXG1WVTwUs8+zxfcrVO+vSsuOpVKxCfYZiQ0/aPKuxQbQ8lIz+DClxC8u+snlcJ7Yr1z1JPqUH0V+GDXbOwAib931Y4Imaq0NTIXPXY+N5L18GJ37SVWu+hwXff8l72Ds9XuwYIBaXPq6Shm4l+Vl/5QiOlV+uTk6YR9PxKsI9xNJny31ygK1e+nIRC1N97EGkFPI+jCpiHe5PCEy7oWqWSwRrpOvhFzcbTWMbm3ZJAOn1rUKpYIt/lDhW/5RHHteeWFN60qo98YJuoq1nK3uW5AabyspC1BcIEpOhft+SZAShYoLSvnmSfnYADUERP5jJn2h5XtsgCRuhYQqAvwTwn33+YWEKUI72HX5AtfSAZDe8F2DtPPm77afhl0EkthzuCQU0BWApgQIH9+KB0JhopMM7bJrdTRoleM2JAVNMyPF+wdoaz+XJpGoVAQ7WXUkcV7gT3oUZyi/ISIJAVKhgNp+4b4veCFhYVJw4locdSjZCp9cPUhLF9EZ3KKzURepMEtCDPP3VcWFx4UIiZIklIpFNfHpdEafIF2aRmOcrUmjohbT2WUllbmRvgfbythbQO3222fpDJoufaQPncYYuqoGtUEsCJZL6/3PR5b4syeSjZMQG/T2maGANlXT2v8S4AULWaUkCxfLyW8iW4kdka+nEMjxpL2NCwsYNBp+Q61PF43zyDg9Bm9+3NNySn78jMZUUkumqE4Gp7JmFOdP1vc8PpRrzj9+wPinCy8K1PiJ4aYbnTYpCCbDkBSbzhu2QJ1Gd82t8jI8TH51+OzvXoWbnXUOBkNW+0mWFwGcGOUVpU81/n3TOHb5oMt2FgYGjzau0Nif0Ss7Q3XB33hjjQHjHA5E5aOyIQc8CBrLdQSs3j92VG+3nNEjbkbdbBr9zm04ruvw37vh0QKOdeGIkckc80fX3KH/h7PT4BOjgCty8VZ5ux1MoO5Cf5naca2LAsEgehI+drX8o/0Nu+W0m6K/I9gGPd/dfx/EN/wN62AhsBWuAAAAAElFTkSuQmCC
|
||||
">
|
||||
</div>
|
||||
</div>
|
||||
|
||||
## Overview
|
||||
# T5
|
||||
|
||||
The T5 model was presented in [Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer](https://arxiv.org/pdf/1910.10683.pdf) by [Colin Raffel](https://huggingface.co/craffel), Noam Shazeer, [Adam Roberts](https://huggingface.co/adarob), Katherine Lee, Sharan Narang,
|
||||
Michael Matena, Yanqi Zhou, Wei Li, [Peter J. Liu](https://huggingface.co/peterjliu).
|
||||
[T5](https://huggingface.co/papers/1910.10683) is a encoder-decoder transformer available in a range of sizes from 60M to 11B parameters. It is designed to handle a wide range of NLP tasks by treating them all as text-to-text problems. This eliminates the need for task-specific architectures because T5 converts every NLP task into a text generation task.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
To formulate every task as text generation, each task is prepended with a task-specific prefix (e.g., translate English to German: ..., summarize: ...). This enables T5 to handle tasks like translation, summarization, question answering, and more.
|
||||
|
||||
*Transfer learning, where a model is first pre-trained on a data-rich task before being fine-tuned on a downstream
|
||||
task, has emerged as a powerful technique in natural language processing (NLP). The effectiveness of transfer learning
|
||||
has given rise to a diversity of approaches, methodology, and practice. In this paper, we explore the landscape of
|
||||
transfer learning techniques for NLP by introducing a unified framework that converts every language problem into a
|
||||
text-to-text format. Our systematic study compares pretraining objectives, architectures, unlabeled datasets, transfer
|
||||
approaches, and other factors on dozens of language understanding tasks. By combining the insights from our exploration
|
||||
with scale and our new "Colossal Clean Crawled Corpus", we achieve state-of-the-art results on many benchmarks covering
|
||||
summarization, question answering, text classification, and more. To facilitate future work on transfer learning for
|
||||
NLP, we release our dataset, pre-trained models, and code.*
|
||||
You can find all official T5 checkpoints under the [T5](https://huggingface.co/collections/google/t5-release-65005e7c520f8d7b4d037918) collection.
|
||||
|
||||
All checkpoints can be found on the [hub](https://huggingface.co/models?search=t5).
|
||||
> [!TIP]
|
||||
> Click on the T5 models in the right sidebar for more examples of how to apply T5 to different language tasks.
|
||||
|
||||
This model was contributed by [thomwolf](https://huggingface.co/thomwolf). The original code can be found [here](https://github.com/google-research/text-to-text-transfer-transformer).
|
||||
The example below demonstrates how to generate text with [`Pipeline`], [`AutoModel`], and how to translate with T5 from the command line.
|
||||
|
||||
## Usage tips
|
||||
<hfoptions id="usage">
|
||||
<hfoption id="Pipeline">
|
||||
|
||||
- T5 is an encoder-decoder model pre-trained on a multi-task mixture of unsupervised and supervised tasks and for which
|
||||
each task is converted into a text-to-text format. T5 works well on a variety of tasks out-of-the-box by prepending a
|
||||
different prefix to the input corresponding to each task, e.g., for translation: *translate English to German: ...*,
|
||||
for summarization: *summarize: ...*.
|
||||
- The pretraining includes both supervised and self-supervised training. Supervised training is conducted on downstream tasks provided by the GLUE and SuperGLUE benchmarks (converting them into text-to-text tasks as explained above).
|
||||
- Self-supervised training uses corrupted tokens, by randomly removing 15% of the tokens and replacing them with individual sentinel tokens (if several consecutive tokens are marked for removal, the whole group is replaced with a single sentinel token). The input of the encoder is the corrupted sentence, the input of the decoder is the original sentence and the target is then the dropped out tokens delimited by their sentinel tokens.
|
||||
```py
|
||||
import torch
|
||||
from transformers import pipeline
|
||||
|
||||
- T5 uses relative scalar embeddings. Encoder input padding can be done on the left and on the right.
|
||||
|
||||
- See the [training](#training), [inference](#inference) and [resources](#resources) sections below for all details regarding usage.
|
||||
|
||||
T5 comes in different sizes:
|
||||
|
||||
- [google-t5/t5-small](https://huggingface.co/google-t5/t5-small)
|
||||
|
||||
- [google-t5/t5-base](https://huggingface.co/google-t5/t5-base)
|
||||
|
||||
- [google-t5/t5-large](https://huggingface.co/google-t5/t5-large)
|
||||
|
||||
- [google-t5/t5-3b](https://huggingface.co/google-t5/t5-3b)
|
||||
|
||||
- [google-t5/t5-11b](https://huggingface.co/google-t5/t5-11b).
|
||||
|
||||
Based on the original T5 model, Google has released some follow-up works:
|
||||
|
||||
- **T5v1.1**: T5v1.1 is an improved version of T5 with some architectural tweaks, and is pre-trained on C4 only without
|
||||
mixing in the supervised tasks. Refer to the documentation of T5v1.1 which can be found [here](t5v1.1).
|
||||
|
||||
- **mT5**: mT5 is a multilingual T5 model. It is pre-trained on the mC4 corpus, which includes 101 languages. Refer to
|
||||
the documentation of mT5 which can be found [here](mt5).
|
||||
|
||||
- **byT5**: byT5 is a T5 model pre-trained on byte sequences rather than SentencePiece subword token sequences. Refer
|
||||
to the documentation of byT5 which can be found [here](byt5).
|
||||
|
||||
- **UL2**: UL2 is a T5 like model pretrained on various denoising objectives
|
||||
|
||||
- **Flan-T5**: Flan is a pretraining methods that is based on prompting. The Flan-T5 are T5 models trained on the Flan collection of
|
||||
datasets which include: `taskmaster2`, `djaym7/wiki_dialog`, `deepmind/code_contests`, `lambada`, `gsm8k`, `aqua_rat`, `esnli`, `quasc` and `qed`.
|
||||
|
||||
- **FLan-UL2** : the UL2 model finetuned using the "Flan" prompt tuning and dataset collection.
|
||||
|
||||
- **UMT5**: UmT5 is a multilingual T5 model trained on an improved and refreshed mC4 multilingual corpus, 29 trillion characters across 107 language, using a new sampling method, UniMax. Refer to
|
||||
the documentation of mT5 which can be found [here](umt5).
|
||||
|
||||
## Training
|
||||
|
||||
T5 is an encoder-decoder model and converts all NLP problems into a text-to-text format. It is trained using teacher
|
||||
forcing. This means that for training, we always need an input sequence and a corresponding target sequence. The input
|
||||
sequence is fed to the model using `input_ids`. The target sequence is shifted to the right, i.e., prepended by a
|
||||
start-sequence token and fed to the decoder using the `decoder_input_ids`. In teacher-forcing style, the target
|
||||
sequence is then appended by the EOS token and corresponds to the `labels`. The PAD token is hereby used as the
|
||||
start-sequence token. T5 can be trained / fine-tuned both in a supervised and unsupervised fashion.
|
||||
|
||||
One can use [`T5ForConditionalGeneration`] (or the Tensorflow/Flax variant), which includes the
|
||||
language modeling head on top of the decoder.
|
||||
|
||||
- Unsupervised denoising training
|
||||
|
||||
In this setup, spans of the input sequence are masked by so-called sentinel tokens (*a.k.a* unique mask tokens) and
|
||||
the output sequence is formed as a concatenation of the same sentinel tokens and the *real* masked tokens. Each
|
||||
sentinel token represents a unique mask token for this sentence and should start with `<extra_id_0>`,
|
||||
`<extra_id_1>`, ... up to `<extra_id_99>`. As a default, 100 sentinel tokens are available in
|
||||
[`T5Tokenizer`].
|
||||
|
||||
For instance, the sentence "The cute dog walks in the park" with the masks put on "cute dog" and "the" should be
|
||||
processed as follows:
|
||||
|
||||
```python
|
||||
>>> from transformers import T5Tokenizer, T5ForConditionalGeneration
|
||||
|
||||
>>> tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-small")
|
||||
>>> model = T5ForConditionalGeneration.from_pretrained("google-t5/t5-small")
|
||||
|
||||
>>> input_ids = tokenizer("The <extra_id_0> walks in <extra_id_1> park", return_tensors="pt").input_ids
|
||||
>>> labels = tokenizer("<extra_id_0> cute dog <extra_id_1> the <extra_id_2>", return_tensors="pt").input_ids
|
||||
|
||||
>>> # the forward function automatically creates the correct decoder_input_ids
|
||||
>>> loss = model(input_ids=input_ids, labels=labels).loss
|
||||
>>> loss.item()
|
||||
3.7837
|
||||
pipeline = pipeline(
|
||||
task="text2text-generation",
|
||||
model="google-t5/t5-base",
|
||||
torch_dtype=torch.float16,
|
||||
device=0
|
||||
)
|
||||
pipeline("translate English to French: The weather is nice today.")
|
||||
```
|
||||
|
||||
If you're interested in pre-training T5 on a new corpus, check out the [run_t5_mlm_flax.py](https://github.com/huggingface/transformers/tree/main/examples/flax/language-modeling) script in the Examples
|
||||
directory.
|
||||
</hfoption>
|
||||
<hfoption id="AutoModel">
|
||||
|
||||
- Supervised training
|
||||
```py
|
||||
import torch
|
||||
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
||||
|
||||
In this setup, the input sequence and output sequence are a standard sequence-to-sequence input-output mapping.
|
||||
Suppose that we want to fine-tune the model for translation for example, and we have a training example: the input
|
||||
sequence "The house is wonderful." and output sequence "Das Haus ist wunderbar.", then they should be prepared for
|
||||
the model as follows:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"google-t5/t5-base"
|
||||
)
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained(
|
||||
"google-t5/t5-base",
|
||||
torch_dtype=torch.float16,
|
||||
device_map="auto"
|
||||
)
|
||||
|
||||
```python
|
||||
>>> from transformers import T5Tokenizer, T5ForConditionalGeneration
|
||||
input_ids = tokenizer("translate English to French: The weather is nice today.", return_tensors="pt").to("cuda")
|
||||
|
||||
>>> tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-small")
|
||||
>>> model = T5ForConditionalGeneration.from_pretrained("google-t5/t5-small")
|
||||
|
||||
>>> input_ids = tokenizer("translate English to German: The house is wonderful.", return_tensors="pt").input_ids
|
||||
>>> labels = tokenizer("Das Haus ist wunderbar.", return_tensors="pt").input_ids
|
||||
|
||||
>>> # the forward function automatically creates the correct decoder_input_ids
|
||||
>>> loss = model(input_ids=input_ids, labels=labels).loss
|
||||
>>> loss.item()
|
||||
0.2542
|
||||
output = model.generate(**input_ids, cache_implementation="static")
|
||||
print(tokenizer.decode(output[0], skip_special_tokens=True))
|
||||
```
|
||||
|
||||
As you can see, only 2 inputs are required for the model in order to compute a loss: `input_ids` (which are the
|
||||
`input_ids` of the encoded input sequence) and `labels` (which are the `input_ids` of the encoded
|
||||
target sequence). The model will automatically create the `decoder_input_ids` based on the `labels`, by
|
||||
shifting them one position to the right and prepending the `config.decoder_start_token_id`, which for T5 is
|
||||
equal to 0 (i.e. the id of the pad token). Also note the task prefix: we prepend the input sequence with 'translate
|
||||
English to German: ' before encoding it. This will help in improving the performance, as this task prefix was used
|
||||
during T5's pre-training.
|
||||
</hfoption>
|
||||
<hfoption id="transformers-cli">
|
||||
|
||||
However, the example above only shows a single training example. In practice, one trains deep learning models in
|
||||
batches. This entails that we must pad/truncate examples to the same length. For encoder-decoder models, one
|
||||
typically defines a `max_source_length` and `max_target_length`, which determine the maximum length of the
|
||||
input and output sequences respectively (otherwise they are truncated). These should be carefully set depending on
|
||||
the task.
|
||||
|
||||
In addition, we must make sure that padding token id's of the `labels` are not taken into account by the loss
|
||||
function. In PyTorch and Tensorflow, this can be done by replacing them with -100, which is the `ignore_index`
|
||||
of the `CrossEntropyLoss`. In Flax, one can use the `decoder_attention_mask` to ignore padded tokens from
|
||||
the loss (see the [Flax summarization script](https://github.com/huggingface/transformers/tree/main/examples/flax/summarization) for details). We also pass
|
||||
`attention_mask` as additional input to the model, which makes sure that padding tokens of the inputs are
|
||||
ignored. The code example below illustrates all of this.
|
||||
|
||||
```python
|
||||
>>> from transformers import T5Tokenizer, T5ForConditionalGeneration
|
||||
>>> import torch
|
||||
|
||||
>>> tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-small")
|
||||
>>> model = T5ForConditionalGeneration.from_pretrained("google-t5/t5-small")
|
||||
|
||||
>>> # the following 2 hyperparameters are task-specific
|
||||
>>> max_source_length = 512
|
||||
>>> max_target_length = 128
|
||||
|
||||
>>> # Suppose we have the following 2 training examples:
|
||||
>>> input_sequence_1 = "Welcome to NYC"
|
||||
>>> output_sequence_1 = "Bienvenue à NYC"
|
||||
|
||||
>>> input_sequence_2 = "HuggingFace is a company"
|
||||
>>> output_sequence_2 = "HuggingFace est une entreprise"
|
||||
|
||||
>>> # encode the inputs
|
||||
>>> task_prefix = "translate English to French: "
|
||||
>>> input_sequences = [input_sequence_1, input_sequence_2]
|
||||
|
||||
>>> encoding = tokenizer(
|
||||
... [task_prefix + sequence for sequence in input_sequences],
|
||||
... padding="longest",
|
||||
... max_length=max_source_length,
|
||||
... truncation=True,
|
||||
... return_tensors="pt",
|
||||
... )
|
||||
|
||||
>>> input_ids, attention_mask = encoding.input_ids, encoding.attention_mask
|
||||
|
||||
>>> # encode the targets
|
||||
>>> target_encoding = tokenizer(
|
||||
... [output_sequence_1, output_sequence_2],
|
||||
... padding="longest",
|
||||
... max_length=max_target_length,
|
||||
... truncation=True,
|
||||
... return_tensors="pt",
|
||||
... )
|
||||
>>> labels = target_encoding.input_ids
|
||||
|
||||
>>> # replace padding token id's of the labels by -100 so it's ignored by the loss
|
||||
>>> labels[labels == tokenizer.pad_token_id] = -100
|
||||
|
||||
>>> # forward pass
|
||||
>>> loss = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels).loss
|
||||
>>> loss.item()
|
||||
0.188
|
||||
```bash
|
||||
echo -e "translate English to French: The weather is nice today." | transformers-cli run --task text2text-generation --model google-t5/t5-base --device 0
|
||||
```
|
||||
|
||||
Additional training tips:
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
- T5 models need a slightly higher learning rate than the default one set in the `Trainer` when using the AdamW
|
||||
optimizer. Typically, 1e-4 and 3e-4 work well for most problems (classification, summarization, translation, question
|
||||
answering, question generation). Note that T5 was pre-trained using the AdaFactor optimizer.
|
||||
Quantization reduces the memory burden of large models by representing the weights in a lower precision. Refer to the [Quantization](../quantization/overview) overview for more available quantization backends.
|
||||
|
||||
According to [this forum post](https://discuss.huggingface.co/t/t5-finetuning-tips/684), task prefixes matter when
|
||||
(1) doing multi-task training (2) your task is similar or related to one of the supervised tasks used in T5's
|
||||
pre-training mixture (see Appendix D of the [paper](https://arxiv.org/pdf/1910.10683.pdf) for the task prefixes
|
||||
used).
|
||||
The example below uses [torchao](../quantization/torchao) to only quantize the weights to int4.
|
||||
|
||||
If training on TPU, it is recommended to pad all examples of the dataset to the same length or make use of
|
||||
*pad_to_multiple_of* to have a small number of predefined bucket sizes to fit all examples in. Dynamically padding
|
||||
batches to the longest example is not recommended on TPU as it triggers a recompilation for every batch shape that is
|
||||
encountered during training thus significantly slowing down the training. only padding up to the longest example in a
|
||||
batch) leads to very slow training on TPU.
|
||||
```py
|
||||
# pip install torchao
|
||||
import torch
|
||||
from transformers import TorchAoConfig, AutoModelForSeq2SeqLM, AutoTokenizer
|
||||
|
||||
## Inference
|
||||
quantization_config = TorchAoConfig("int4_weight_only", group_size=128)
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained(
|
||||
"google/t5-v1_1-xl",
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
quantization_config=quantization_config
|
||||
)
|
||||
|
||||
At inference time, it is recommended to use [`~generation.GenerationMixin.generate`]. This
|
||||
method takes care of encoding the input and feeding the encoded hidden states via cross-attention layers to the decoder
|
||||
and auto-regressively generates the decoder output. Check out [this blog post](https://huggingface.co/blog/how-to-generate) to know all the details about generating text with Transformers.
|
||||
There's also [this blog post](https://huggingface.co/blog/encoder-decoder#encoder-decoder) which explains how
|
||||
generation works in general in encoder-decoder models.
|
||||
tokenizer = AutoTokenizer.from_pretrained("google/t5-v1_1-xl")
|
||||
input_ids = tokenizer("translate English to French: The weather is nice today.", return_tensors="pt").to("cuda")
|
||||
|
||||
```python
|
||||
>>> from transformers import T5Tokenizer, T5ForConditionalGeneration
|
||||
|
||||
>>> tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-small")
|
||||
>>> model = T5ForConditionalGeneration.from_pretrained("google-t5/t5-small")
|
||||
|
||||
>>> input_ids = tokenizer("translate English to German: The house is wonderful.", return_tensors="pt").input_ids
|
||||
>>> outputs = model.generate(input_ids)
|
||||
>>> print(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
||||
Das Haus ist wunderbar.
|
||||
output = model.generate(**input_ids, cache_implementation="static")
|
||||
print(tokenizer.decode(output[0], skip_special_tokens=True))
|
||||
```
|
||||
|
||||
Note that T5 uses the `pad_token_id` as the `decoder_start_token_id`, so when doing generation without using
|
||||
[`~generation.GenerationMixin.generate`], make sure you start it with the `pad_token_id`.
|
||||
## Notes
|
||||
|
||||
The example above only shows a single example. You can also do batched inference, like so:
|
||||
|
||||
```python
|
||||
>>> from transformers import T5Tokenizer, T5ForConditionalGeneration
|
||||
|
||||
>>> tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-small")
|
||||
>>> model = T5ForConditionalGeneration.from_pretrained("google-t5/t5-small")
|
||||
|
||||
>>> task_prefix = "translate English to German: "
|
||||
>>> # use different length sentences to test batching
|
||||
>>> sentences = ["The house is wonderful.", "I like to work in NYC."]
|
||||
|
||||
>>> inputs = tokenizer([task_prefix + sentence for sentence in sentences], return_tensors="pt", padding=True)
|
||||
|
||||
>>> output_sequences = model.generate(
|
||||
... input_ids=inputs["input_ids"],
|
||||
... attention_mask=inputs["attention_mask"],
|
||||
... do_sample=False, # disable sampling to test if batching affects output
|
||||
... )
|
||||
|
||||
>>> print(tokenizer.batch_decode(output_sequences, skip_special_tokens=True))
|
||||
['Das Haus ist wunderbar.', 'Ich arbeite gerne in NYC.']
|
||||
```
|
||||
|
||||
Because T5 has been trained with the span-mask denoising objective,
|
||||
it can be used to predict the sentinel (masked-out) tokens during inference.
|
||||
The predicted tokens will then be placed between the sentinel tokens.
|
||||
|
||||
```python
|
||||
>>> from transformers import T5Tokenizer, T5ForConditionalGeneration
|
||||
|
||||
>>> tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-small")
|
||||
>>> model = T5ForConditionalGeneration.from_pretrained("google-t5/t5-small")
|
||||
|
||||
>>> input_ids = tokenizer("The <extra_id_0> walks in <extra_id_1> park", return_tensors="pt").input_ids
|
||||
|
||||
>>> sequence_ids = model.generate(input_ids)
|
||||
>>> sequences = tokenizer.batch_decode(sequence_ids)
|
||||
>>> sequences
|
||||
['<pad> <extra_id_0> park offers <extra_id_1> the <extra_id_2> park.</s>']
|
||||
```
|
||||
|
||||
## Performance
|
||||
|
||||
If you'd like a faster training and inference performance, install [NVIDIA APEX](https://github.com/NVIDIA/apex#quick-start) for NVIDIA GPUs, or [ROCm APEX](https://github.com/ROCmSoftwarePlatform/apex) for AMD GPUs and then the model will automatically use `apex.normalization.FusedRMSNorm` instead of `T5LayerNorm`. The former uses an optimized fused kernel which is several times faster than the latter.
|
||||
|
||||
|
||||
## Resources
|
||||
|
||||
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with T5. If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we'll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource.
|
||||
|
||||
<PipelineTag pipeline="text-classification"/>
|
||||
|
||||
- A notebook for how to [finetune T5 for classification and multiple choice](https://colab.research.google.com/github/patil-suraj/exploring-T5/blob/master/t5_fine_tuning.ipynb).
|
||||
- A notebook for how to [finetune T5 for sentiment span extraction](https://colab.research.google.com/github/enzoampil/t5-intro/blob/master/t5_qa_training_pytorch_span_extraction.ipynb). 🌎
|
||||
|
||||
<PipelineTag pipeline="token-classification"/>
|
||||
|
||||
- A notebook for how to [finetune T5 for named entity recognition](https://colab.research.google.com/drive/1obr78FY_cBmWY5ODViCmzdY6O1KB65Vc?usp=sharing). 🌎
|
||||
|
||||
<PipelineTag pipeline="text-generation"/>
|
||||
|
||||
- A notebook for [Finetuning CodeT5 for generating docstrings from Ruby code](https://colab.research.google.com/github/NielsRogge/Transformers-Tutorials/blob/master/T5/Fine_tune_CodeT5_for_generating_docstrings_from_Ruby_code.ipynb).
|
||||
|
||||
<PipelineTag pipeline="summarization"/>
|
||||
|
||||
- A notebook to [Finetune T5-base-dutch to perform Dutch abstractive summarization on a TPU](https://colab.research.google.com/github/NielsRogge/Transformers-Tutorials/blob/master/T5/Fine_tuning_Dutch_T5_base_on_CNN_Daily_Mail_for_summarization_(on_TPU_using_HuggingFace_Accelerate).ipynb).
|
||||
- A notebook for how to [finetune T5 for summarization in PyTorch and track experiments with WandB](https://colab.research.google.com/github/abhimishra91/transformers-tutorials/blob/master/transformers_summarization_wandb.ipynb#scrollTo=OKRpFvYhBauC). 🌎
|
||||
- A blog post on [Distributed Training: Train BART/T5 for Summarization using 🤗 Transformers and Amazon SageMaker](https://huggingface.co/blog/sagemaker-distributed-training-seq2seq).
|
||||
- [`T5ForConditionalGeneration`] is supported by this [example script](https://github.com/huggingface/transformers/tree/main/examples/pytorch/summarization) and [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/summarization.ipynb).
|
||||
- [`TFT5ForConditionalGeneration`] is supported by this [example script](https://github.com/huggingface/transformers/tree/main/examples/tensorflow/summarization) and [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/summarization-tf.ipynb).
|
||||
- [`FlaxT5ForConditionalGeneration`] is supported by this [example script](https://github.com/huggingface/transformers/tree/main/examples/flax/summarization).
|
||||
- [Summarization](https://huggingface.co/course/chapter7/5?fw=pt#summarization) chapter of the 🤗 Hugging Face course.
|
||||
- [Summarization task guide](../tasks/summarization)
|
||||
|
||||
<PipelineTag pipeline="fill-mask"/>
|
||||
|
||||
- [`FlaxT5ForConditionalGeneration`] is supported by this [example script](https://github.com/huggingface/transformers/tree/main/examples/flax/language-modeling#t5-like-span-masked-language-modeling) for training T5 with a span-masked language model objective. The script also shows how to train a T5 tokenizer. [`FlaxT5ForConditionalGeneration`] is also supported by this [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/masked_language_modeling_flax.ipynb).
|
||||
|
||||
<PipelineTag pipeline="translation"/>
|
||||
|
||||
- [`T5ForConditionalGeneration`] is supported by this [example script](https://github.com/huggingface/transformers/tree/main/examples/pytorch/translation) and [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/translation.ipynb).
|
||||
- [`TFT5ForConditionalGeneration`] is supported by this [example script](https://github.com/huggingface/transformers/tree/main/examples/tensorflow/translation) and [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/translation-tf.ipynb).
|
||||
- [Translation task guide](../tasks/translation)
|
||||
|
||||
<PipelineTag pipeline="question-answering"/>
|
||||
|
||||
- A notebook on how to [finetune T5 for question answering with TensorFlow 2](https://colab.research.google.com/github/snapthat/TF-T5-text-to-text/blob/master/snapthatT5/notebooks/TF-T5-Datasets%20Training.ipynb). 🌎
|
||||
- A notebook on how to [finetune T5 for question answering on a TPU](https://colab.research.google.com/github/patil-suraj/exploring-T5/blob/master/T5_on_TPU.ipynb#scrollTo=QLGiFCDqvuil).
|
||||
|
||||
🚀 **Deploy**
|
||||
- A blog post on how to deploy [T5 11B for inference for less than $500](https://www.philschmid.de/deploy-t5-11b).
|
||||
- You can pad the encoder inputs on the left or right because T5 uses relative scalar embeddings.
|
||||
- T5 models need a slightly higher learning rate than the default used in [`Trainer`]. Typically, values of `1e-4` and `3e-4` work well for most tasks.
|
||||
|
||||
## T5Config
|
||||
|
||||
|
||||
@ -44,11 +44,6 @@ import os
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
# initialize distributed environment
|
||||
rank = int(os.environ["RANK"])
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
torch.cuda.set_device(device)
|
||||
torch.distributed.init_process_group("nccl", device_id=device)
|
||||
|
||||
# enable tensor parallelism
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
@ -59,7 +54,7 @@ model = AutoModelForCausalLM.from_pretrained(
|
||||
# prepare input tokens
|
||||
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
|
||||
prompt = "Can I help"
|
||||
inputs = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
|
||||
inputs = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)
|
||||
|
||||
# distributed run
|
||||
outputs = model(inputs)
|
||||
@ -71,6 +66,13 @@ Launch the inference script above on [torchrun](https://pytorch.org/docs/stable/
|
||||
torchrun --nproc-per-node 4 demo.py
|
||||
```
|
||||
|
||||
For CPU, please binding different socket on each rank. For example, if you are using Intel 4th Gen Xeon:
|
||||
```bash
|
||||
export OMP_NUM_THREADS=56
|
||||
numactl -C 0-55 -m 0 torchrun --nnodes=2 --node_rank=0 --master_addr="127.0.0.1" --master_port=29500 --nproc-per-node 1 demo.py & numactl -C 56-111 -m 1 torchrun --nnodes=2 --node_rank=1 --master_addr="127.0.0.1" --master_port=29500 --nproc-per-node 1 demo.py & wait
|
||||
```
|
||||
The CPU benchmark data will be released soon.
|
||||
|
||||
You can benefit from considerable speed ups for inference, especially for inputs with large batch size or long sequences.
|
||||
|
||||
For a single forward pass on [Llama](./model_doc/llama) with a sequence length of 512 and various batch sizes, you can expect the following speed ups.
|
||||
|
||||
@ -40,7 +40,7 @@ Use the Space below to help you pick a quantization method depending on your har
|
||||
| [VPTQ](./vptq) | 🔴 | 🔴 | 🟢 | 🟡 | 🔴 | 🔴 | 🟢 | 1/8 | 🔴 | 🟢 | 🟢 | https://github.com/microsoft/VPTQ |
|
||||
| [FINEGRAINED_FP8](./finegrained_fp8) | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | 🔴 | 8 | 🔴 | 🟢 | 🟢 | |
|
||||
| [SpQR](./spqr) | 🔴 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | 🟢 | 3 | 🔴 | 🟢 | 🟢 | https://github.com/Vahe1994/SpQR/ |
|
||||
| [Quark](./quark.md) | 🔴 | 🟢 | 🟢 | 🟢 | 🟢 | 🟢 | ? | 2/4/6/8/9/16 | 🔴 | 🔴 | 🟢 | https://quark.docs.amd.com/latest/ |
|
||||
| [Quark](./quark) | 🔴 | 🟢 | 🟢 | 🟢 | 🟢 | 🟢 | ? | 2/4/6/8/9/16 | 🔴 | 🔴 | 🟢 | https://quark.docs.amd.com/latest/ |
|
||||
|
||||
## Resources
|
||||
|
||||
|
||||
135
docs/source/en/quantization/selecting.md
Normal file
135
docs/source/en/quantization/selecting.md
Normal file
@ -0,0 +1,135 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# Selecting a quantization method
|
||||
|
||||
There are many quantization methods available in Transformers for inference and fine-tuning. This guide helps you choose the most common and production-ready quantization techniques depending on your use case, and presents the advantages and disadvantages of each technique.
|
||||
|
||||
For a comprehensive overview of all supported methods and their features, refer back to the table in the [Overview](./overview).
|
||||
|
||||
## Inference
|
||||
|
||||
Consider the quantization methods below for inference.
|
||||
|
||||
| quantization method | use case |
|
||||
|---|---|
|
||||
| bitsandbytes | ease of use and QLoRA fine-tuning on NVIDIA GPUs |
|
||||
| compressed-tensors | loading specific quantized formats (FP8, Sparse) |
|
||||
| GPTQModel or AWQ | good 4-bit accuracy with upfront calibration |
|
||||
| HQQ | fast on the fly quantization without calibration |
|
||||
| torchao | flexibility and fast inference with torch.compile |
|
||||
|
||||
### No Calibration Required (On-the-fly Quantization)
|
||||
|
||||
These methods are generally easier to use as they don't need a separate calibration dataset or step.
|
||||
|
||||
#### bitsandbytes
|
||||
|
||||
| Pros | Cons |
|
||||
|--------------------------------------------------------------|---------------------------------------------------------|
|
||||
| Very simple, no calibration dataset required for inference. | Primarily optimized for NVIDIA GPUs (CUDA). |
|
||||
| Good community support and widely adopted. | Inference speedup isn't guaranteed. |
|
||||
|
||||
See the [bitsandbytes documentation](./bitsandbytes) for more details.
|
||||
|
||||
#### HQQ (Half-Quadratic Quantization)
|
||||
|
||||
| Pros | Cons |
|
||||
|----------------------------------------------------------------------|----------------------------------------------------------------------------|
|
||||
| Fast quantization process, no calibration data needed. | Accuracy can degrade significantly at bit depths <4-bit. |
|
||||
| Multiple backends for fast inference. | Inference speed may not match others unless using `torch.compile` or backends. |
|
||||
| Compatible with `torch.compile`. | |
|
||||
| Supports wide range of bit depths (8, 4, 3, 2, 1-bit). | |
|
||||
|
||||
See the [HQQ documentation](./hqq) for more details.
|
||||
|
||||
#### torchao
|
||||
|
||||
| Pros | Cons |
|
||||
|----------------------------------------------------------------------|----------------------------------------------------------------------|
|
||||
| Strong integration with `torch.compile` for potential speedups. | Newer library, ecosystem still evolving. |
|
||||
| Offers decent CPU quantization support. | Performance depends on `torch.compile` working well. |
|
||||
| Flexibility in quantization schemes (int8, int4, fp8). | 4-bit quantization (int4wo) may not match GPTQ/AWQ in accuracy. |
|
||||
|
||||
See the [torchao documentation](./torchao) for more details.
|
||||
|
||||
### Calibration-based Quantization
|
||||
|
||||
These methods require an upfront calibration step using a dataset to potentially achieve higher accuracy.
|
||||
|
||||
#### GPTQ/GPTQModel
|
||||
|
||||
Calibration for 8B model takes ~20 minutes on one A100 gpu.
|
||||
|
||||
| Pros | Cons |
|
||||
|----------------------------------------------------------------------|----------------------------------------------------------------------|
|
||||
| Often achieves high accuracy. | Requires a calibration dataset and a separate calibration step. |
|
||||
| Can lead to inference speedups. | Possible to overfit on calibration data. |
|
||||
| Many pre-quantized GPTQ models on [Hugging Face Hub](https://huggingface.co/models?other=gptq). | |
|
||||
|
||||
See the [GPTQ documentation](./gptq) for more details.
|
||||
|
||||
#### AWQ (Activation-aware Weight Quantization)
|
||||
|
||||
Calibration for 8B model takes ~10 minutes on one A100 gpu.
|
||||
|
||||
| Pros | Cons |
|
||||
|----------------------------------------------------------------------|-----------------------------------------------------|
|
||||
| Often achieves high accuracy at 4-bit. (Sometimes surpasses GPTQ on specific tasks.) | Requires calibration if quantizing yourself. |
|
||||
| Can lead to inference speedups. | |
|
||||
| Shorter calibration time than GPTQ. | |
|
||||
| Many pre-quantized AWQ models on [Hugging Face Hub](https://huggingface.co/models?other=awq). | |
|
||||
|
||||
See the [AWQ documentation](./awq) for more details.
|
||||
|
||||
### Loading Specific Formats
|
||||
|
||||
#### compressed-tensors
|
||||
|
||||
| Pros | Cons |
|
||||
|--------------------------------------------------------------|-------------------------------------------------------------|
|
||||
| Supports flexible formats including FP8 and sparsity. | Primarily for loading pre-quantized models. |
|
||||
| | Doesn't perform quantization within Transformers directly. |
|
||||
|
||||
See the [compressed-tensors documentation](./compressed_tensors) for more details.
|
||||
|
||||
## Fine-tuning
|
||||
|
||||
Consider the quantization method below during fine-tuning to save memory.
|
||||
|
||||
### bitsandbytes[[training]]
|
||||
|
||||
* **Description:** The standard method for QLoRA fine-tuning via PEFT.
|
||||
* **Pros:** Enables fine-tuning large models on consumer GPUs; widely supported and documented for PEFT.
|
||||
* **Cons:** Primarily for NVIDIA GPUs.
|
||||
|
||||
Other methods offer PEFT compatibility, though bitsandbytes is the most established and straightforward path for QLoRA.
|
||||
|
||||
See the [bitsandbytes documentation](./bitsandbytes#qlora) and [PEFT Docs](https://huggingface.co/docs/peft/developer_guides/quantization#aqlm-quantization) for more details.
|
||||
|
||||
## Research
|
||||
|
||||
Methods like [AQLM](./aqlm), [SpQR](./spqr), [VPTQ](./vptq), [HIGGS](./higgs), etc., push the boundaries of compression (< 2-bit) or explore novel techniques.
|
||||
|
||||
* Consider these if:
|
||||
* You need extreme compression (sub-4-bit).
|
||||
* You are conducting research or require state-of-the-art results from their respective papers.
|
||||
* You have significant compute resources available for potentially complex quantization procedures.
|
||||
We recommend consulting each methods documentation and associated papers carefully before choosing one for use in production.
|
||||
|
||||
|
||||
> [!TIP]
|
||||
> Always benchmark the performance (accuracy and speed) of the quantized model on your specific task and hardware to ensure it meets your requirements. Refer to the individual documentation pages linked above for detailed usage instructions.
|
||||
@ -341,29 +341,9 @@ use_cpu: false
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="Tensor parallelism with PyTorch 2">
|
||||
|
||||
```yaml
|
||||
compute_environment: LOCAL_MACHINE
|
||||
tp_config:
|
||||
tp_size: 4
|
||||
distributed_type: TP
|
||||
downcast_bf16: 'no'
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: 'no'
|
||||
num_machines: 1
|
||||
num_processes: 4
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
||||
```
|
||||
|
||||
</hfoptions>
|
||||
|
||||
|
||||
Run [accelerate_launch](https://hf.co/docs/accelerate/package_reference/cli#accelerate-launch) to start training with the configurations set in `config_file.yaml`. This file is saved to the Accelerate cache folder and automatically loaded when you run `accelerate_launch`.
|
||||
|
||||
The example below launches the [run_glue.py](../../../examples/pytorch/text-classification/run_glue) script with the FSDP configuration shown earlier. Parameters from the `config_file.yaml` file can also be directly set in the command line.
|
||||
|
||||
@ -363,29 +363,6 @@ use_cpu: false
|
||||
|
||||
</hfoption>
|
||||
|
||||
<hfoption id="Tensor Parallelism with PyTorch 2">
|
||||
|
||||
```yml
|
||||
compute_environment: LOCAL_MACHINE
|
||||
tp_config:
|
||||
tp_size: 4
|
||||
distributed_type: TP
|
||||
downcast_bf16: 'no'
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: 'no'
|
||||
num_machines: 1
|
||||
num_processes: 4
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
||||
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
El comando [`accelerate_launch`](https://huggingface.co/docs/accelerate/package_reference/cli#accelerate-launch) es la forma recomendada de lanzar tu script de entrenamiento en un sistema distribuido con Accelerate y [`Trainer`] con los parámetros especificados en `config_file.yaml`. Este archivo se guarda en la carpeta de caché de Accelerate y se carga automáticamente cuando ejecutas `accelerate_launch`.
|
||||
|
||||
@ -720,6 +720,8 @@
|
||||
title: (번역중) Perceiver
|
||||
- local: in_translation
|
||||
title: (번역중) Pix2Struct
|
||||
- local: model_doc/qwen2_vl
|
||||
title: Qwen2VL
|
||||
- local: in_translation
|
||||
title: (번역중) Segment Anything
|
||||
- local: in_translation
|
||||
|
||||
@ -99,7 +99,7 @@ model.generation_config.max_new_tokens = 16
|
||||
|
||||
past_key_values = StaticCache(
|
||||
config=model.config,
|
||||
batch_size=1,
|
||||
max_batch_size=1,
|
||||
# 캐시를 재사용할 계획이 있는 경우, 모든 경우에 충분한 캐시 길이를 설정해야 합니다
|
||||
max_cache_len=prompt_length+(model.generation_config.max_new_tokens*2),
|
||||
device=model.device,
|
||||
@ -109,7 +109,7 @@ outputs = model.generate(**input_ids, past_key_values=past_key_values)
|
||||
print(tokenizer.batch_decode(outputs, skip_special_tokens=True))
|
||||
['The theory of special relativity states 1. The speed of light is constant in all inertial reference frames. 2']
|
||||
|
||||
# 생성된 텍스트와 동일한 캐시 객체를 전달하여, 중단한 곳에서 생성을 계속합니다.
|
||||
# 생성된 텍스트와 동일한 캐시 객체를 전달하여, 중단한 곳에서 생성을 계속합니다.
|
||||
# 다중 턴 대화의 경우, 생성된 텍스트에 새로운 사용자 입력을 추가할 수 있습니다.
|
||||
new_input_ids = outputs
|
||||
outputs = model.generate(new_input_ids, past_key_values=past_key_values)
|
||||
|
||||
303
docs/source/ko/model_doc/qwen2_vl.md
Normal file
303
docs/source/ko/model_doc/qwen2_vl.md
Normal file
@ -0,0 +1,303 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# Qwen2-VL[[Qwen2-VL]]
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
|
||||
</div>
|
||||
|
||||
## Overview[[Overview]]
|
||||
|
||||
[Qwen2-VL](https://qwenlm.github.io/blog/qwen2-vl/) 모델은 알리바바 리서치의 Qwen팀에서 개발한 [Qwen-VL](https://arxiv.org/pdf/2308.12966) 모델의 주요 업데이트 버전입니다.
|
||||
|
||||
블로그의 요약은 다음과 같습니다:
|
||||
|
||||
*이 블로그는 지난 몇 년간 Qwen-VL에서 중대한 개선을 거쳐 발전된 Qwen2-VL 모델을 소개합니다. 중요 개선 사항은 향상된 이미지 이해, 고급 비디오 이해, 통합 시각 에이전트 기능, 확장된 다언어 지원을 포함하고 있습니다.모델 아키텍처는 Naive Dynamic Resolution 지원을 통해 임의의 이미지 해상도를 처리할 수 있도록 최적화되었으며, 멀티모달 회전 위치 임베딩(M-ROPE)을 활용하여 1D 텍스트와 다차원 시각 데이터를 효과적으로 처리합니다. 이 업데이트된 모델은 시각 관련 작업에서 GPT-4o와 Claude 3.5 Sonnet 같은 선도적인 AI 시스템과 경쟁력 있는 성능을 보여주며, 텍스트 능력에서는 오픈소스 모델 중 상위권에 랭크되어 있습니다. 이러한 발전은 Qwen2-VL을 강력한 멀티모달 처리 및 추론 능력이 필요한 다양한 응용 분야에서 활용할 수 있는 다재다능한 도구로 만들어줍니다.*
|
||||
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/qwen2_vl_architecture.jpeg"
|
||||
alt="drawing" width="600"/>
|
||||
|
||||
<small> Qwen2-VL 구조. 출처: <a href="https://qwenlm.github.io/blog/qwen2-vl/">블로그 게시글</a> </small>
|
||||
|
||||
이 모델은 [simonJJJ](https://huggingface.co/simonJJJ)에 의해 기여되었습니다.
|
||||
|
||||
## 사용 예시[[Usage example]]
|
||||
|
||||
### 단일 미디어 추론[[Single Media inference]]
|
||||
|
||||
이 모델은 이미지와 비디오를 모두 인풋으로 받을 수 있습니다. 다음은 추론을 위한 예제 코드입니다.
|
||||
|
||||
```python
|
||||
|
||||
import torch
|
||||
from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
|
||||
|
||||
# 사용 가능한 장치에서 모델을 반 정밀도(half-precision)로 로드
|
||||
model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", device_map="auto")
|
||||
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
|
||||
|
||||
|
||||
conversation = [
|
||||
{
|
||||
"role":"user",
|
||||
"content":[
|
||||
{
|
||||
"type":"image",
|
||||
"url": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg"
|
||||
},
|
||||
{
|
||||
"type":"text",
|
||||
"text":"Describe this image."
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
inputs = processor.apply_chat_template(
|
||||
conversation,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt"
|
||||
).to(model.device)
|
||||
|
||||
# 추론: 아웃풋 생성
|
||||
output_ids = model.generate(**inputs, max_new_tokens=128)
|
||||
generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, output_ids)]
|
||||
output_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
||||
print(output_text)
|
||||
|
||||
|
||||
|
||||
# 비디오
|
||||
conversation = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "video", "path": "/path/to/video.mp4"},
|
||||
{"type": "text", "text": "What happened in the video?"},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
inputs = processor.apply_chat_template(
|
||||
conversation,
|
||||
video_fps=1,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt"
|
||||
).to(model.device)
|
||||
|
||||
|
||||
# 추론: 아웃풋 생성
|
||||
output_ids = model.generate(**inputs, max_new_tokens=128)
|
||||
generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, output_ids)]
|
||||
output_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
||||
print(output_text)
|
||||
```
|
||||
|
||||
### 배치 혼합 미디어 추론[[Batch Mixed Media Inference]]
|
||||
|
||||
이 모델은 이미지, 비디오, 텍스트 등 다양한 유형의 데이터를 혼합하여 배치 입력으로 처리할 수 있습니다. 다음은 예제입니다.
|
||||
|
||||
```python
|
||||
|
||||
# 첫번째 이미지에 대한 대화
|
||||
conversation1 = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image", "path": "/path/to/image1.jpg"},
|
||||
{"type": "text", "text": "Describe this image."}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
# 두 개의 이미지에 대한 대화
|
||||
conversation2 = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image", "path": "/path/to/image2.jpg"},
|
||||
{"type": "image", "path": "/path/to/image3.jpg"},
|
||||
{"type": "text", "text": "What is written in the pictures?"}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
# 순수 텍스트로만 이루어진 대화
|
||||
conversation3 = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "who are you?"
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
# 혼합된 미디어로 이루어진 대화
|
||||
conversation4 = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image", "path": "/path/to/image3.jpg"},
|
||||
{"type": "image", "path": "/path/to/image4.jpg"},
|
||||
{"type": "video", "path": "/path/to/video.jpg"},
|
||||
{"type": "text", "text": "What are the common elements in these medias?"},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
conversations = [conversation1, conversation2, conversation3, conversation4]
|
||||
# 배치 추론을 위한 준비
|
||||
ipnuts = processor.apply_chat_template(
|
||||
conversations,
|
||||
video_fps=1,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt"
|
||||
).to(model.device)
|
||||
|
||||
|
||||
# 배치 추론
|
||||
output_ids = model.generate(**inputs, max_new_tokens=128)
|
||||
generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, output_ids)]
|
||||
output_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
||||
print(output_text)
|
||||
```
|
||||
|
||||
### 사용 팁[[Usage Tips]]
|
||||
|
||||
#### 이미지 해상도 트레이드오프[[Image Resolution trade-off]]
|
||||
|
||||
이 모델은 다양한 해상도의 입력을 지원합니다. 디폴트로 입력에 대해 네이티브(native) 해상도를 사용하지만, 더 높은 해상도를 적용하면 성능이 향상될 수 있습니다. 다만, 이는 더 많은 연산 비용을 초래합니다. 사용자는 최적의 설정을 위해 최소 및 최대 픽셀 수를 조정할 수 있습니다.
|
||||
|
||||
```python
|
||||
min_pixels = 224*224
|
||||
max_pixels = 2048*2048
|
||||
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels)
|
||||
```
|
||||
|
||||
제한된 GPU RAM의 경우, 다음과 같이 해상도를 줄일 수 있습니다:
|
||||
|
||||
```python
|
||||
min_pixels = 256*28*28
|
||||
max_pixels = 1024*28*28
|
||||
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels)
|
||||
```
|
||||
이렇게 하면 각 이미지가 256~1024개의 토큰으로 인코딩됩니다. 여기서 28은 모델이 14 크기의 패치(patch)와 2의 시간 패치(temporal patch size)를 사용하기 때문에 나온 값입니다 (14 × 2 = 28).
|
||||
|
||||
|
||||
#### 다중 이미지 인풋[[Multiple Image Inputs]]
|
||||
|
||||
기본적으로 이미지와 비디오 콘텐츠는 대화에 직접 포함됩니다. 여러 개의 이미지를 처리할 때는 이미지 및 비디오에 라벨을 추가하면 참조하기가 더 쉬워집니다. 사용자는 다음 설정을 통해 이 동작을 제어할 수 있습니다:
|
||||
|
||||
```python
|
||||
conversation = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image"},
|
||||
{"type": "text", "text": "Hello, how are you?"}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "I'm doing well, thank you for asking. How can I assist you today?"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "Can you describe these images and video?"},
|
||||
{"type": "image"},
|
||||
{"type": "image"},
|
||||
{"type": "video"},
|
||||
{"type": "text", "text": "These are from my vacation."}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "I'd be happy to describe the images and video for you. Could you please provide more context about your vacation?"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "It was a trip to the mountains. Can you see the details in the images and video?"
|
||||
}
|
||||
]
|
||||
|
||||
# 디폴트:
|
||||
prompt_without_id = processor.apply_chat_template(conversation, add_generation_prompt=True)
|
||||
# 예상 아웃풋: '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Hello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing well, thank you for asking. How can I assist you today?<|im_end|>\n<|im_start|>user\nCan you describe these images and video?<|vision_start|><|image_pad|><|vision_end|><|vision_start|><|image_pad|><|vision_end|><|vision_start|><|video_pad|><|vision_end|>These are from my vacation.<|im_end|>\n<|im_start|>assistant\nI'd be happy to describe the images and video for you. Could you please provide more context about your vacation?<|im_end|>\n<|im_start|>user\nIt was a trip to the mountains. Can you see the details in the images and video?<|im_end|>\n<|im_start|>assistant\n'
|
||||
|
||||
|
||||
# id 추가
|
||||
prompt_with_id = processor.apply_chat_template(conversation, add_generation_prompt=True, add_vision_id=True)
|
||||
# 예상 아웃풋: '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nPicture 1: <|vision_start|><|image_pad|><|vision_end|>Hello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing well, thank you for asking. How can I assist you today?<|im_end|>\n<|im_start|>user\nCan you describe these images and video?Picture 2: <|vision_start|><|image_pad|><|vision_end|>Picture 3: <|vision_start|><|image_pad|><|vision_end|>Video 1: <|vision_start|><|video_pad|><|vision_end|>These are from my vacation.<|im_end|>\n<|im_start|>assistant\nI'd be happy to describe the images and video for you. Could you please provide more context about your vacation?<|im_end|>\n<|im_start|>user\nIt was a trip to the mountains. Can you see the details in the images and video?<|im_end|>\n<|im_start|>assistant\n'
|
||||
|
||||
```
|
||||
|
||||
#### 빠른 생성을 위한 Flash-Attention 2[[Flash-Attention 2 to speed up generation]]
|
||||
|
||||
첫번째로, Flash Attention 2의 최신 버전을 설치합니다:
|
||||
|
||||
```bash
|
||||
pip install -U flash-attn --no-build-isolation
|
||||
```
|
||||
|
||||
또한, Flash-Attention 2를 지원하는 하드웨어가 필요합니다. 자세한 내용은 공식 문서인 [flash attention repository](https://github.com/Dao-AILab/flash-attention)에서 확인할 수 있습니다. FlashAttention-2는 모델이 `torch.float16` 또는 `torch.bfloat16` 형식으로 로드된 경우에만 사용할 수 있습니다.
|
||||
|
||||
Flash Attention-2를 사용하여 모델을 로드하고 실행하려면, 다음과 같이 모델을 로드할 때 `attn_implementation="flash_attention_2"` 옵션을 추가하면 됩니다:
|
||||
|
||||
```python
|
||||
from transformers import Qwen2VLForConditionalGeneration
|
||||
|
||||
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
||||
"Qwen/Qwen2-VL-7B-Instruct",
|
||||
torch_dtype=torch.bfloat16,
|
||||
attn_implementation="flash_attention_2",
|
||||
)
|
||||
```
|
||||
|
||||
## Qwen2VLConfig
|
||||
|
||||
[[autodoc]] Qwen2VLConfig
|
||||
|
||||
## Qwen2VLImageProcessor
|
||||
|
||||
[[autodoc]] Qwen2VLImageProcessor
|
||||
- preprocess
|
||||
|
||||
## Qwen2VLImageProcessorFast
|
||||
|
||||
[[autodoc]] Qwen2VLImageProcessorFast
|
||||
- preprocess
|
||||
|
||||
## Qwen2VLProcessor
|
||||
|
||||
[[autodoc]] Qwen2VLProcessor
|
||||
|
||||
## Qwen2VLModel
|
||||
|
||||
[[autodoc]] Qwen2VLModel
|
||||
- forward
|
||||
|
||||
## Qwen2VLForConditionalGeneration
|
||||
|
||||
[[autodoc]] Qwen2VLForConditionalGeneration
|
||||
- forward
|
||||
@ -549,29 +549,7 @@ use_cpu: false
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="Tensor Parallelism with PyTorch 2">
|
||||
|
||||
```yml
|
||||
compute_environment: LOCAL_MACHINE
|
||||
tp_config:
|
||||
tp_size: 4
|
||||
distributed_type: TP
|
||||
downcast_bf16: 'no'
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: 'no'
|
||||
num_machines: 1
|
||||
num_processes: 4
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
||||
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
[`accelerate_launch`](https://huggingface.co/docs/accelerate/package_reference/cli#accelerate-launch) 명령은 Accelerate와 [`Trainer`]를 사용하여 분산 시스템에서 훈련 스크립트를 실행하는 권장 방법이며, `config_file.yaml`에 지정된 매개변수를 사용합니다. 이 파일은 Accelerate 캐시 폴더에 저장되며 `accelerate_launch`를 실행할 때 자동으로 로드됩니다.
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Team All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Team All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -833,8 +832,7 @@ def main():
|
||||
# No need to shuffle here
|
||||
loader = data_loader(rng, _ds, batch_size=batch_size, shuffle=False)
|
||||
|
||||
for batch in loader:
|
||||
yield batch
|
||||
yield from loader
|
||||
|
||||
# Metric
|
||||
metric = evaluate.load("rouge", cache_dir=model_args.cache_dir)
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Team All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -30,7 +29,7 @@ from dataclasses import asdict, dataclass, field
|
||||
from enum import Enum
|
||||
from itertools import chain
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Optional
|
||||
|
||||
import flax
|
||||
import jax
|
||||
@ -294,7 +293,7 @@ class FlaxDataCollatorForBartDenoisingLM:
|
||||
" language modeling. "
|
||||
)
|
||||
|
||||
def __call__(self, examples: List[Dict[str, List[int]]]) -> BatchEncoding:
|
||||
def __call__(self, examples: list[dict[str, list[int]]]) -> BatchEncoding:
|
||||
# convert list to dict and tensorize input
|
||||
batch = BatchEncoding(
|
||||
{k: np.array([examples[i][k] for i in range(len(examples))]) for k, v in examples[0].items()}
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Team All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Team All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -33,7 +32,7 @@ from itertools import chain
|
||||
|
||||
# You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments.
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import flax
|
||||
import jax
|
||||
@ -302,7 +301,7 @@ class FlaxDataCollatorForLanguageModeling:
|
||||
"You should pass `mlm=False` to train on causal language modeling instead."
|
||||
)
|
||||
|
||||
def __call__(self, examples: List[Dict[str, np.ndarray]], pad_to_multiple_of: int) -> Dict[str, np.ndarray]:
|
||||
def __call__(self, examples: list[dict[str, np.ndarray]], pad_to_multiple_of: int) -> dict[str, np.ndarray]:
|
||||
# Handle dict or lists with proper padding and conversion to tensor.
|
||||
batch = self.tokenizer.pad(examples, pad_to_multiple_of=pad_to_multiple_of, return_tensors=TensorType.NUMPY)
|
||||
|
||||
@ -316,7 +315,7 @@ class FlaxDataCollatorForLanguageModeling:
|
||||
|
||||
def mask_tokens(
|
||||
self, inputs: np.ndarray, special_tokens_mask: Optional[np.ndarray]
|
||||
) -> Tuple[np.ndarray, np.ndarray]:
|
||||
) -> tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
|
||||
"""
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Team All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -32,7 +31,7 @@ from dataclasses import asdict, dataclass, field
|
||||
from enum import Enum
|
||||
from itertools import chain
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Optional
|
||||
|
||||
import flax
|
||||
import jax
|
||||
@ -338,7 +337,7 @@ class FlaxDataCollatorForT5MLM:
|
||||
pad_token_id: int
|
||||
decoder_start_token_id: int
|
||||
|
||||
def __call__(self, examples: List[Dict[str, np.ndarray]]) -> BatchEncoding:
|
||||
def __call__(self, examples: list[dict[str, np.ndarray]]) -> BatchEncoding:
|
||||
# convert list to dict and tensorize input
|
||||
batch = BatchEncoding(
|
||||
{k: np.array([examples[i][k] for i in range(len(examples))]) for k, v in examples[0].items()}
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
#!/usr/bin/env python3
|
||||
import json
|
||||
from typing import Iterator, List, Union
|
||||
from collections.abc import Iterator
|
||||
from typing import Union
|
||||
|
||||
from tokenizers import AddedToken, Regex, Tokenizer, decoders, normalizers, pre_tokenizers, trainers
|
||||
from tokenizers.implementations.base_tokenizer import BaseTokenizer
|
||||
@ -72,7 +73,7 @@ class SentencePieceUnigramTokenizer(BaseTokenizer):
|
||||
|
||||
def train(
|
||||
self,
|
||||
files: Union[str, List[str]],
|
||||
files: Union[str, list[str]],
|
||||
vocab_size: int = 8000,
|
||||
show_progress: bool = True,
|
||||
):
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Team All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -28,7 +27,7 @@ import time
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, Optional, Tuple
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import datasets
|
||||
import evaluate
|
||||
@ -61,7 +60,7 @@ from transformers.utils import check_min_version, send_example_telemetry
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.51.0.dev0")
|
||||
check_min_version("4.52.0.dev0")
|
||||
|
||||
Array = Any
|
||||
Dataset = datasets.arrow_dataset.Dataset
|
||||
@ -908,8 +907,8 @@ def main():
|
||||
|
||||
# region Define train step functions
|
||||
def train_step(
|
||||
state: train_state.TrainState, batch: Dict[str, Array], dropout_rng: PRNGKey
|
||||
) -> Tuple[train_state.TrainState, float]:
|
||||
state: train_state.TrainState, batch: dict[str, Array], dropout_rng: PRNGKey
|
||||
) -> tuple[train_state.TrainState, float]:
|
||||
"""Trains model with an optimizer (both in `state`) on `batch`, returning a pair `(new_state, loss)`."""
|
||||
dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
|
||||
start_positions = batch.pop("start_positions")
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The HuggingFace Team All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -20,7 +19,7 @@ import collections
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
from tqdm.auto import tqdm
|
||||
@ -32,7 +31,7 @@ logger = logging.getLogger(__name__)
|
||||
def postprocess_qa_predictions(
|
||||
examples,
|
||||
features,
|
||||
predictions: Tuple[np.ndarray, np.ndarray],
|
||||
predictions: tuple[np.ndarray, np.ndarray],
|
||||
version_2_with_negative: bool = False,
|
||||
n_best_size: int = 20,
|
||||
max_answer_length: int = 30,
|
||||
@ -223,7 +222,7 @@ def postprocess_qa_predictions(
|
||||
# If we have an output_dir, let's save all those dicts.
|
||||
if output_dir is not None:
|
||||
if not os.path.isdir(output_dir):
|
||||
raise EnvironmentError(f"{output_dir} is not a directory.")
|
||||
raise OSError(f"{output_dir} is not a directory.")
|
||||
|
||||
prediction_file = os.path.join(
|
||||
output_dir, "predictions.json" if prefix is None else f"{prefix}_predictions.json"
|
||||
@ -253,7 +252,7 @@ def postprocess_qa_predictions(
|
||||
def postprocess_qa_predictions_with_beam_search(
|
||||
examples,
|
||||
features,
|
||||
predictions: Tuple[np.ndarray, np.ndarray],
|
||||
predictions: tuple[np.ndarray, np.ndarray],
|
||||
version_2_with_negative: bool = False,
|
||||
n_best_size: int = 20,
|
||||
max_answer_length: int = 30,
|
||||
@ -417,7 +416,7 @@ def postprocess_qa_predictions_with_beam_search(
|
||||
# If we have an output_dir, let's save all those dicts.
|
||||
if output_dir is not None:
|
||||
if not os.path.isdir(output_dir):
|
||||
raise EnvironmentError(f"{output_dir} is not a directory.")
|
||||
raise OSError(f"{output_dir} is not a directory.")
|
||||
|
||||
prediction_file = os.path.join(
|
||||
output_dir, "predictions.json" if prefix is None else f"{prefix}_predictions.json"
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -25,7 +24,7 @@ import time
|
||||
from dataclasses import field
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import datasets
|
||||
import evaluate
|
||||
@ -60,7 +59,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risk.
|
||||
check_min_version("4.51.0.dev0")
|
||||
check_min_version("4.52.0.dev0")
|
||||
|
||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/flax/speech-recognition/requirements.txt")
|
||||
|
||||
@ -303,7 +302,7 @@ class FlaxDataCollatorSpeechSeq2SeqWithPadding:
|
||||
pad_input_to_multiple_of: Optional[int] = None
|
||||
pad_target_to_multiple_of: Optional[int] = None
|
||||
|
||||
def __call__(self, features: List[Dict[str, Union[List[int], np.ndarray]]]) -> Dict[str, np.ndarray]:
|
||||
def __call__(self, features: list[dict[str, Union[list[int], np.ndarray]]]) -> dict[str, np.ndarray]:
|
||||
# split inputs and labels since they have to be of different lengths and need
|
||||
# different padding methods
|
||||
model_input_name = self.processor.model_input_names[0]
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Team All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -64,7 +63,7 @@ def get_setup_file():
|
||||
def get_results(output_dir, split="eval"):
|
||||
path = os.path.join(output_dir, f"{split}_results.json")
|
||||
if os.path.exists(path):
|
||||
with open(path, "r") as f:
|
||||
with open(path) as f:
|
||||
return json.load(f)
|
||||
raise ValueError(f"can't find {path}")
|
||||
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -25,7 +24,7 @@ import time
|
||||
import warnings
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, Optional, Tuple
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import datasets
|
||||
import evaluate
|
||||
@ -56,7 +55,7 @@ from transformers.utils import check_min_version, send_example_telemetry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.51.0.dev0")
|
||||
check_min_version("4.52.0.dev0")
|
||||
|
||||
Array = Any
|
||||
Dataset = datasets.arrow_dataset.Dataset
|
||||
@ -572,8 +571,8 @@ def main():
|
||||
|
||||
# define step functions
|
||||
def train_step(
|
||||
state: train_state.TrainState, batch: Dict[str, Array], dropout_rng: PRNGKey
|
||||
) -> Tuple[train_state.TrainState, float]:
|
||||
state: train_state.TrainState, batch: dict[str, Array], dropout_rng: PRNGKey
|
||||
) -> tuple[train_state.TrainState, float]:
|
||||
"""Trains model with an optimizer (both in `state`) on `batch`, returning a pair `(new_state, loss)`."""
|
||||
dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
|
||||
targets = batch.pop("labels")
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -27,7 +26,7 @@ from dataclasses import asdict, dataclass, field
|
||||
from enum import Enum
|
||||
from itertools import chain
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, Optional, Tuple
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import datasets
|
||||
import evaluate
|
||||
@ -57,7 +56,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.51.0.dev0")
|
||||
check_min_version("4.52.0.dev0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")
|
||||
|
||||
@ -651,8 +650,8 @@ def main():
|
||||
|
||||
# define step functions
|
||||
def train_step(
|
||||
state: train_state.TrainState, batch: Dict[str, Array], dropout_rng: PRNGKey
|
||||
) -> Tuple[train_state.TrainState, float]:
|
||||
state: train_state.TrainState, batch: dict[str, Array], dropout_rng: PRNGKey
|
||||
) -> tuple[train_state.TrainState, float]:
|
||||
"""Trains model with an optimizer (both in `state`) on `batch`, returning a pair `(new_state, loss)`."""
|
||||
dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
|
||||
targets = batch.pop("labels")
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Team All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
||||
@ -15,7 +15,7 @@
|
||||
import csv
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
@ -59,7 +59,7 @@ class PlotArguments:
|
||||
default=None,
|
||||
metadata={"help": "Filename under which the plot will be saved. If unused no plot is saved."},
|
||||
)
|
||||
short_model_names: Optional[List[str]] = list_field(
|
||||
short_model_names: Optional[list[str]] = list_field(
|
||||
default=None, metadata={"help": "List of model names that are used instead of the ones in the csv file."}
|
||||
)
|
||||
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The HuggingFace Inc. team.
|
||||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
||||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
@ -18,7 +17,7 @@
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, Optional
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
from utils_multiple_choice import MultipleChoiceDataset, Split, processors
|
||||
@ -187,7 +186,7 @@ def main():
|
||||
else None
|
||||
)
|
||||
|
||||
def compute_metrics(p: EvalPrediction) -> Dict:
|
||||
def compute_metrics(p: EvalPrediction) -> dict:
|
||||
preds = np.argmax(p.predictions, axis=1)
|
||||
return {"acc": simple_accuracy(preds, p.label_ids)}
|
||||
|
||||
@ -228,7 +227,7 @@ def main():
|
||||
logger.info("***** Eval results *****")
|
||||
for key, value in result.items():
|
||||
logger.info(" %s = %s", key, value)
|
||||
writer.write("%s = %s\n" % (key, value))
|
||||
writer.write("{} = {}\n".format(key, value))
|
||||
|
||||
results.update(result)
|
||||
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
||||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
@ -22,7 +21,7 @@ import logging
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
import tqdm
|
||||
from filelock import FileLock
|
||||
@ -49,8 +48,8 @@ class InputExample:
|
||||
|
||||
example_id: str
|
||||
question: str
|
||||
contexts: List[str]
|
||||
endings: List[str]
|
||||
contexts: list[str]
|
||||
endings: list[str]
|
||||
label: Optional[str]
|
||||
|
||||
|
||||
@ -62,9 +61,9 @@ class InputFeatures:
|
||||
"""
|
||||
|
||||
example_id: str
|
||||
input_ids: List[List[int]]
|
||||
attention_mask: Optional[List[List[int]]]
|
||||
token_type_ids: Optional[List[List[int]]]
|
||||
input_ids: list[list[int]]
|
||||
attention_mask: Optional[list[list[int]]]
|
||||
token_type_ids: Optional[list[list[int]]]
|
||||
label: Optional[int]
|
||||
|
||||
|
||||
@ -84,7 +83,7 @@ if is_torch_available():
|
||||
soon.
|
||||
"""
|
||||
|
||||
features: List[InputFeatures]
|
||||
features: list[InputFeatures]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -149,7 +148,7 @@ if is_tf_available():
|
||||
soon.
|
||||
"""
|
||||
|
||||
features: List[InputFeatures]
|
||||
features: list[InputFeatures]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -253,7 +252,7 @@ class RaceProcessor(DataProcessor):
|
||||
|
||||
def get_train_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
logger.info("LOOKING AT {} train".format(data_dir))
|
||||
logger.info(f"LOOKING AT {data_dir} train")
|
||||
high = os.path.join(data_dir, "train/high")
|
||||
middle = os.path.join(data_dir, "train/middle")
|
||||
high = self._read_txt(high)
|
||||
@ -262,7 +261,7 @@ class RaceProcessor(DataProcessor):
|
||||
|
||||
def get_dev_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
logger.info("LOOKING AT {} dev".format(data_dir))
|
||||
logger.info(f"LOOKING AT {data_dir} dev")
|
||||
high = os.path.join(data_dir, "dev/high")
|
||||
middle = os.path.join(data_dir, "dev/middle")
|
||||
high = self._read_txt(high)
|
||||
@ -271,7 +270,7 @@ class RaceProcessor(DataProcessor):
|
||||
|
||||
def get_test_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
logger.info("LOOKING AT {} test".format(data_dir))
|
||||
logger.info(f"LOOKING AT {data_dir} test")
|
||||
high = os.path.join(data_dir, "test/high")
|
||||
middle = os.path.join(data_dir, "test/middle")
|
||||
high = self._read_txt(high)
|
||||
@ -286,7 +285,7 @@ class RaceProcessor(DataProcessor):
|
||||
lines = []
|
||||
files = glob.glob(input_dir + "/*txt")
|
||||
for file in tqdm.tqdm(files, desc="read files"):
|
||||
with open(file, "r", encoding="utf-8") as fin:
|
||||
with open(file, encoding="utf-8") as fin:
|
||||
data_raw = json.load(fin)
|
||||
data_raw["race_id"] = file
|
||||
lines.append(data_raw)
|
||||
@ -296,7 +295,7 @@ class RaceProcessor(DataProcessor):
|
||||
"""Creates examples for the training and dev sets."""
|
||||
examples = []
|
||||
for _, data_raw in enumerate(lines):
|
||||
race_id = "%s-%s" % (set_type, data_raw["race_id"])
|
||||
race_id = "{}-{}".format(set_type, data_raw["race_id"])
|
||||
article = data_raw["article"]
|
||||
for i in range(len(data_raw["answers"])):
|
||||
truth = str(ord(data_raw["answers"][i]) - ord("A"))
|
||||
@ -320,17 +319,17 @@ class SynonymProcessor(DataProcessor):
|
||||
|
||||
def get_train_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
logger.info("LOOKING AT {} train".format(data_dir))
|
||||
logger.info(f"LOOKING AT {data_dir} train")
|
||||
return self._create_examples(self._read_csv(os.path.join(data_dir, "mctrain.csv")), "train")
|
||||
|
||||
def get_dev_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
logger.info("LOOKING AT {} dev".format(data_dir))
|
||||
logger.info(f"LOOKING AT {data_dir} dev")
|
||||
return self._create_examples(self._read_csv(os.path.join(data_dir, "mchp.csv")), "dev")
|
||||
|
||||
def get_test_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
logger.info("LOOKING AT {} dev".format(data_dir))
|
||||
logger.info(f"LOOKING AT {data_dir} dev")
|
||||
|
||||
return self._create_examples(self._read_csv(os.path.join(data_dir, "mctest.csv")), "test")
|
||||
|
||||
@ -339,10 +338,10 @@ class SynonymProcessor(DataProcessor):
|
||||
return ["0", "1", "2", "3", "4"]
|
||||
|
||||
def _read_csv(self, input_file):
|
||||
with open(input_file, "r", encoding="utf-8") as f:
|
||||
with open(input_file, encoding="utf-8") as f:
|
||||
return list(csv.reader(f))
|
||||
|
||||
def _create_examples(self, lines: List[List[str]], type: str):
|
||||
def _create_examples(self, lines: list[list[str]], type: str):
|
||||
"""Creates examples for the training and dev sets."""
|
||||
|
||||
examples = [
|
||||
@ -366,17 +365,17 @@ class SwagProcessor(DataProcessor):
|
||||
|
||||
def get_train_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
logger.info("LOOKING AT {} train".format(data_dir))
|
||||
logger.info(f"LOOKING AT {data_dir} train")
|
||||
return self._create_examples(self._read_csv(os.path.join(data_dir, "train.csv")), "train")
|
||||
|
||||
def get_dev_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
logger.info("LOOKING AT {} dev".format(data_dir))
|
||||
logger.info(f"LOOKING AT {data_dir} dev")
|
||||
return self._create_examples(self._read_csv(os.path.join(data_dir, "val.csv")), "dev")
|
||||
|
||||
def get_test_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
logger.info("LOOKING AT {} dev".format(data_dir))
|
||||
logger.info(f"LOOKING AT {data_dir} dev")
|
||||
raise ValueError(
|
||||
"For swag testing, the input file does not contain a label column. It can not be tested in current code "
|
||||
"setting!"
|
||||
@ -388,10 +387,10 @@ class SwagProcessor(DataProcessor):
|
||||
return ["0", "1", "2", "3"]
|
||||
|
||||
def _read_csv(self, input_file):
|
||||
with open(input_file, "r", encoding="utf-8") as f:
|
||||
with open(input_file, encoding="utf-8") as f:
|
||||
return list(csv.reader(f))
|
||||
|
||||
def _create_examples(self, lines: List[List[str]], type: str):
|
||||
def _create_examples(self, lines: list[list[str]], type: str):
|
||||
"""Creates examples for the training and dev sets."""
|
||||
if type == "train" and lines[0][-1] != "label":
|
||||
raise ValueError("For training, the input file must contain a label column.")
|
||||
@ -417,16 +416,16 @@ class ArcProcessor(DataProcessor):
|
||||
|
||||
def get_train_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
logger.info("LOOKING AT {} train".format(data_dir))
|
||||
logger.info(f"LOOKING AT {data_dir} train")
|
||||
return self._create_examples(self._read_json(os.path.join(data_dir, "train.jsonl")), "train")
|
||||
|
||||
def get_dev_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
logger.info("LOOKING AT {} dev".format(data_dir))
|
||||
logger.info(f"LOOKING AT {data_dir} dev")
|
||||
return self._create_examples(self._read_json(os.path.join(data_dir, "dev.jsonl")), "dev")
|
||||
|
||||
def get_test_examples(self, data_dir):
|
||||
logger.info("LOOKING AT {} test".format(data_dir))
|
||||
logger.info(f"LOOKING AT {data_dir} test")
|
||||
return self._create_examples(self._read_json(os.path.join(data_dir, "test.jsonl")), "test")
|
||||
|
||||
def get_labels(self):
|
||||
@ -434,7 +433,7 @@ class ArcProcessor(DataProcessor):
|
||||
return ["0", "1", "2", "3"]
|
||||
|
||||
def _read_json(self, input_file):
|
||||
with open(input_file, "r", encoding="utf-8") as fin:
|
||||
with open(input_file, encoding="utf-8") as fin:
|
||||
lines = fin.readlines()
|
||||
return lines
|
||||
|
||||
@ -504,11 +503,11 @@ class ArcProcessor(DataProcessor):
|
||||
|
||||
|
||||
def convert_examples_to_features(
|
||||
examples: List[InputExample],
|
||||
label_list: List[str],
|
||||
examples: list[InputExample],
|
||||
label_list: list[str],
|
||||
max_length: int,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
) -> List[InputFeatures]:
|
||||
) -> list[InputFeatures]:
|
||||
"""
|
||||
Loads a data file into a list of `InputFeatures`
|
||||
"""
|
||||
|
||||
@ -2,7 +2,7 @@ import argparse
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.utilities import rank_zero_info
|
||||
@ -201,7 +201,7 @@ class BaseTransformer(pl.LightningModule):
|
||||
)
|
||||
|
||||
@pl.utilities.rank_zero_only
|
||||
def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
|
||||
def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None:
|
||||
save_path = self.output_dir.joinpath("best_tfmr")
|
||||
self.model.config.save_step = self.step_count
|
||||
self.model.save_pretrained(save_path)
|
||||
@ -282,7 +282,7 @@ class LoggingCallback(pl.Callback):
|
||||
# Log results
|
||||
for key in sorted(metrics):
|
||||
if key not in ["log", "progress_bar"]:
|
||||
rank_zero_info("{} = {}\n".format(key, str(metrics[key])))
|
||||
rank_zero_info(f"{key} = {str(metrics[key])}\n")
|
||||
|
||||
def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
|
||||
rank_zero_info("***** Test results *****")
|
||||
@ -292,8 +292,8 @@ class LoggingCallback(pl.Callback):
|
||||
with open(output_test_results_file, "w") as writer:
|
||||
for key in sorted(metrics):
|
||||
if key not in ["log", "progress_bar"]:
|
||||
rank_zero_info("{} = {}\n".format(key, str(metrics[key])))
|
||||
writer.write("{} = {}\n".format(key, str(metrics[key])))
|
||||
rank_zero_info(f"{key} = {str(metrics[key])}\n")
|
||||
writer.write(f"{key} = {str(metrics[key])}\n")
|
||||
|
||||
|
||||
def add_generic_args(parser, root_dir) -> None:
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
||||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
@ -68,7 +67,7 @@ def set_seed(args):
|
||||
|
||||
|
||||
def to_list(tensor):
|
||||
return tensor.detach().cpu().tolist()
|
||||
return tensor.tolist()
|
||||
|
||||
|
||||
def train(args, train_dataset, model, tokenizer):
|
||||
@ -231,14 +230,14 @@ def train(args, train_dataset, model, tokenizer):
|
||||
if args.local_rank == -1 and args.evaluate_during_training:
|
||||
results = evaluate(args, model, tokenizer)
|
||||
for key, value in results.items():
|
||||
tb_writer.add_scalar("eval_{}".format(key), value, global_step)
|
||||
tb_writer.add_scalar(f"eval_{key}", value, global_step)
|
||||
tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step)
|
||||
tb_writer.add_scalar("loss", (tr_loss - logging_loss) / args.logging_steps, global_step)
|
||||
logging_loss = tr_loss
|
||||
|
||||
# Save model checkpoint
|
||||
if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
|
||||
output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step))
|
||||
output_dir = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
||||
# Take care of distributed/parallel training
|
||||
model_to_save = model.module if hasattr(model, "module") else model
|
||||
model_to_save.save_pretrained(output_dir)
|
||||
@ -281,7 +280,7 @@ def evaluate(args, model, tokenizer, prefix=""):
|
||||
model = torch.nn.DataParallel(model)
|
||||
|
||||
# Eval!
|
||||
logger.info("***** Running evaluation {} *****".format(prefix))
|
||||
logger.info(f"***** Running evaluation {prefix} *****")
|
||||
logger.info(" Num examples = %d", len(dataset))
|
||||
logger.info(" Batch size = %d", args.eval_batch_size)
|
||||
|
||||
@ -348,11 +347,11 @@ def evaluate(args, model, tokenizer, prefix=""):
|
||||
logger.info(" Evaluation done in total %f secs (%f sec per example)", evalTime, evalTime / len(dataset))
|
||||
|
||||
# Compute predictions
|
||||
output_prediction_file = os.path.join(args.output_dir, "predictions_{}.json".format(prefix))
|
||||
output_nbest_file = os.path.join(args.output_dir, "nbest_predictions_{}.json".format(prefix))
|
||||
output_prediction_file = os.path.join(args.output_dir, f"predictions_{prefix}.json")
|
||||
output_nbest_file = os.path.join(args.output_dir, f"nbest_predictions_{prefix}.json")
|
||||
|
||||
if args.version_2_with_negative:
|
||||
output_null_log_odds_file = os.path.join(args.output_dir, "null_odds_{}.json".format(prefix))
|
||||
output_null_log_odds_file = os.path.join(args.output_dir, f"null_odds_{prefix}.json")
|
||||
else:
|
||||
output_null_log_odds_file = None
|
||||
|
||||
@ -828,10 +827,10 @@ def main():
|
||||
# Evaluate
|
||||
result = evaluate(args, model, tokenizer, prefix=global_step)
|
||||
|
||||
result = {k + ("_{}".format(global_step) if global_step else ""): v for k, v in result.items()}
|
||||
result = {k + (f"_{global_step}" if global_step else ""): v for k, v in result.items()}
|
||||
results.update(result)
|
||||
|
||||
logger.info("Results: {}".format(results))
|
||||
logger.info(f"Results: {results}")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
||||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
|
||||
@ -20,10 +20,10 @@ def fill_mask(masked_input, model, tokenizer, topk=5):
|
||||
topk_filled_outputs = []
|
||||
for index, predicted_token_bpe in enumerate(topk_predicted_token_bpe.split(" ")):
|
||||
predicted_token = predicted_token_bpe.replace("\u2581", " ")
|
||||
if " {0}".format(masked_token) in masked_input:
|
||||
if f" {masked_token}" in masked_input:
|
||||
topk_filled_outputs.append(
|
||||
(
|
||||
masked_input.replace(" {0}".format(masked_token), predicted_token),
|
||||
masked_input.replace(f" {masked_token}", predicted_token),
|
||||
values[index].item(),
|
||||
predicted_token,
|
||||
)
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
#!/usr/bin/env python
|
||||
import argparse
|
||||
import json
|
||||
from typing import List
|
||||
|
||||
from ltp import LTP
|
||||
|
||||
@ -42,7 +41,7 @@ def is_chinese(word: str):
|
||||
return 1
|
||||
|
||||
|
||||
def get_chinese_word(tokens: List[str]):
|
||||
def get_chinese_word(tokens: list[str]):
|
||||
word_set = set()
|
||||
|
||||
for token in tokens:
|
||||
@ -53,7 +52,7 @@ def get_chinese_word(tokens: List[str]):
|
||||
return word_list
|
||||
|
||||
|
||||
def add_sub_symbol(bert_tokens: List[str], chinese_word_set: set()):
|
||||
def add_sub_symbol(bert_tokens: list[str], chinese_word_set: set()):
|
||||
if not chinese_word_set:
|
||||
return bert_tokens
|
||||
max_word_len = max([len(w) for w in chinese_word_set])
|
||||
@ -77,7 +76,7 @@ def add_sub_symbol(bert_tokens: List[str], chinese_word_set: set()):
|
||||
return bert_word
|
||||
|
||||
|
||||
def prepare_ref(lines: List[str], ltp_tokenizer: LTP, bert_tokenizer: BertTokenizer):
|
||||
def prepare_ref(lines: list[str], ltp_tokenizer: LTP, bert_tokenizer: BertTokenizer):
|
||||
ltp_res = []
|
||||
|
||||
for i in range(0, len(lines), 100):
|
||||
@ -117,7 +116,7 @@ def prepare_ref(lines: List[str], ltp_tokenizer: LTP, bert_tokenizer: BertTokeni
|
||||
def main(args):
|
||||
# For Chinese (Ro)Bert, the best result is from : RoBERTa-wwm-ext (https://github.com/ymcui/Chinese-BERT-wwm)
|
||||
# If we want to fine-tune these model, we have to use same tokenizer : LTP (https://github.com/HIT-SCIR/ltp)
|
||||
with open(args.file_name, "r", encoding="utf-8") as f:
|
||||
with open(args.file_name, encoding="utf-8") as f:
|
||||
data = f.readlines()
|
||||
data = [line.strip() for line in data if len(line) > 0 and not line.isspace()] # avoid delimiter like '\u2029'
|
||||
ltp_tokenizer = LTP(args.ltp) # faster in GPU device
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
||||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
@ -358,7 +357,7 @@ def main():
|
||||
logger.info("***** Eval results *****")
|
||||
for key in sorted(result.keys()):
|
||||
logger.info(" %s = %s", key, str(result[key]))
|
||||
writer.write("%s = %s\n" % (key, str(result[key])))
|
||||
writer.write("{} = {}\n".format(key, str(result[key])))
|
||||
|
||||
results.update(result)
|
||||
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.
|
||||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
@ -163,7 +162,7 @@ def main():
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
n_gpu = torch.cuda.device_count()
|
||||
logger.info("device: {}, n_gpu {}".format(device, n_gpu))
|
||||
logger.info(f"device: {device}, n_gpu {n_gpu}")
|
||||
|
||||
if not args.do_train and not args.do_eval:
|
||||
raise ValueError("At least one of `do_train` or `do_eval` must be True.")
|
||||
@ -261,7 +260,7 @@ def main():
|
||||
loss.item() if exp_average_loss is None else 0.7 * exp_average_loss + 0.3 * loss.item()
|
||||
)
|
||||
nb_tr_steps += 1
|
||||
tqdm_bar.desc = "Training loss: {:.2e} lr: {:.2e}".format(exp_average_loss, scheduler.get_lr()[0])
|
||||
tqdm_bar.desc = f"Training loss: {exp_average_loss:.2e} lr: {scheduler.get_lr()[0]:.2e}"
|
||||
|
||||
# Save a trained model
|
||||
if args.do_train:
|
||||
@ -313,7 +312,7 @@ def main():
|
||||
logger.info("***** Eval results *****")
|
||||
for key in sorted(result.keys()):
|
||||
logger.info(" %s = %s", key, str(result[key]))
|
||||
writer.write("%s = %s\n" % (key, str(result[key])))
|
||||
writer.write("{} = {}\n".format(key, str(result[key])))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
||||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
@ -51,7 +50,7 @@ except ImportError:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SwagExample(object):
|
||||
class SwagExample:
|
||||
"""A single training/test example for the SWAG dataset."""
|
||||
|
||||
def __init__(self, swag_id, context_sentence, start_ending, ending_0, ending_1, ending_2, ending_3, label=None):
|
||||
@ -71,22 +70,22 @@ class SwagExample(object):
|
||||
|
||||
def __repr__(self):
|
||||
attributes = [
|
||||
"swag_id: {}".format(self.swag_id),
|
||||
"context_sentence: {}".format(self.context_sentence),
|
||||
"start_ending: {}".format(self.start_ending),
|
||||
"ending_0: {}".format(self.endings[0]),
|
||||
"ending_1: {}".format(self.endings[1]),
|
||||
"ending_2: {}".format(self.endings[2]),
|
||||
"ending_3: {}".format(self.endings[3]),
|
||||
f"swag_id: {self.swag_id}",
|
||||
f"context_sentence: {self.context_sentence}",
|
||||
f"start_ending: {self.start_ending}",
|
||||
f"ending_0: {self.endings[0]}",
|
||||
f"ending_1: {self.endings[1]}",
|
||||
f"ending_2: {self.endings[2]}",
|
||||
f"ending_3: {self.endings[3]}",
|
||||
]
|
||||
|
||||
if self.label is not None:
|
||||
attributes.append("label: {}".format(self.label))
|
||||
attributes.append(f"label: {self.label}")
|
||||
|
||||
return ", ".join(attributes)
|
||||
|
||||
|
||||
class InputFeatures(object):
|
||||
class InputFeatures:
|
||||
def __init__(self, example_id, choices_features, label):
|
||||
self.example_id = example_id
|
||||
self.choices_features = [
|
||||
@ -97,7 +96,7 @@ class InputFeatures(object):
|
||||
|
||||
|
||||
def read_swag_examples(input_file, is_training=True):
|
||||
with open(input_file, "r", encoding="utf-8") as f:
|
||||
with open(input_file, encoding="utf-8") as f:
|
||||
lines = list(csv.reader(f))
|
||||
|
||||
if is_training and lines[0][-1] != "label":
|
||||
@ -179,15 +178,15 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length, is_trainin
|
||||
label = example.label
|
||||
if example_index < 5:
|
||||
logger.info("*** Example ***")
|
||||
logger.info("swag_id: {}".format(example.swag_id))
|
||||
logger.info(f"swag_id: {example.swag_id}")
|
||||
for choice_idx, (tokens, input_ids, input_mask, segment_ids) in enumerate(choices_features):
|
||||
logger.info("choice: {}".format(choice_idx))
|
||||
logger.info(f"choice: {choice_idx}")
|
||||
logger.info("tokens: {}".format(" ".join(tokens)))
|
||||
logger.info("input_ids: {}".format(" ".join(map(str, input_ids))))
|
||||
logger.info("input_mask: {}".format(" ".join(map(str, input_mask))))
|
||||
logger.info("segment_ids: {}".format(" ".join(map(str, segment_ids))))
|
||||
if is_training:
|
||||
logger.info("label: {}".format(label))
|
||||
logger.info(f"label: {label}")
|
||||
|
||||
features.append(InputFeatures(example_id=example.swag_id, choices_features=choices_features, label=label))
|
||||
|
||||
@ -382,14 +381,14 @@ def train(args, train_dataset, model, tokenizer):
|
||||
): # Only evaluate when single GPU otherwise metrics may not average well
|
||||
results = evaluate(args, model, tokenizer)
|
||||
for key, value in results.items():
|
||||
tb_writer.add_scalar("eval_{}".format(key), value, global_step)
|
||||
tb_writer.add_scalar(f"eval_{key}", value, global_step)
|
||||
tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step)
|
||||
tb_writer.add_scalar("loss", (tr_loss - logging_loss) / args.logging_steps, global_step)
|
||||
logging_loss = tr_loss
|
||||
|
||||
if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
|
||||
# Save model checkpoint
|
||||
output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step))
|
||||
output_dir = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
||||
model_to_save = (
|
||||
model.module if hasattr(model, "module") else model
|
||||
) # Take care of distributed/parallel training
|
||||
@ -423,7 +422,7 @@ def evaluate(args, model, tokenizer, prefix=""):
|
||||
eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
|
||||
|
||||
# Eval!
|
||||
logger.info("***** Running evaluation {} *****".format(prefix))
|
||||
logger.info(f"***** Running evaluation {prefix} *****")
|
||||
logger.info(" Num examples = %d", len(dataset))
|
||||
logger.info(" Batch size = %d", args.eval_batch_size)
|
||||
|
||||
@ -466,7 +465,7 @@ def evaluate(args, model, tokenizer, prefix=""):
|
||||
logger.info("***** Eval results *****")
|
||||
for key in sorted(result.keys()):
|
||||
logger.info("%s = %s", key, str(result[key]))
|
||||
writer.write("%s = %s\n" % (key, str(result[key])))
|
||||
writer.write("{} = {}\n".format(key, str(result[key])))
|
||||
|
||||
return result
|
||||
|
||||
@ -710,10 +709,10 @@ def main():
|
||||
# Evaluate
|
||||
result = evaluate(args, model, tokenizer, prefix=global_step)
|
||||
|
||||
result = {k + ("_{}".format(global_step) if global_step else ""): v for k, v in result.items()}
|
||||
result = {k + (f"_{global_step}" if global_step else ""): v for k, v in result.items()}
|
||||
results.update(result)
|
||||
|
||||
logger.info("Results: {}".format(results))
|
||||
logger.info(f"Results: {results}")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.
|
||||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
@ -66,7 +65,7 @@ def main():
|
||||
ptvsd.wait_for_attach()
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
|
||||
logger.info("device: {}".format(device))
|
||||
logger.info(f"device: {device}")
|
||||
|
||||
# Load a pre-processed dataset
|
||||
# You can also build the corpus yourself using TransfoXLCorpus methods
|
||||
@ -111,7 +110,7 @@ def main():
|
||||
total_loss += seq_len * loss.item()
|
||||
total_len += seq_len
|
||||
total_time = time.time() - start_time
|
||||
logger.info("Time : {:.2f}s, {:.2f}ms/segment".format(total_time, 1000 * total_time / (idx + 1)))
|
||||
logger.info(f"Time : {total_time:.2f}s, {1000 * total_time / (idx + 1):.2f}ms/segment")
|
||||
return total_loss / total_len
|
||||
|
||||
# Run on test data.
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 Huggingface
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -13,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import io
|
||||
import json
|
||||
import unittest
|
||||
|
||||
@ -25,7 +23,7 @@ from utils import calculate_bleu
|
||||
|
||||
|
||||
filename = get_tests_dir() + "/test_data/fsmt/fsmt_val_data.json"
|
||||
with io.open(filename, "r", encoding="utf-8") as f:
|
||||
with open(filename, encoding="utf-8") as f:
|
||||
bleu_data = json.load(f)
|
||||
|
||||
|
||||
|
||||
@ -19,7 +19,6 @@ import time
|
||||
from json import JSONDecodeError
|
||||
from logging import getLogger
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
@ -55,10 +54,10 @@ def eval_data_dir(
|
||||
task="summarization",
|
||||
local_rank=None,
|
||||
num_return_sequences=1,
|
||||
dataset_kwargs: Dict = None,
|
||||
dataset_kwargs: dict = None,
|
||||
prefix="",
|
||||
**generate_kwargs,
|
||||
) -> Dict:
|
||||
) -> dict:
|
||||
"""Run evaluation on part of the data for one gpu and save to {save_dir}/rank_{rank}_output.json"""
|
||||
model_name = str(model_name)
|
||||
assert local_rank is not None
|
||||
@ -211,7 +210,7 @@ def run_generate():
|
||||
calc_bleu = "translation" in args.task
|
||||
score_fn = calculate_bleu if calc_bleu else calculate_rouge
|
||||
metric_name = "bleu" if calc_bleu else "rouge"
|
||||
metrics: Dict = score_fn(preds, labels)
|
||||
metrics: dict = score_fn(preds, labels)
|
||||
metrics["n_obs"] = len(preds)
|
||||
runtime = time.time() - start_time
|
||||
metrics["seconds_per_sample"] = round(runtime / metrics["n_obs"], 4)
|
||||
@ -227,7 +226,7 @@ def run_generate():
|
||||
shutil.rmtree(json_save_dir)
|
||||
|
||||
|
||||
def combine_partial_results(partial_results) -> List:
|
||||
def combine_partial_results(partial_results) -> list:
|
||||
"""Concatenate partial results into one file, then sort it by id."""
|
||||
records = []
|
||||
for partial_result in partial_results:
|
||||
@ -237,7 +236,7 @@ def combine_partial_results(partial_results) -> List:
|
||||
return preds
|
||||
|
||||
|
||||
def gather_results_from_each_node(num_replicas, save_dir, timeout) -> List[Dict[str, List]]:
|
||||
def gather_results_from_each_node(num_replicas, save_dir, timeout) -> list[dict[str, list]]:
|
||||
# WAIT FOR lots of .json files
|
||||
start_wait = time.time()
|
||||
logger.info("waiting for all nodes to finish")
|
||||
|
||||
@ -20,7 +20,6 @@ import time
|
||||
import warnings
|
||||
from logging import getLogger
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
@ -36,7 +35,7 @@ DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
|
||||
def generate_summaries_or_translations(
|
||||
examples: List[str],
|
||||
examples: list[str],
|
||||
out_file: str,
|
||||
model_name: str,
|
||||
batch_size: int = 8,
|
||||
@ -45,7 +44,7 @@ def generate_summaries_or_translations(
|
||||
task="summarization",
|
||||
prefix=None,
|
||||
**generate_kwargs,
|
||||
) -> Dict:
|
||||
) -> dict:
|
||||
"""Save model.generate results to <out_file>, and return how long it took."""
|
||||
fout = Path(out_file).open("w", encoding="utf-8")
|
||||
model_name = str(model_name)
|
||||
|
||||
@ -34,7 +34,7 @@ task_score_names = {
|
||||
|
||||
def parse_search_arg(search):
|
||||
groups = search.split()
|
||||
entries = dict((g.split("=") for g in groups))
|
||||
entries = dict(g.split("=") for g in groups)
|
||||
entry_names = list(entries.keys())
|
||||
sets = [[f"--{k} {v}" for v in vs.split(":")] for k, vs in entries.items()]
|
||||
matrix = [list(x) for x in itertools.product(*sets)]
|
||||
@ -105,7 +105,7 @@ def run_search():
|
||||
col_widths = {col: len(str(col)) for col in col_names}
|
||||
results = []
|
||||
for r in matrix:
|
||||
hparams = dict((x.replace("--", "").split() for x in r))
|
||||
hparams = dict(x.replace("--", "").split() for x in r)
|
||||
args_exp = " ".join(r).split()
|
||||
args_exp.extend(["--bs", str(args.bs)]) # in case we need to reduce its size due to CUDA OOM
|
||||
sys.argv = args_normal + args_exp
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user