mirror of
https://github.com/huggingface/transformers.git
synced 2025-11-19 17:54:41 +08:00
Compare commits
96 Commits
check_comm
...
fix-quants
| Author | SHA1 | Date | |
|---|---|---|---|
| b6b93d77dd | |||
| a5c903f877 | |||
| 67302b043e | |||
| 9f31104786 | |||
| d372b82754 | |||
| b2feaa215f | |||
| 1acbd0b327 | |||
| c40b370bd0 | |||
| b1bdf9cb39 | |||
| cd416f3c5c | |||
| 1742d1198d | |||
| 16924cd33a | |||
| 266d3b0568 | |||
| 8d6c4583bf | |||
| 2cc9152da0 | |||
| 8637f6e7ae | |||
| 0e74a71c03 | |||
| 47227f4775 | |||
| 7f9f4d9cc6 | |||
| 462beff5c3 | |||
| 66d57110f0 | |||
| 8598421b51 | |||
| 16c7afd06f | |||
| 309180f93a | |||
| 8976ceb051 | |||
| c01e711ee5 | |||
| 082e3ff4a3 | |||
| c0678c81b9 | |||
| f78cadfc97 | |||
| eddd51ec3d | |||
| 7607d80f7e | |||
| 32a58e3146 | |||
| 6f6095e0cf | |||
| c4cfc2e023 | |||
| 5c6d6bed4d | |||
| 80134e6e66 | |||
| ce40ca0d4c | |||
| 6408d3b01a | |||
| f40ef03214 | |||
| 5150dac727 | |||
| 27c3807991 | |||
| ffb35fe142 | |||
| 1fd63dd532 | |||
| 240d19f4a3 | |||
| ba938fa590 | |||
| 6744ebe745 | |||
| 1709ed96e4 | |||
| fd36275be2 | |||
| 922e85487b | |||
| f9e668abf3 | |||
| 7951105d69 | |||
| 58a3f8caac | |||
| fcea1e1fe0 | |||
| 563f2ffb21 | |||
| 6f479d5d75 | |||
| d012f34e0d | |||
| e76364d5c1 | |||
| 2b8068c306 | |||
| 33c60a5254 | |||
| fa22b56903 | |||
| f30c22500b | |||
| 496c283615 | |||
| df45a92cea | |||
| 3ff0e69f84 | |||
| 31839d741a | |||
| 2072f3059e | |||
| 3760afb21c | |||
| 3c0b2b101e | |||
| e869e9df54 | |||
| 37d48bbb48 | |||
| 21913b2e10 | |||
| f028e9340c | |||
| 4dd4a8fafe | |||
| 03538a80be | |||
| 700c48a29f | |||
| 18a19dea61 | |||
| dba6aeb1e3 | |||
| 1c9077f66d | |||
| 756742354b | |||
| 926c37aaf4 | |||
| f5630f9b1a | |||
| e8a6eb3304 | |||
| 370fc65ee5 | |||
| f065e402fc | |||
| 91d250efb1 | |||
| 7cb4280112 | |||
| 144c8ce280 | |||
| 069684ef87 | |||
| a127710b3a | |||
| 08f52e2178 | |||
| c790403039 | |||
| 8012f80f72 | |||
| 7b325cd573 | |||
| a9e2b80c71 | |||
| bc8b0b0541 | |||
| cbd83bf161 |
@ -46,8 +46,8 @@ jobs:
|
||||
- run: uv pip install -U -e .
|
||||
- run: echo 'export "GIT_COMMIT_MESSAGE=$(git show -s --format=%s)"' >> "$BASH_ENV" && source "$BASH_ENV"
|
||||
- run: mkdir -p test_preparation
|
||||
- run: python utils/tests_fetcher.py | tee tests_fetched_summary.txt
|
||||
- run: python utils/tests_fetcher.py --filter_tests
|
||||
- run: python utils/tests_fetcher.py | tee tests_fetched_summary.txt || true
|
||||
- run: python utils/tests_fetcher.py --filter_tests || true
|
||||
- run: export "GIT_COMMIT_MESSAGE=$(git show -s --format=%s)" && echo $GIT_COMMIT_MESSAGE && python .circleci/create_circleci_config.py --fetcher_folder test_preparation
|
||||
- run: |
|
||||
if [ ! -s test_preparation/generated_config.yml ]; then
|
||||
@ -98,8 +98,8 @@ jobs:
|
||||
- run: uv pip install -U -e .
|
||||
- run: echo 'export "GIT_COMMIT_MESSAGE=$(git show -s --format=%s)"' >> "$BASH_ENV" && source "$BASH_ENV"
|
||||
- run: mkdir -p test_preparation
|
||||
- run: python utils/tests_fetcher.py --fetch_all | tee tests_fetched_summary.txt
|
||||
- run: python utils/tests_fetcher.py --filter_tests
|
||||
- run: python utils/tests_fetcher.py --fetch_all | tee tests_fetched_summary.txt || true
|
||||
- run: python utils/tests_fetcher.py --filter_tests || true
|
||||
- run: export "GIT_COMMIT_MESSAGE=$(git show -s --format=%s)" && echo $GIT_COMMIT_MESSAGE && python .circleci/create_circleci_config.py --fetcher_folder test_preparation
|
||||
- run: |
|
||||
if [ ! -s test_preparation/generated_config.yml ]; then
|
||||
|
||||
7
.github/workflows/benchmark.yml
vendored
7
.github/workflows/benchmark.yml
vendored
@ -32,16 +32,15 @@ jobs:
|
||||
options: --gpus all --privileged --ipc host
|
||||
steps:
|
||||
- name: Get repo
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v5
|
||||
with:
|
||||
ref: ${{ github.event.pull_request.head.sha || github.sha }}
|
||||
fetch-depth: 1
|
||||
|
||||
- name: Install benchmark script dependencies
|
||||
run: python3 -m pip install -r benchmark_v2/requirements.txt kernels
|
||||
|
||||
- name: Reinstall transformers in edit mode (remove the one installed during docker image build)
|
||||
working-directory: /transformers
|
||||
run: python3 -m pip uninstall -y transformers && python3 -m pip install -e ".[torch]" && python3 -m pip uninstall -y torchvision # temp fix
|
||||
run: python3 -m pip uninstall -y transformers && python3 -m pip install -e ".[torch]"
|
||||
|
||||
- name: Run benchmark
|
||||
run: |
|
||||
|
||||
23
.github/workflows/check-workflow-permissions.yml
vendored
Normal file
23
.github/workflows/check-workflow-permissions.yml
vendored
Normal file
@ -0,0 +1,23 @@
|
||||
---
|
||||
name: Check Permissions Advisor
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
workflow_name:
|
||||
description: 'Workflow file name'
|
||||
type: string
|
||||
run_count:
|
||||
description: 'Number of runs to analyze'
|
||||
type: string
|
||||
default: "10"
|
||||
|
||||
jobs:
|
||||
advisor:
|
||||
uses: huggingface/security-workflows/.github/workflows/permissions-advisor-reusable.yml@main
|
||||
permissions:
|
||||
actions: read
|
||||
contents: read
|
||||
with:
|
||||
workflow_name: ${{ inputs.workflow_name }}
|
||||
run_count: ${{ fromJSON(inputs.run_count) }}
|
||||
19
.github/workflows/check_failed_tests.yml
vendored
19
.github/workflows/check_failed_tests.yml
vendored
@ -116,7 +116,6 @@ jobs:
|
||||
uses: actions/github-script@v6
|
||||
with:
|
||||
script: |
|
||||
|
||||
const { data: pr } = await github.rest.pulls.get({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
@ -131,24 +130,6 @@ jobs:
|
||||
|
||||
core.setOutput('merge_commit_base_sha', merge_commit.parents[0].sha);
|
||||
|
||||
- name: Update clone 2
|
||||
working-directory: /transformers
|
||||
if: ${{ env.process == 'true' }}
|
||||
env:
|
||||
commit_sha: ${{ inputs.commit_sha || github.sha }}
|
||||
run: |
|
||||
git fetch origin ${{ steps.pr_info.outputs.merge_commit_base_sha }} && git checkout ${{ steps.pr_info.outputs.merge_commit_base_sha }}
|
||||
git log -n 3
|
||||
|
||||
- name: Update clone 3
|
||||
working-directory: /transformers
|
||||
if: ${{ env.process == 'true' }}
|
||||
env:
|
||||
commit_sha: ${{ inputs.commit_sha || github.sha }}
|
||||
run: |
|
||||
git fetch origin "$commit_sha" && git checkout "$commit_sha"
|
||||
git log -n 3
|
||||
|
||||
# Usually, `END_SHA` should be the commit of the last previous workflow run of the **SAME** (scheduled) workflow.
|
||||
# (This is why we don't need to specify `workflow_id` which would be fetched automatically in the python script.)
|
||||
- name: Get `END_SHA` from previous CI runs of the same workflow
|
||||
|
||||
46
.github/workflows/self-comment-ci.yml
vendored
46
.github/workflows/self-comment-ci.yml
vendored
@ -6,10 +6,9 @@ on:
|
||||
- created
|
||||
branches-ignore:
|
||||
- main
|
||||
pull_request:
|
||||
#concurrency:
|
||||
# group: ${{ github.workflow }}-${{ github.event.issue.number }}-${{ startsWith(github.event.comment.body, 'run-slow') || startsWith(github.event.comment.body, 'run slow') || startsWith(github.event.comment.body, 'run_slow') }}
|
||||
# cancel-in-progress: true
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.event.issue.number }}-${{ startsWith(github.event.comment.body, 'run-slow') || startsWith(github.event.comment.body, 'run slow') || startsWith(github.event.comment.body, 'run_slow') }}
|
||||
cancel-in-progress: true
|
||||
permissions: read-all
|
||||
|
||||
env:
|
||||
@ -28,7 +27,7 @@ env:
|
||||
jobs:
|
||||
get-pr-number:
|
||||
name: Get PR number
|
||||
# if: ${{ github.event.issue.state == 'open' && contains(fromJSON('["ydshieh", "ArthurZucker", "zucchini-nlp", "molbap", "gante", "LysandreJik", "Cyrilvallez", "Rocketknight1", "SunMarc", "eustlb", "MekkCyber", "vasqu", "ivarflakstad", "stevhliu", "ebezzam", "remi-or", "itazap"]'), github.actor) && (startsWith(github.event.comment.body, 'run-slow') || startsWith(github.event.comment.body, 'run slow') || startsWith(github.event.comment.body, 'run_slow')) }}
|
||||
if: ${{ github.event.issue.state == 'open' && contains(fromJSON('["ydshieh", "ArthurZucker", "zucchini-nlp", "molbap", "gante", "LysandreJik", "Cyrilvallez", "Rocketknight1", "SunMarc", "eustlb", "MekkCyber", "vasqu", "ivarflakstad", "stevhliu", "ebezzam", "remi-or", "itazap"]'), github.actor) && (startsWith(github.event.comment.body, 'run-slow') || startsWith(github.event.comment.body, 'run slow') || startsWith(github.event.comment.body, 'run_slow')) }}
|
||||
uses: ./.github/workflows/get-pr-number.yml
|
||||
|
||||
get-pr-info:
|
||||
@ -52,14 +51,13 @@ jobs:
|
||||
COMMENT_DATE: ${{ github.event.comment.created_at }}
|
||||
PR_MERGE_COMMIT_TIMESTAMP: ${{ needs.get-pr-info.outputs.PR_MERGE_COMMIT_TIMESTAMP }}
|
||||
run: |
|
||||
echo "bon"
|
||||
# COMMENT_TIMESTAMP=$(date -d "${COMMENT_DATE}" +"%s")
|
||||
# echo "COMMENT_DATE: $COMMENT_DATE"
|
||||
# echo "COMMENT_TIMESTAMP: $COMMENT_TIMESTAMP"
|
||||
# if [ $COMMENT_TIMESTAMP -le $PR_MERGE_COMMIT_TIMESTAMP ]; then
|
||||
# echo "Last commit on the pull request is newer than the issue comment triggering this run! Abort!";
|
||||
# exit -1;
|
||||
# fi
|
||||
COMMENT_TIMESTAMP=$(date -d "${COMMENT_DATE}" +"%s")
|
||||
echo "COMMENT_DATE: $COMMENT_DATE"
|
||||
echo "COMMENT_TIMESTAMP: $COMMENT_TIMESTAMP"
|
||||
if [ $COMMENT_TIMESTAMP -le $PR_MERGE_COMMIT_TIMESTAMP ]; then
|
||||
echo "Last commit on the pull request is newer than the issue comment triggering this run! Abort!";
|
||||
exit -1;
|
||||
fi
|
||||
|
||||
# use a python script to handle this complex logic.
|
||||
get-tests:
|
||||
@ -72,21 +70,21 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: "0"
|
||||
ref: ${{ needs.check-timestamps.outputs.PR_MERGE_SHA }}
|
||||
ref: "refs/pull/${{ needs.get-pr-number.outputs.PR_NUMBER }}/merge"
|
||||
|
||||
# - name: Verify merge commit SHA
|
||||
# env:
|
||||
# VERIFIED_PR_MERGE_SHA: ${{ needs.check-timestamps.outputs.PR_MERGE_SHA }}
|
||||
# run: |
|
||||
# PR_MERGE_SHA=$(git log -1 --format=%H)
|
||||
# if [ $PR_MERGE_SHA != $VERIFIED_PR_MERGE_SHA ]; then
|
||||
# echo "The merged commit SHA is not the same as the verified one! Security issue detected, abort the workflow!";
|
||||
# exit -1;
|
||||
# fi
|
||||
- name: Verify merge commit SHA
|
||||
env:
|
||||
VERIFIED_PR_MERGE_SHA: ${{ needs.check-timestamps.outputs.PR_MERGE_SHA }}
|
||||
run: |
|
||||
PR_MERGE_SHA=$(git log -1 --format=%H)
|
||||
if [ $PR_MERGE_SHA != $VERIFIED_PR_MERGE_SHA ]; then
|
||||
echo "The merged commit SHA is not the same as the verified one! Security issue detected, abort the workflow!";
|
||||
exit -1;
|
||||
fi
|
||||
|
||||
- name: Get models to test
|
||||
env:
|
||||
PR_COMMENT: "run-slow: vit"
|
||||
PR_COMMENT: ${{ github.event.comment.body }}
|
||||
run: |
|
||||
python -m pip install GitPython
|
||||
python utils/pr_slow_ci_models.py --message "$PR_COMMENT" | tee output.txt
|
||||
|
||||
80
.github/workflows/self-scheduled-caller.yml
vendored
80
.github/workflows/self-scheduled-caller.yml
vendored
@ -6,7 +6,7 @@ on:
|
||||
- cron: "17 2 * * *"
|
||||
push:
|
||||
branches:
|
||||
- check_commitxxx
|
||||
- run_nvidia_ci*
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
prev_workflow_run_id:
|
||||
@ -23,7 +23,7 @@ on:
|
||||
|
||||
# Used for `push` to easily modify the target workflow runs to compare against
|
||||
env:
|
||||
prev_workflow_run_id: "19089641651"
|
||||
prev_workflow_run_id: ""
|
||||
other_workflow_run_id: ""
|
||||
|
||||
|
||||
@ -52,10 +52,84 @@ jobs:
|
||||
uses: ./.github/workflows/self-scheduled.yml
|
||||
with:
|
||||
job: run_models_gpu
|
||||
slack_report_channel: "#transformers-ci-dummy"
|
||||
slack_report_channel: "#transformers-ci-daily-models"
|
||||
docker: huggingface/transformers-all-latest-gpu
|
||||
ci_event: Daily CI
|
||||
runner_type: "a10"
|
||||
report_repo_id: hf-internal-testing/transformers_daily_ci
|
||||
commit_sha: ${{ github.sha }}
|
||||
secrets: inherit
|
||||
|
||||
torch-pipeline:
|
||||
name: Torch pipeline CI
|
||||
uses: ./.github/workflows/self-scheduled.yml
|
||||
with:
|
||||
job: run_pipelines_torch_gpu
|
||||
slack_report_channel: "#transformers-ci-daily-pipeline-torch"
|
||||
docker: huggingface/transformers-all-latest-gpu
|
||||
ci_event: Daily CI
|
||||
report_repo_id: hf-internal-testing/transformers_daily_ci
|
||||
commit_sha: ${{ github.sha }}
|
||||
secrets: inherit
|
||||
|
||||
example-ci:
|
||||
name: Example CI
|
||||
uses: ./.github/workflows/self-scheduled.yml
|
||||
with:
|
||||
job: run_examples_gpu
|
||||
slack_report_channel: "#transformers-ci-daily-examples"
|
||||
docker: huggingface/transformers-all-latest-gpu
|
||||
ci_event: Daily CI
|
||||
report_repo_id: hf-internal-testing/transformers_daily_ci
|
||||
commit_sha: ${{ github.sha }}
|
||||
secrets: inherit
|
||||
|
||||
trainer-fsdp-ci:
|
||||
name: Trainer/FSDP CI
|
||||
uses: ./.github/workflows/self-scheduled.yml
|
||||
with:
|
||||
job: run_trainer_and_fsdp_gpu
|
||||
slack_report_channel: "#transformers-ci-daily-training"
|
||||
docker: huggingface/transformers-all-latest-gpu
|
||||
runner_type: "a10"
|
||||
ci_event: Daily CI
|
||||
report_repo_id: hf-internal-testing/transformers_daily_ci
|
||||
commit_sha: ${{ github.sha }}
|
||||
secrets: inherit
|
||||
|
||||
deepspeed-ci:
|
||||
name: DeepSpeed CI
|
||||
uses: ./.github/workflows/self-scheduled.yml
|
||||
with:
|
||||
job: run_torch_cuda_extensions_gpu
|
||||
slack_report_channel: "#transformers-ci-daily-training"
|
||||
docker: huggingface/transformers-pytorch-deepspeed-latest-gpu
|
||||
ci_event: Daily CI
|
||||
working-directory-prefix: /workspace
|
||||
report_repo_id: hf-internal-testing/transformers_daily_ci
|
||||
commit_sha: ${{ github.sha }}
|
||||
secrets: inherit
|
||||
|
||||
quantization-ci:
|
||||
name: Quantization CI
|
||||
uses: ./.github/workflows/self-scheduled.yml
|
||||
with:
|
||||
job: run_quantization_torch_gpu
|
||||
slack_report_channel: "#transformers-ci-daily-quantization"
|
||||
docker: huggingface/transformers-quantization-latest-gpu
|
||||
ci_event: Daily CI
|
||||
report_repo_id: hf-internal-testing/transformers_daily_ci
|
||||
commit_sha: ${{ github.sha }}
|
||||
secrets: inherit
|
||||
|
||||
kernels-ci:
|
||||
name: Kernels CI
|
||||
uses: ./.github/workflows/self-scheduled.yml
|
||||
with:
|
||||
job: run_kernels_gpu
|
||||
slack_report_channel: "#transformers-ci-daily-kernels"
|
||||
docker: huggingface/transformers-all-latest-gpu
|
||||
ci_event: Daily CI
|
||||
report_repo_id: hf-internal-testing/transformers_daily_ci
|
||||
commit_sha: ${{ github.sha }}
|
||||
secrets: inherit
|
||||
12
.github/workflows/self-scheduled.yml
vendored
12
.github/workflows/self-scheduled.yml
vendored
@ -67,7 +67,7 @@ jobs:
|
||||
if: contains(fromJSON('["run_models_gpu", "run_trainer_and_fsdp_gpu", "run_quantization_torch_gpu"]'), inputs.job)
|
||||
strategy:
|
||||
matrix:
|
||||
machine_type: [aws-g5-4xlarge-cache]
|
||||
machine_type: [aws-g5-4xlarge-cache, aws-g5-12xlarge-cache]
|
||||
runs-on:
|
||||
group: '${{ matrix.machine_type }}'
|
||||
container:
|
||||
@ -136,7 +136,7 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
machine_type: [aws-g5-4xlarge-cache]
|
||||
machine_type: [aws-g5-4xlarge-cache, aws-g5-12xlarge-cache]
|
||||
slice_id: ${{ fromJSON(needs.setup.outputs.slice_ids) }}
|
||||
uses: ./.github/workflows/model_jobs.yml
|
||||
with:
|
||||
@ -157,7 +157,7 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
machine_type: [aws-g5-4xlarge-cache]
|
||||
machine_type: [aws-g5-4xlarge-cache, aws-g5-12xlarge-cache]
|
||||
slice_id: [0, 1]
|
||||
uses: ./.github/workflows/model_jobs.yml
|
||||
with:
|
||||
@ -177,7 +177,7 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
machine_type: [aws-g5-4xlarge-cache]
|
||||
machine_type: [aws-g5-4xlarge-cache, aws-g5-12xlarge-cache]
|
||||
runs-on:
|
||||
group: '${{ matrix.machine_type }}'
|
||||
container:
|
||||
@ -322,7 +322,7 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
machine_type: [aws-g5-4xlarge-cache]
|
||||
machine_type: [aws-g5-4xlarge-cache, aws-g5-12xlarge-cache]
|
||||
runs-on:
|
||||
group: '${{ matrix.machine_type }}'
|
||||
container:
|
||||
@ -427,7 +427,7 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
folders: ${{ fromJson(needs.setup.outputs.quantization_matrix) }}
|
||||
machine_type: [aws-g5-4xlarge-cache]
|
||||
machine_type: [aws-g5-4xlarge-cache, aws-g5-12xlarge-cache]
|
||||
runs-on:
|
||||
group: '${{ matrix.machine_type }}'
|
||||
container:
|
||||
|
||||
@ -125,8 +125,9 @@ If you're contributing a **vision-language model** (or any multimodal model that
|
||||
All new models should use the modular architecture pattern. Create a `modular_<model_name>.py` file using the modular model converter:
|
||||
|
||||
- Use the CLI, [`transformers add-new-model-like`](https://github.com/huggingface/transformers/blob/main/src/transformers/cli/add_new_model_like.py) to generate a modular skeleton and get started
|
||||
- All code should be in the modular file if possible. Modeling must be in it, it's better if configuration is in it as well.
|
||||
- All code should be in the modular file if possible. Modeling must be in it, it's better if configuration is in it as well. [Modular guide](./modular_transformers#implementing-a-modular-file) shows a quick way to set up a modular file.
|
||||
- Reuse existing patterns from similar models as much as possible
|
||||
- You can make the model compatible with inference engines such as vLLM or SGLang, and enable zero-effort integration. See specific requirements for model implementation in ["Transformers modeling backend"](./transformers_as_backend#multimodal-models)
|
||||
|
||||
To verify your modular file is correct, run:
|
||||
|
||||
|
||||
1
Makefile
1
Makefile
@ -45,6 +45,7 @@ repo-consistency:
|
||||
python utils/check_modular_conversion.py
|
||||
python utils/check_dummies.py
|
||||
python utils/check_repo.py
|
||||
python utils/check_init_weights_data.py
|
||||
python utils/check_inits.py
|
||||
python utils/check_pipeline_typing.py
|
||||
python utils/check_config_docstrings.py
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
gpustat==1.1.1
|
||||
psutil==6.0.0
|
||||
psycopg2==2.9.9
|
||||
torch>=2.4.0
|
||||
hf_xet
|
||||
pandas>=1.5.0
|
||||
pandas>=1.5.0
|
||||
|
||||
@ -36,6 +36,7 @@ class BenchmarkConfig:
|
||||
warmup_iterations: int = 5,
|
||||
measurement_iterations: int = 20,
|
||||
gpu_monitoring: bool = True, # NOTE: you may want to disable this at times as we have obsvered it could heavily slow down benchmarks on AMD
|
||||
continuous_batching: bool = False,
|
||||
batch_size: int = 1,
|
||||
sequence_length: int = 128,
|
||||
num_tokens_to_generate: int = 128,
|
||||
@ -51,6 +52,7 @@ class BenchmarkConfig:
|
||||
self.warmup_iterations = warmup_iterations
|
||||
self.measurement_iterations = measurement_iterations
|
||||
self.gpu_monitoring = gpu_monitoring
|
||||
self.continuous_batching = continuous_batching
|
||||
# Input parameters
|
||||
self.batch_size = batch_size
|
||||
self.sequence_length = sequence_length
|
||||
@ -85,6 +87,22 @@ class BenchmarkConfig:
|
||||
if is_fa:
|
||||
logger.warning("Flash attention does not support compile mode. Turning off compile mode.")
|
||||
self.compile_mode = None
|
||||
# Handle SDPA backend if not determined by the config (needs to be done before skipping duplicates)
|
||||
if self.attn_implementation == "sdpa" and self.sdpa_backend is None:
|
||||
default_backend = "flash_attention" # FIXME: torch has a _cur_sdpa_kernel_backends but it fails
|
||||
logger.warning(f"No SDPA backend provided, using {default_backend} instead.")
|
||||
self.sdpa_backend = default_backend
|
||||
if self.continuous_batching:
|
||||
if self.attn_implementation == "flex_attention":
|
||||
logger.error(
|
||||
"disabling continuous batching because of invalid configuration: flex attention is not supported"
|
||||
)
|
||||
self.continuous_batching = False
|
||||
elif self.attn_implementation == "sdpa" and self.sdpa_backend is not None:
|
||||
logger.warning(
|
||||
"when continuous batching is enabled, sdpa_backend must be None because of the attention mask, setting it to None"
|
||||
)
|
||||
self.sdpa_backend = "math"
|
||||
|
||||
@property
|
||||
def hash(self) -> str:
|
||||
@ -100,6 +118,7 @@ class BenchmarkConfig:
|
||||
attn_code += f"_{self.sdpa_backend}" if self.attn_implementation == "sdpa" else ""
|
||||
compile_str = f"compiled_{self.compile_mode}" if self.compile_mode is not None else "uncompiled"
|
||||
kernelize_str = "kernelized" if self.kernelize else "unkernelized"
|
||||
continuous_batching_str = "cb" if self.continuous_batching else "generate"
|
||||
sep = "-"
|
||||
else:
|
||||
iter_str = f"{self.warmup_iterations} warmup, {self.measurement_iterations} iterations"
|
||||
@ -109,8 +128,11 @@ class BenchmarkConfig:
|
||||
attn_code += f" with {self.sdpa_backend} backend" if self.attn_implementation == "sdpa" else ""
|
||||
compile_str = "compiled" if self.compile_mode is not None else "not compiled"
|
||||
kernelize_str = "kernelized" if self.kernelize else "not kernelized"
|
||||
continuous_batching_str = "continuous batching" if self.continuous_batching else "regular generate"
|
||||
sep = ", "
|
||||
return sep.join([iter_str, gpu_monitor_str, dimensions_str, attn_code, compile_str, kernelize_str])
|
||||
return sep.join(
|
||||
[iter_str, gpu_monitor_str, dimensions_str, attn_code, compile_str, kernelize_str, continuous_batching_str]
|
||||
)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
@ -118,6 +140,7 @@ class BenchmarkConfig:
|
||||
"warmup_iterations": self.warmup_iterations,
|
||||
"measurement_iterations": self.measurement_iterations,
|
||||
"gpu_monitoring": self.gpu_monitoring,
|
||||
"continuous_batching": self.continuous_batching,
|
||||
"batch_size": self.batch_size,
|
||||
"sequence_length": self.sequence_length,
|
||||
"num_tokens_to_generate": self.num_tokens_to_generate,
|
||||
@ -134,6 +157,7 @@ class BenchmarkConfig:
|
||||
warmup_iterations=data.get("warmup_iterations", 5),
|
||||
measurement_iterations=data.get("measurement_iterations", 20),
|
||||
gpu_monitoring=data.get("gpu_monitoring", False),
|
||||
continuous_batching=data.get("continuous_batching", False),
|
||||
batch_size=data.get("batch_size", 1),
|
||||
sequence_length=data.get("sequence_length", 128),
|
||||
num_tokens_to_generate=data.get("num_tokens_to_generate", 128),
|
||||
@ -191,15 +215,17 @@ def get_config_by_level(level: int) -> list[BenchmarkConfig]:
|
||||
# Usually there is not much to gain by compiling with other modes, but we allow it for level 4
|
||||
compile_modes = BenchmarkConfig.all_compiled_modes if level >= 4 else [None, "default"]
|
||||
for cm in compile_modes:
|
||||
for kernelize_on in [False, KERNELIZATION_AVAILABLE]:
|
||||
configs.append(
|
||||
BenchmarkConfig(
|
||||
attn_implementation=attn_implementation,
|
||||
sdpa_backend=sdpa_backend,
|
||||
compile_mode=cm,
|
||||
kernelize=kernelize_on,
|
||||
for kernelize_on in {False, KERNELIZATION_AVAILABLE}:
|
||||
for cb_on in [False, True]:
|
||||
configs.append(
|
||||
BenchmarkConfig(
|
||||
attn_implementation=attn_implementation,
|
||||
sdpa_backend=sdpa_backend,
|
||||
compile_mode=cm,
|
||||
kernelize=kernelize_on,
|
||||
continuous_batching=cb_on,
|
||||
)
|
||||
)
|
||||
)
|
||||
return configs
|
||||
# Otherwise, we add the configs for the given level
|
||||
if level >= 0:
|
||||
@ -207,8 +233,10 @@ def get_config_by_level(level: int) -> list[BenchmarkConfig]:
|
||||
if level >= 1:
|
||||
configs.append(BenchmarkConfig(attn_implementation="flash_attention_2"))
|
||||
configs.append(BenchmarkConfig(attn_implementation="eager", compile_mode="default"))
|
||||
configs.append(BenchmarkConfig(attn_implementation="flash_attention_2", continuous_batching=True))
|
||||
if level >= 2:
|
||||
configs.append(BenchmarkConfig(attn_implementation="sdpa", compile_mode="default"))
|
||||
configs.append(BenchmarkConfig(attn_implementation="flex_attention", compile_mode="default", kernelize=True))
|
||||
configs.append(BenchmarkConfig(attn_implementation="flash_attention_2", kernelize=True))
|
||||
configs.append(BenchmarkConfig(attn_implementation="paged|sdpa", continuous_batching=True))
|
||||
return configs
|
||||
|
||||
@ -117,8 +117,6 @@ def flush_memory():
|
||||
# Clear CUDA cache
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_max_memory_allocated()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
torch.cuda.synchronize()
|
||||
gc.collect()
|
||||
|
||||
@ -234,8 +232,9 @@ class BenchmarkRunner:
|
||||
self.logger.info(f"Running benchmark scenario: {config.name}")
|
||||
|
||||
# Quick validation: try one measurement first to see if this scenario works
|
||||
generate_fn = self.time_generate_batch if config.continuous_batching else self.time_generate
|
||||
flush_memory()
|
||||
e2e_latency, token_generation_times, shape_and_decoded_output, gpu_metrics = self.time_generate(
|
||||
e2e_latency, token_generation_times, shape_and_decoded_output, gpu_metrics = generate_fn(
|
||||
max_new_tokens=1, gpu_monitor=None
|
||||
)
|
||||
if e2e_latency < 0:
|
||||
@ -245,14 +244,14 @@ class BenchmarkRunner:
|
||||
# Warmup runs
|
||||
self.logger.info(f"Warming up with {config.warmup_iterations} iterations...")
|
||||
for _ in trange(config.warmup_iterations):
|
||||
_ = self.time_generate(max_new_tokens=config.num_tokens_to_generate)
|
||||
_ = generate_fn(max_new_tokens=config.num_tokens_to_generate)
|
||||
self.logger.info("Warmup over.")
|
||||
|
||||
# Measurement runs
|
||||
result = BenchmarkResult()
|
||||
self.logger.info(f"Benchmarking with {config.measurement_iterations} iterations.")
|
||||
for _ in trange(config.measurement_iterations):
|
||||
e2e_latency, token_generation_times, shape_and_decoded_output, gpu_metrics = self.time_generate(
|
||||
e2e_latency, token_generation_times, shape_and_decoded_output, gpu_metrics = generate_fn(
|
||||
max_new_tokens=config.num_tokens_to_generate,
|
||||
gpu_monitor=(GPUMonitor(logger=self.logger) if config.gpu_monitoring else None),
|
||||
)
|
||||
@ -274,6 +273,58 @@ class BenchmarkRunner:
|
||||
"config": config,
|
||||
}
|
||||
|
||||
# TODO: refactor `generate_batch` to handle streaming so we can use it here
|
||||
def time_generate_batch(
|
||||
self,
|
||||
max_new_tokens: int,
|
||||
gpu_monitor: GPUMonitor | None = None,
|
||||
) -> tuple[float, list[float], str, GPURawMetrics | None]:
|
||||
if gpu_monitor is not None:
|
||||
gpu_monitor.start()
|
||||
config = GenerationConfig(
|
||||
max_new_tokens=max_new_tokens,
|
||||
eos_token_id=self.tokenizer.eos_token_id,
|
||||
pad_token_id=self.tokenizer.pad_token_id,
|
||||
do_sample=True,
|
||||
)
|
||||
manager = self.model.init_continuous_batching(config)
|
||||
manager.start()
|
||||
try:
|
||||
first_req_results = []
|
||||
timestamps = []
|
||||
wall_time_0 = time.perf_counter()
|
||||
inputs = self.inputs["input_ids"].tolist()
|
||||
manager.add_requests(inputs, max_new_tokens=max_new_tokens, streaming=True)
|
||||
first_req_id = None
|
||||
num_requests = len(inputs)
|
||||
finished_requests = 0
|
||||
while finished_requests < num_requests:
|
||||
# NOTE: I don't like having the extra if stmt here, but hopefully won't degrade perf too much
|
||||
result = manager.get_result()
|
||||
if result:
|
||||
timestamps.append(time.perf_counter() - wall_time_0)
|
||||
if result.is_finished():
|
||||
finished_requests += 1
|
||||
if first_req_id is None:
|
||||
first_req_id = result.request_id
|
||||
if result.request_id == first_req_id:
|
||||
first_req_results.append(result)
|
||||
else:
|
||||
if not manager.is_running():
|
||||
raise RuntimeError("Generation thread exited unexpectedly")
|
||||
wall_time_1 = time.perf_counter()
|
||||
gpu_metrics = gpu_monitor.stop_and_collect() if gpu_monitor is not None else None
|
||||
decoded_output = self.tokenizer.decode(
|
||||
[res.generated_tokens[0] for res in first_req_results], skip_special_tokens=True
|
||||
)
|
||||
shape_and_decoded_output = f"{(1, len(first_req_results))} | {decoded_output}"
|
||||
e2e_latency = wall_time_1 - wall_time_0
|
||||
return e2e_latency, timestamps, shape_and_decoded_output, gpu_metrics
|
||||
except Exception as e:
|
||||
raise e
|
||||
finally:
|
||||
manager.stop()
|
||||
|
||||
def time_generate(
|
||||
self,
|
||||
max_new_tokens: int,
|
||||
@ -339,12 +390,6 @@ class BenchmarkRunner:
|
||||
|
||||
n_configs = len(benchmark_configs)
|
||||
for i, config in enumerate(benchmark_configs):
|
||||
# Handle SDPA backend if not determined by the config (needs to be done before skipping duplicates)
|
||||
if config.attn_implementation == "sdpa" and config.sdpa_backend is None:
|
||||
default_backend = "flash_attention" # FIXME: torch has a _cur_sdpa_kernel_backends but it fails
|
||||
self.logger.warning(f"No SDPA backend provided, using {default_backend} instead.")
|
||||
config.sdpa_backend = default_backend
|
||||
|
||||
# Skip if already run
|
||||
if config.hash in all_results:
|
||||
self.logger.info(f"Skipping duplicate config {config.name} for model {model_id} ({i + 1}/{n_configs})")
|
||||
@ -368,21 +413,27 @@ class BenchmarkRunner:
|
||||
self.cleanup()
|
||||
self.save_results(model_id, all_results, timestamp=timestamp)
|
||||
|
||||
if len(all_results) < 1:
|
||||
raise RuntimeError("No benchmark was run succesfully")
|
||||
|
||||
if pretty_print_summary:
|
||||
print()
|
||||
print("=" * 100)
|
||||
print(f"Finished benchmarks in {time.perf_counter() - start_time:.2f} seconds")
|
||||
print(f"Total number of benchmarks: {len(all_results)}")
|
||||
if len(all_results) > 0:
|
||||
print("First run metadata:")
|
||||
first_key = list(all_results.keys())[0]
|
||||
first_metadata = all_results[first_key]["metadata"].to_dict()
|
||||
hardware_info = first_metadata.pop("hardware_info")
|
||||
pretty_print_dict(first_metadata | hardware_info, tabs=1)
|
||||
print("First run metadata:")
|
||||
first_key = list(all_results.keys())[0]
|
||||
first_metadata = all_results[first_key]["metadata"].to_dict()
|
||||
hardware_info = first_metadata.pop("hardware_info")
|
||||
pretty_print_dict(first_metadata | hardware_info, tabs=1)
|
||||
for result in all_results.values():
|
||||
print("=" * 100)
|
||||
print(f"Config: {result['config'].infer_name(compact=False)}\n")
|
||||
result["measurements"].pprint(batch_size=result["config"].batch_size, tabs=1)
|
||||
result["measurements"].pprint(
|
||||
batch_size=result["config"].batch_size,
|
||||
num_generated_tokens=result["config"].num_tokens_to_generate,
|
||||
tabs=1,
|
||||
)
|
||||
print("=" * 100)
|
||||
|
||||
return (timestamp, all_results)
|
||||
|
||||
@ -36,16 +36,17 @@ def add_unit_to_duration(stats: dict[str, float]) -> dict[str, str]:
|
||||
return stats
|
||||
|
||||
|
||||
def equalize_lengths_and_collate(stats: list[dict[str, str]]) -> list[str]:
|
||||
def equalize_lengths_and_collate(stats: dict[str, dict[str, str]]) -> dict[str, str]:
|
||||
"""Note: This operation is destructive as it will update values in place before returning a new correctly formatted dict"""
|
||||
keys = ["avg", "std", "min", "med", "max", "p95"]
|
||||
for key in keys:
|
||||
max_length = max(len(stat[key]) for stat in stats)
|
||||
for stat in stats:
|
||||
max_length = max(len(stat[key]) for stat in stats.values())
|
||||
for stat in stats.values():
|
||||
stat[key] = stat[key].ljust(max_length, " ")
|
||||
return [" ".join([f"{key}={stat[key]}" for key in keys]) for stat in stats]
|
||||
return {name: " ".join([f"{key}={stat[key]}" for key in keys]) for name, stat in stats.items()}
|
||||
|
||||
|
||||
def pretty_print_dict(data: dict[str, Any], tabs: int = 0) -> None:
|
||||
def pretty_print_dict(data: dict[str, str], tabs: int = 0) -> None:
|
||||
max_key_length = max([len(key) for key in data.keys()])
|
||||
for key, value in data.items():
|
||||
tabs_str = " " * tabs
|
||||
@ -141,27 +142,19 @@ class BenchmarkResult:
|
||||
def get_measured_itl(self) -> list[float]:
|
||||
return [(dt[-1] - dt[0]) / (len(dt) - 1) for dt in self.token_generation_times if len(dt) > 1]
|
||||
|
||||
def get_throughput(self, batch_size: int) -> float:
|
||||
return [
|
||||
batch_size * len(dt) / e2e_latency
|
||||
for e2e_latency, dt in zip(self.e2e_latency, self.token_generation_times)
|
||||
]
|
||||
def get_throughput(self, total_generated_tokens: int) -> list[float]:
|
||||
return [total_generated_tokens / e2e_latency for e2e_latency in self.e2e_latency]
|
||||
|
||||
def pprint(self, batch_size: int = 0, tabs: int = 0) -> None:
|
||||
stats_to_collate = [
|
||||
add_unit_to_duration(compute_basic_statistics(self.e2e_latency)),
|
||||
add_unit_to_duration(compute_basic_statistics(self.get_measured_ttft())),
|
||||
add_unit_to_duration(compute_basic_statistics(self.get_measured_itl())),
|
||||
]
|
||||
if batch_size > 0:
|
||||
throughput_stats = compute_basic_statistics(self.get_throughput(batch_size))
|
||||
stats_to_collate.append({key: f"{value:.2f}tok/s" for key, value in throughput_stats.items()})
|
||||
collated_stats = equalize_lengths_and_collate(stats_to_collate)
|
||||
dict_to_pprint = {
|
||||
"E2E Latency": collated_stats[0],
|
||||
"Time to First Token": collated_stats[1],
|
||||
"Inter-Token Latency": collated_stats[2],
|
||||
def pprint(self, batch_size: int = 0, num_generated_tokens: int = 0, tabs: int = 0) -> None:
|
||||
measurements = {
|
||||
"E2E Latency": add_unit_to_duration(compute_basic_statistics(self.e2e_latency)),
|
||||
"Time to First Token": add_unit_to_duration(compute_basic_statistics(self.get_measured_ttft())),
|
||||
}
|
||||
itl_values = self.get_measured_itl()
|
||||
if len(itl_values) > 0:
|
||||
measurements["Inter-Token Latency"] = add_unit_to_duration(compute_basic_statistics(itl_values))
|
||||
if batch_size > 0:
|
||||
dict_to_pprint["Throughput"] = collated_stats[3]
|
||||
throughput_stats = compute_basic_statistics(self.get_throughput(batch_size * num_generated_tokens))
|
||||
measurements["Throughput"] = {key: f"{value:.2f}tok/s" for key, value in throughput_stats.items()}
|
||||
dict_to_pprint = equalize_lengths_and_collate(measurements)
|
||||
pretty_print_dict(dict_to_pprint, tabs=tabs)
|
||||
|
||||
@ -2,6 +2,5 @@ numpy>=1.21.0
|
||||
psutil>=5.8.0
|
||||
gpustat>=1.0.0
|
||||
torch>=2.0.0
|
||||
transformers>=4.30.0
|
||||
datasets>=2.10.0
|
||||
huggingface_hub>=0.16.0
|
||||
|
||||
@ -80,6 +80,10 @@ if __name__ == "__main__":
|
||||
logger.info(f"Benchmark run UUID: {benchmark_run_uuid}")
|
||||
logger.info(f"Output directory: {args.output_dir}")
|
||||
|
||||
# We cannot compute ITL if we don't have at least two measurements
|
||||
if any(n <= 1 for n in args.num_tokens_to_generate):
|
||||
raise ValueError("--num_tokens_to_generate arguments should be larger than 1")
|
||||
|
||||
# Error out if one of the arguments is not provided
|
||||
if len(args.batch_size) * len(args.sequence_length) * len(args.num_tokens_to_generate) == 0:
|
||||
raise ValueError(
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
FROM rocm/pytorch:rocm7.0.2_ubuntu24.04_py3.12_pytorch_release_2.7.1
|
||||
FROM rocm/pytorch:rocm7.1_ubuntu22.04_py3.10_pytorch_release_2.8.0
|
||||
LABEL maintainer="Hugging Face"
|
||||
|
||||
ARG DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
@ -508,16 +508,16 @@ BERT `_init_weights` Methode:
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.zero_()
|
||||
module.weight.fill_(1.0)
|
||||
```
|
||||
|
||||
Sie können weitere benutzerdefinierte Schemata verwenden, wenn Sie eine spezielle Initialisierung für einige Module benötigen. Zum Beispiel in
|
||||
@ -533,9 +533,9 @@ def _init_weights(self, module):
|
||||
module.project_hid._is_hf_initialized = True
|
||||
module.project_q._is_hf_initialized = True
|
||||
elif isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
```
|
||||
|
||||
Das Flag `_is_hf_initialized` wird intern verwendet, um sicherzustellen, dass wir ein Submodul nur einmal initialisieren. Wenn Sie es auf
|
||||
|
||||
@ -118,7 +118,7 @@
|
||||
- local: tools
|
||||
title: Tools
|
||||
- local: transformers_as_backend
|
||||
title: Inference server backends
|
||||
title: Transformers as modeling backend
|
||||
- local: continuous_batching
|
||||
title: Continuous Batching
|
||||
title: Inference
|
||||
@ -420,8 +420,6 @@
|
||||
title: BLOOM
|
||||
- local: model_doc/blt
|
||||
title: BLT
|
||||
- local: model_doc/bort
|
||||
title: BORT
|
||||
- local: model_doc/byt5
|
||||
title: ByT5
|
||||
- local: model_doc/camembert
|
||||
@ -476,8 +474,6 @@
|
||||
title: Ernie4_5
|
||||
- local: model_doc/ernie4_5_moe
|
||||
title: Ernie4_5_MoE
|
||||
- local: model_doc/ernie_m
|
||||
title: ErnieM
|
||||
- local: model_doc/esm
|
||||
title: ESM
|
||||
- local: model_doc/exaone4
|
||||
@ -532,8 +528,6 @@
|
||||
title: GPTBigCode
|
||||
- local: model_doc/gpt_oss
|
||||
title: GptOss
|
||||
- local: model_doc/gptsan-japanese
|
||||
title: GPTSAN Japanese
|
||||
- local: model_doc/gpt-sw3
|
||||
title: GPTSw3
|
||||
- local: model_doc/granite
|
||||
@ -558,8 +552,6 @@
|
||||
title: Jamba
|
||||
- local: model_doc/jetmoe
|
||||
title: JetMoe
|
||||
- local: model_doc/jukebox
|
||||
title: Jukebox
|
||||
- local: model_doc/led
|
||||
title: LED
|
||||
- local: model_doc/lfm2
|
||||
@ -594,8 +586,6 @@
|
||||
title: MarkupLM
|
||||
- local: model_doc/mbart
|
||||
title: MBart and MBart-50
|
||||
- local: model_doc/mega
|
||||
title: MEGA
|
||||
- local: model_doc/megatron-bert
|
||||
title: MegatronBERT
|
||||
- local: model_doc/megatron_gpt2
|
||||
@ -630,8 +620,6 @@
|
||||
title: myt5
|
||||
- local: model_doc/nemotron
|
||||
title: Nemotron
|
||||
- local: model_doc/nezha
|
||||
title: NEZHA
|
||||
- local: model_doc/nllb
|
||||
title: NLLB
|
||||
- local: model_doc/nllb-moe
|
||||
@ -646,8 +634,6 @@
|
||||
title: Olmo3
|
||||
- local: model_doc/olmoe
|
||||
title: OLMoE
|
||||
- local: model_doc/open-llama
|
||||
title: Open-Llama
|
||||
- local: model_doc/opt
|
||||
title: OPT
|
||||
- local: model_doc/pegasus
|
||||
@ -668,8 +654,6 @@
|
||||
title: PLBart
|
||||
- local: model_doc/prophetnet
|
||||
title: ProphetNet
|
||||
- local: model_doc/qdqbert
|
||||
title: QDQBert
|
||||
- local: model_doc/qwen2
|
||||
title: Qwen2
|
||||
- local: model_doc/qwen2_moe
|
||||
@ -682,16 +666,12 @@
|
||||
title: Qwen3Next
|
||||
- local: model_doc/rag
|
||||
title: RAG
|
||||
- local: model_doc/realm
|
||||
title: REALM
|
||||
- local: model_doc/recurrent_gemma
|
||||
title: RecurrentGemma
|
||||
- local: model_doc/reformer
|
||||
title: Reformer
|
||||
- local: model_doc/rembert
|
||||
title: RemBERT
|
||||
- local: model_doc/retribert
|
||||
title: RetriBERT
|
||||
- local: model_doc/roberta
|
||||
title: RoBERTa
|
||||
- local: model_doc/roberta-prelayernorm
|
||||
@ -720,10 +700,6 @@
|
||||
title: T5Gemma
|
||||
- local: model_doc/t5v1.1
|
||||
title: T5v1.1
|
||||
- local: model_doc/tapex
|
||||
title: TAPEX
|
||||
- local: model_doc/transfo-xl
|
||||
title: Transformer XL
|
||||
- local: model_doc/ul2
|
||||
title: UL2
|
||||
- local: model_doc/umt5
|
||||
@ -736,8 +712,6 @@
|
||||
title: XGLM
|
||||
- local: model_doc/xlm
|
||||
title: XLM
|
||||
- local: model_doc/xlm-prophetnet
|
||||
title: XLM-ProphetNet
|
||||
- local: model_doc/xlm-roberta
|
||||
title: XLM-RoBERTa
|
||||
- local: model_doc/xlm-roberta-xl
|
||||
@ -784,8 +758,6 @@
|
||||
title: Depth Anything V2
|
||||
- local: model_doc/depth_pro
|
||||
title: DepthPro
|
||||
- local: model_doc/deta
|
||||
title: DETA
|
||||
- local: model_doc/detr
|
||||
title: DETR
|
||||
- local: model_doc/dinat
|
||||
@ -800,8 +772,6 @@
|
||||
title: DiT
|
||||
- local: model_doc/dpt
|
||||
title: DPT
|
||||
- local: model_doc/efficientformer
|
||||
title: EfficientFormer
|
||||
- local: model_doc/efficientloftr
|
||||
title: EfficientLoFTR
|
||||
- local: model_doc/efficientnet
|
||||
@ -838,8 +808,6 @@
|
||||
title: MobileViT
|
||||
- local: model_doc/mobilevitv2
|
||||
title: MobileViTV2
|
||||
- local: model_doc/nat
|
||||
title: NAT
|
||||
- local: model_doc/poolformer
|
||||
title: PoolFormer
|
||||
- local: model_doc/prompt_depth_anything
|
||||
@ -886,12 +854,8 @@
|
||||
title: Timm Wrapper
|
||||
- local: model_doc/upernet
|
||||
title: UperNet
|
||||
- local: model_doc/van
|
||||
title: VAN
|
||||
- local: model_doc/vit
|
||||
title: Vision Transformer (ViT)
|
||||
- local: model_doc/vit_hybrid
|
||||
title: ViT Hybrid
|
||||
- local: model_doc/vitdet
|
||||
title: ViTDet
|
||||
- local: model_doc/vit_mae
|
||||
@ -930,8 +894,6 @@
|
||||
title: Hubert
|
||||
- local: model_doc/kyutai_speech_to_text
|
||||
title: Kyutai Speech-To-Text
|
||||
- local: model_doc/mctct
|
||||
title: MCTCT
|
||||
- local: model_doc/mimi
|
||||
title: Mimi
|
||||
- local: model_doc/mms
|
||||
@ -958,8 +920,6 @@
|
||||
title: SEW-D
|
||||
- local: model_doc/speech_to_text
|
||||
title: Speech2Text
|
||||
- local: model_doc/speech_to_text_2
|
||||
title: Speech2Text2
|
||||
- local: model_doc/speecht5
|
||||
title: SpeechT5
|
||||
- local: model_doc/unispeech
|
||||
@ -1008,6 +968,8 @@
|
||||
title: AltCLIP
|
||||
- local: model_doc/aria
|
||||
title: Aria
|
||||
- local: model_doc/audioflamingo3
|
||||
title: AudioFlamingo3
|
||||
- local: model_doc/aya_vision
|
||||
title: AyaVision
|
||||
- local: model_doc/blip
|
||||
@ -1064,6 +1026,8 @@
|
||||
title: Gemma3n
|
||||
- local: model_doc/git
|
||||
title: GIT
|
||||
- local: model_doc/glm46v
|
||||
title: Glm46V
|
||||
- local: model_doc/glm4v
|
||||
title: glm4v
|
||||
- local: model_doc/glm4v_moe
|
||||
@ -1184,8 +1148,6 @@
|
||||
title: TAPAS
|
||||
- local: model_doc/trocr
|
||||
title: TrOCR
|
||||
- local: model_doc/tvlt
|
||||
title: TVLT
|
||||
- local: model_doc/tvp
|
||||
title: TVP
|
||||
- local: model_doc/udop
|
||||
@ -1212,8 +1174,6 @@
|
||||
- sections:
|
||||
- local: model_doc/decision_transformer
|
||||
title: Decision Transformer
|
||||
- local: model_doc/trajectory_transformer
|
||||
title: Trajectory Transformer
|
||||
title: Reinforcement learning models
|
||||
- sections:
|
||||
- local: model_doc/autoformer
|
||||
@ -1229,10 +1189,6 @@
|
||||
- local: model_doc/timesfm
|
||||
title: TimesFM
|
||||
title: Time series models
|
||||
- sections:
|
||||
- local: model_doc/graphormer
|
||||
title: Graphormer
|
||||
title: Graph models
|
||||
title: Models
|
||||
- sections:
|
||||
- local: internal/modeling_utils
|
||||
|
||||
@ -314,16 +314,16 @@ Random initialization occurs in the `_init_weights` method of `BrandNewLlamaPreT
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.zero_()
|
||||
module.weight.fill_(1.0)
|
||||
```
|
||||
|
||||
The initialization scheme can look different if you need to adapt it to your model. For example, [`Wav2Vec2ForPreTraining`] initializes [nn.Linear](https://pytorch.org/docs/stable/generated/torch.nn.Linear.html) in its last two linear layers.
|
||||
@ -339,9 +339,9 @@ def _init_weights(self, module):
|
||||
module.project_hid._is_hf_initialized = True
|
||||
module.project_q._is_hf_initialized = True
|
||||
elif isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
```
|
||||
|
||||
### Convert checkpoints to Transformers
|
||||
|
||||
402
docs/source/en/model_doc/audioflamingo3.md
Normal file
402
docs/source/en/model_doc/audioflamingo3.md
Normal file
@ -0,0 +1,402 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
*This model was released on 2025-07-10 and added to Hugging Face Transformers on 2025-11-11.*
|
||||
|
||||
# Audio Flamingo 3
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
|
||||
## Overview
|
||||
|
||||
Audio Flamingo 3 (AF3) is a fully open large audio–language model designed for robust understanding and reasoning over speech, environmental sounds, and music. AF3 pairs a Whisper-style audio encoder with a causal language model and performs replace-in-place audio–text fusion: the processor aligns post-pool audio frames to a dedicated placeholder token and the model replaces those token slots with projected audio embeddings during the forward pass.
|
||||
|
||||
The model checkpoint is available at: [nvidia/audio-flamingo-3-hf](https://huggingface.co/nvidia/audio-flamingo-3-hf)
|
||||
|
||||
Highlights:
|
||||
|
||||
- Unified audio encoder across speech, sound, and music.
|
||||
- **Long-audio support via windowing and post-pool alignment (up to 10 minutes maximum).** The model processes audio in 30-second windows with a hard limit of 20 windows (10 minutes total). Audio longer than 10 minutes will be truncated.
|
||||
- Deterministic fusion that preserves sequence length by replacing audio placeholder tokens with audio embeddings.
|
||||
|
||||
This model was contributed by [Lasha Koroshinadze](https://huggingface.co/lashahub) and [Eric Bezzam](https://huggingface.co/bezzam).
|
||||
|
||||
### Paper
|
||||
|
||||
[Audio Flamingo 3](https://huggingface.co/papers/2507.08128): Advancing Audio Intelligence with Fully Open Large Audio Language Models
|
||||
A. Goel, S. Ghosh, J. Kim, S. Kumar, Z. Kong, S. Lee, C.-H. H. Yang, R. Duraiswami, D. Manocha, R. Valle, B. Catanzaro
|
||||
NVIDIA and University of Maryland
|
||||
Project: https://research.nvidia.com/labs/adlr/AF3/
|
||||
|
||||
## Usage
|
||||
|
||||
### Audio Instruct Mode
|
||||
|
||||
The model supports audio-text instructions, including multi-turn interactions, all processed in batches.
|
||||
|
||||
➡️ audio + text instruction
|
||||
|
||||
```python
|
||||
from transformers import AudioFlamingo3ForConditionalGeneration, AutoProcessor
|
||||
|
||||
model_id = "nvidia/audio-flamingo-3-hf"
|
||||
processor = AutoProcessor.from_pretrained(model_id)
|
||||
model = AudioFlamingo3ForConditionalGeneration.from_pretrained(model_id, device_map="auto")
|
||||
|
||||
conversation = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "Transcribe the input speech."},
|
||||
{"type": "audio", "path": "https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/WhDJDIviAOg_120_10.mp3"},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
inputs = processor.apply_chat_template(
|
||||
conversation,
|
||||
tokenize=True,
|
||||
add_generation_prompt=True,
|
||||
return_dict=True,
|
||||
).to(model.device)
|
||||
|
||||
outputs = model.generate(**inputs, max_new_tokens=500)
|
||||
|
||||
decoded_outputs = processor.batch_decode(outputs[:, inputs.input_ids.shape[1]:], skip_special_tokens=True)
|
||||
print(decoded_outputs)
|
||||
```
|
||||
|
||||
➡️ multi-turn:
|
||||
|
||||
```python
|
||||
from transformers import AudioFlamingo3ForConditionalGeneration, AutoProcessor
|
||||
|
||||
model_id = "nvidia/audio-flamingo-3-hf"
|
||||
processor = AutoProcessor.from_pretrained(model_id)
|
||||
model = AudioFlamingo3ForConditionalGeneration.from_pretrained(model_id, device_map="auto")
|
||||
|
||||
conversation = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Instruction: How does the tone of female speech change throughout the audio? Choose the correct option among the options below: (A) Sad to happy (B) Happy to sad (C) Neutral to happy (D) Happy to neutral.",
|
||||
},
|
||||
{"type": "audio", "path": "https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/000000786159.31.wav"},
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": "(A) Sad to happy"}],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "Why do you think so?"},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
inputs = processor.apply_chat_template(
|
||||
conversation,
|
||||
tokenize=True,
|
||||
add_generation_prompt=True,
|
||||
return_dict=True,
|
||||
).to(model.device)
|
||||
|
||||
outputs = model.generate(**inputs, max_new_tokens=500)
|
||||
|
||||
decoded_outputs = processor.batch_decode(outputs[:, inputs.input_ids.shape[1]:], skip_special_tokens=True)
|
||||
print(decoded_outputs)
|
||||
```
|
||||
|
||||
➡️ text only:
|
||||
|
||||
```python
|
||||
from transformers import AudioFlamingo3ForConditionalGeneration, AutoProcessor
|
||||
|
||||
model_id = "nvidia/audio-flamingo-3-hf"
|
||||
processor = AutoProcessor.from_pretrained(model_id)
|
||||
model = AudioFlamingo3ForConditionalGeneration.from_pretrained(model_id, device_map="auto")
|
||||
|
||||
conversation = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What is the capital of France?"},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
inputs = processor.apply_chat_template(
|
||||
conversation,
|
||||
tokenize=True,
|
||||
add_generation_prompt=True,
|
||||
return_dict=True,
|
||||
).to(model.device)
|
||||
|
||||
outputs = model.generate(**inputs, max_new_tokens=500)
|
||||
|
||||
decoded_outputs = processor.batch_decode(outputs[:, inputs.input_ids.shape[1]:], skip_special_tokens=True)
|
||||
print(decoded_outputs)
|
||||
```
|
||||
|
||||
➡️ audio only:
|
||||
|
||||
```python
|
||||
from transformers import AudioFlamingo3ForConditionalGeneration, AutoProcessor
|
||||
|
||||
model_id = "nvidia/audio-flamingo-3-hf"
|
||||
processor = AutoProcessor.from_pretrained(model_id)
|
||||
model = AudioFlamingo3ForConditionalGeneration.from_pretrained(model_id, device_map="auto")
|
||||
|
||||
conversation = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "audio", "path": "https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/WhDJDIviAOg_120_10.mp3"},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
inputs = processor.apply_chat_template(
|
||||
conversation,
|
||||
tokenize=True,
|
||||
add_generation_prompt=True,
|
||||
return_dict=True,
|
||||
).to(model.device)
|
||||
|
||||
outputs = model.generate(**inputs, max_new_tokens=500)
|
||||
|
||||
decoded_outputs = processor.batch_decode(outputs[:, inputs.input_ids.shape[1]:], skip_special_tokens=True)
|
||||
print(decoded_outputs)
|
||||
```
|
||||
|
||||
➡️ batched inference!
|
||||
|
||||
```python
|
||||
from transformers import AudioFlamingo3ForConditionalGeneration, AutoProcessor
|
||||
|
||||
model_id = "nvidia/audio-flamingo-3-hf"
|
||||
processor = AutoProcessor.from_pretrained(model_id)
|
||||
model = AudioFlamingo3ForConditionalGeneration.from_pretrained(model_id, device_map="auto")
|
||||
|
||||
conversations = [
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "Transcribe the input speech."},
|
||||
{
|
||||
"type": "audio",
|
||||
"path": "https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/t_837b89f2-26aa-4ee2-bdf6-f73f0dd59b26.wav",
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "This track feels really peaceful and introspective. What elements make it feel so calming and meditative?",
|
||||
},
|
||||
{"type": "audio", "path": "https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/FPSbCAANfbJLVSwD.mp3"},
|
||||
],
|
||||
}
|
||||
],
|
||||
]
|
||||
|
||||
inputs = processor.apply_chat_template(
|
||||
conversations,
|
||||
tokenize=True,
|
||||
add_generation_prompt=True,
|
||||
return_dict=True,
|
||||
).to(model.device)
|
||||
|
||||
outputs = model.generate(**inputs, max_new_tokens=500)
|
||||
|
||||
decoded_outputs = processor.batch_decode(outputs[:, inputs.input_ids.shape[1]:], skip_special_tokens=True)
|
||||
print(decoded_outputs)
|
||||
```
|
||||
|
||||
➡️ Training:
|
||||
|
||||
```python
|
||||
from transformers import AudioFlamingo3ForConditionalGeneration, AutoProcessor
|
||||
|
||||
model_id = "nvidia/audio-flamingo-3-hf"
|
||||
processor = AutoProcessor.from_pretrained(model_id)
|
||||
model = AudioFlamingo3ForConditionalGeneration.from_pretrained(model_id, device_map="auto")
|
||||
model.train()
|
||||
|
||||
conversation = [
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "Transcribe the input speech."},
|
||||
{"type": "audio", "path": "https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/WhDJDIviAOg_120_10.mp3"},
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": "The transcription of the audio is 'summer follows spring the days grow longer and the nights are warm'."}],
|
||||
}
|
||||
],
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "This track feels really peaceful and introspective. What elements make it feel so calming and meditative?",
|
||||
},
|
||||
{"type": "audio", "path": "https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/FPSbCAANfbJLVSwD.mp3"},
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": "The transcription of the audio is 'some transcription of the audio'."}],
|
||||
}
|
||||
|
||||
]
|
||||
]
|
||||
|
||||
inputs = processor.apply_chat_template(
|
||||
conversation,
|
||||
tokenize=True,
|
||||
add_generation_prompt=True,
|
||||
return_dict=True,
|
||||
output_labels=True,
|
||||
).to(model.device)
|
||||
|
||||
loss = model(**inputs).loss
|
||||
loss.backward()
|
||||
```
|
||||
|
||||
➡️ transcription shortcut
|
||||
|
||||
```python
|
||||
from transformers import AudioFlamingo3ForConditionalGeneration, AutoProcessor
|
||||
|
||||
model_id = "nvidia/audio-flamingo-3-hf"
|
||||
processor = AutoProcessor.from_pretrained(model_id)
|
||||
model = AudioFlamingo3ForConditionalGeneration.from_pretrained(model_id, device_map="auto")
|
||||
|
||||
inputs = processor.apply_transcription_request(audio="https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/t_837b89f2-26aa-4ee2-bdf6-f73f0dd59b26.wav").to(model.device)
|
||||
|
||||
outputs = model.generate(**inputs, max_new_tokens=500)
|
||||
decoded_outputs = processor.batch_decode(outputs[:, inputs.input_ids.shape[1]:], skip_special_tokens=True, strip_prefix=True)
|
||||
|
||||
print(decoded_outputs)
|
||||
```
|
||||
|
||||
The model is trained to emit transcriptions prefixed with assistant framing such as `The spoken content of the audio is "<text>".`. Use `strip_prefix=True` (as shown above) to remove the fixed assistant sentence and surrounding quotes so that only the transcription remains.
|
||||
|
||||
## How the model works
|
||||
|
||||
### Architecture
|
||||
|
||||
* **AudioFlamingo3Encoder**
|
||||
Whisper-style feature extractor + encoder → average-pool over time (stride 2) → LayerNorm.
|
||||
Produces per-frame hidden states at the post-pool rate.
|
||||
|
||||
* **AudioFlamingo3MultiModalProjector**
|
||||
A small MLP that maps encoder features to the language model’s hidden size.
|
||||
|
||||
* **AudioFlamingo3ForConditionalGeneration**
|
||||
A causal language model that accepts text embeddings where each audio placeholder token slot is replaced, in place, by an audio frame embedding. No sequence-length change is introduced by fusion.
|
||||
|
||||
### Processor-level alignment
|
||||
|
||||
1. Each raw waveform is split into fixed-length windows based on the feature extractor’s `chunk_length` (seconds) and `sampling_rate` (Hz).
|
||||
2. For each window, the processor computes the number of post-pool frames `post_pool_len` that the encoder will output (matching the conv/pool schedule).
|
||||
3. The processor expands the audio placeholder token by the total number of post-pool frames across all windows.
|
||||
4. The model later replaces those token positions with the corresponding projected audio embeddings.
|
||||
|
||||
## Usage patterns
|
||||
|
||||
### Transcription shortcut
|
||||
|
||||
For automatic speech recognition you can skip writing the default instruction each time and call
|
||||
[`~transformers.AudioFlamingo3Processor.apply_transcription_request`]:
|
||||
|
||||
```python
|
||||
inputs = processor.apply_transcription_request(audio=audio_array)
|
||||
```
|
||||
|
||||
Pass `prompt="Transcribe the input speech."` (or a list of prompts for batch audio) to customize the instruction while
|
||||
keeping the audio placeholder handling.
|
||||
|
||||
`audio` accepts in-memory arrays, local file paths, or URLs. Any processor kwargs (`text_kwargs`, `audio_kwargs`, etc.)
|
||||
are forwarded, so you can tweak padding or tensor formats just like when calling `processor(...)`.
|
||||
|
||||
## Long audio and windowing
|
||||
|
||||
**Important: Maximum audio length is 10 minutes.** Audio longer than this will be truncated.
|
||||
|
||||
* The default setup processes 30-second windows at 16 kHz mono.
|
||||
* **The processor enforces a hard limit of 20 windows per sample, resulting in a maximum of 10 minutes of audio (20 windows × 30 seconds).**
|
||||
* For each window:
|
||||
|
||||
* `mel_len` is the padded mel length.
|
||||
* A conv stack reduces time as `conv_output_len = (mel_len - 1) // 2 + 1`.
|
||||
* Post-pool frames per window: `post_pool_len = (conv_output_len - 2) // 2 + 1`.
|
||||
* An audio placeholder token is expanded to the sum of `post_pool_len` across all windows.
|
||||
|
||||
## Padding, attention, and caching
|
||||
|
||||
* **Left padding vs right padding**
|
||||
For generation with mixed prompt lengths in a batch, left padding is usually preferable.
|
||||
For training, right padding is common; AF3’s fusion mechanism itself is padding-agnostic because it replaces in place.
|
||||
* **Attention masks**
|
||||
The processor returns `attention_mask` (text) and `input_features_mask` (audio). The model builds an internal 4-D mask on the encoder’s pre-pool axis with negative infinity at pad positions.
|
||||
* **Caching**
|
||||
During generation, `input_features` and `input_features_mask` are only passed on the first step. Subsequent steps use cached keys/values from the language model.
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
* Empty or truncated outputs when batching
|
||||
Use left padding for batched generation and decode only the new tokens after the prompt length, as shown in the quickstart.
|
||||
|
||||
## AudioFlamingo3Config
|
||||
|
||||
[[autodoc]] AudioFlamingo3Config
|
||||
|
||||
## AudioFlamingo3EncoderConfig
|
||||
|
||||
[[autodoc]] AudioFlamingo3EncoderConfig
|
||||
|
||||
## AudioFlamingo3Processor
|
||||
|
||||
[[autodoc]] AudioFlamingo3Processor
|
||||
|
||||
## AudioFlamingo3Encoder
|
||||
|
||||
[[autodoc]] AudioFlamingo3Encoder
|
||||
- forward
|
||||
|
||||
## AudioFlamingo3ForConditionalGeneration
|
||||
|
||||
[[autodoc]] AudioFlamingo3ForConditionalGeneration
|
||||
- forward
|
||||
@ -1,60 +0,0 @@
|
||||
<!--Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
*This model was released on 2020-10-20 and added to Hugging Face Transformers on 2023-06-20.*
|
||||
|
||||
# BORT
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This model is in maintenance mode only, we do not accept any new PRs changing its code.
|
||||
|
||||
If you run into any issues running this model, please reinstall the last version that supported this model: v4.30.0.
|
||||
You can do so by running the following command: `pip install -U transformers==4.30.0`.
|
||||
|
||||
</Tip>
|
||||
|
||||
## Overview
|
||||
|
||||
The BORT model was proposed in [Optimal Subarchitecture Extraction for BERT](https://huggingface.co/papers/2010.10499) by
|
||||
Adrian de Wynter and Daniel J. Perry. It is an optimal subset of architectural parameters for the BERT, which the
|
||||
authors refer to as "Bort".
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
*We extract an optimal subset of architectural parameters for the BERT architecture from Devlin et al. (2018) by
|
||||
applying recent breakthroughs in algorithms for neural architecture search. This optimal subset, which we refer to as
|
||||
"Bort", is demonstrably smaller, having an effective (that is, not counting the embedding layer) size of 5.5% the
|
||||
original BERT-large architecture, and 16% of the net size. Bort is also able to be pretrained in 288 GPU hours, which
|
||||
is 1.2% of the time required to pretrain the highest-performing BERT parametric architectural variant, RoBERTa-large
|
||||
(Liu et al., 2019), and about 33% of that of the world-record, in GPU hours, required to train BERT-large on the same
|
||||
hardware. It is also 7.9x faster on a CPU, as well as being better performing than other compressed variants of the
|
||||
architecture, and some of the non-compressed variants: it obtains performance improvements of between 0.3% and 31%,
|
||||
absolute, with respect to BERT-large, on multiple public natural language understanding (NLU) benchmarks.*
|
||||
|
||||
This model was contributed by [stefan-it](https://huggingface.co/stefan-it). The original code can be found [here](https://github.com/alexa/bort/).
|
||||
|
||||
## Usage tips
|
||||
|
||||
- BORT's model architecture is based on BERT, refer to [BERT's documentation page](bert) for the
|
||||
model's API reference as well as usage examples.
|
||||
- BORT uses the RoBERTa tokenizer instead of the BERT tokenizer, refer to [RoBERTa's documentation page](roberta) for the tokenizer's API reference as well as usage examples.
|
||||
- BORT requires a specific fine-tuning algorithm, called [Agora](https://adewynter.github.io/notes/bort_algorithms_and_applications.html#fine-tuning-with-algebraic-topology) ,
|
||||
that is sadly not open-sourced yet. It would be very useful for the community, if someone tries to implement the
|
||||
algorithm to make BORT fine-tuning work.
|
||||
@ -1,78 +0,0 @@
|
||||
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
*This model was released on 2022-12-12 and added to Hugging Face Transformers on 2023-06-20.*
|
||||
|
||||
# DETA
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This model is in maintenance mode only, we don't accept any new PRs changing its code.
|
||||
If you run into any issues running this model, please reinstall the last version that supported this model: v4.40.2.
|
||||
You can do so by running the following command: `pip install -U transformers==4.40.2`.
|
||||
|
||||
</Tip>
|
||||
|
||||
## Overview
|
||||
|
||||
The DETA model was proposed in [NMS Strikes Back](https://huggingface.co/papers/2212.06137) by Jeffrey Ouyang-Zhang, Jang Hyun Cho, Xingyi Zhou, Philipp Krähenbühl.
|
||||
DETA (short for Detection Transformers with Assignment) improves [Deformable DETR](deformable_detr) by replacing the one-to-one bipartite Hungarian matching loss
|
||||
with one-to-many label assignments used in traditional detectors with non-maximum suppression (NMS). This leads to significant gains of up to 2.5 mAP.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
*Detection Transformer (DETR) directly transforms queries to unique objects by using one-to-one bipartite matching during training and enables end-to-end object detection. Recently, these models have surpassed traditional detectors on COCO with undeniable elegance. However, they differ from traditional detectors in multiple designs, including model architecture and training schedules, and thus the effectiveness of one-to-one matching is not fully understood. In this work, we conduct a strict comparison between the one-to-one Hungarian matching in DETRs and the one-to-many label assignments in traditional detectors with non-maximum supervision (NMS). Surprisingly, we observe one-to-many assignments with NMS consistently outperform standard one-to-one matching under the same setting, with a significant gain of up to 2.5 mAP. Our detector that trains Deformable-DETR with traditional IoU-based label assignment achieved 50.2 COCO mAP within 12 epochs (1x schedule) with ResNet50 backbone, outperforming all existing traditional or transformer-based detectors in this setting. On multiple datasets, schedules, and architectures, we consistently show bipartite matching is unnecessary for performant detection transformers. Furthermore, we attribute the success of detection transformers to their expressive transformer architecture.*
|
||||
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/deta_architecture.jpg"
|
||||
alt="drawing" width="600"/>
|
||||
|
||||
<small> DETA overview. Taken from the <a href="https://huggingface.co/papers/2212.06137">original paper</a>. </small>
|
||||
|
||||
This model was contributed by [nielsr](https://huggingface.co/nielsr).
|
||||
The original code can be found [here](https://github.com/jozhang97/DETA).
|
||||
|
||||
## Resources
|
||||
|
||||
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with DETA.
|
||||
|
||||
- Demo notebooks for DETA can be found [here](https://github.com/NielsRogge/Transformers-Tutorials/tree/master/DETA).
|
||||
- Scripts for finetuning [`DetaForObjectDetection`] with [`Trainer`] or [Accelerate](https://huggingface.co/docs/accelerate/index) can be found [here](https://github.com/huggingface/transformers/tree/main/examples/pytorch/object-detection).
|
||||
- See also: [Object detection task guide](../tasks/object_detection).
|
||||
|
||||
If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we'll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource.
|
||||
|
||||
## DetaConfig
|
||||
|
||||
[[autodoc]] DetaConfig
|
||||
|
||||
## DetaImageProcessor
|
||||
|
||||
[[autodoc]] DetaImageProcessor
|
||||
- preprocess
|
||||
- post_process_object_detection
|
||||
|
||||
## DetaModel
|
||||
|
||||
[[autodoc]] DetaModel
|
||||
- forward
|
||||
|
||||
## DetaForObjectDetection
|
||||
|
||||
[[autodoc]] DetaForObjectDetection
|
||||
- forward
|
||||
@ -169,6 +169,9 @@ print("Pooled output shape:", pooled_output.shape)
|
||||
[[autodoc]] DINOv3ViTModel
|
||||
- forward
|
||||
|
||||
## DINOv3ViTBackbone
|
||||
[[autodoc]] DINOv3ViTBackbone
|
||||
|
||||
## DINOv3ConvNextModel
|
||||
|
||||
[[autodoc]] DINOv3ConvNextModel
|
||||
|
||||
@ -1,85 +0,0 @@
|
||||
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
*This model was released on 2022-06-02 and added to Hugging Face Transformers on 2023-06-20.*
|
||||
|
||||
# EfficientFormer
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This model is in maintenance mode only, we don't accept any new PRs changing its code.
|
||||
If you run into any issues running this model, please reinstall the last version that supported this model: v4.40.2.
|
||||
You can do so by running the following command: `pip install -U transformers==4.40.2`.
|
||||
|
||||
</Tip>
|
||||
|
||||
## Overview
|
||||
|
||||
The EfficientFormer model was proposed in [EfficientFormer: Vision Transformers at MobileNet Speed](https://huggingface.co/papers/2206.01191)
|
||||
by Yanyu Li, Geng Yuan, Yang Wen, Eric Hu, Georgios Evangelidis, Sergey Tulyakov, Yanzhi Wang, Jian Ren. EfficientFormer proposes a
|
||||
dimension-consistent pure transformer that can be run on mobile devices for dense prediction tasks like image classification, object
|
||||
detection and semantic segmentation.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
*Vision Transformers (ViT) have shown rapid progress in computer vision tasks, achieving promising results on various benchmarks.
|
||||
However, due to the massive number of parameters and model design, e.g., attention mechanism, ViT-based models are generally
|
||||
times slower than lightweight convolutional networks. Therefore, the deployment of ViT for real-time applications is particularly
|
||||
challenging, especially on resource-constrained hardware such as mobile devices. Recent efforts try to reduce the computation
|
||||
complexity of ViT through network architecture search or hybrid design with MobileNet block, yet the inference speed is still
|
||||
unsatisfactory. This leads to an important question: can transformers run as fast as MobileNet while obtaining high performance?
|
||||
To answer this, we first revisit the network architecture and operators used in ViT-based models and identify inefficient designs.
|
||||
Then we introduce a dimension-consistent pure transformer (without MobileNet blocks) as a design paradigm.
|
||||
Finally, we perform latency-driven slimming to get a series of final models dubbed EfficientFormer.
|
||||
Extensive experiments show the superiority of EfficientFormer in performance and speed on mobile devices.
|
||||
Our fastest model, EfficientFormer-L1, achieves 79.2% top-1 accuracy on ImageNet-1K with only 1.6 ms inference latency on
|
||||
iPhone 12 (compiled with CoreML), which { runs as fast as MobileNetV2×1.4 (1.6 ms, 74.7% top-1),} and our largest model,
|
||||
EfficientFormer-L7, obtains 83.3% accuracy with only 7.0 ms latency. Our work proves that properly designed transformers can
|
||||
reach extremely low latency on mobile devices while maintaining high performance.*
|
||||
|
||||
This model was contributed by [novice03](https://huggingface.co/novice03) and [Bearnardd](https://huggingface.co/Bearnardd).
|
||||
The original code can be found [here](https://github.com/snap-research/EfficientFormer).
|
||||
|
||||
## Documentation resources
|
||||
|
||||
- [Image classification task guide](../tasks/image_classification)
|
||||
|
||||
## EfficientFormerConfig
|
||||
|
||||
[[autodoc]] EfficientFormerConfig
|
||||
|
||||
## EfficientFormerImageProcessor
|
||||
|
||||
[[autodoc]] EfficientFormerImageProcessor
|
||||
- preprocess
|
||||
|
||||
## EfficientFormerModel
|
||||
|
||||
[[autodoc]] EfficientFormerModel
|
||||
- forward
|
||||
|
||||
## EfficientFormerForImageClassification
|
||||
|
||||
[[autodoc]] EfficientFormerForImageClassification
|
||||
- forward
|
||||
|
||||
## EfficientFormerForImageClassificationWithTeacher
|
||||
|
||||
[[autodoc]] EfficientFormerForImageClassificationWithTeacher
|
||||
- forward
|
||||
@ -1,97 +0,0 @@
|
||||
<!--Copyright 2023 The HuggingFace and Baidu Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
*This model was released on 2020-12-31 and added to Hugging Face Transformers on 2023-06-20.*
|
||||
|
||||
# ErnieM
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This model is in maintenance mode only, we don't accept any new PRs changing its code.
|
||||
If you run into any issues running this model, please reinstall the last version that supported this model: v4.40.2.
|
||||
You can do so by running the following command: `pip install -U transformers==4.40.2`.
|
||||
|
||||
</Tip>
|
||||
|
||||
## Overview
|
||||
|
||||
The ErnieM model was proposed in [ERNIE-M: Enhanced Multilingual Representation by Aligning
|
||||
Cross-lingual Semantics with Monolingual Corpora](https://huggingface.co/papers/2012.15674) by Xuan Ouyang, Shuohuan Wang, Chao Pang, Yu Sun,
|
||||
Hao Tian, Hua Wu, Haifeng Wang.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
*Recent studies have demonstrated that pre-trained cross-lingual models achieve impressive performance in downstream cross-lingual tasks. This improvement benefits from learning a large amount of monolingual and parallel corpora. Although it is generally acknowledged that parallel corpora are critical for improving the model performance, existing methods are often constrained by the size of parallel corpora, especially for lowresource languages. In this paper, we propose ERNIE-M, a new training method that encourages the model to align the representation of multiple languages with monolingual corpora, to overcome the constraint that the parallel corpus size places on the model performance. Our key insight is to integrate back-translation into the pre-training process. We generate pseudo-parallel sentence pairs on a monolingual corpus to enable the learning of semantic alignments between different languages, thereby enhancing the semantic modeling of cross-lingual models. Experimental results show that ERNIE-M outperforms existing cross-lingual models and delivers new state-of-the-art results in various cross-lingual downstream tasks.*
|
||||
This model was contributed by [Susnato Dhar](https://huggingface.co/susnato). The original code can be found [here](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/paddlenlp/transformers/ernie_m).
|
||||
|
||||
## Usage tips
|
||||
|
||||
- Ernie-M is a BERT-like model so it is a stacked Transformer Encoder.
|
||||
- Instead of using MaskedLM for pretraining (like BERT) the authors used two novel techniques: `Cross-attention Masked Language Modeling` and `Back-translation Masked Language Modeling`. For now these two LMHead objectives are not implemented here.
|
||||
- It is a multilingual language model.
|
||||
- Next Sentence Prediction was not used in pretraining process.
|
||||
|
||||
## Resources
|
||||
|
||||
- [Text classification task guide](../tasks/sequence_classification)
|
||||
- [Token classification task guide](../tasks/token_classification)
|
||||
- [Question answering task guide](../tasks/question_answering)
|
||||
- [Multiple choice task guide](../tasks/multiple_choice)
|
||||
|
||||
## ErnieMConfig
|
||||
|
||||
[[autodoc]] ErnieMConfig
|
||||
|
||||
## ErnieMTokenizer
|
||||
|
||||
[[autodoc]] ErnieMTokenizer
|
||||
- build_inputs_with_special_tokens
|
||||
- get_special_tokens_mask
|
||||
- create_token_type_ids_from_sequences
|
||||
- save_vocabulary
|
||||
|
||||
## ErnieMModel
|
||||
|
||||
[[autodoc]] ErnieMModel
|
||||
- forward
|
||||
|
||||
## ErnieMForSequenceClassification
|
||||
|
||||
[[autodoc]] ErnieMForSequenceClassification
|
||||
- forward
|
||||
|
||||
## ErnieMForMultipleChoice
|
||||
|
||||
[[autodoc]] ErnieMForMultipleChoice
|
||||
- forward
|
||||
|
||||
## ErnieMForTokenClassification
|
||||
|
||||
[[autodoc]] ErnieMForTokenClassification
|
||||
- forward
|
||||
|
||||
## ErnieMForQuestionAnswering
|
||||
|
||||
[[autodoc]] ErnieMForQuestionAnswering
|
||||
- forward
|
||||
|
||||
## ErnieMForInformationExtraction
|
||||
|
||||
[[autodoc]] ErnieMForInformationExtraction
|
||||
- forward
|
||||
34
docs/source/en/model_doc/glm46v.md
Normal file
34
docs/source/en/model_doc/glm46v.md
Normal file
@ -0,0 +1,34 @@
|
||||
# GLM-4.6V
|
||||
|
||||
## Glm46VConfig
|
||||
|
||||
[[autodoc]] Glm46VConfig
|
||||
|
||||
## Glm46VImageProcessor
|
||||
|
||||
[[autodoc]] Glm46VImageProcessor
|
||||
- preprocess
|
||||
|
||||
## Glm46VVideoProcessor
|
||||
|
||||
[[autodoc]] Glm46VVideoProcessor
|
||||
- preprocess
|
||||
|
||||
## Glm46VImageProcessorFast
|
||||
|
||||
[[autodoc]] Glm46VImageProcessorFast
|
||||
- preprocess
|
||||
|
||||
## Glm46VProcessor
|
||||
|
||||
[[autodoc]] Glm46VProcessor
|
||||
|
||||
## Glm46VModel
|
||||
|
||||
[[autodoc]] Glm46VModel
|
||||
- forward
|
||||
|
||||
## Glm46VForConditionalGeneration
|
||||
|
||||
[[autodoc]] Glm46VForConditionalGeneration
|
||||
- forward
|
||||
@ -170,6 +170,11 @@ print(output_text)
|
||||
|
||||
[[autodoc]] Glm4vConfig
|
||||
|
||||
|
||||
## Glm4vVisionConfig
|
||||
|
||||
[[autodoc]] Glm4vVisionConfig
|
||||
|
||||
## Glm4vTextConfig
|
||||
|
||||
[[autodoc]] Glm4vTextConfig
|
||||
@ -193,6 +198,11 @@ print(output_text)
|
||||
|
||||
[[autodoc]] Glm4vProcessor
|
||||
|
||||
## Glm4vVisionModel
|
||||
|
||||
[[autodoc]] Glm4vVisionModel
|
||||
- forward
|
||||
|
||||
## Glm4vTextModel
|
||||
|
||||
[[autodoc]] Glm4vTextModel
|
||||
|
||||
@ -22,7 +22,7 @@ rendered properly in your Markdown viewer.
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white"> </div>
|
||||
</div>
|
||||
|
||||
# Glm4vMoe
|
||||
# Glm4vMoeMoe
|
||||
|
||||
## Overview
|
||||
|
||||
@ -48,10 +48,20 @@ The model also introduces a **Thinking Mode** switch, allowing users to balance
|
||||
|
||||
[[autodoc]] Glm4vMoeConfig
|
||||
|
||||
|
||||
## Glm4vMoeVisionConfig
|
||||
|
||||
[[autodoc]] Glm4vMoeVisionConfig
|
||||
|
||||
## Glm4vMoeTextConfig
|
||||
|
||||
[[autodoc]] Glm4vMoeTextConfig
|
||||
|
||||
## Glm4vMoeVisionModel
|
||||
|
||||
[[autodoc]] Glm4vMoeVisionModel
|
||||
- forward
|
||||
|
||||
## Glm4vMoeTextModel
|
||||
|
||||
[[autodoc]] Glm4vMoeTextModel
|
||||
@ -65,4 +75,4 @@ The model also introduces a **Thinking Mode** switch, allowing users to balance
|
||||
## Glm4vMoeForConditionalGeneration
|
||||
|
||||
[[autodoc]] Glm4vMoeForConditionalGeneration
|
||||
- forward
|
||||
- forward
|
||||
@ -1,145 +0,0 @@
|
||||
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
*This model was released on 2023-02-07 and added to Hugging Face Transformers on 2023-06-20.*
|
||||
|
||||
# GPTSAN-japanese
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This model is in maintenance mode only, we don't accept any new PRs changing its code.
|
||||
If you run into any issues running this model, please reinstall the last version that supported this model: v4.40.2.
|
||||
You can do so by running the following command: `pip install -U transformers==4.40.2`.
|
||||
|
||||
</Tip>
|
||||
|
||||
## Overview
|
||||
|
||||
The [GPTSAN-japanese](https://huggingface.co/Tanrei/GPTSAN-japanese) model was released in the repository by Toshiyuki Sakamoto (tanreinama).
|
||||
|
||||
GPTSAN is a Japanese language model using Switch Transformer. It has the same structure as the model introduced as Prefix LM
|
||||
in the T5 paper, and support both Text Generation and Masked Language Modeling tasks. These basic tasks similarly can
|
||||
fine-tune for translation or summarization.
|
||||
|
||||
### Usage example
|
||||
|
||||
The `generate()` method can be used to generate text using GPTSAN-Japanese model.
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoModel, AutoTokenizer
|
||||
from accelerate import Accelerator
|
||||
>>> import torch
|
||||
|
||||
>>> device = Accelerator().device
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("Tanrei/GPTSAN-japanese")
|
||||
>>> model = AutoModel.from_pretrained("Tanrei/GPTSAN-japanese").to(device)
|
||||
>>> x_tok = tokenizer("は、", prefix_text="織田信長", return_tensors="pt")
|
||||
>>> torch.manual_seed(0)
|
||||
>>> gen_tok = model.generate(x_tok.input_ids.to(model.device), token_type_ids=x_tok.token_type_ids.to(model.device), max_new_tokens=20)
|
||||
>>> tokenizer.decode(gen_tok[0])
|
||||
'織田信長は、2004年に『戦国BASARA』のために、豊臣秀吉'
|
||||
```
|
||||
|
||||
## GPTSAN Features
|
||||
|
||||
GPTSAN has some unique features. It has a model structure of Prefix-LM. It works as a shifted Masked Language Model for Prefix Input tokens. Un-prefixed inputs behave like normal generative models.
|
||||
The Spout vector is a GPTSAN specific input. Spout is pre-trained with random inputs, but you can specify a class of text or an arbitrary vector during fine-tuning. This allows you to indicate the tendency of the generated text.
|
||||
GPTSAN has a sparse Feed Forward based on Switch-Transformer. You can also add other layers and train them partially. See the original GPTSAN repository for details.
|
||||
|
||||
### Prefix-LM Model
|
||||
|
||||
GPTSAN has the structure of the model named Prefix-LM in the `T5` paper. (The original GPTSAN repository calls it `hybrid`)
|
||||
In GPTSAN, the `Prefix` part of Prefix-LM, that is, the input position that can be referenced by both tokens, can be specified with any length.
|
||||
Arbitrary lengths can also be specified differently for each batch.
|
||||
This length applies to the text entered in `prefix_text` for the tokenizer.
|
||||
The tokenizer returns the mask of the `Prefix` part of Prefix-LM as `token_type_ids`.
|
||||
The model treats the part where `token_type_ids` is 1 as a `Prefix` part, that is, the input can refer to both tokens before and after.
|
||||
|
||||
## Usage tips
|
||||
|
||||
Specifying the Prefix part is done with a mask passed to self-attention.
|
||||
When token_type_ids=None or all zero, it is equivalent to regular causal mask
|
||||
|
||||
for example:
|
||||
|
||||
>>> x_token = tokenizer("アイウエ")
|
||||
|
||||
```text
|
||||
input_ids: | SOT | SEG | ア | イ | ウ | エ |
|
||||
token_type_ids: | 1 | 0 | 0 | 0 | 0 | 0 |
|
||||
prefix_lm_mask:
|
||||
SOT | 1 0 0 0 0 0 |
|
||||
SEG | 1 1 0 0 0 0 |
|
||||
ア | 1 1 1 0 0 0 |
|
||||
イ | 1 1 1 1 0 0 |
|
||||
ウ | 1 1 1 1 1 0 |
|
||||
エ | 1 1 1 1 1 1 |
|
||||
```
|
||||
|
||||
>>> x_token = tokenizer("", prefix_text="アイウエ")
|
||||
|
||||
```text
|
||||
input_ids: | SOT | ア | イ | ウ | エ | SEG |
|
||||
token_type_ids: | 1 | 1 | 1 | 1 | 1 | 0 |
|
||||
prefix_lm_mask:
|
||||
SOT | 1 1 1 1 1 0 |
|
||||
ア | 1 1 1 1 1 0 |
|
||||
イ | 1 1 1 1 1 0 |
|
||||
ウ | 1 1 1 1 1 0 |
|
||||
エ | 1 1 1 1 1 0 |
|
||||
SEG | 1 1 1 1 1 1 |
|
||||
```
|
||||
|
||||
>>> x_token = tokenizer("ウエ", prefix_text="アイ")
|
||||
|
||||
```text
|
||||
input_ids: | SOT | ア | イ | SEG | ウ | エ |
|
||||
token_type_ids: | 1 | 1 | 1 | 0 | 0 | 0 |
|
||||
prefix_lm_mask:
|
||||
SOT | 1 1 1 0 0 0 |
|
||||
ア | 1 1 1 0 0 0 |
|
||||
イ | 1 1 1 0 0 0 |
|
||||
SEG | 1 1 1 1 0 0 |
|
||||
ウ | 1 1 1 1 1 0 |
|
||||
エ | 1 1 1 1 1 1 |
|
||||
```
|
||||
|
||||
### Spout Vector
|
||||
|
||||
A Spout Vector is a special vector for controlling text generation.
|
||||
This vector is treated as the first embedding in self-attention to bring extraneous attention to the generated tokens.
|
||||
In the pre-trained model published from `Tanrei/GPTSAN-japanese`, the Spout Vector is a 128-dimensional vector that passes through 8 fully connected layers in the model and is projected into the space acting as external attention.
|
||||
The Spout Vector projected by the fully connected layer is split to be passed to all self-attentions.
|
||||
|
||||
## GPTSanJapaneseConfig
|
||||
|
||||
[[autodoc]] GPTSanJapaneseConfig
|
||||
|
||||
## GPTSanJapaneseTokenizer
|
||||
|
||||
[[autodoc]] GPTSanJapaneseTokenizer
|
||||
|
||||
## GPTSanJapaneseModel
|
||||
|
||||
[[autodoc]] GPTSanJapaneseModel
|
||||
|
||||
## GPTSanJapaneseForConditionalGeneration
|
||||
|
||||
[[autodoc]] GPTSanJapaneseForConditionalGeneration
|
||||
- forward
|
||||
@ -1,60 +0,0 @@
|
||||
<!--Copyright 2022 The HuggingFace Team and Microsoft. All rights reserved.
|
||||
|
||||
Licensed under the MIT License; you may not use this file except in compliance with
|
||||
the License.
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
*This model was released on 2021-06-09 and added to Hugging Face Transformers on 2023-06-20.*
|
||||
|
||||
# Graphormer
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This model is in maintenance mode only, we don't accept any new PRs changing its code.
|
||||
If you run into any issues running this model, please reinstall the last version that supported this model: v4.40.2.
|
||||
You can do so by running the following command: `pip install -U transformers==4.40.2`.
|
||||
|
||||
</Tip>
|
||||
|
||||
## Overview
|
||||
|
||||
The Graphormer model was proposed in [Do Transformers Really Perform Bad for Graph Representation?](https://huggingface.co/papers/2106.05234) by
|
||||
Chengxuan Ying, Tianle Cai, Shengjie Luo, Shuxin Zheng, Guolin Ke, Di He, Yanming Shen and Tie-Yan Liu. It is a Graph Transformer model, modified to allow computations on graphs instead of text sequences by generating embeddings and features of interest during preprocessing and collation, then using a modified attention.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
*The Transformer architecture has become a dominant choice in many domains, such as natural language processing and computer vision. Yet, it has not achieved competitive performance on popular leaderboards of graph-level prediction compared to mainstream GNN variants. Therefore, it remains a mystery how Transformers could perform well for graph representation learning. In this paper, we solve this mystery by presenting Graphormer, which is built upon the standard Transformer architecture, and could attain excellent results on a broad range of graph representation learning tasks, especially on the recent OGB Large-Scale Challenge. Our key insight to utilizing Transformer in the graph is the necessity of effectively encoding the structural information of a graph into the model. To this end, we propose several simple yet effective structural encoding methods to help Graphormer better model graph-structured data. Besides, we mathematically characterize the expressive power of Graphormer and exhibit that with our ways of encoding the structural information of graphs, many popular GNN variants could be covered as the special cases of Graphormer.*
|
||||
|
||||
This model was contributed by [clefourrier](https://huggingface.co/clefourrier). The original code can be found [here](https://github.com/microsoft/Graphormer).
|
||||
|
||||
## Usage tips
|
||||
|
||||
This model will not work well on large graphs (more than 100 nodes/edges), as it will make the memory explode.
|
||||
You can reduce the batch size, increase your RAM, or decrease the `UNREACHABLE_NODE_DISTANCE` parameter in algos_graphormer.pyx, but it will be hard to go above 700 nodes/edges.
|
||||
|
||||
This model does not use a tokenizer, but instead a special collator during training.
|
||||
|
||||
## GraphormerConfig
|
||||
|
||||
[[autodoc]] GraphormerConfig
|
||||
|
||||
## GraphormerModel
|
||||
|
||||
[[autodoc]] GraphormerModel
|
||||
- forward
|
||||
|
||||
## GraphormerForGraphClassification
|
||||
|
||||
[[autodoc]] GraphormerForGraphClassification
|
||||
- forward
|
||||
@ -1,99 +0,0 @@
|
||||
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
*This model was released on 2020-04-30 and added to Hugging Face Transformers on 2023-06-20.*
|
||||
|
||||
# Jukebox
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This model is in maintenance mode only, we don't accept any new PRs changing its code.
|
||||
If you run into any issues running this model, please reinstall the last version that supported this model: v4.40.2.
|
||||
You can do so by running the following command: `pip install -U transformers==4.40.2`.
|
||||
|
||||
</Tip>
|
||||
|
||||
## Overview
|
||||
|
||||
The Jukebox model was proposed in [Jukebox: A generative model for music](https://huggingface.co/papers/2005.00341)
|
||||
by Prafulla Dhariwal, Heewoo Jun, Christine Payne, Jong Wook Kim, Alec Radford,
|
||||
Ilya Sutskever. It introduces a generative music model which can produce minute long samples that can be conditioned on
|
||||
an artist, genres and lyrics.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
*We introduce Jukebox, a model that generates music with singing in the raw audio domain. We tackle the long context of raw audio using a multiscale VQ-VAE to compress it to discrete codes, and modeling those using autoregressive Transformers. We show that the combined model at scale can generate high-fidelity and diverse songs with coherence up to multiple minutes. We can condition on artist and genre to steer the musical and vocal style, and on unaligned lyrics to make the singing more controllable. We are releasing thousands of non cherry-picked samples, along with model weights and code.*
|
||||
|
||||
As shown on the following figure, Jukebox is made of 3 `priors` which are decoder only models. They follow the architecture described in [Generating Long Sequences with Sparse Transformers](https://huggingface.co/papers/1904.10509), modified to support longer context length.
|
||||
First, a autoencoder is used to encode the text lyrics. Next, the first (also called `top_prior`) prior attends to the last hidden states extracted from the lyrics encoder. The priors are linked to the previous priors respectively via an `AudioConditioner` module. The`AudioConditioner` upsamples the outputs of the previous prior to raw tokens at a certain audio frame per second resolution.
|
||||
The metadata such as *artist, genre and timing* are passed to each prior, in the form of a start token and positional embedding for the timing data. The hidden states are mapped to the closest codebook vector from the VQVAE in order to convert them to raw audio.
|
||||
|
||||

|
||||
|
||||
This model was contributed by [Arthur Zucker](https://huggingface.co/ArthurZ).
|
||||
The original code can be found [here](https://github.com/openai/jukebox).
|
||||
|
||||
## Usage tips
|
||||
|
||||
- This model only supports inference. This is for a few reasons, mostly because it requires a crazy amount of memory to train. Feel free to open a PR and add what's missing to have a full integration with the hugging face trainer!
|
||||
- This model is very slow, and takes 8h to generate a minute long audio using the 5b top prior on a V100 GPU. In order automaticallay handle the device on which the model should execute, use `accelerate`.
|
||||
- Contrary to the paper, the order of the priors goes from `0` to `1` as it felt more intuitive : we sample starting from `0`.
|
||||
- Primed sampling (conditioning the sampling on raw audio) requires more memory than ancestral sampling and should be used with `fp16` set to `True`.
|
||||
|
||||
This model was contributed by [Arthur Zucker](https://huggingface.co/ArthurZ).
|
||||
The original code can be found [here](https://github.com/openai/jukebox).
|
||||
|
||||
## JukeboxConfig
|
||||
|
||||
[[autodoc]] JukeboxConfig
|
||||
|
||||
## JukeboxPriorConfig
|
||||
|
||||
[[autodoc]] JukeboxPriorConfig
|
||||
|
||||
## JukeboxVQVAEConfig
|
||||
|
||||
[[autodoc]] JukeboxVQVAEConfig
|
||||
|
||||
## JukeboxTokenizer
|
||||
|
||||
[[autodoc]] JukeboxTokenizer
|
||||
- save_vocabulary
|
||||
|
||||
## JukeboxModel
|
||||
|
||||
[[autodoc]] JukeboxModel
|
||||
- ancestral_sample
|
||||
- primed_sample
|
||||
- continue_sample
|
||||
- upsample
|
||||
- _sample
|
||||
|
||||
## JukeboxPrior
|
||||
|
||||
[[autodoc]] JukeboxPrior
|
||||
- sample
|
||||
- forward
|
||||
|
||||
## JukeboxVQVAE
|
||||
|
||||
[[autodoc]] JukeboxVQVAE
|
||||
- forward
|
||||
- encode
|
||||
- decode
|
||||
@ -1,84 +0,0 @@
|
||||
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
*This model was released on 2021-10-30 and added to Hugging Face Transformers on 2023-06-20.*
|
||||
|
||||
# M-CTC-T
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This model is in maintenance mode only, so we won't accept any new PRs changing its code.
|
||||
|
||||
If you run into any issues running this model, please reinstall the last version that supported this model: v4.30.0.
|
||||
You can do so by running the following command: `pip install -U transformers==4.30.0`.
|
||||
|
||||
</Tip>
|
||||
|
||||
## Overview
|
||||
|
||||
The M-CTC-T model was proposed in [Pseudo-Labeling For Massively Multilingual Speech Recognition](https://huggingface.co/papers/2111.00161) by Loren Lugosch, Tatiana Likhomanenko, Gabriel Synnaeve, and Ronan Collobert. The model is a 1B-param transformer encoder, with a CTC head over 8065 character labels and a language identification head over 60 language ID labels. It is trained on Common Voice (version 6.1, December 2020 release) and VoxPopuli. After training on Common Voice and VoxPopuli, the model is trained on Common Voice only. The labels are unnormalized character-level transcripts (punctuation and capitalization are not removed). The model takes as input Mel filterbank features from a 16Khz audio signal.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
*Semi-supervised learning through pseudo-labeling has become a staple of state-of-the-art monolingual
|
||||
speech recognition systems. In this work, we extend pseudo-labeling to massively multilingual speech
|
||||
recognition with 60 languages. We propose a simple pseudo-labeling recipe that works well even
|
||||
with low-resource languages: train a supervised multilingual model, fine-tune it with semi-supervised
|
||||
learning on a target language, generate pseudo-labels for that language, and train a final model using
|
||||
pseudo-labels for all languages, either from scratch or by fine-tuning. Experiments on the labeled
|
||||
Common Voice and unlabeled VoxPopuli datasets show that our recipe can yield a model with better
|
||||
performance for many languages that also transfers well to LibriSpeech.*
|
||||
|
||||
This model was contributed by [cwkeam](https://huggingface.co/cwkeam). The original code can be found [here](https://github.com/flashlight/wav2letter/tree/main/recipes/mling_pl).
|
||||
|
||||
## Usage tips
|
||||
|
||||
The PyTorch version of this model is only available in torch 1.9 and higher.
|
||||
|
||||
## Resources
|
||||
|
||||
- [Automatic speech recognition task guide](../tasks/asr)
|
||||
|
||||
## MCTCTConfig
|
||||
|
||||
[[autodoc]] MCTCTConfig
|
||||
|
||||
## MCTCTFeatureExtractor
|
||||
|
||||
[[autodoc]] MCTCTFeatureExtractor
|
||||
- __call__
|
||||
|
||||
## MCTCTProcessor
|
||||
|
||||
[[autodoc]] MCTCTProcessor
|
||||
- __call__
|
||||
- from_pretrained
|
||||
- save_pretrained
|
||||
- batch_decode
|
||||
- decode
|
||||
|
||||
## MCTCTModel
|
||||
|
||||
[[autodoc]] MCTCTModel
|
||||
- forward
|
||||
|
||||
## MCTCTForCTC
|
||||
|
||||
[[autodoc]] MCTCTForCTC
|
||||
- forward
|
||||
@ -1,94 +0,0 @@
|
||||
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
*This model was released on 2022-09-21 and added to Hugging Face Transformers on 2023-06-20.*
|
||||
|
||||
# MEGA
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This model is in maintenance mode only, we don't accept any new PRs changing its code.
|
||||
If you run into any issues running this model, please reinstall the last version that supported this model: v4.40.2.
|
||||
You can do so by running the following command: `pip install -U transformers==4.40.2`.
|
||||
|
||||
</Tip>
|
||||
|
||||
## Overview
|
||||
|
||||
The MEGA model was proposed in [Mega: Moving Average Equipped Gated Attention](https://huggingface.co/papers/2209.10655) by Xuezhe Ma, Chunting Zhou, Xiang Kong, Junxian He, Liangke Gui, Graham Neubig, Jonathan May, and Luke Zettlemoyer.
|
||||
MEGA proposes a new approach to self-attention with each encoder layer having a multi-headed exponential moving average in addition to a single head of standard dot-product attention, giving the attention mechanism
|
||||
stronger positional biases. This allows MEGA to perform competitively to Transformers on standard benchmarks including LRA
|
||||
while also having significantly fewer parameters. MEGA's compute efficiency allows it to scale to very long sequences, making it an
|
||||
attractive option for long-document NLP tasks.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
*The design choices in the Transformer attention mechanism, including weak inductive bias and quadratic computational complexity, have limited its application for modeling long sequences. In this paper, we introduce Mega, a simple, theoretically grounded, single-head gated attention mechanism equipped with (exponential) moving average to incorporate inductive bias of position-aware local dependencies into the position-agnostic attention mechanism. We further propose a variant of Mega that offers linear time and space complexity yet yields only minimal quality loss, by efficiently splitting the whole sequence into multiple chunks with fixed length. Extensive experiments on a wide range of sequence modeling benchmarks, including the Long Range Arena, neural machine translation, auto-regressive language modeling, and image and speech classification, show that Mega achieves significant improvements over other sequence models, including variants of Transformers and recent state space models.*
|
||||
|
||||
This model was contributed by [mnaylor](https://huggingface.co/mnaylor).
|
||||
The original code can be found [here](https://github.com/facebookresearch/mega).
|
||||
|
||||
## Usage tips
|
||||
|
||||
- MEGA can perform quite well with relatively few parameters. See Appendix D in the MEGA paper for examples of architectural specs which perform well in various settings. If using MEGA as a decoder, be sure to set `bidirectional=False` to avoid errors with default bidirectional.
|
||||
- Mega-chunk is a variant of mega that reduces time and spaces complexity from quadratic to linear. Utilize chunking with MegaConfig.use_chunking and control chunk size with MegaConfig.chunk_size
|
||||
|
||||
## Implementation Notes
|
||||
|
||||
- The original implementation of MEGA had an inconsistent expectation of attention masks for padding and causal self-attention between the softmax attention and Laplace/squared ReLU method. This implementation addresses that inconsistency.
|
||||
- The original implementation did not include token type embeddings; this implementation adds support for these, with the option controlled by MegaConfig.add_token_type_embeddings
|
||||
|
||||
## MegaConfig
|
||||
|
||||
[[autodoc]] MegaConfig
|
||||
|
||||
## MegaModel
|
||||
|
||||
[[autodoc]] MegaModel
|
||||
- forward
|
||||
|
||||
## MegaForCausalLM
|
||||
|
||||
[[autodoc]] MegaForCausalLM
|
||||
- forward
|
||||
|
||||
## MegaForMaskedLM
|
||||
|
||||
[[autodoc]] MegaForMaskedLM
|
||||
- forward
|
||||
|
||||
## MegaForSequenceClassification
|
||||
|
||||
[[autodoc]] MegaForSequenceClassification
|
||||
- forward
|
||||
|
||||
## MegaForMultipleChoice
|
||||
|
||||
[[autodoc]] MegaForMultipleChoice
|
||||
- forward
|
||||
|
||||
## MegaForTokenClassification
|
||||
|
||||
[[autodoc]] MegaForTokenClassification
|
||||
- forward
|
||||
|
||||
## MegaForQuestionAnswering
|
||||
|
||||
[[autodoc]] MegaForQuestionAnswering
|
||||
- forward
|
||||
@ -1,101 +0,0 @@
|
||||
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
*This model was released on 2022-04-14 and added to Hugging Face Transformers on 2023-06-20.*
|
||||
|
||||
# Neighborhood Attention Transformer
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This model is in maintenance mode only, we don't accept any new PRs changing its code.
|
||||
If you run into any issues running this model, please reinstall the last version that supported this model: v4.40.2.
|
||||
You can do so by running the following command: `pip install -U transformers==4.40.2`.
|
||||
|
||||
</Tip>
|
||||
|
||||
## Overview
|
||||
|
||||
NAT was proposed in [Neighborhood Attention Transformer](https://huggingface.co/papers/2204.07143)
|
||||
by Ali Hassani, Steven Walton, Jiachen Li, Shen Li, and Humphrey Shi.
|
||||
|
||||
It is a hierarchical vision transformer based on Neighborhood Attention, a sliding-window self attention pattern.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
*We present Neighborhood Attention (NA), the first efficient and scalable sliding-window attention mechanism for vision.
|
||||
NA is a pixel-wise operation, localizing self attention (SA) to the nearest neighboring pixels, and therefore enjoys a
|
||||
linear time and space complexity compared to the quadratic complexity of SA. The sliding-window pattern allows NA's
|
||||
receptive field to grow without needing extra pixel shifts, and preserves translational equivariance, unlike
|
||||
Swin Transformer's Window Self Attention (WSA). We develop NATTEN (Neighborhood Attention Extension), a Python package
|
||||
with efficient C++ and CUDA kernels, which allows NA to run up to 40% faster than Swin's WSA while using up to 25% less
|
||||
memory. We further present Neighborhood Attention Transformer (NAT), a new hierarchical transformer design based on NA
|
||||
that boosts image classification and downstream vision performance. Experimental results on NAT are competitive;
|
||||
NAT-Tiny reaches 83.2% top-1 accuracy on ImageNet, 51.4% mAP on MS-COCO and 48.4% mIoU on ADE20K, which is 1.9%
|
||||
ImageNet accuracy, 1.0% COCO mAP, and 2.6% ADE20K mIoU improvement over a Swin model with similar size.*
|
||||
|
||||
<img
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/neighborhood-attention-pattern.jpg"
|
||||
alt="drawing" width="600"/>
|
||||
|
||||
<small> Neighborhood Attention compared to other attention patterns.
|
||||
Taken from the <a href="https://huggingface.co/papers/2204.07143">original paper</a>.</small>
|
||||
|
||||
This model was contributed by [Ali Hassani](https://huggingface.co/alihassanijr).
|
||||
The original code can be found [here](https://github.com/SHI-Labs/Neighborhood-Attention-Transformer).
|
||||
|
||||
## Usage tips
|
||||
|
||||
- One can use the [`AutoImageProcessor`] API to prepare images for the model.
|
||||
- NAT can be used as a *backbone*. When `output_hidden_states = True`,
|
||||
it will output both `hidden_states` and `reshaped_hidden_states`.
|
||||
The `reshaped_hidden_states` have a shape of `(batch, num_channels, height, width)` rather than
|
||||
`(batch_size, height, width, num_channels)`.
|
||||
|
||||
Notes:
|
||||
|
||||
- NAT depends on [NATTEN](https://github.com/SHI-Labs/NATTEN/)'s implementation of Neighborhood Attention.
|
||||
You can install it with pre-built wheels for Linux by referring to [shi-labs.com/natten](https://shi-labs.com/natten),
|
||||
or build on your system by running `pip install natten`.
|
||||
Note that the latter will likely take time to compile. NATTEN does not support Windows devices yet.
|
||||
- Patch size of 4 is only supported at the moment.
|
||||
|
||||
## Resources
|
||||
|
||||
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with NAT.
|
||||
|
||||
<PipelineTag pipeline="image-classification"/>
|
||||
|
||||
- [`NatForImageClassification`] is supported by this [example script](https://github.com/huggingface/transformers/tree/main/examples/pytorch/image-classification) and [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/image_classification.ipynb).
|
||||
- See also: [Image classification task guide](../tasks/image_classification)
|
||||
|
||||
If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we'll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource.
|
||||
|
||||
## NatConfig
|
||||
|
||||
[[autodoc]] NatConfig
|
||||
|
||||
## NatModel
|
||||
|
||||
[[autodoc]] NatModel
|
||||
- forward
|
||||
|
||||
## NatForImageClassification
|
||||
|
||||
[[autodoc]] NatForImageClassification
|
||||
- forward
|
||||
@ -1,101 +0,0 @@
|
||||
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
*This model was released on 2019-08-31 and added to Hugging Face Transformers on 2023-06-20.*
|
||||
|
||||
# Nezha
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This model is in maintenance mode only, we don't accept any new PRs changing its code.
|
||||
If you run into any issues running this model, please reinstall the last version that supported this model: v4.40.2.
|
||||
You can do so by running the following command: `pip install -U transformers==4.40.2`.
|
||||
|
||||
</Tip>
|
||||
|
||||
## Overview
|
||||
|
||||
The Nezha model was proposed in [NEZHA: Neural Contextualized Representation for Chinese Language Understanding](https://huggingface.co/papers/1909.00204) by Junqiu Wei et al.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
*The pre-trained language models have achieved great successes in various natural language understanding (NLU) tasks
|
||||
due to its capacity to capture the deep contextualized information in text by pre-training on large-scale corpora.
|
||||
In this technical report, we present our practice of pre-training language models named NEZHA (NEural contextualiZed
|
||||
representation for CHinese lAnguage understanding) on Chinese corpora and finetuning for the Chinese NLU tasks.
|
||||
The current version of NEZHA is based on BERT with a collection of proven improvements, which include Functional
|
||||
Relative Positional Encoding as an effective positional encoding scheme, Whole Word Masking strategy,
|
||||
Mixed Precision Training and the LAMB Optimizer in training the models. The experimental results show that NEZHA
|
||||
achieves the state-of-the-art performances when finetuned on several representative Chinese tasks, including
|
||||
named entity recognition (People's Daily NER), sentence matching (LCQMC), Chinese sentiment classification (ChnSenti)
|
||||
and natural language inference (XNLI).*
|
||||
|
||||
This model was contributed by [sijunhe](https://huggingface.co/sijunhe). The original code can be found [here](https://github.com/huawei-noah/Pretrained-Language-Model/tree/master/NEZHA-PyTorch).
|
||||
|
||||
## Resources
|
||||
|
||||
- [Text classification task guide](../tasks/sequence_classification)
|
||||
- [Token classification task guide](../tasks/token_classification)
|
||||
- [Question answering task guide](../tasks/question_answering)
|
||||
- [Masked language modeling task guide](../tasks/masked_language_modeling)
|
||||
- [Multiple choice task guide](../tasks/multiple_choice)
|
||||
|
||||
## NezhaConfig
|
||||
|
||||
[[autodoc]] NezhaConfig
|
||||
|
||||
## NezhaModel
|
||||
|
||||
[[autodoc]] NezhaModel
|
||||
- forward
|
||||
|
||||
## NezhaForPreTraining
|
||||
|
||||
[[autodoc]] NezhaForPreTraining
|
||||
- forward
|
||||
|
||||
## NezhaForMaskedLM
|
||||
|
||||
[[autodoc]] NezhaForMaskedLM
|
||||
- forward
|
||||
|
||||
## NezhaForNextSentencePrediction
|
||||
|
||||
[[autodoc]] NezhaForNextSentencePrediction
|
||||
- forward
|
||||
|
||||
## NezhaForSequenceClassification
|
||||
|
||||
[[autodoc]] NezhaForSequenceClassification
|
||||
- forward
|
||||
|
||||
## NezhaForMultipleChoice
|
||||
|
||||
[[autodoc]] NezhaForMultipleChoice
|
||||
- forward
|
||||
|
||||
## NezhaForTokenClassification
|
||||
|
||||
[[autodoc]] NezhaForTokenClassification
|
||||
- forward
|
||||
|
||||
## NezhaForQuestionAnswering
|
||||
|
||||
[[autodoc]] NezhaForQuestionAnswering
|
||||
- forward
|
||||
@ -1,66 +0,0 @@
|
||||
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
*This model was released on 2023-04-16 and added to Hugging Face Transformers on 2023-06-20.*
|
||||
|
||||
# Open-Llama
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This model is in maintenance mode only, we don't accept any new PRs changing its code.
|
||||
|
||||
If you run into any issues running this model, please reinstall the last version that supported this model: v4.31.0.
|
||||
You can do so by running the following command: `pip install -U transformers==4.31.0`.
|
||||
|
||||
</Tip>
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This model differs from the [OpenLLaMA models](https://huggingface.co/models?search=openllama) on the Hugging Face Hub, which primarily use the [LLaMA](llama) architecture.
|
||||
|
||||
</Tip>
|
||||
|
||||
## Overview
|
||||
|
||||
The Open-Llama model was proposed in the open source Open-Llama project by community developer s-JoL.
|
||||
|
||||
The model is mainly based on LLaMA with some modifications, incorporating memory-efficient attention from Xformers, stable embedding from Bloom, and shared input-output embedding from PaLM.
|
||||
And the model is pre-trained on both Chinese and English, which gives it better performance on Chinese language tasks.
|
||||
|
||||
This model was contributed by [s-JoL](https://huggingface.co/s-JoL).
|
||||
The original code was released on GitHub by [s-JoL](https://github.com/s-JoL), but is now removed.
|
||||
|
||||
## OpenLlamaConfig
|
||||
|
||||
[[autodoc]] OpenLlamaConfig
|
||||
|
||||
## OpenLlamaModel
|
||||
|
||||
[[autodoc]] OpenLlamaModel
|
||||
- forward
|
||||
|
||||
## OpenLlamaForCausalLM
|
||||
|
||||
[[autodoc]] OpenLlamaForCausalLM
|
||||
- forward
|
||||
|
||||
## OpenLlamaForSequenceClassification
|
||||
|
||||
[[autodoc]] OpenLlamaForSequenceClassification
|
||||
- forward
|
||||
@ -1,183 +0,0 @@
|
||||
<!--Copyright 2021 NVIDIA Corporation and The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
*This model was released on 2020-04-20 and added to Hugging Face Transformers on 2023-06-20.*
|
||||
|
||||
# QDQBERT
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This model is in maintenance mode only, we don't accept any new PRs changing its code.
|
||||
If you run into any issues running this model, please reinstall the last version that supported this model: v4.40.2.
|
||||
You can do so by running the following command: `pip install -U transformers==4.40.2`.
|
||||
|
||||
</Tip>
|
||||
|
||||
## Overview
|
||||
|
||||
The QDQBERT model can be referenced in [Integer Quantization for Deep Learning Inference: Principles and Empirical
|
||||
Evaluation](https://huggingface.co/papers/2004.09602) by Hao Wu, Patrick Judd, Xiaojie Zhang, Mikhail Isaev and Paulius
|
||||
Micikevicius.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
*Quantization techniques can reduce the size of Deep Neural Networks and improve inference latency and throughput by
|
||||
taking advantage of high throughput integer instructions. In this paper we review the mathematical aspects of
|
||||
quantization parameters and evaluate their choices on a wide range of neural network models for different application
|
||||
domains, including vision, speech, and language. We focus on quantization techniques that are amenable to acceleration
|
||||
by processors with high-throughput integer math pipelines. We also present a workflow for 8-bit quantization that is
|
||||
able to maintain accuracy within 1% of the floating-point baseline on all networks studied, including models that are
|
||||
more difficult to quantize, such as MobileNets and BERT-large.*
|
||||
|
||||
This model was contributed by [shangz](https://huggingface.co/shangz).
|
||||
|
||||
## Usage tips
|
||||
|
||||
- QDQBERT model adds fake quantization operations (pair of QuantizeLinear/DequantizeLinear ops) to (i) linear layer
|
||||
inputs and weights, (ii) matmul inputs, (iii) residual add inputs, in BERT model.
|
||||
- QDQBERT requires the dependency of [Pytorch Quantization Toolkit](https://github.com/NVIDIA/TensorRT/tree/master/tools/pytorch-quantization). To install `pip install pytorch-quantization --extra-index-url https://pypi.ngc.nvidia.com`
|
||||
- QDQBERT model can be loaded from any checkpoint of HuggingFace BERT model (for example *google-bert/bert-base-uncased*), and
|
||||
perform Quantization Aware Training/Post Training Quantization.
|
||||
- A complete example of using QDQBERT model to perform Quatization Aware Training and Post Training Quantization for
|
||||
SQUAD task can be found at https://github.com/huggingface/transformers-research-projects/tree/main/quantization-qdqbert.
|
||||
|
||||
### Set default quantizers
|
||||
|
||||
QDQBERT model adds fake quantization operations (pair of QuantizeLinear/DequantizeLinear ops) to BERT by
|
||||
`TensorQuantizer` in [Pytorch Quantization Toolkit](https://github.com/NVIDIA/TensorRT/tree/master/tools/pytorch-quantization). `TensorQuantizer` is the module
|
||||
for quantizing tensors, with `QuantDescriptor` defining how the tensor should be quantized. Refer to [Pytorch
|
||||
Quantization Toolkit userguide](https://docs.nvidia.com/deeplearning/tensorrt/pytorch-quantization-toolkit/docs/userguide.html) for more details.
|
||||
|
||||
Before creating QDQBERT model, one has to set the default `QuantDescriptor` defining default tensor quantizers.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> import pytorch_quantization.nn as quant_nn
|
||||
>>> from pytorch_quantization.tensor_quant import QuantDescriptor
|
||||
|
||||
>>> # The default tensor quantizer is set to use Max calibration method
|
||||
>>> input_desc = QuantDescriptor(num_bits=8, calib_method="max")
|
||||
>>> # The default tensor quantizer is set to be per-channel quantization for weights
|
||||
>>> weight_desc = QuantDescriptor(num_bits=8, axis=((0,)))
|
||||
>>> quant_nn.QuantLinear.set_default_quant_desc_input(input_desc)
|
||||
>>> quant_nn.QuantLinear.set_default_quant_desc_weight(weight_desc)
|
||||
```
|
||||
|
||||
### Calibration
|
||||
|
||||
Calibration is the terminology of passing data samples to the quantizer and deciding the best scaling factors for
|
||||
tensors. After setting up the tensor quantizers, one can use the following example to calibrate the model:
|
||||
|
||||
```python
|
||||
>>> # Find the TensorQuantizer and enable calibration
|
||||
>>> for name, module in model.named_modules():
|
||||
... if name.endswith("_input_quantizer"):
|
||||
... module.enable_calib()
|
||||
... module.disable_quant() # Use full precision data to calibrate
|
||||
|
||||
>>> # Feeding data samples
|
||||
>>> model(x)
|
||||
>>> # ...
|
||||
|
||||
>>> # Finalize calibration
|
||||
>>> for name, module in model.named_modules():
|
||||
... if name.endswith("_input_quantizer"):
|
||||
... module.load_calib_amax()
|
||||
... module.enable_quant()
|
||||
|
||||
>>> # If running on accelerator, it needs to call `.to(xx)` again because new tensors will be created by calibration process
|
||||
>>> from accelerate import Accelerator
|
||||
>>> device = Accelerator().device
|
||||
>>> model.to(device)
|
||||
|
||||
>>> # Keep running the quantized model
|
||||
>>> # ...
|
||||
```
|
||||
|
||||
### Export to ONNX
|
||||
|
||||
The goal of exporting to ONNX is to deploy inference by [TensorRT](https://developer.nvidia.com/tensorrt). Fake
|
||||
quantization will be broken into a pair of QuantizeLinear/DequantizeLinear ONNX ops. After setting static member of
|
||||
TensorQuantizer to use Pytorch's own fake quantization functions, fake quantized model can be exported to ONNX, follow
|
||||
the instructions in [torch.onnx](https://pytorch.org/docs/stable/onnx.html). Example:
|
||||
|
||||
```python
|
||||
>>> from pytorch_quantization.nn import TensorQuantizer
|
||||
|
||||
>>> TensorQuantizer.use_fb_fake_quant = True
|
||||
|
||||
>>> # Load the calibrated model
|
||||
>>> ...
|
||||
>>> # ONNX export
|
||||
>>> torch.onnx.export(...)
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
- [Text classification task guide](../tasks/sequence_classification)
|
||||
- [Token classification task guide](../tasks/token_classification)
|
||||
- [Question answering task guide](../tasks/question_answering)
|
||||
- [Causal language modeling task guide](../tasks/language_modeling)
|
||||
- [Masked language modeling task guide](../tasks/masked_language_modeling)
|
||||
- [Multiple choice task guide](../tasks/multiple_choice)
|
||||
|
||||
## QDQBertConfig
|
||||
|
||||
[[autodoc]] QDQBertConfig
|
||||
|
||||
## QDQBertModel
|
||||
|
||||
[[autodoc]] QDQBertModel
|
||||
- forward
|
||||
|
||||
## QDQBertLMHeadModel
|
||||
|
||||
[[autodoc]] QDQBertLMHeadModel
|
||||
- forward
|
||||
|
||||
## QDQBertForMaskedLM
|
||||
|
||||
[[autodoc]] QDQBertForMaskedLM
|
||||
- forward
|
||||
|
||||
## QDQBertForSequenceClassification
|
||||
|
||||
[[autodoc]] QDQBertForSequenceClassification
|
||||
- forward
|
||||
|
||||
## QDQBertForNextSentencePrediction
|
||||
|
||||
[[autodoc]] QDQBertForNextSentencePrediction
|
||||
- forward
|
||||
|
||||
## QDQBertForMultipleChoice
|
||||
|
||||
[[autodoc]] QDQBertForMultipleChoice
|
||||
- forward
|
||||
|
||||
## QDQBertForTokenClassification
|
||||
|
||||
[[autodoc]] QDQBertForTokenClassification
|
||||
- forward
|
||||
|
||||
## QDQBertForQuestionAnswering
|
||||
|
||||
[[autodoc]] QDQBertForQuestionAnswering
|
||||
- forward
|
||||
@ -136,7 +136,7 @@ inputs = processor.apply_chat_template(
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
video_fps=1,
|
||||
fps=1,
|
||||
|
||||
# kwargs to be passed to `Qwen2-5-OmniProcessor`
|
||||
padding=True,
|
||||
@ -245,7 +245,7 @@ inputs = processor.apply_chat_template(
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
video_fps=1,
|
||||
fps=1,
|
||||
|
||||
# kwargs to be passed to `Qwen2-5-OmniProcessor`
|
||||
padding=True,
|
||||
|
||||
@ -54,7 +54,7 @@ processor = AutoProcessor.from_pretrained("Qwen/Qwen2-Audio-7B", trust_remote_co
|
||||
prompt = "<|audio_bos|><|AUDIO|><|audio_eos|>Generate the caption in English:"
|
||||
url = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-Audio/glass-breaking-151256.mp3"
|
||||
audio, sr = librosa.load(BytesIO(urlopen(url).read()), sr=processor.feature_extractor.sampling_rate)
|
||||
inputs = processor(text=prompt, audios=audio, return_tensors="pt").to(model.device)
|
||||
inputs = processor(text=prompt, audio=audio, return_tensors="pt").to(model.device)
|
||||
|
||||
generate_ids = model.generate(**inputs, max_length=256)
|
||||
generate_ids = generate_ids[:, inputs.input_ids.size(1):]
|
||||
@ -63,7 +63,7 @@ response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_
|
||||
|
||||
# We can also omit the audio_bos and audio_eos tokens
|
||||
prompt = "<|AUDIO|>Generate the caption in English:"
|
||||
inputs = processor(text=prompt, audios=audio, return_tensors="pt").to(model.device)
|
||||
inputs = processor(text=prompt, audio=audio, return_tensors="pt").to(model.device)
|
||||
|
||||
generate_ids = model.generate(**inputs, max_length=256)
|
||||
generate_ids = generate_ids[:, inputs.input_ids.size(1):]
|
||||
@ -106,7 +106,7 @@ for message in conversation:
|
||||
sr=processor.feature_extractor.sampling_rate)[0]
|
||||
)
|
||||
|
||||
inputs = processor(text=text, audios=audios, return_tensors="pt", padding=True)
|
||||
inputs = processor(text=text, audio=audios, return_tensors="pt", padding=True)
|
||||
inputs.input_ids = inputs.input_ids.to(model.device)
|
||||
|
||||
generate_ids = model.generate(**inputs, max_length=256)
|
||||
@ -156,7 +156,7 @@ for message in conversation:
|
||||
sr=processor.feature_extractor.sampling_rate)[0]
|
||||
)
|
||||
|
||||
inputs = processor(text=text, audios=audios, return_tensors="pt", padding=True)
|
||||
inputs = processor(text=text, audio=audios, return_tensors="pt", padding=True)
|
||||
inputs.input_ids = inputs.input_ids.to(model.device)
|
||||
|
||||
generate_ids = model.generate(**inputs, max_length=256)
|
||||
@ -213,7 +213,7 @@ for conversation in conversations:
|
||||
sr=processor.feature_extractor.sampling_rate)[0]
|
||||
)
|
||||
|
||||
inputs = processor(text=text, audios=audios, return_tensors="pt", padding=True)
|
||||
inputs = processor(text=text, audio=audios, return_tensors="pt", padding=True)
|
||||
inputs['input_ids'] = inputs['input_ids'].to(model.device)
|
||||
inputs.input_ids = inputs.input_ids.to(model.device)
|
||||
|
||||
|
||||
@ -80,7 +80,7 @@ inputs = processor.apply_chat_template(
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
video_fps=1,
|
||||
fps=1,
|
||||
|
||||
# kwargs to be passed to `Qwen3OmniMoeProcessor`
|
||||
padding=True,
|
||||
@ -136,7 +136,7 @@ inputs = processor.apply_chat_template(
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
video_fps=1,
|
||||
fps=1,
|
||||
|
||||
# kwargs to be passed to `Qwen3OmniMoeProcessor`
|
||||
padding=True,
|
||||
@ -245,7 +245,7 @@ inputs = processor.apply_chat_template(
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
video_fps=1,
|
||||
fps=1,
|
||||
|
||||
# kwargs to be passed to `Qwen3OmniMoeProcessor`
|
||||
padding=True,
|
||||
|
||||
@ -1,102 +0,0 @@
|
||||
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
*This model was released on 2020-02-10 and added to Hugging Face Transformers on 2023-06-20.*
|
||||
|
||||
# REALM
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This model is in maintenance mode only, we don't accept any new PRs changing its code.
|
||||
If you run into any issues running this model, please reinstall the last version that supported this model: v4.40.2.
|
||||
You can do so by running the following command: `pip install -U transformers==4.40.2`.
|
||||
|
||||
</Tip>
|
||||
|
||||
## Overview
|
||||
|
||||
The REALM model was proposed in [REALM: Retrieval-Augmented Language Model Pre-Training](https://huggingface.co/papers/2002.08909) by Kelvin Guu, Kenton Lee, Zora Tung, Panupong Pasupat and Ming-Wei Chang. It's a
|
||||
retrieval-augmented language model that firstly retrieves documents from a textual knowledge corpus and then
|
||||
utilizes retrieved documents to process question answering tasks.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
*Language model pre-training has been shown to capture a surprising amount of world knowledge, crucial for NLP tasks
|
||||
such as question answering. However, this knowledge is stored implicitly in the parameters of a neural network,
|
||||
requiring ever-larger networks to cover more facts. To capture knowledge in a more modular and interpretable way, we
|
||||
augment language model pre-training with a latent knowledge retriever, which allows the model to retrieve and attend
|
||||
over documents from a large corpus such as Wikipedia, used during pre-training, fine-tuning and inference. For the
|
||||
first time, we show how to pre-train such a knowledge retriever in an unsupervised manner, using masked language
|
||||
modeling as the learning signal and backpropagating through a retrieval step that considers millions of documents. We
|
||||
demonstrate the effectiveness of Retrieval-Augmented Language Model pre-training (REALM) by fine-tuning on the
|
||||
challenging task of Open-domain Question Answering (Open-QA). We compare against state-of-the-art models for both
|
||||
explicit and implicit knowledge storage on three popular Open-QA benchmarks, and find that we outperform all previous
|
||||
methods by a significant margin (4-16% absolute accuracy), while also providing qualitative benefits such as
|
||||
interpretability and modularity.*
|
||||
|
||||
This model was contributed by [qqaatw](https://huggingface.co/qqaatw). The original code can be found
|
||||
[here](https://github.com/google-research/language/tree/master/language/realm).
|
||||
|
||||
## RealmConfig
|
||||
|
||||
[[autodoc]] RealmConfig
|
||||
|
||||
## RealmTokenizer
|
||||
|
||||
[[autodoc]] RealmTokenizer
|
||||
- build_inputs_with_special_tokens
|
||||
- get_special_tokens_mask
|
||||
- create_token_type_ids_from_sequences
|
||||
- save_vocabulary
|
||||
- batch_encode_candidates
|
||||
|
||||
## RealmTokenizerFast
|
||||
|
||||
[[autodoc]] RealmTokenizerFast
|
||||
- batch_encode_candidates
|
||||
|
||||
## RealmRetriever
|
||||
|
||||
[[autodoc]] RealmRetriever
|
||||
|
||||
## RealmEmbedder
|
||||
|
||||
[[autodoc]] RealmEmbedder
|
||||
- forward
|
||||
|
||||
## RealmScorer
|
||||
|
||||
[[autodoc]] RealmScorer
|
||||
- forward
|
||||
|
||||
## RealmKnowledgeAugEncoder
|
||||
|
||||
[[autodoc]] RealmKnowledgeAugEncoder
|
||||
- forward
|
||||
|
||||
## RealmReader
|
||||
|
||||
[[autodoc]] RealmReader
|
||||
- forward
|
||||
|
||||
## RealmForOpenQA
|
||||
|
||||
[[autodoc]] RealmForOpenQA
|
||||
- block_embedding_to
|
||||
- forward
|
||||
@ -1,57 +0,0 @@
|
||||
<!--Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
*This model was released on 2020-06-12 and added to Hugging Face Transformers on 2023-06-20.*
|
||||
|
||||
# RetriBERT
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This model is in maintenance mode only, so we won't accept any new PRs changing its code.
|
||||
|
||||
If you run into any issues running this model, please reinstall the last version that supported this model: v4.30.0.
|
||||
You can do so by running the following command: `pip install -U transformers==4.30.0`.
|
||||
|
||||
</Tip>
|
||||
|
||||
## Overview
|
||||
|
||||
The [RetriBERT](https://huggingface.co/yjernite/retribert-base-uncased/tree/main) model was proposed in the blog post [Explain Anything Like I'm Five: A Model for Open Domain Long Form
|
||||
Question Answering](https://yjernite.github.io/lfqa.html). RetriBERT is a small model that uses either a single or
|
||||
pair of BERT encoders with lower-dimension projection for dense semantic indexing of text.
|
||||
|
||||
This model was contributed by [yjernite](https://huggingface.co/yjernite). Code to train and use the model can be
|
||||
found [here](https://github.com/huggingface/transformers/tree/main/examples/research-projects/distillation).
|
||||
|
||||
## RetriBertConfig
|
||||
|
||||
[[autodoc]] RetriBertConfig
|
||||
|
||||
## RetriBertTokenizer
|
||||
|
||||
[[autodoc]] RetriBertTokenizer
|
||||
|
||||
## RetriBertTokenizerFast
|
||||
|
||||
[[autodoc]] RetriBertTokenizerFast
|
||||
|
||||
## RetriBertModel
|
||||
|
||||
[[autodoc]] RetriBertModel
|
||||
- forward
|
||||
@ -61,7 +61,7 @@ Here is how to use the processor to process text and audio:
|
||||
>>> audio_sample = next(iter(dataset))["audio"]
|
||||
|
||||
>>> # now, process it
|
||||
>>> audio_inputs = processor(audios=audio_sample["array"], return_tensors="pt")
|
||||
>>> audio_inputs = processor(audio=audio_sample["array"], return_tensors="pt")
|
||||
|
||||
>>> # now, process some English test as well
|
||||
>>> text_inputs = processor(text = "Hello, my dog is cute", src_lang="eng", return_tensors="pt")
|
||||
|
||||
@ -61,7 +61,7 @@ Here is how to use the processor to process text and audio:
|
||||
>>> audio_sample = next(iter(dataset))["audio"]
|
||||
|
||||
>>> # now, process it
|
||||
>>> audio_inputs = processor(audios=audio_sample["array"], return_tensors="pt")
|
||||
>>> audio_inputs = processor(audio=audio_sample["array"], return_tensors="pt")
|
||||
|
||||
>>> # now, process some English text as well
|
||||
>>> text_inputs = processor(text = "Hello, my dog is cute", src_lang="eng", return_tensors="pt")
|
||||
|
||||
@ -159,7 +159,7 @@ conversation3 = [
|
||||
|
||||
conversations = [conversation1, conversation2, conversation3]
|
||||
inputs = processor.apply_chat_template(
|
||||
conversation,
|
||||
conversations,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
|
||||
@ -1,133 +0,0 @@
|
||||
<!--Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
*This model was released on 2021-04-14 and added to Hugging Face Transformers on 2023-06-20.*
|
||||
|
||||
# Speech2Text2
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This model is in maintenance mode only, we don't accept any new PRs changing its code.
|
||||
If you run into any issues running this model, please reinstall the last version that supported this model: v4.40.2.
|
||||
You can do so by running the following command: `pip install -U transformers==4.40.2`.
|
||||
|
||||
</Tip>
|
||||
|
||||
## Overview
|
||||
|
||||
The Speech2Text2 model is used together with [Wav2Vec2](wav2vec2) for Speech Translation models proposed in
|
||||
[Large-Scale Self- and Semi-Supervised Learning for Speech Translation](https://huggingface.co/papers/2104.06678) by
|
||||
Changhan Wang, Anne Wu, Juan Pino, Alexei Baevski, Michael Auli, Alexis Conneau.
|
||||
|
||||
Speech2Text2 is a *decoder-only* transformer model that can be used with any speech *encoder-only*, such as
|
||||
[Wav2Vec2](wav2vec2) or [HuBERT](hubert) for Speech-to-Text tasks. Please refer to the
|
||||
[SpeechEncoderDecoder](speech-encoder-decoder) class on how to combine Speech2Text2 with any speech *encoder-only*
|
||||
model.
|
||||
|
||||
This model was contributed by [Patrick von Platen](https://huggingface.co/patrickvonplaten).
|
||||
|
||||
The original code can be found [here](https://github.com/pytorch/fairseq/blob/1f7ef9ed1e1061f8c7f88f8b94c7186834398690/fairseq/models/wav2vec/wav2vec2_asr.py#L266).
|
||||
|
||||
## Usage tips
|
||||
|
||||
- Speech2Text2 achieves state-of-the-art results on the CoVoST Speech Translation dataset. For more information, see
|
||||
the [official models](https://huggingface.co/models?other=speech2text2) .
|
||||
- Speech2Text2 is always used within the [SpeechEncoderDecoder](speech-encoder-decoder) framework.
|
||||
- Speech2Text2's tokenizer is based on [fastBPE](https://github.com/glample/fastBPE).
|
||||
|
||||
## Inference
|
||||
|
||||
Speech2Text2's [`SpeechEncoderDecoderModel`] model accepts raw waveform input values from speech and
|
||||
makes use of [`~generation.GenerationMixin.generate`] to translate the input speech
|
||||
autoregressively to the target language.
|
||||
|
||||
The [`Wav2Vec2FeatureExtractor`] class is responsible for preprocessing the input speech and
|
||||
[`Speech2Text2Tokenizer`] decodes the generated target tokens to the target string. The
|
||||
[`Speech2Text2Processor`] wraps [`Wav2Vec2FeatureExtractor`] and
|
||||
[`Speech2Text2Tokenizer`] into a single instance to both extract the input features and decode the
|
||||
predicted token ids.
|
||||
|
||||
- Step-by-step Speech Translation
|
||||
|
||||
```python
|
||||
>>> from transformers import Speech2Text2Processor, SpeechEncoderDecoderModel
|
||||
>>> from datasets import load_dataset
|
||||
|
||||
>>> model = SpeechEncoderDecoderModel.from_pretrained("facebook/s2t-wav2vec2-large-en-de")
|
||||
>>> processor = Speech2Text2Processor.from_pretrained("facebook/s2t-wav2vec2-large-en-de")
|
||||
|
||||
|
||||
>>> def map_to_array(example):
|
||||
... example["speech"] = example["audio"]["array"]
|
||||
... return example
|
||||
|
||||
|
||||
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||
>>> ds = ds.map(map_to_array)
|
||||
|
||||
>>> inputs = processor(ds["speech"][0], sampling_rate=16_000, return_tensors="pt")
|
||||
>>> generated_ids = model.generate(inputs=inputs["input_values"], attention_mask=inputs["attention_mask"])
|
||||
|
||||
>>> transcription = processor.batch_decode(generated_ids)
|
||||
```
|
||||
|
||||
- Speech Translation via Pipelines
|
||||
|
||||
The automatic speech recognition pipeline can also be used to translate speech in just a couple lines of code
|
||||
|
||||
```python
|
||||
>>> from datasets import load_dataset
|
||||
>>> from transformers import pipeline
|
||||
|
||||
>>> librispeech_en = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||
>>> asr = pipeline(
|
||||
... "automatic-speech-recognition",
|
||||
... model="facebook/s2t-wav2vec2-large-en-de",
|
||||
... feature_extractor="facebook/s2t-wav2vec2-large-en-de",
|
||||
... )
|
||||
|
||||
>>> translation_de = asr(librispeech_en[0]["file"])
|
||||
```
|
||||
|
||||
See [model hub](https://huggingface.co/models?filter=speech2text2) to look for Speech2Text2 checkpoints.
|
||||
|
||||
## Resources
|
||||
|
||||
- [Causal language modeling task guide](../tasks/language_modeling)
|
||||
|
||||
## Speech2Text2Config
|
||||
|
||||
[[autodoc]] Speech2Text2Config
|
||||
|
||||
## Speech2TextTokenizer
|
||||
|
||||
[[autodoc]] Speech2Text2Tokenizer
|
||||
- batch_decode
|
||||
- decode
|
||||
- save_vocabulary
|
||||
|
||||
## Speech2Text2Processor
|
||||
|
||||
[[autodoc]] Speech2Text2Processor
|
||||
- __call__
|
||||
- from_pretrained
|
||||
- save_pretrained
|
||||
- batch_decode
|
||||
- decode
|
||||
|
||||
## Speech2Text2ForCausalLM
|
||||
|
||||
[[autodoc]] Speech2Text2ForCausalLM
|
||||
- forward
|
||||
@ -1,155 +0,0 @@
|
||||
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
*This model was released on 2021-07-16 and added to Hugging Face Transformers on 2023-06-20.*
|
||||
|
||||
# TAPEX
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This model is in maintenance mode only, we don't accept any new PRs changing its code.
|
||||
|
||||
If you run into any issues running this model, please reinstall the last version that supported this model: v4.30.0.
|
||||
You can do so by running the following command: `pip install -U transformers==4.30.0`.
|
||||
|
||||
</Tip>
|
||||
|
||||
## Overview
|
||||
|
||||
The TAPEX model was proposed in [TAPEX: Table Pre-training via Learning a Neural SQL Executor](https://huggingface.co/papers/2107.07653) by Qian Liu,
|
||||
Bei Chen, Jiaqi Guo, Morteza Ziyadi, Zeqi Lin, Weizhu Chen, Jian-Guang Lou. TAPEX pre-trains a BART model to solve synthetic SQL queries, after
|
||||
which it can be fine-tuned to answer natural language questions related to tabular data, as well as performing table fact checking.
|
||||
|
||||
TAPEX has been fine-tuned on several datasets:
|
||||
|
||||
- [SQA](https://www.microsoft.com/en-us/download/details.aspx?id=54253) (Sequential Question Answering by Microsoft)
|
||||
- [WTQ](https://github.com/ppasupat/WikiTableQuestions) (Wiki Table Questions by Stanford University)
|
||||
- [WikiSQL](https://github.com/salesforce/WikiSQL) (by Salesforce)
|
||||
- [TabFact](https://tabfact.github.io/) (by USCB NLP Lab).
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
*Recent progress in language model pre-training has achieved a great success via leveraging large-scale unstructured textual data. However, it is
|
||||
still a challenge to apply pre-training on structured tabular data due to the absence of large-scale high-quality tabular data. In this paper, we
|
||||
propose TAPEX to show that table pre-training can be achieved by learning a neural SQL executor over a synthetic corpus, which is obtained by automatically
|
||||
synthesizing executable SQL queries and their execution outputs. TAPEX addresses the data scarcity challenge via guiding the language model to mimic a SQL
|
||||
executor on the diverse, large-scale and high-quality synthetic corpus. We evaluate TAPEX on four benchmark datasets. Experimental results demonstrate that
|
||||
TAPEX outperforms previous table pre-training approaches by a large margin and achieves new state-of-the-art results on all of them. This includes improvements
|
||||
on the weakly-supervised WikiSQL denotation accuracy to 89.5% (+2.3%), the WikiTableQuestions denotation accuracy to 57.5% (+4.8%), the SQA denotation accuracy
|
||||
to 74.5% (+3.5%), and the TabFact accuracy to 84.2% (+3.2%). To our knowledge, this is the first work to exploit table pre-training via synthetic executable programs
|
||||
and to achieve new state-of-the-art results on various downstream tasks.*
|
||||
|
||||
## Usage tips
|
||||
|
||||
- TAPEX is a generative (seq2seq) model. One can directly plug in the weights of TAPEX into a BART model.
|
||||
- TAPEX has checkpoints on the hub that are either pre-trained only, or fine-tuned on WTQ, SQA, WikiSQL and TabFact.
|
||||
- Sentences + tables are presented to the model as `sentence + " " + linearized table`. The linearized table has the following format:
|
||||
`col: col1 | col2 | col 3 row 1 : val1 | val2 | val3 row 2 : ...`.
|
||||
- TAPEX has its own tokenizer, that allows to prepare all data for the model easily. One can pass Pandas DataFrames and strings to the tokenizer,
|
||||
and it will automatically create the `input_ids` and `attention_mask` (as shown in the usage examples below).
|
||||
|
||||
### Usage: inference
|
||||
|
||||
Below, we illustrate how to use TAPEX for table question answering. As one can see, one can directly plug in the weights of TAPEX into a BART model.
|
||||
We use the [Auto API](auto), which will automatically instantiate the appropriate tokenizer ([`TapexTokenizer`]) and model ([`BartForConditionalGeneration`]) for us,
|
||||
based on the configuration file of the checkpoint on the hub.
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
||||
>>> import pandas as pd
|
||||
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("microsoft/tapex-large-finetuned-wtq")
|
||||
>>> model = AutoModelForSeq2SeqLM.from_pretrained("microsoft/tapex-large-finetuned-wtq")
|
||||
|
||||
>>> # prepare table + question
|
||||
>>> data = {"Actors": ["Brad Pitt", "Leonardo Di Caprio", "George Clooney"], "Number of movies": ["87", "53", "69"]}
|
||||
>>> table = pd.DataFrame.from_dict(data)
|
||||
>>> question = "how many movies does Leonardo Di Caprio have?"
|
||||
|
||||
>>> encoding = tokenizer(table, question, return_tensors="pt")
|
||||
|
||||
>>> # let the model generate an answer autoregressively
|
||||
>>> outputs = model.generate(**encoding)
|
||||
|
||||
>>> # decode back to text
|
||||
>>> predicted_answer = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
|
||||
>>> print(predicted_answer)
|
||||
53
|
||||
```
|
||||
|
||||
Note that [`TapexTokenizer`] also supports batched inference. Hence, one can provide a batch of different tables/questions, or a batch of a single table
|
||||
and multiple questions, or a batch of a single query and multiple tables. Let's illustrate this:
|
||||
|
||||
```python
|
||||
>>> # prepare table + question
|
||||
>>> data = {"Actors": ["Brad Pitt", "Leonardo Di Caprio", "George Clooney"], "Number of movies": ["87", "53", "69"]}
|
||||
>>> table = pd.DataFrame.from_dict(data)
|
||||
>>> questions = [
|
||||
... "how many movies does Leonardo Di Caprio have?",
|
||||
... "which actor has 69 movies?",
|
||||
... "what's the first name of the actor who has 87 movies?",
|
||||
... ]
|
||||
>>> encoding = tokenizer(table, questions, padding=True, return_tensors="pt")
|
||||
|
||||
>>> # let the model generate an answer autoregressively
|
||||
>>> outputs = model.generate(**encoding)
|
||||
|
||||
>>> # decode back to text
|
||||
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
[' 53', ' george clooney', ' brad pitt']
|
||||
```
|
||||
|
||||
In case one wants to do table verification (i.e. the task of determining whether a given sentence is supported or refuted by the contents
|
||||
of a table), one can instantiate a [`BartForSequenceClassification`] model. TAPEX has checkpoints on the hub fine-tuned on TabFact, an important
|
||||
benchmark for table fact checking (it achieves 84% accuracy). The code example below again leverages the [Auto API](auto).
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
||||
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("microsoft/tapex-large-finetuned-tabfact")
|
||||
>>> model = AutoModelForSequenceClassification.from_pretrained("microsoft/tapex-large-finetuned-tabfact")
|
||||
|
||||
>>> # prepare table + sentence
|
||||
>>> data = {"Actors": ["Brad Pitt", "Leonardo Di Caprio", "George Clooney"], "Number of movies": ["87", "53", "69"]}
|
||||
>>> table = pd.DataFrame.from_dict(data)
|
||||
>>> sentence = "George Clooney has 30 movies"
|
||||
|
||||
>>> encoding = tokenizer(table, sentence, return_tensors="pt")
|
||||
|
||||
>>> # forward pass
|
||||
>>> outputs = model(**encoding)
|
||||
|
||||
>>> # print prediction
|
||||
>>> predicted_class_idx = outputs.logits[0].argmax(dim=0).item()
|
||||
>>> print(model.config.id2label[predicted_class_idx])
|
||||
Refused
|
||||
```
|
||||
|
||||
<Tip>
|
||||
|
||||
TAPEX architecture is the same as BART, except for tokenization. Refer to [BART documentation](bart) for information on
|
||||
configuration classes and their parameters. TAPEX-specific tokenizer is documented below.
|
||||
|
||||
</Tip>
|
||||
|
||||
## TapexTokenizer
|
||||
|
||||
[[autodoc]] TapexTokenizer
|
||||
- __call__
|
||||
- save_vocabulary
|
||||
@ -1,66 +0,0 @@
|
||||
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
*This model was released on 2021-06-03 and added to Hugging Face Transformers on 2023-06-20.*
|
||||
|
||||
# Trajectory Transformer
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This model is in maintenance mode only, so we won't accept any new PRs changing its code.
|
||||
|
||||
If you run into any issues running this model, please reinstall the last version that supported this model: v4.30.0.
|
||||
You can do so by running the following command: `pip install -U transformers==4.30.0`.
|
||||
|
||||
</Tip>
|
||||
|
||||
## Overview
|
||||
|
||||
The Trajectory Transformer model was proposed in [Offline Reinforcement Learning as One Big Sequence Modeling Problem](https://huggingface.co/papers/2106.02039) by Michael Janner, Qiyang Li, Sergey Levine.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
*Reinforcement learning (RL) is typically concerned with estimating stationary policies or single-step models,
|
||||
leveraging the Markov property to factorize problems in time. However, we can also view RL as a generic sequence
|
||||
modeling problem, with the goal being to produce a sequence of actions that leads to a sequence of high rewards.
|
||||
Viewed in this way, it is tempting to consider whether high-capacity sequence prediction models that work well
|
||||
in other domains, such as natural-language processing, can also provide effective solutions to the RL problem.
|
||||
To this end, we explore how RL can be tackled with the tools of sequence modeling, using a Transformer architecture
|
||||
to model distributions over trajectories and repurposing beam search as a planning algorithm. Framing RL as sequence
|
||||
modeling problem simplifies a range of design decisions, allowing us to dispense with many of the components common
|
||||
in offline RL algorithms. We demonstrate the flexibility of this approach across long-horizon dynamics prediction,
|
||||
imitation learning, goal-conditioned RL, and offline RL. Further, we show that this approach can be combined with
|
||||
existing model-free algorithms to yield a state-of-the-art planner in sparse-reward, long-horizon tasks.*
|
||||
|
||||
This model was contributed by [CarlCochet](https://huggingface.co/CarlCochet). The original code can be found [here](https://github.com/jannerm/trajectory-transformer).
|
||||
|
||||
## Usage tips
|
||||
|
||||
This Transformer is used for deep reinforcement learning. To use it, you need to create sequences from
|
||||
actions, states and rewards from all previous timesteps. This model will treat all these elements together
|
||||
as one big sequence (a trajectory).
|
||||
|
||||
## TrajectoryTransformerConfig
|
||||
|
||||
[[autodoc]] TrajectoryTransformerConfig
|
||||
|
||||
## TrajectoryTransformerModel
|
||||
|
||||
[[autodoc]] TrajectoryTransformerModel
|
||||
- forward
|
||||
@ -1,136 +0,0 @@
|
||||
<!--Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
*This model was released on 2019-01-09 and added to Hugging Face Transformers on 2023-06-20.*
|
||||
|
||||
# Transformer XL
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This model is in maintenance mode only, so we won't accept any new PRs changing its code. This model was deprecated due to security issues linked to `pickle.load`.
|
||||
|
||||
We recommend switching to more recent models for improved security.
|
||||
|
||||
In case you would still like to use `TransfoXL` in your experiments, we recommend using the [Hub checkpoint](https://huggingface.co/transfo-xl/transfo-xl-wt103) with a specific revision to ensure you are downloading safe files from the Hub.
|
||||
|
||||
You will need to set the environment variable `TRUST_REMOTE_CODE` to `True` in order to allow the
|
||||
usage of `pickle.load()`:
|
||||
|
||||
```python
|
||||
import os
|
||||
from transformers import TransfoXLTokenizer, TransfoXLLMHeadModel
|
||||
|
||||
os.environ["TRUST_REMOTE_CODE"] = "True"
|
||||
|
||||
checkpoint = 'transfo-xl/transfo-xl-wt103'
|
||||
revision = '40a186da79458c9f9de846edfaea79c412137f97'
|
||||
|
||||
tokenizer = TransfoXLTokenizer.from_pretrained(checkpoint, revision=revision)
|
||||
model = TransfoXLLMHeadModel.from_pretrained(checkpoint, revision=revision)
|
||||
```
|
||||
|
||||
If you run into any issues running this model, please reinstall the last version that supported this model: v4.35.0.
|
||||
You can do so by running the following command: `pip install -U transformers==4.35.0`.
|
||||
|
||||
</Tip>
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<a href="https://huggingface.co/models?filter=transfo-xl">
|
||||
<img alt="Models" src="https://img.shields.io/badge/All_model_pages-transfo--xl-blueviolet">
|
||||
</a>
|
||||
<a href="https://huggingface.co/spaces/docs-demos/transfo-xl-wt103">
|
||||
<img alt="Spaces" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue">
|
||||
</a>
|
||||
</div>
|
||||
|
||||
## Overview
|
||||
|
||||
The Transformer-XL model was proposed in [Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context](https://huggingface.co/papers/1901.02860) by Zihang Dai, Zhilin Yang, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan
|
||||
Salakhutdinov. It's a causal (uni-directional) transformer with relative positioning (sinusoïdal) embeddings which can
|
||||
reuse previously computed hidden-states to attend to longer context (memory). This model also uses adaptive softmax
|
||||
inputs and outputs (tied).
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
*Transformers have a potential of learning longer-term dependency, but are limited by a fixed-length context in the
|
||||
setting of language modeling. We propose a novel neural architecture Transformer-XL that enables learning dependency
|
||||
beyond a fixed length without disrupting temporal coherence. It consists of a segment-level recurrence mechanism and a
|
||||
novel positional encoding scheme. Our method not only enables capturing longer-term dependency, but also resolves the
|
||||
context fragmentation problem. As a result, Transformer-XL learns dependency that is 80% longer than RNNs and 450%
|
||||
longer than vanilla Transformers, achieves better performance on both short and long sequences, and is up to 1,800+
|
||||
times faster than vanilla Transformers during evaluation. Notably, we improve the state-of-the-art results of
|
||||
bpc/perplexity to 0.99 on enwiki8, 1.08 on text8, 18.3 on WikiText-103, 21.8 on One Billion Word, and 54.5 on Penn
|
||||
Treebank (without finetuning). When trained only on WikiText-103, Transformer-XL manages to generate reasonably
|
||||
coherent, novel text articles with thousands of tokens.*
|
||||
|
||||
This model was contributed by [thomwolf](https://huggingface.co/thomwolf). The original code can be found [here](https://github.com/kimiyoung/transformer-xl).
|
||||
|
||||
## Usage tips
|
||||
|
||||
- Transformer-XL uses relative sinusoidal positional embeddings. Padding can be done on the left or on the right. The
|
||||
original implementation trains on SQuAD with padding on the left, therefore the padding defaults are set to left.
|
||||
- Transformer-XL is one of the few models that has no sequence length limit.
|
||||
- Same as a regular GPT model, but introduces a recurrence mechanism for two consecutive segments (similar to a regular RNNs with two consecutive inputs). In this context, a segment is a number of consecutive tokens (for instance 512) that may span across multiple documents, and segments are fed in order to the model.
|
||||
- Basically, the hidden states of the previous segment are concatenated to the current input to compute the attention scores. This allows the model to pay attention to information that was in the previous segment as well as the current one. By stacking multiple attention layers, the receptive field can be increased to multiple previous segments.
|
||||
- This changes the positional embeddings to positional relative embeddings (as the regular positional embeddings would give the same results in the current input and the current hidden state at a given position) and needs to make some adjustments in the way attention scores are computed.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
TransformerXL does **not** work with *torch.nn.DataParallel* due to a bug in PyTorch, see [issue #36035](https://github.com/pytorch/pytorch/issues/36035)
|
||||
|
||||
</Tip>
|
||||
|
||||
## Resources
|
||||
|
||||
- [Text classification task guide](../tasks/sequence_classification)
|
||||
- [Causal language modeling task guide](../tasks/language_modeling)
|
||||
|
||||
## TransfoXLConfig
|
||||
|
||||
[[autodoc]] TransfoXLConfig
|
||||
|
||||
## TransfoXLTokenizer
|
||||
|
||||
[[autodoc]] TransfoXLTokenizer
|
||||
- save_vocabulary
|
||||
|
||||
## TransfoXL specific outputs
|
||||
|
||||
[[autodoc]] models.deprecated.transfo_xl.modeling_transfo_xl.TransfoXLModelOutput
|
||||
|
||||
[[autodoc]] models.deprecated.transfo_xl.modeling_transfo_xl.TransfoXLLMHeadModelOutput
|
||||
|
||||
## TransfoXLModel
|
||||
|
||||
[[autodoc]] TransfoXLModel
|
||||
- forward
|
||||
|
||||
## TransfoXLLMHeadModel
|
||||
|
||||
[[autodoc]] TransfoXLLMHeadModel
|
||||
- forward
|
||||
|
||||
## TransfoXLForSequenceClassification
|
||||
|
||||
[[autodoc]] TransfoXLForSequenceClassification
|
||||
- forward
|
||||
|
||||
## Internal Layers
|
||||
|
||||
[[autodoc]] AdaptiveEmbedding
|
||||
@ -1,90 +0,0 @@
|
||||
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
*This model was released on 2022-09-28 and added to Hugging Face Transformers on 2023-06-20.*
|
||||
|
||||
# TVLT
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This model is in maintenance mode only, we don't accept any new PRs changing its code.
|
||||
If you run into any issues running this model, please reinstall the last version that supported this model: v4.40.2.
|
||||
You can do so by running the following command: `pip install -U transformers==4.40.2`.
|
||||
|
||||
</Tip>
|
||||
|
||||
## Overview
|
||||
|
||||
The TVLT model was proposed in [TVLT: Textless Vision-Language Transformer](https://huggingface.co/papers/2209.14156)
|
||||
by Zineng Tang, Jaemin Cho, Yixin Nie, Mohit Bansal (the first three authors contributed equally). The Textless Vision-Language Transformer (TVLT) is a model that uses raw visual and audio inputs for vision-and-language representation learning, without using text-specific modules such as tokenization or automatic speech recognition (ASR). It can perform various audiovisual and vision-language tasks like retrieval, question answering, etc.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
*In this work, we present the Textless Vision-Language Transformer (TVLT), where homogeneous transformer blocks take raw visual and audio inputs for vision-and-language representation learning with minimal modality-specific design, and do not use text-specific modules such as tokenization or automatic speech recognition (ASR). TVLT is trained by reconstructing masked patches of continuous video frames and audio spectrograms (masked autoencoding) and contrastive modeling to align video and audio. TVLT attains performance comparable to its text-based counterpart on various multimodal tasks, such as visual question answering, image retrieval, video retrieval, and multimodal sentiment analysis, with 28x faster inference speed and only 1/3 of the parameters. Our findings suggest the possibility of learning compact and efficient visual-linguistic representations from low-level visual and audio signals without assuming the prior existence of text.*
|
||||
|
||||
<p align="center">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/tvlt_architecture.png"
|
||||
alt="drawing" width="600"/>
|
||||
</p>
|
||||
|
||||
<small> TVLT architecture. Taken from the <a href="[https://huggingface.co/papers/2102.03334](https://huggingface.co/papers/2209.14156)">original paper</a>. </small>
|
||||
|
||||
The original code can be found [here](https://github.com/zinengtang/TVLT). This model was contributed by [Zineng Tang](https://huggingface.co/ZinengTang).
|
||||
|
||||
## Usage tips
|
||||
|
||||
- TVLT is a model that takes both `pixel_values` and `audio_values` as input. One can use [`TvltProcessor`] to prepare data for the model.
|
||||
This processor wraps an image processor (for the image/video modality) and an audio feature extractor (for the audio modality) into one.
|
||||
- TVLT is trained with images/videos and audios of various sizes: the authors resize and crop the input images/videos to 224 and limit the length of audio spectrogram to 2048. To make batching of videos and audios possible, the authors use a `pixel_mask` that indicates which pixels are real/padding and `audio_mask` that indicates which audio values are real/padding.
|
||||
- The design of TVLT is very similar to that of a standard Vision Transformer (ViT) and masked autoencoder (MAE) as in [ViTMAE](vitmae). The difference is that the model includes embedding layers for the audio modality.
|
||||
- The PyTorch version of this model is only available in torch 1.10 and higher.
|
||||
|
||||
## TvltConfig
|
||||
|
||||
[[autodoc]] TvltConfig
|
||||
|
||||
## TvltProcessor
|
||||
|
||||
[[autodoc]] TvltProcessor
|
||||
- __call__
|
||||
|
||||
## TvltFeatureExtractor
|
||||
|
||||
[[autodoc]] TvltFeatureExtractor
|
||||
- __call__
|
||||
|
||||
## TvltImageProcessor
|
||||
|
||||
[[autodoc]] TvltImageProcessor
|
||||
- preprocess
|
||||
|
||||
## TvltModel
|
||||
|
||||
[[autodoc]] TvltModel
|
||||
- forward
|
||||
|
||||
## TvltForPreTraining
|
||||
|
||||
[[autodoc]] TvltForPreTraining
|
||||
- forward
|
||||
|
||||
## TvltForAudioVisualClassification
|
||||
|
||||
[[autodoc]] TvltForAudioVisualClassification
|
||||
- forward
|
||||
@ -1,76 +0,0 @@
|
||||
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
*This model was released on 2022-02-20 and added to Hugging Face Transformers on 2023-06-20.*
|
||||
|
||||
# VAN
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This model is in maintenance mode only, we don't accept any new PRs changing its code.
|
||||
|
||||
If you run into any issues running this model, please reinstall the last version that supported this model: v4.30.0.
|
||||
You can do so by running the following command: `pip install -U transformers==4.30.0`.
|
||||
|
||||
</Tip>
|
||||
|
||||
## Overview
|
||||
|
||||
The VAN model was proposed in [Visual Attention Network](https://huggingface.co/papers/2202.09741) by Meng-Hao Guo, Cheng-Ze Lu, Zheng-Ning Liu, Ming-Ming Cheng, Shi-Min Hu.
|
||||
|
||||
This paper introduces a new attention layer based on convolution operations able to capture both local and distant relationships. This is done by combining normal and large kernel convolution layers. The latter uses a dilated convolution to capture distant correlations.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
*While originally designed for natural language processing tasks, the self-attention mechanism has recently taken various computer vision areas by storm. However, the 2D nature of images brings three challenges for applying self-attention in computer vision. (1) Treating images as 1D sequences neglects their 2D structures. (2) The quadratic complexity is too expensive for high-resolution images. (3) It only captures spatial adaptability but ignores channel adaptability. In this paper, we propose a novel large kernel attention (LKA) module to enable self-adaptive and long-range correlations in self-attention while avoiding the above issues. We further introduce a novel neural network based on LKA, namely Visual Attention Network (VAN). While extremely simple, VAN outperforms the state-of-the-art vision transformers and convolutional neural networks with a large margin in extensive experiments, including image classification, object detection, semantic segmentation, instance segmentation, etc. Code is available at [this https URL](https://github.com/Visual-Attention-Network/VAN-Classification).*
|
||||
|
||||
Tips:
|
||||
|
||||
- VAN does not have an embedding layer, thus the `hidden_states` will have a length equal to the number of stages.
|
||||
|
||||
The figure below illustrates the architecture of a Visual Attention Layer. Taken from the [original paper](https://huggingface.co/papers/2202.09741).
|
||||
|
||||
<img width="600" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/van_architecture.png"/>
|
||||
|
||||
This model was contributed by [Francesco](https://huggingface.co/Francesco). The original code can be found [here](https://github.com/Visual-Attention-Network/VAN-Classification).
|
||||
|
||||
## Resources
|
||||
|
||||
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with VAN.
|
||||
|
||||
<PipelineTag pipeline="image-classification"/>
|
||||
|
||||
- [`VanForImageClassification`] is supported by this [example script](https://github.com/huggingface/transformers/tree/main/examples/pytorch/image-classification) and [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/image_classification.ipynb).
|
||||
- See also: [Image classification task guide](../tasks/image_classification)
|
||||
|
||||
If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we'll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource.
|
||||
|
||||
## VanConfig
|
||||
|
||||
[[autodoc]] VanConfig
|
||||
|
||||
## VanModel
|
||||
|
||||
[[autodoc]] VanModel
|
||||
- forward
|
||||
|
||||
## VanForImageClassification
|
||||
|
||||
[[autodoc]] VanForImageClassification
|
||||
- forward
|
||||
@ -1,112 +0,0 @@
|
||||
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
*This model was released on 2020-10-22 and added to Hugging Face Transformers on 2023-06-20.*
|
||||
|
||||
# Hybrid Vision Transformer (ViT Hybrid)
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This model is in maintenance mode only, we don't accept any new PRs changing its code.
|
||||
If you run into any issues running this model, please reinstall the last version that supported this model: v4.40.2.
|
||||
You can do so by running the following command: `pip install -U transformers==4.40.2`.
|
||||
|
||||
</Tip>
|
||||
|
||||
## Overview
|
||||
|
||||
The hybrid Vision Transformer (ViT) model was proposed in [An Image is Worth 16x16 Words: Transformers for Image Recognition
|
||||
at Scale](https://huggingface.co/papers/2010.11929) by Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk
|
||||
Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob
|
||||
Uszkoreit, Neil Houlsby. It's the first paper that successfully trains a Transformer encoder on ImageNet, attaining
|
||||
very good results compared to familiar convolutional architectures. ViT hybrid is a slight variant of the [plain Vision Transformer](vit),
|
||||
by leveraging a convolutional backbone (specifically, [BiT](bit)) whose features are used as initial "tokens" for the Transformer.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
*While the Transformer architecture has become the de-facto standard for natural language processing tasks, its
|
||||
applications to computer vision remain limited. In vision, attention is either applied in conjunction with
|
||||
convolutional networks, or used to replace certain components of convolutional networks while keeping their overall
|
||||
structure in place. We show that this reliance on CNNs is not necessary and a pure transformer applied directly to
|
||||
sequences of image patches can perform very well on image classification tasks. When pre-trained on large amounts of
|
||||
data and transferred to multiple mid-sized or small image recognition benchmarks (ImageNet, CIFAR-100, VTAB, etc.),
|
||||
Vision Transformer (ViT) attains excellent results compared to state-of-the-art convolutional networks while requiring
|
||||
substantially fewer computational resources to train.*
|
||||
|
||||
This model was contributed by [nielsr](https://huggingface.co/nielsr). The original code (written in JAX) can be
|
||||
found [here](https://github.com/google-research/vision_transformer).
|
||||
|
||||
## Using Scaled Dot Product Attention (SDPA)
|
||||
|
||||
PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function
|
||||
encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the
|
||||
[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
|
||||
or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention)
|
||||
page for more information.
|
||||
|
||||
SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set
|
||||
`attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used.
|
||||
|
||||
```py
|
||||
from transformers import ViTHybridForImageClassification
|
||||
model = ViTHybridForImageClassification.from_pretrained("google/vit-hybrid-base-bit-384", attn_implementation="sdpa", dtype=torch.float16)
|
||||
...
|
||||
```
|
||||
|
||||
For the best speedups, we recommend loading the model in half-precision (e.g. `torch.float16` or `torch.bfloat16`).
|
||||
|
||||
On a local benchmark (A100-40GB, PyTorch 2.3.0, OS Ubuntu 22.04) with `float32` and `google/vit-hybrid-base-bit-384` model, we saw the following speedups during inference.
|
||||
|
||||
| Batch size | Average inference time (ms), eager mode | Average inference time (ms), sdpa model | Speed up, Sdpa / Eager (x) |
|
||||
|--------------|-------------------------------------------|-------------------------------------------|------------------------------|
|
||||
| 1 | 29 | 18 | 1.61 |
|
||||
| 2 | 26 | 18 | 1.44 |
|
||||
| 4 | 25 | 18 | 1.39 |
|
||||
| 8 | 34 | 24 | 1.42 |
|
||||
|
||||
## Resources
|
||||
|
||||
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with ViT Hybrid.
|
||||
|
||||
<PipelineTag pipeline="image-classification"/>
|
||||
|
||||
- [`ViTHybridForImageClassification`] is supported by this [example script](https://github.com/huggingface/transformers/tree/main/examples/pytorch/image-classification) and [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/image_classification.ipynb).
|
||||
- See also: [Image classification task guide](../tasks/image_classification)
|
||||
|
||||
If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we'll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource.
|
||||
|
||||
## ViTHybridConfig
|
||||
|
||||
[[autodoc]] ViTHybridConfig
|
||||
|
||||
## ViTHybridImageProcessor
|
||||
|
||||
[[autodoc]] ViTHybridImageProcessor
|
||||
- preprocess
|
||||
|
||||
## ViTHybridModel
|
||||
|
||||
[[autodoc]] ViTHybridModel
|
||||
- forward
|
||||
|
||||
## ViTHybridForImageClassification
|
||||
|
||||
[[autodoc]] ViTHybridForImageClassification
|
||||
- forward
|
||||
@ -1,99 +0,0 @@
|
||||
<!--Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
*This model was released on 2020-01-13 and added to Hugging Face Transformers on 2023-06-20.*
|
||||
|
||||
# XLM-ProphetNet
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This model is in maintenance mode only, we don't accept any new PRs changing its code.
|
||||
If you run into any issues running this model, please reinstall the last version that supported this model: v4.40.2.
|
||||
You can do so by running the following command: `pip install -U transformers==4.40.2`.
|
||||
|
||||
</Tip>
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<a href="https://huggingface.co/models?filter=xprophetnet">
|
||||
<img alt="Models" src="https://img.shields.io/badge/All_model_pages-xprophetnet-blueviolet">
|
||||
</a>
|
||||
<a href="https://huggingface.co/spaces/docs-demos/xprophetnet-large-wiki100-cased-xglue-ntg">
|
||||
<img alt="Spaces" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue">
|
||||
</a>
|
||||
</div>
|
||||
|
||||
**DISCLAIMER:** If you see something strange, file a [Github Issue](https://github.com/huggingface/transformers/issues/new?assignees=&labels=&template=bug-report.md&title) and assign
|
||||
@patrickvonplaten
|
||||
|
||||
## Overview
|
||||
|
||||
The XLM-ProphetNet model was proposed in [ProphetNet: Predicting Future N-gram for Sequence-to-Sequence Pre-training,](https://huggingface.co/papers/2001.04063) by Yu Yan, Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei
|
||||
Zhang, Ming Zhou on 13 Jan, 2020.
|
||||
|
||||
XLM-ProphetNet is an encoder-decoder model and can predict n-future tokens for "ngram" language modeling instead of
|
||||
just the next token. Its architecture is identical to ProhpetNet, but the model was trained on the multi-lingual
|
||||
"wiki100" Wikipedia dump. XLM-ProphetNet's model architecture and pretraining objective is same as ProphetNet, but XLM-ProphetNet was pre-trained on the cross-lingual dataset XGLUE.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
*In this paper, we present a new sequence-to-sequence pretraining model called ProphetNet, which introduces a novel
|
||||
self-supervised objective named future n-gram prediction and the proposed n-stream self-attention mechanism. Instead of
|
||||
the optimization of one-step ahead prediction in traditional sequence-to-sequence model, the ProphetNet is optimized by
|
||||
n-step ahead prediction which predicts the next n tokens simultaneously based on previous context tokens at each time
|
||||
step. The future n-gram prediction explicitly encourages the model to plan for the future tokens and prevent
|
||||
overfitting on strong local correlations. We pre-train ProphetNet using a base scale dataset (16GB) and a large scale
|
||||
dataset (160GB) respectively. Then we conduct experiments on CNN/DailyMail, Gigaword, and SQuAD 1.1 benchmarks for
|
||||
abstractive summarization and question generation tasks. Experimental results show that ProphetNet achieves new
|
||||
state-of-the-art results on all these datasets compared to the models using the same scale pretraining corpus.*
|
||||
|
||||
The Authors' code can be found [here](https://github.com/microsoft/ProphetNet).
|
||||
|
||||
## Resources
|
||||
|
||||
- [Causal language modeling task guide](../tasks/language_modeling)
|
||||
- [Translation task guide](../tasks/translation)
|
||||
- [Summarization task guide](../tasks/summarization)
|
||||
|
||||
## XLMProphetNetConfig
|
||||
|
||||
[[autodoc]] XLMProphetNetConfig
|
||||
|
||||
## XLMProphetNetTokenizer
|
||||
|
||||
[[autodoc]] XLMProphetNetTokenizer
|
||||
|
||||
## XLMProphetNetModel
|
||||
|
||||
[[autodoc]] XLMProphetNetModel
|
||||
|
||||
## XLMProphetNetEncoder
|
||||
|
||||
[[autodoc]] XLMProphetNetEncoder
|
||||
|
||||
## XLMProphetNetDecoder
|
||||
|
||||
[[autodoc]] XLMProphetNetDecoder
|
||||
|
||||
## XLMProphetNetForConditionalGeneration
|
||||
|
||||
[[autodoc]] XLMProphetNetForConditionalGeneration
|
||||
|
||||
## XLMProphetNetForCausalLM
|
||||
|
||||
[[autodoc]] XLMProphetNetForCausalLM
|
||||
@ -1,6 +1,6 @@
|
||||
# Contributing a new model to Transformers
|
||||
|
||||
Modular Transformers lowers the bar for contributing models and significantly reduces the code required to add a model by allowing imports and inheritance.
|
||||
Modular Transformers lowers the bar for contributing models and significantly reduces the code required to add a model by allowing imports and inheritance. We recommend to go through [general contribution guidelines for new models](./contributing#do-you-want-to-implement-a-new-model) before diving into the details here.
|
||||
|
||||
One of Transformers' core design feature is the [single model, single file](https://huggingface.co/blog/transformers-design-philosophy) policy. Model components - such as attention layers - are repeated across many files and any independent implementations tend to diverge as fixes and changes are applied to specific parts of the code.
|
||||
|
||||
|
||||
@ -149,7 +149,7 @@ The example below packs `up_proj` and `gate_proj` into a single `gate_up_proj` m
|
||||
```python
|
||||
class Llama4TextExperts(nn.Module):
|
||||
...
|
||||
self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, 2 * self.expert_dim))
|
||||
self.gate_up_proj = nn.Parameter(torch.zeros(self.num_experts, self.hidden_size, 2 * self.expert_dim))
|
||||
```
|
||||
|
||||
Batch matrix multiplication can be used in the `forward` pass to compute the output of the `gate_up_proj` module.
|
||||
|
||||
@ -329,7 +329,7 @@ from torchao.dtypes import Int4XPULayout
|
||||
from torchao.quantization.quant_primitives import ZeroPointDomain
|
||||
|
||||
|
||||
quant_config = Int4WeightOnlyConfig(group_size=128, layout=Int4XPULayout(), zero_point_domain=ZeroPointDomain.INT)
|
||||
quant_config = Int4WeightOnlyConfig(group_size=128, layout=Int4XPULayout(), zero_point_domain=ZeroPointDomain.INT, int4_packing_format="plain_int32")
|
||||
quantization_config = TorchAoConfig(quant_type=quant_config)
|
||||
|
||||
# Load and quantize the model
|
||||
@ -342,7 +342,7 @@ quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
|
||||
input_text = "What are we having for dinner?"
|
||||
input_ids = tokenizer(input_text, return_tensors="pt").to(model.device)
|
||||
input_ids = tokenizer(input_text, return_tensors="pt").to(quantized_model.device)
|
||||
|
||||
# auto-compile the quantized model with `cache_implementation="static"` to get speed up
|
||||
output = quantized_model.generate(**input_ids, max_new_tokens=10, cache_implementation="static")
|
||||
@ -395,7 +395,7 @@ from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
|
||||
from torchao.quantization import Int4WeightOnlyConfig
|
||||
from torchao.dtypes import Int4CPULayout
|
||||
|
||||
quant_config = Int4WeightOnlyConfig(group_size=128, layout=Int4CPULayout())
|
||||
quant_config = Int4WeightOnlyConfig(group_size=128, layout=Int4CPULayout(), int4_packing_format="opaque")
|
||||
quantization_config = TorchAoConfig(quant_type=quant_config)
|
||||
|
||||
# Load and quantize the model
|
||||
@ -422,7 +422,7 @@ print(tokenizer.decode(output[0], skip_special_tokens=True))
|
||||
|
||||
#### 1. Skip quantization for certain layers
|
||||
|
||||
With `ModuleFqnToConfig` we can specify a default configuration for all layers while skipping quantization for certain layers.
|
||||
With `FqnToConfig` we can specify a default configuration for all layers while skipping quantization for certain layers.
|
||||
|
||||
```py
|
||||
import torch
|
||||
@ -430,11 +430,11 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig
|
||||
|
||||
model_id = "meta-llama/Llama-3.1-8B-Instruct"
|
||||
|
||||
from torchao.quantization import Int4WeightOnlyConfig, ModuleFqnToConfig
|
||||
from torchao.quantization import Int4WeightOnlyConfig, FqnToConfig
|
||||
config = Int4WeightOnlyConfig(group_size=128)
|
||||
|
||||
# set default to int4 (for linears), and skip quantizing `model.layers.0.self_attn.q_proj`
|
||||
quant_config = ModuleFqnToConfig({"_default": config, "model.layers.0.self_attn.q_proj": None})
|
||||
quant_config = FqnToConfig({"_default": config, "model.layers.0.self_attn.q_proj": None})
|
||||
quantization_config = TorchAoConfig(quant_type=quant_config)
|
||||
quantized_model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", dtype=torch.bfloat16, quantization_config=quantization_config)
|
||||
# lm_head is not quantized and model.layers.0.self_attn.q_proj is not quantized
|
||||
@ -459,7 +459,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig
|
||||
|
||||
model_id = "facebook/opt-125m"
|
||||
|
||||
from torchao.quantization import Int4WeightOnlyConfig, ModuleFqnToConfig, Int8DynamicActivationInt4WeightConfig, IntxWeightOnlyConfig, PerAxis, MappingType
|
||||
from torchao.quantization import Int4WeightOnlyConfig, FqnToConfig, Int8DynamicActivationInt4WeightConfig, IntxWeightOnlyConfig, PerAxis, MappingType
|
||||
|
||||
weight_dtype = torch.int8
|
||||
granularity = PerAxis(0)
|
||||
@ -470,7 +470,7 @@ embedding_config = IntxWeightOnlyConfig(
|
||||
mapping_type=mapping_type,
|
||||
)
|
||||
linear_config = Int8DynamicActivationInt4WeightConfig(group_size=128)
|
||||
quant_config = ModuleFqnToConfig({"_default": linear_config, "model.decoder.embed_tokens": embedding_config, "model.decoder.embed_positions": None})
|
||||
quant_config = FqnToConfig({"_default": linear_config, "model.decoder.embed_tokens": embedding_config, "model.decoder.embed_positions": None})
|
||||
# set `include_embedding` to True in order to include embedding in quantization
|
||||
# when `include_embedding` is True, we'll remove input embedding from `modules_not_to_convert` as well
|
||||
quantization_config = TorchAoConfig(quant_type=quant_config, include_embedding=True)
|
||||
@ -521,7 +521,7 @@ from torchao.quantization import (
|
||||
IntxWeightOnlyConfig,
|
||||
PerRow,
|
||||
PerAxis,
|
||||
ModuleFqnToConfig,
|
||||
FqnToConfig,
|
||||
Float8Tensor,
|
||||
Int4TilePackedTo4dTensor,
|
||||
IntxUnpackedToInt8Tensor,
|
||||
@ -550,7 +550,7 @@ qconfig_dict = {
|
||||
|
||||
"_default": intxwo,
|
||||
}
|
||||
quant_config = ModuleFqnToConfig(qconfig_dict)
|
||||
quant_config = FqnToConfig(qconfig_dict)
|
||||
quantization_config = TorchAoConfig(quant_type=quant_config)
|
||||
quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id,
|
||||
|
||||
@ -14,9 +14,9 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# Inference server backends
|
||||
# Transformers as modeling backend
|
||||
|
||||
Transformers' models are compatible with different inference servers like vLLM and SGLang. Instead of implementing a model for each inference server, you only need one model, which can be plugged into any inference server. It simplifies maintenance and makes it easy for users to use different inference servers for different use cases.
|
||||
Transformers' models are compatible with different inference servers like vLLM and SGLang. Instead of implementing a new model architecture from scratch for each inference server, you only need a model definition in `transformers`, which can be plugged into any inference server. It simplifies maintenance and makes it easy for users to use different inference servers for different use cases.
|
||||
|
||||
With Transformers as a backend, you can also serve any model - including custom and Hub-hosted models - without waiting for native support.
|
||||
|
||||
@ -157,57 +157,13 @@ class MyConfig(PreTrainedConfig):
|
||||
|
||||
### Multimodal models
|
||||
|
||||
For multimodal models, you need to include a few more changes on top of the general recommendations. These rules ensure that your model integrates properly with multimodal data.
|
||||
For multimodal models, you need to include a few more changes on top of the general recommendations outlined in ["contribuiting a model"](./contributing#vision-language-model-contribution-checklist). These rules ensure that your model integrates properly and enables processing multimodal data.
|
||||
|
||||
1. A multimodal model requires a base `MyMultiModalModel` class to handle multimodal fusion without a language modeling head and a separate generative class that adds a head.
|
||||
1. A multimodal model's processing class must have the `self.image_token` and `self.image_token_ids` attributes. These are placeholder tokens used to indicate image positions in the input. This placeholder token is the same token used in the input prompt to denote images and used in model code to scatter image features.
|
||||
|
||||
The base model needs to implement the `get_image_features()` method to accept image pixel values and return encoded outputs. These are later merged with the language embeddings and don't require any postprocessing. The shape of the returned features must match the number of input images. If a vision encoder returns variable-length outputs (patch-based), return a list of 2D tensors of size `(image_seq_len, image_dim)` for each image.
|
||||
2. The processing class needs `self._get_num_multimodal_tokens` method to compute the number of placeholder tokens needed for multimodal inputs with given sizes and to return a [`MultiModalData`] object. The placeholders between `<image>` tokens such as row or column tokens don't count as image placeholders. Only tokens that are actually replaced by image features later in modeling should be counted!
|
||||
|
||||
Expand the code below for an example.
|
||||
|
||||
<details>
|
||||
<summary>modeling_my_multimodal_model.py</summary>
|
||||
|
||||
```python
|
||||
from transformers.generation import GenerationMixin
|
||||
|
||||
class MyMultimodalModel(MyMultimodalPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.language_model = AutoModel.from_config(config.text_config)
|
||||
self.vision_tower = AutoModel.from_config(config.vision_config)
|
||||
self.multimodal_projection = nn.Linear(vision_dim, text_dim)
|
||||
|
||||
def get_image_features(self, pixel_values):
|
||||
return self.vision_tower(pixel_values).last_hidden_states
|
||||
|
||||
def forward(self, input_ids, pixel_values, **kwargs):
|
||||
# process your inputs
|
||||
return MyModelOutputWithPast(
|
||||
last_hidden_state=last_hidden_state,
|
||||
image_hidden_states=image_features,
|
||||
[...]
|
||||
)
|
||||
|
||||
class MyMultimodalModelForConditionalGeneration(MyMultimodalPreTrainedModel, GenerationMixin):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.model = MyMultimodalModel(config)
|
||||
self.lm_head = nn.Linear(hidden_dim, vocab_size)
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
2. A multimodal model config must be nested with the following fields.
|
||||
* text_config: decoder language model config
|
||||
* vision_config: vision encoder config
|
||||
* image_token_id: ID of the image placeholder token used in the input to indicate image position
|
||||
|
||||
3. A multimodal model's processing class must have the `self.image_token` and `self.image_token_ids` attributes. These are placeholder tokens used to indicate image positions in the input. The placeholder token is the same token used in the input prompt and to mask scatter image features.
|
||||
|
||||
The processing class also needs `self._get_num_multimodal_tokens` method to compute the number of placeholder tokens needed for multimodal inputs with given sizes and to return a [`MultiModalData`] object. The placeholder for row and column tokens don't count as image placeholders. Only the tokens that are actually replaced by image features are computed.
|
||||
|
||||
Finally, when `return_mm_token_type_ids=True`, the class has to return `mm_token_type_ids` to indicate whether each position is a text token (`0`) or image placeholder token (`1`). Each image's token type IDs must be contiguous with no breaks between consecutive ones.
|
||||
3. The processor needs to check the value of `return_mm_token_type_ids` and return `mm_token_type_ids` to indicate whether each position is a text token (`0`), image placeholder token (`1`) or video placeholder token (`2`). Each multimodal token type ID sequence must be contiguous without breaks between consecutive tokens, therefore special tokens for begin/end/row/column must be treated as placeholders.
|
||||
|
||||
Expand the code below for an example.
|
||||
|
||||
@ -246,5 +202,5 @@ class MyMultimodalProcessor(ProcessorMixin):
|
||||
|
||||
## Resources
|
||||
|
||||
* Read the [Transformers backend integration in vLLM](https://blog.vllm.ai/2025/04/11/transformers-backend.html) blog post for more details about the Transformers backend in vLLM.
|
||||
* Read the [Transformers backend integration in SGLang](https://huggingface.co/blog/transformers-backend-sglang) blog post for more details about the Transformers backend in SGLang.
|
||||
* Read the [Transformers modeling backend integration in vLLM](https://blog.vllm.ai/2025/04/11/transformers-backend.html) blog post for more details about the Transformers modeling backend in vLLM.
|
||||
* Read the [Transformers modeling backend integration in SGLang](https://huggingface.co/blog/transformers-backend-sglang) blog post for more details about the Transformers modeling backend in SGLang.
|
||||
|
||||
@ -170,7 +170,7 @@ Per quanto riguarda la classe `TrainingArguments`:
|
||||
- L'argomento `evaluate_during_training` di `TrainingArguments` è deprecato a favore di `eval_strategy`.
|
||||
|
||||
Per quanto riguarda il modello Transfo-XL:
|
||||
- L'attributo di configurazione `tie_weight` di Transfo-XL diventa `tie_words_embeddings`.
|
||||
- L'attributo di configurazione `tie_weight` di Transfo-XL diventa `tie_word_embeddings`.
|
||||
- Il metodo di modellazione `reset_length` di Transfo-XL diventa `reset_memory_length`.
|
||||
|
||||
Per quanto riguarda le pipeline:
|
||||
|
||||
@ -406,16 +406,16 @@ model = BrandNewBertModel(BrandNewBertConfig())
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.zero_()
|
||||
module.weight.fill_(1.0)
|
||||
```
|
||||
|
||||
特定のモジュールに特別な初期化が必要な場合、カスタムスキームをさらに持つことができます。たとえば、
|
||||
@ -431,9 +431,9 @@ def _init_weights(self, module):
|
||||
module.project_hid._is_hf_initialized = True
|
||||
module.project_q._is_hf_initialized = True
|
||||
elif isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
```
|
||||
|
||||
`_is_hf_initialized`フラグは、サブモジュールを一度だけ初期化することを確実にするために内部で使用されます。
|
||||
|
||||
@ -348,16 +348,16 @@ model = BrandNewBertModel(BrandNewBertConfig())
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.zero_()
|
||||
module.weight.fill_(1.0)
|
||||
```
|
||||
|
||||
몇 가지 모듈에 대해 특별한 초기화가 필요한 경우 사용자 정의 방식을 사용할 수도 있습니다. 예를 들어, `Wav2Vec2ForPreTraining`에서 마지막 두 개의 선형 레이어는 일반적인 PyTorch `nn.Linear`의 초기화를 가져야 하지만, 다른 모든 레이어는 위와 같은 초기화를 사용해야 합니다. 이는 다음과 같이 코드화됩니다:
|
||||
@ -371,9 +371,9 @@ def _init_weights(self, module):
|
||||
module.project_hid._is_hf_initialized = True
|
||||
module.project_q._is_hf_initialized = True
|
||||
elif isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
```
|
||||
|
||||
`_is_hf_initialized` 플래그는 서브모듈을 한 번만 초기화하도록 내부적으로 사용됩니다. `module.project_q` 및 `module.project_hid`에 대해 `True`로 설정함으로써, 우리가 수행한 사용자 정의 초기화가 이후에 덮어쓰이지 않도록 합니다. 즉, `_init_weights` 함수가 이들에게 적용되지 않습니다.
|
||||
|
||||
@ -152,7 +152,7 @@ class ParallelInterface(MutableMapping):
|
||||
```python
|
||||
class Llama4TextExperts(nn.Module):
|
||||
...
|
||||
self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, 2 * self.expert_dim))
|
||||
self.gate_up_proj = nn.Parameter(torch.zeros(self.num_experts, self.hidden_size, 2 * self.expert_dim))
|
||||
```
|
||||
|
||||
배치 행렬 곱셈을 `forward` 패스에서 사용하여 `gate_up_proj` 모듈의 출력을 계산할 수 있습니다.
|
||||
|
||||
@ -502,16 +502,10 @@ class DummyBertLMPredictionHead(nn.Module):
|
||||
|
||||
# The output weights are the same as the input embeddings, but there is
|
||||
# an output-only bias for each token.
|
||||
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=True)
|
||||
|
||||
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
||||
|
||||
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
||||
self.decoder.bias = self.bias
|
||||
|
||||
def _tie_weights(self):
|
||||
self.decoder.bias = self.bias
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.transform(hidden_states)
|
||||
hidden_states = self.decoder(hidden_states)
|
||||
@ -536,18 +530,18 @@ class DummyBertPreTrainedModel(PreTrainedModel):
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.zero_()
|
||||
module.weight.fill_(1.0)
|
||||
elif isinstance(module, DummyBertLMPredictionHead):
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
|
||||
|
||||
@auto_docstring(
|
||||
|
||||
@ -265,7 +265,7 @@ class MyNewModel2PreTrainedModel(PreTrainedModel):
|
||||
|
||||
# We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight)
|
||||
if "RMSNorm" in module.__class__.__name__:
|
||||
module.weight.data.zero_()
|
||||
module.weight.zero_()
|
||||
|
||||
|
||||
class MyNewModel2ForSequenceClassification(GenericForSequenceClassification, MyNewModel2PreTrainedModel):
|
||||
|
||||
@ -104,9 +104,9 @@ class NewTaskModelPreTrainedModel(PreTrainedModel):
|
||||
std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
|
||||
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
module.weight.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
|
||||
|
||||
def token_type_ids_mask_function(
|
||||
@ -428,7 +428,7 @@ class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin):
|
||||
"^multi_modal_projector": "model.multi_modal_projector",
|
||||
"^language_model.lm_head": "lm_head",
|
||||
}
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"}
|
||||
main_input_name: ClassVar[str] = "doc_input_ids" # transformers-related
|
||||
|
||||
def __init__(self, config):
|
||||
@ -440,7 +440,15 @@ class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin):
|
||||
self.custom_text_proj = nn.Linear(self.config.text_config.hidden_size, self.embedding_dim)
|
||||
|
||||
if self.language_model._tied_weights_keys is not None:
|
||||
self._tied_weights_keys = [f"model.language_model.{k}" for k in self.language_model._tied_weights_keys]
|
||||
prefix = "model.language_model."
|
||||
prefixed_mapping = {
|
||||
f"{prefix}{target}": f"{prefix}{source}"
|
||||
for target, source in self.language_model._tied_weights_keys.items()
|
||||
}
|
||||
if isinstance(self._tied_weights_keys, dict):
|
||||
self._tied_weights_keys.update(prefixed_mapping)
|
||||
else:
|
||||
self._tied_weights_keys = prefixed_mapping
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
|
||||
@ -505,16 +505,10 @@ class RobertaLMPredictionHead(nn.Module):
|
||||
|
||||
# The output weights are the same as the input embeddings, but there is
|
||||
# an output-only bias for each token.
|
||||
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=True)
|
||||
|
||||
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
||||
|
||||
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
||||
self.decoder.bias = self.bias
|
||||
|
||||
def _tie_weights(self):
|
||||
self.decoder.bias = self.bias
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.transform(hidden_states)
|
||||
hidden_states = self.decoder(hidden_states)
|
||||
@ -539,18 +533,18 @@ class RobertaPreTrainedModel(PreTrainedModel):
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.zero_()
|
||||
module.weight.fill_(1.0)
|
||||
elif isinstance(module, RobertaLMPredictionHead):
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
|
||||
|
||||
@auto_docstring(
|
||||
|
||||
@ -846,11 +846,11 @@ class TestDetrPreTrainedModel(PreTrainedModel):
|
||||
nn.init.xavier_uniform_(module.output_proj.weight.data)
|
||||
nn.init.constant_(module.output_proj.bias.data, 0.0)
|
||||
elif isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
module.weight.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
module.weight.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
if hasattr(module, "reference_points") and not self.config.two_stage:
|
||||
|
||||
@ -19,7 +19,15 @@ class NewTaskModelForNewTask(PaliGemmaForConditionalGeneration):
|
||||
self.custom_text_proj = nn.Linear(self.config.text_config.hidden_size, self.embedding_dim)
|
||||
|
||||
if self.language_model._tied_weights_keys is not None:
|
||||
self._tied_weights_keys = [f"model.language_model.{k}" for k in self.language_model._tied_weights_keys]
|
||||
prefix = "model.language_model."
|
||||
prefixed_mapping = {
|
||||
f"{prefix}{target}": f"{prefix}{source}"
|
||||
for target, source in self.language_model._tied_weights_keys.items()
|
||||
}
|
||||
if isinstance(self._tied_weights_keys, dict):
|
||||
self._tied_weights_keys.update(prefixed_mapping)
|
||||
else:
|
||||
self._tied_weights_keys = prefixed_mapping
|
||||
|
||||
self.post_init()
|
||||
|
||||
|
||||
@ -27,7 +27,6 @@
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import warnings
|
||||
from dataclasses import dataclass, field
|
||||
from random import randint
|
||||
from typing import Optional
|
||||
@ -180,29 +179,11 @@ class ModelArguments:
|
||||
)
|
||||
},
|
||||
)
|
||||
freeze_feature_extractor: Optional[bool] = field(
|
||||
default=None, metadata={"help": "Whether to freeze the feature extractor layers of the model."}
|
||||
)
|
||||
ignore_mismatched_sizes: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Will enable to load a pretrained model whose head dimensions are different."},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if not self.freeze_feature_extractor and self.freeze_feature_encoder:
|
||||
warnings.warn(
|
||||
"The argument `--freeze_feature_extractor` is deprecated and "
|
||||
"will be removed in a future version. Use `--freeze_feature_encoder` "
|
||||
"instead. Setting `freeze_feature_encoder==True`.",
|
||||
FutureWarning,
|
||||
)
|
||||
if self.freeze_feature_extractor and not self.freeze_feature_encoder:
|
||||
raise ValueError(
|
||||
"The argument `--freeze_feature_extractor` is deprecated and "
|
||||
"should not be used in combination with `--freeze_feature_encoder`. "
|
||||
"Only make use of `--freeze_feature_encoder`."
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
# See all possible arguments in src/transformers/training_args.py
|
||||
|
||||
@ -17,6 +17,7 @@ import contextlib
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from itertools import cycle
|
||||
from typing import Optional
|
||||
|
||||
import datasets
|
||||
@ -29,42 +30,32 @@ from transformers.generation import GenerationConfig
|
||||
from transformers.generation.continuous_batching.requests import logger
|
||||
|
||||
|
||||
# MODEL_ID = "Qwen/Qwen3-4B-Instruct-2507"
|
||||
SLIDING_WINDOW = 0
|
||||
MODEL_ID = "google/gemma-2-2b-it" if SLIDING_WINDOW > 0 else "meta-llama/Meta-Llama-3-8B"
|
||||
FORCE_MAX_LENGTH = False # should be False unless you are debugging sliding window features
|
||||
SKIP_SPECIAL_TOKENS = False
|
||||
|
||||
|
||||
def generate_simple(
|
||||
attn_impl: str, simple_batch_inputs: list[int], generation_config: GenerationConfig
|
||||
def generate_without_cb(
|
||||
model_id: str, sliding_window: int, attn_impl: str, batched_inputs: list[int], generation_config: GenerationConfig
|
||||
) -> dict[str, str]:
|
||||
attn_impl = {
|
||||
"sdpa": "sdpa",
|
||||
"eager": "eager",
|
||||
"paged_attention": "eager", # TODO: this does not work on AMD docker
|
||||
"flash_paged": "flash_attention_2", # TODO: this does not work on AMD docker
|
||||
"kernels-community/flash-attn": "eager",
|
||||
}[attn_impl]
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, dtype=torch.bfloat16, attn_implementation=attn_impl)
|
||||
# Setup model and tokenizer
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, dtype=torch.bfloat16, attn_implementation=attn_impl)
|
||||
model = model.cuda().eval()
|
||||
if getattr(model.config, "sliding_window", None) is not None:
|
||||
model.config.sliding_window = SLIDING_WINDOW
|
||||
|
||||
if sliding_window > 0 and getattr(model.config, "sliding_window", None) is not None:
|
||||
model.config.sliding_window = sliding_window
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
# Generate one by one
|
||||
decoded_outputs = {}
|
||||
for input_ids in tqdm(simple_batch_inputs, desc="Generating outputs without CB"):
|
||||
for input_ids in tqdm(batched_inputs, desc="Generating outputs without CB"):
|
||||
key = " ".join(map(str, input_ids)) # This will be used to identify the output after batched generation
|
||||
input_ids = torch.tensor([input_ids]).to("cuda")
|
||||
# attention_mask = torch.ones_like(input_ids)
|
||||
outputs = model.generate(input_ids, generation_config=generation_config, use_model_defaults=False)
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
outputs = model.generate(
|
||||
input_ids, attention_mask=attention_mask, generation_config=generation_config, use_model_defaults=False
|
||||
)
|
||||
generated_tokens = outputs[0][input_ids.shape[1] :]
|
||||
decoded_output = tokenizer.decode(generated_tokens, skip_special_tokens=SKIP_SPECIAL_TOKENS)
|
||||
decoded_outputs[key] = decoded_output
|
||||
decoded_outputs[key] = tokenizer.decode(generated_tokens, skip_special_tokens=False)
|
||||
return decoded_outputs
|
||||
|
||||
|
||||
def setup_metrics():
|
||||
def maybe_setup_metrics(use_metrics: bool) -> None:
|
||||
if not use_metrics:
|
||||
return
|
||||
try:
|
||||
from opentelemetry import metrics, trace
|
||||
from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter
|
||||
@ -119,16 +110,14 @@ def batch_generate(
|
||||
token_count = 0
|
||||
data = []
|
||||
for i, request in enumerate(batch_outputs):
|
||||
input_text = tokenizer.decode(batch_outputs[request].prompt_ids, skip_special_tokens=SKIP_SPECIAL_TOKENS)
|
||||
input_text = tokenizer.decode(batch_outputs[request].prompt_ids, skip_special_tokens=False)
|
||||
# The key is used to tie back to the output of unbatched generation
|
||||
key = " ".join(map(str, batch_outputs[request].prompt_ids))
|
||||
data.append({"input": input_text, "key": key})
|
||||
|
||||
# Try to decode the output
|
||||
try:
|
||||
output_text = tokenizer.decode(
|
||||
batch_outputs[request].generated_tokens, skip_special_tokens=SKIP_SPECIAL_TOKENS
|
||||
)
|
||||
output_text = tokenizer.decode(batch_outputs[request].generated_tokens, skip_special_tokens=False)
|
||||
token_count += len(batch_outputs[request].generated_tokens[1:])
|
||||
data[-1]["cb_outputs"] = output_text
|
||||
except Exception as e:
|
||||
@ -138,14 +127,7 @@ def batch_generate(
|
||||
|
||||
# Display sample if asked
|
||||
if i < displayed_samples:
|
||||
if len(output_text) > 0:
|
||||
print("-" * 20)
|
||||
print(f"{request} Input: {input_text}")
|
||||
print(f"{request} Output: {output_text}")
|
||||
else:
|
||||
print(f"{request} Input: {input_text}")
|
||||
print("[WARN]")
|
||||
print(f"{request} Output was empty!")
|
||||
print("-" * 20, f"{request} Input: {input_text}", f"{request} Output: {output_text}", sep="\n")
|
||||
|
||||
# Compare with classic generate if asked
|
||||
if expected_outputs is not None:
|
||||
@ -182,75 +164,102 @@ def batch_generate(
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Parse args
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
# Continuous batching parameters
|
||||
parser.add_argument("--num-blocks", "-n", type=int, default=None)
|
||||
parser.add_argument("--max-batch-tokens", "-b", type=int, default=None)
|
||||
|
||||
# Model parameters
|
||||
parser.add_argument("--sliding-window", type=int, default=0)
|
||||
parser.add_argument("--attn", type=str, default="kernels-community/flash-attn", help="Attention implementation")
|
||||
|
||||
# Performance parameters
|
||||
parser.add_argument("--matmul-precision", "-mp", type=str, default="high") # set to "none" to disable
|
||||
parser.add_argument("--cuda-graph", "-cg", help="Use cuda graphs", type=str, default=None)
|
||||
parser.add_argument("--compile", action="store_true", help="Compile the model using torch.compile")
|
||||
parser.add_argument("--do-sample", action="store_true", help="Activate sampling")
|
||||
|
||||
# Benchmark parameters
|
||||
parser.add_argument("--samples", type=int, default=500, help="Number of samples to generate")
|
||||
parser.add_argument("--add-prefix", action="store_true", help="Add a prefix to the samples")
|
||||
parser.add_argument("--compare", action="store_true", help="Compare CB generation with classic generate")
|
||||
parser.add_argument("--profile", type=str, default=None)
|
||||
parser.add_argument("--metrics", action="store_true")
|
||||
parser.add_argument("--force-max-length", action="store_true", help="Force generation to stop at max length")
|
||||
|
||||
# Display parameters
|
||||
parser.add_argument("--displayed", type=int, default=0, help="Number of samples to display")
|
||||
parser.add_argument("--log-level", type=str, default="INFO")
|
||||
parser.add_argument("--output-file", type=str, default=None)
|
||||
parser.add_argument("--compare", action="store_true")
|
||||
parser.add_argument("--metrics", action="store_true")
|
||||
parser.add_argument("--profile", type=str, default=None)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Set log level
|
||||
# Create model
|
||||
model_id = "google/gemma-2-2b-it" if args.sliding_window > 0 else "meta-llama/Llama-3.1-8B-Instruct"
|
||||
has_system_role = args.sliding_window == 0
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation=args.attn, dtype=torch.bfloat16)
|
||||
model = model.cuda().eval()
|
||||
|
||||
if args.sliding_window > 0 and getattr(model.config, "sliding_window", None) is not None:
|
||||
print(f"Setting sliding window from {model.config.sliding_window} to {args.sliding_window}")
|
||||
model.config.sliding_window = args.sliding_window
|
||||
|
||||
# Set up diagnostics
|
||||
logger.setLevel(args.log_level.upper())
|
||||
maybe_setup_metrics(args.metrics)
|
||||
|
||||
# If turned on, we setup metrics
|
||||
if args.metrics:
|
||||
setup_metrics()
|
||||
|
||||
# Set matmul precision if not none
|
||||
# Set up performance
|
||||
if args.matmul_precision != "none":
|
||||
torch.set_float32_matmul_precision(args.matmul_precision)
|
||||
# Parse cuda graph argument
|
||||
if args.cuda_graph is not None:
|
||||
use_cuda_graph = {
|
||||
"none": None,
|
||||
"yes": True, "y": True, "true": True, "t": True, "1": True,
|
||||
"no": False, "n": False, "false": False, "f": False, "0": False,
|
||||
}[args.cuda_graph.lower()] # fmt: skip
|
||||
else:
|
||||
use_cuda_graph = None
|
||||
|
||||
# Prepare model
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
MODEL_ID,
|
||||
attn_implementation=args.attn,
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
model = model.cuda().eval()
|
||||
if getattr(model.config, "sliding_window", None) is not None:
|
||||
print(f"Setting sliding window from {model.config.sliding_window} to {SLIDING_WINDOW}")
|
||||
model.config.sliding_window = SLIDING_WINDOW
|
||||
cuda_graph_arg = args.cuda_graph.lower() if args.cuda_graph is not None else None
|
||||
use_cuda_graph = {
|
||||
"none": None, None: None,
|
||||
"yes": True, "y": True, "true": True, "t": True, "1": True,
|
||||
"no": False, "n": False, "false": False, "f": False, "0": False,
|
||||
}[cuda_graph_arg] # fmt: skip
|
||||
|
||||
# If turned on, we compile the model
|
||||
if args.compile:
|
||||
model.forward = torch.compile(model.forward, mode="max-autotune-no-cudagraphs")
|
||||
|
||||
# Prepare tokenizer and dataset
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, padding_side="left")
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="left")
|
||||
|
||||
dataset = datasets.load_dataset("openai/gsm8k", "socratic", split="test")
|
||||
dataset = dataset.select(range(args.samples))
|
||||
|
||||
simple_batch_inputs = [tokenizer(item["question"])["input_ids"] for item in dataset]
|
||||
if args.add_prefix:
|
||||
possible_prefixes = [
|
||||
None,
|
||||
"You are a bot that solves math problems.",
|
||||
"You are a bot who solves math problems. Try to make your answer clear and understandable, and include your stages of reasoning.",
|
||||
"You are a bot with the aim to solves math problems. Try to make your answer clear and understandable, and include your stages of reasoning. No loud words or emojis, all responses must be readable by a child. Here is now the problem:",
|
||||
] # fmt: skip
|
||||
else:
|
||||
possible_prefixes = [None]
|
||||
|
||||
batched_inputs = []
|
||||
for item, prefix in zip(dataset, cycle(possible_prefixes)):
|
||||
messages = []
|
||||
question = item["question"]
|
||||
if prefix is not None:
|
||||
if has_system_role:
|
||||
messages.append({"role": "system", "content": prefix})
|
||||
else:
|
||||
question = prefix + "\n\n" + question
|
||||
messages.append({"role": "user", "content": question})
|
||||
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
|
||||
batched_inputs.append(inputs["input_ids"])
|
||||
|
||||
# Prepare generation config
|
||||
generation_config = GenerationConfig(
|
||||
generation_cfg = GenerationConfig(
|
||||
max_new_tokens=512,
|
||||
use_cuda_graph=use_cuda_graph,
|
||||
eos_token_id=tokenizer.pad_token_id if FORCE_MAX_LENGTH else tokenizer.eos_token_id,
|
||||
eos_token_id=tokenizer.pad_token_id if args.force_max_length else tokenizer.eos_token_id,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
do_sample=not args.compare,
|
||||
do_sample=args.do_sample,
|
||||
temperature=0.8,
|
||||
top_p=0.9,
|
||||
num_blocks=args.num_blocks,
|
||||
@ -258,7 +267,12 @@ if __name__ == "__main__":
|
||||
)
|
||||
|
||||
# If we need to compare, we need to generate the reference outputs
|
||||
expected_outputs = generate_simple(args.attn, simple_batch_inputs, generation_config) if args.compare else None
|
||||
if args.compare:
|
||||
expected_outputs = generate_without_cb(
|
||||
model_id, args.sliding_window, args.attn, batched_inputs, generation_cfg
|
||||
)
|
||||
else:
|
||||
expected_outputs = None
|
||||
|
||||
# If no output file is provided, we pick a name based on the args
|
||||
if args.output_file is None:
|
||||
@ -271,8 +285,8 @@ if __name__ == "__main__":
|
||||
# Run warmup batch generation # TODO: understand why warmup incurs a large overhead during cache creation
|
||||
batch_generate(
|
||||
model,
|
||||
simple_batch_inputs[: min(5, args.samples)],
|
||||
generation_config,
|
||||
batched_inputs[: min(5, args.samples)],
|
||||
generation_cfg,
|
||||
tokenizer,
|
||||
displayed_samples=-1,
|
||||
)
|
||||
@ -285,8 +299,8 @@ if __name__ == "__main__":
|
||||
# Run batch generation
|
||||
gen_time, tok_per_sec = batch_generate(
|
||||
model,
|
||||
simple_batch_inputs,
|
||||
generation_config,
|
||||
batched_inputs,
|
||||
generation_cfg,
|
||||
tokenizer,
|
||||
displayed_samples=args.displayed,
|
||||
output_file=args.output_file,
|
||||
@ -297,5 +311,5 @@ if __name__ == "__main__":
|
||||
prof.export_chrome_trace(filename)
|
||||
|
||||
# Example usage:
|
||||
# python examples/pytorch/continuous_batching.py --attn sdpa_paged -mp none --samples 3 --compare
|
||||
# python examples/pytorch/continuous_batching.py --num-blocks 369 --max-batch-tokens 23 --attn sdpa_paged -mp none --samples 1 --displayed 0 --output-file sliced.json
|
||||
# python examples/pytorch/continuous_batching.py --attn sdpa --add-prefix --samples 10 --compare
|
||||
# python examples/pytorch/continuous_batching.py --attn flash_attention_2 -mp none --add-prefix --samples 500
|
||||
|
||||
@ -127,7 +127,7 @@ def parse_args():
|
||||
parser.add_argument(
|
||||
"--use_slow_tokenizer",
|
||||
action="store_true",
|
||||
help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).",
|
||||
help="If passed, will use a slow tokenizer (not backed by the Hugging Face Tokenizers library).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--per_device_train_batch_size",
|
||||
|
||||
@ -132,7 +132,7 @@ def parse_args():
|
||||
parser.add_argument(
|
||||
"--use_slow_tokenizer",
|
||||
action="store_true",
|
||||
help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).",
|
||||
help="If passed, will use a slow tokenizer (not backed by the Tokenizers library).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--per_device_train_batch_size",
|
||||
|
||||
@ -130,7 +130,7 @@ def parse_args():
|
||||
parser.add_argument(
|
||||
"--use_slow_tokenizer",
|
||||
action="store_true",
|
||||
help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).",
|
||||
help="If passed, will use a slow tokenizer (not backed by the Hugging Face Tokenizers library).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--per_device_train_batch_size",
|
||||
|
||||
@ -128,7 +128,7 @@ def parse_args():
|
||||
parser.add_argument(
|
||||
"--use_slow_tokenizer",
|
||||
action="store_true",
|
||||
help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).",
|
||||
help="If passed, will use a slow tokenizer (not backed by the HuggingFace Tokenizers library).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--per_device_train_batch_size",
|
||||
|
||||
@ -151,7 +151,7 @@ def parse_args():
|
||||
parser.add_argument(
|
||||
"--use_slow_tokenizer",
|
||||
action="store_true",
|
||||
help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).",
|
||||
help="If passed, will use a slow tokenizer (not backed by the Hugging Face Tokenizers library).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--per_device_train_batch_size",
|
||||
|
||||
@ -223,7 +223,7 @@ def parse_args():
|
||||
parser.add_argument(
|
||||
"--use_slow_tokenizer",
|
||||
action="store_true",
|
||||
help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).",
|
||||
help="If passed, will use a slow tokenizer (not backed by the Hugging Face Tokenizers library).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--per_device_train_batch_size",
|
||||
|
||||
@ -74,6 +74,7 @@ class ExamplesTestsNoTrainer(TestCasePlus):
|
||||
def tearDownClass(cls):
|
||||
shutil.rmtree(cls.tmpdir)
|
||||
|
||||
@slow
|
||||
@mock.patch.dict(os.environ, {"WANDB_MODE": "offline", "DVCLIVE_TEST": "true"})
|
||||
def test_run_glue_no_trainer(self):
|
||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||
@ -147,6 +148,7 @@ class ExamplesTestsNoTrainer(TestCasePlus):
|
||||
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
|
||||
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "mlm_no_trainer")))
|
||||
|
||||
@slow
|
||||
@mock.patch.dict(os.environ, {"WANDB_MODE": "offline", "DVCLIVE_TEST": "true"})
|
||||
def test_run_ner_no_trainer(self):
|
||||
# with so little data distributed training needs more epochs to get the score on par with 0/1 gpu
|
||||
@ -175,6 +177,7 @@ class ExamplesTestsNoTrainer(TestCasePlus):
|
||||
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
|
||||
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "ner_no_trainer")))
|
||||
|
||||
@slow
|
||||
@mock.patch.dict(os.environ, {"WANDB_MODE": "offline", "DVCLIVE_TEST": "true"})
|
||||
def test_run_squad_no_trainer(self):
|
||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||
@ -203,6 +206,7 @@ class ExamplesTestsNoTrainer(TestCasePlus):
|
||||
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
|
||||
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "qa_no_trainer")))
|
||||
|
||||
@slow
|
||||
@mock.patch.dict(os.environ, {"WANDB_MODE": "offline", "DVCLIVE_TEST": "true"})
|
||||
def test_run_swag_no_trainer(self):
|
||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||
@ -305,6 +309,7 @@ class ExamplesTestsNoTrainer(TestCasePlus):
|
||||
result = get_results(tmp_dir)
|
||||
self.assertGreaterEqual(result["eval_overall_accuracy"], 0.10)
|
||||
|
||||
@slow
|
||||
@mock.patch.dict(os.environ, {"WANDB_MODE": "offline", "DVCLIVE_TEST": "true"})
|
||||
def test_run_image_classification_no_trainer(self):
|
||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||
|
||||
@ -374,6 +374,7 @@ class ExamplesTests(TestCasePlus):
|
||||
result = get_results(tmp_dir)
|
||||
self.assertGreaterEqual(result["eval_bleu"], 30)
|
||||
|
||||
@slow
|
||||
def test_run_image_classification(self):
|
||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||
testargs = f"""
|
||||
@ -403,6 +404,7 @@ class ExamplesTests(TestCasePlus):
|
||||
result = get_results(tmp_dir)
|
||||
self.assertGreaterEqual(result["eval_accuracy"], 0.8)
|
||||
|
||||
@slow
|
||||
def test_run_speech_recognition_ctc(self):
|
||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||
testargs = f"""
|
||||
@ -573,6 +575,7 @@ class ExamplesTests(TestCasePlus):
|
||||
model = ViTMAEForPreTraining.from_pretrained(tmp_dir)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
@slow
|
||||
def test_run_semantic_segmentation(self):
|
||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||
testargs = f"""
|
||||
@ -597,6 +600,7 @@ class ExamplesTests(TestCasePlus):
|
||||
result = get_results(tmp_dir)
|
||||
self.assertGreaterEqual(result["eval_overall_accuracy"], 0.1)
|
||||
|
||||
@slow
|
||||
@patch.dict(os.environ, {"WANDB_DISABLED": "true"})
|
||||
def test_run_object_detection(self):
|
||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||
@ -624,6 +628,7 @@ class ExamplesTests(TestCasePlus):
|
||||
result = get_results(tmp_dir)
|
||||
self.assertGreaterEqual(result["test_map"], 0.1)
|
||||
|
||||
@slow
|
||||
@patch.dict(os.environ, {"WANDB_DISABLED": "true"})
|
||||
def test_run_instance_segmentation(self):
|
||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||
|
||||
@ -120,7 +120,7 @@ def parse_args():
|
||||
parser.add_argument(
|
||||
"--use_slow_tokenizer",
|
||||
action="store_true",
|
||||
help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).",
|
||||
help="If passed, will use a slow tokenizer (not backed by the Hugging Face Tokenizers library).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--per_device_train_batch_size",
|
||||
|
||||
@ -212,7 +212,7 @@ def parse_args():
|
||||
parser.add_argument(
|
||||
"--use_slow_tokenizer",
|
||||
action="store_true",
|
||||
help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).",
|
||||
help="If passed, will use a slow tokenizer (not backed by the Hugging Face Tokenizers library).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--per_device_train_batch_size",
|
||||
|
||||
@ -50,6 +50,7 @@ checkpoint: 检查点
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://huggingface.co/models"><img alt="Checkpoints on Hub" src="https://img.shields.io/endpoint?url=https://huggingface.co/api/shields/models&color=brightgreen"></a>
|
||||
<a href="https://circleci.com/gh/huggingface/transformers"><img alt="Build" src="https://img.shields.io/circleci/build/github/huggingface/transformers/main"></a>
|
||||
<a href="https://github.com/huggingface/transformers/blob/main/LICENSE"><img alt="GitHub" src="https://img.shields.io/github/license/huggingface/transformers.svg?color=blue"></a>
|
||||
<a href="https://huggingface.co/docs/transformers/index"><img alt="Documentation" src="https://img.shields.io/website/http/huggingface.co/docs/transformers/index.svg?down_color=red&down_message=offline&up_message=online"></a>
|
||||
@ -60,7 +61,7 @@ checkpoint: 检查点
|
||||
|
||||
<h4 align="center">
|
||||
<p>
|
||||
<a href="https://github.com/huggingface/transformers/">English</a> |
|
||||
<a href="https://github.com/huggingface/transformers/blob/main/README.md">English</a> |
|
||||
<b>简体中文</b> |
|
||||
<a href="https://github.com/huggingface/transformers/blob/main/i18n/README_zh-hant.md">繁體中文</a> |
|
||||
<a href="https://github.com/huggingface/transformers/blob/main/i18n/README_ko.md">한국어</a> |
|
||||
@ -68,7 +69,7 @@ checkpoint: 检查点
|
||||
<a href="https://github.com/huggingface/transformers/blob/main/i18n/README_ja.md">日本語</a> |
|
||||
<a href="https://github.com/huggingface/transformers/blob/main/i18n/README_hd.md">हिन्दी</a> |
|
||||
<a href="https://github.com/huggingface/transformers/blob/main/i18n/README_ru.md">Русский</a> |
|
||||
<a href="https://github.com/huggingface/transformers/blob/main/i18n/README_pt-br.md">Рortuguês</a> |
|
||||
<a href="https://github.com/huggingface/transformers/blob/main/i18n/README_pt-br.md">Português</a> |
|
||||
<a href="https://github.com/huggingface/transformers/blob/main/i18n/README_te.md">తెలుగు</a> |
|
||||
<a href="https://github.com/huggingface/transformers/blob/main/i18n/README_fr.md">Français</a> |
|
||||
<a href="https://github.com/huggingface/transformers/blob/main/i18n/README_de.md">Deutsch</a> |
|
||||
@ -81,182 +82,258 @@ checkpoint: 检查点
|
||||
</h4>
|
||||
|
||||
<h3 align="center">
|
||||
<p>为 Jax、PyTorch 和 TensorFlow 打造的先进的自然语言处理函数库</p>
|
||||
<p>为文本、视觉、音频、视频与多模态提供推理与训练的先进预训练模型</p>
|
||||
</h3>
|
||||
|
||||
<h3 align="center">
|
||||
<a href="https://hf.co/course"><img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/course_banner.png"></a>
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/transformers_as_a_model_definition.png"/>
|
||||
</h3>
|
||||
|
||||
🤗 Transformers 提供了数以千计的预训练模型,支持 100 多种语言的文本分类、信息抽取、问答、摘要、翻译、文本生成。它的宗旨是让最先进的 NLP 技术人人易用。
|
||||
Transformers 充当跨文本、计算机视觉、音频、视频与多模态的最先进机器学习模型的「模型定义框架」,同时覆盖推理与训练。
|
||||
|
||||
🤗 Transformers 提供了便于快速下载和使用的API,让你可以把预训练模型用在给定文本、在你的数据集上微调然后通过 [model hub](https://huggingface.co/models) 与社区共享。同时,每个定义的 Python 模块都是完全独立的,便于修改和快速进行研究实验。
|
||||
它将模型的定义集中化,使整个生态系统对该定义达成一致。`transformers` 是跨框架的枢纽:一旦某模型定义被支持,它通常就能兼容多数训练框架(如 Axolotl、Unsloth、DeepSpeed、FSDP、PyTorch‑Lightning 等)、推理引擎(如 vLLM、SGLang、TGI 等),以及依赖 `transformers` 模型定义的相关库(如 llama.cpp、mlx 等)。
|
||||
|
||||
🤗 Transformers 支持三个最热门的深度学习库: [Jax](https://jax.readthedocs.io/en/latest/), [PyTorch](https://pytorch.org/) 以及 [TensorFlow](https://www.tensorflow.org/) — 并与之无缝整合。你可以直接使用一个框架训练你的模型然后用另一个加载和推理。
|
||||
我们的目标是持续支持新的最先进模型,并通过让模型定义保持简单、可定制且高效来普及其使用。
|
||||
|
||||
## 在线演示
|
||||
|
||||
你可以直接在模型页面上测试大多数 [model hub](https://huggingface.co/models) 上的模型。 我们也提供了 [私有模型托管、模型版本管理以及推理API](https://huggingface.co/pricing)。
|
||||
|
||||
这里是一些例子:
|
||||
- [用 BERT 做掩码填词](https://huggingface.co/google-bert/bert-base-uncased?text=Paris+is+the+%5BMASK%5D+of+France)
|
||||
- [用 Electra 做命名实体识别](https://huggingface.co/dbmdz/electra-large-discriminator-finetuned-conll03-english?text=My+name+is+Sarah+and+I+live+in+London+city)
|
||||
- [用 GPT-2 做文本生成](https://huggingface.co/openai-community/gpt2?text=A+long+time+ago%2C+)
|
||||
- [用 RoBERTa 做自然语言推理](https://huggingface.co/FacebookAI/roberta-large-mnli?text=The+dog+was+lost.+Nobody+lost+any+animal)
|
||||
- [用 BART 做文本摘要](https://huggingface.co/facebook/bart-large-cnn?text=The+tower+is+324+metres+%281%2C063+ft%29+tall%2C+about+the+same+height+as+an+81-storey+building%2C+and+the+tallest+structure+in+Paris.+Its+base+is+square%2C+measuring+125+metres+%28410+ft%29+on+each+side.+During+its+construction%2C+the+Eiffel+Tower+surpassed+the+Washington+Monument+to+become+the+tallest+man-made+structure+in+the+world%2C+a+title+it+held+for+41+years+until+the+Chrysler+Building+in+New+York+City+was+finished+in+1930.+It+was+the+first+structure+to+reach+a+height+of+300+metres.+Due+to+the+addition+of+a+broadcasting+aerial+at+the+top+of+the+tower+in+1957%2C+it+is+now+taller+than+the+Chrysler+Building+by+5.2+metres+%2817+ft%29.+Excluding+transmitters%2C+the+Eiffel+Tower+is+the+second+tallest+free-standing+structure+in+France+after+the+Millau+Viaduct)
|
||||
- [用 DistilBERT 做问答](https://huggingface.co/distilbert/distilbert-base-uncased-distilled-squad?text=Which+name+is+also+used+to+describe+the+Amazon+rainforest+in+English%3F&context=The+Amazon+rainforest+%28Portuguese%3A+Floresta+Amaz%C3%B4nica+or+Amaz%C3%B4nia%3B+Spanish%3A+Selva+Amaz%C3%B3nica%2C+Amazon%C3%ADa+or+usually+Amazonia%3B+French%3A+For%C3%AAt+amazonienne%3B+Dutch%3A+Amazoneregenwoud%29%2C+also+known+in+English+as+Amazonia+or+the+Amazon+Jungle%2C+is+a+moist+broadleaf+forest+that+covers+most+of+the+Amazon+basin+of+South+America.+This+basin+encompasses+7%2C000%2C000+square+kilometres+%282%2C700%2C000+sq+mi%29%2C+of+which+5%2C500%2C000+square+kilometres+%282%2C100%2C000+sq+mi%29+are+covered+by+the+rainforest.+This+region+includes+territory+belonging+to+nine+nations.+The+majority+of+the+forest+is+contained+within+Brazil%2C+with+60%25+of+the+rainforest%2C+followed+by+Peru+with+13%25%2C+Colombia+with+10%25%2C+and+with+minor+amounts+in+Venezuela%2C+Ecuador%2C+Bolivia%2C+Guyana%2C+Suriname+and+French+Guiana.+States+or+departments+in+four+nations+contain+%22Amazonas%22+in+their+names.+The+Amazon+represents+over+half+of+the+planet%27s+remaining+rainforests%2C+and+comprises+the+largest+and+most+biodiverse+tract+of+tropical+rainforest+in+the+world%2C+with+an+estimated+390+billion+individual+trees+divided+into+16%2C000+species)
|
||||
- [用 T5 做翻译](https://huggingface.co/google-t5/t5-base?text=My+name+is+Wolfgang+and+I+live+in+Berlin)
|
||||
|
||||
**[Write With Transformer](https://transformer.huggingface.co)**,由 Hugging Face 团队打造,是一个文本生成的官方 demo。
|
||||
|
||||
## 如果你在寻找由 Hugging Face 团队提供的定制化支持服务
|
||||
|
||||
<a target="_blank" href="https://huggingface.co/support">
|
||||
<img alt="HuggingFace Expert Acceleration Program" src="https://huggingface.co/front/thumbnails/support.png" style="max-width: 600px; border: 1px solid #eee; border-radius: 4px; box-shadow: 0 1px 2px 0 rgba(0, 0, 0, 0.05);">
|
||||
</a><br>
|
||||
|
||||
## 快速上手
|
||||
|
||||
我们为快速使用模型提供了 `pipeline` API。Pipeline 聚合了预训练模型和对应的文本预处理。下面是一个快速使用 pipeline 去判断正负面情绪的例子:
|
||||
|
||||
```python
|
||||
>>> from transformers import pipeline
|
||||
|
||||
# 使用情绪分析 pipeline
|
||||
>>> classifier = pipeline('sentiment-analysis')
|
||||
>>> classifier('We are very happy to introduce pipeline to the transformers repository.')
|
||||
[{'label': 'POSITIVE', 'score': 0.9996980428695679}]
|
||||
```
|
||||
|
||||
第二行代码下载并缓存了 pipeline 使用的预训练模型,而第三行代码则在给定的文本上进行了评估。这里的答案"正面" (positive) 具有 99 的置信度。
|
||||
|
||||
许多的 NLP 任务都有开箱即用的预训练 `pipeline`。比如说,我们可以轻松的从给定文本中抽取问题答案:
|
||||
|
||||
``` python
|
||||
>>> from transformers import pipeline
|
||||
|
||||
# 使用问答 pipeline
|
||||
>>> question_answerer = pipeline('question-answering')
|
||||
>>> question_answerer({
|
||||
... 'question': 'What is the name of the repository ?',
|
||||
... 'context': 'Pipeline has been included in the huggingface/transformers repository'
|
||||
... })
|
||||
{'score': 0.30970096588134766, 'start': 34, 'end': 58, 'answer': 'huggingface/transformers'}
|
||||
|
||||
```
|
||||
|
||||
除了给出答案,预训练模型还给出了对应的置信度分数、答案在词符化 (tokenized) 后的文本中开始和结束的位置。你可以从[这个教程](https://huggingface.co/docs/transformers/task_summary)了解更多 `pipeline` API 支持的任务。
|
||||
|
||||
要在你的任务上下载和使用任意预训练模型也很简单,只需三行代码。这里是 PyTorch 版的示例:
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, AutoModel
|
||||
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
|
||||
>>> model = AutoModel.from_pretrained("google-bert/bert-base-uncased")
|
||||
|
||||
>>> inputs = tokenizer("Hello world!", return_tensors="pt")
|
||||
>>> outputs = model(**inputs)
|
||||
```
|
||||
这里是等效的 TensorFlow 代码:
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, TFAutoModel
|
||||
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
|
||||
>>> model = TFAutoModel.from_pretrained("google-bert/bert-base-uncased")
|
||||
|
||||
>>> inputs = tokenizer("Hello world!", return_tensors="tf")
|
||||
>>> outputs = model(**inputs)
|
||||
```
|
||||
|
||||
词符化器 (tokenizer) 为所有的预训练模型提供了预处理,并可以直接对单个字符串进行调用(比如上面的例子)或对列表 (list) 调用。它会输出一个你可以在下游代码里使用或直接通过 `**` 解包表达式传给模型的词典 (dict)。
|
||||
|
||||
模型本身是一个常规的 [Pytorch `nn.Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) 或 [TensorFlow `tf.keras.Model`](https://www.tensorflow.org/api_docs/python/tf/keras/Model)(取决于你的后端),可以常规方式使用。 [这个教程](https://huggingface.co/transformers/training.html)解释了如何将这样的模型整合到经典的 PyTorch 或 TensorFlow 训练循环中,或是如何使用我们的 `Trainer` (训练器)API 来在一个新的数据集上快速微调。
|
||||
|
||||
## 为什么要用 transformers?
|
||||
|
||||
1. 便于使用的先进模型:
|
||||
- NLU 和 NLG 上表现优越
|
||||
- 对教学和实践友好且低门槛
|
||||
- 高级抽象,只需了解三个类
|
||||
- 对所有模型统一的API
|
||||
|
||||
1. 更低计算开销,更少的碳排放:
|
||||
- 研究人员可以分享已训练的模型而非每次从头开始训练
|
||||
- 工程师可以减少计算用时和生产环境开销
|
||||
- 数十种模型架构、两千多个预训练模型、100多种语言支持
|
||||
|
||||
1. 对于模型生命周期的每一个部分都面面俱到:
|
||||
- 训练先进的模型,只需 3 行代码
|
||||
- 模型在不同深度学习框架间任意转移,随你心意
|
||||
- 为训练、评估和生产选择最适合的框架,衔接无缝
|
||||
|
||||
1. 为你的需求轻松定制专属模型和用例:
|
||||
- 我们为每种模型架构提供了多个用例来复现原论文结果
|
||||
- 模型内部结构保持透明一致
|
||||
- 模型文件可单独使用,方便修改和快速实验
|
||||
|
||||
## 什么情况下我不该用 transformers?
|
||||
|
||||
- 本库并不是模块化的神经网络工具箱。模型文件中的代码特意呈若璞玉,未经额外抽象封装,以便研究人员快速迭代修改而不致溺于抽象和文件跳转之中。
|
||||
- `Trainer` API 并非兼容任何模型,只为本库之模型优化。若是在寻找适用于通用机器学习的训练循环实现,请另觅他库。
|
||||
- 尽管我们已尽力而为,[examples 目录](https://github.com/huggingface/transformers/tree/main/examples)中的脚本也仅为用例而已。对于你的特定问题,它们并不一定开箱即用,可能需要改几行代码以适之。
|
||||
目前在 [Hugging Face Hub](https://huggingface.com/models) 上有超过 1M+ 使用 `transformers` 的[模型检查点](https://huggingface.co/models?library=transformers&sort=trending),可随取随用。
|
||||
|
||||
今天就去探索 Hub,找到一个模型,并用 Transformers 立刻开始吧。
|
||||
|
||||
## 安装
|
||||
|
||||
### 使用 pip
|
||||
Transformers 支持 Python 3.9+,以及 [PyTorch](https://pytorch.org/get-started/locally/) 2.1+。
|
||||
|
||||
这个仓库已在 Python 3.9+、Flax 0.4.1+、PyTorch 2.1+ 和 TensorFlow 2.6+ 下经过测试。
|
||||
使用 [venv](https://docs.python.org/3/library/venv.html) 或 [uv](https://docs.astral.sh/uv/)(一个基于 Rust 的快速 Python 包与项目管理器)创建并激活虚拟环境:
|
||||
|
||||
你可以在[虚拟环境](https://docs.python.org/3/library/venv.html)中安装 🤗 Transformers。如果你还不熟悉 Python 的虚拟环境,请阅此[用户说明](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/)。
|
||||
|
||||
首先,用你打算使用的版本的 Python 创建一个虚拟环境并激活。
|
||||
|
||||
然后,你需要安装 Flax、PyTorch 或 TensorFlow 其中之一。关于在你使用的平台上安装这些框架,请参阅 [TensorFlow 安装页](https://www.tensorflow.org/install/), [PyTorch 安装页](https://pytorch.org/get-started/locally/#start-locally) 或 [Flax 安装页](https://github.com/google/flax#quick-install)。
|
||||
|
||||
当这些后端之一安装成功后, 🤗 Transformers 可依此安装:
|
||||
|
||||
```bash
|
||||
pip install transformers
|
||||
```py
|
||||
# venv
|
||||
python -m venv .my-env
|
||||
source .my-env/bin/activate
|
||||
# uv
|
||||
uv venv .my-env
|
||||
source .my-env/bin/activate
|
||||
```
|
||||
|
||||
如果你想要试试用例或者想在正式发布前使用最新的开发中代码,你得[从源代码安装](https://huggingface.co/docs/transformers/installation#installing-from-source)。
|
||||
在虚拟环境中安装 Transformers:
|
||||
|
||||
### 使用 conda
|
||||
```py
|
||||
# pip
|
||||
pip install "transformers[torch]"
|
||||
|
||||
🤗 Transformers 可以通过 conda 依此安装:
|
||||
|
||||
```shell script
|
||||
conda install conda-forge::transformers
|
||||
# uv
|
||||
uv pip install "transformers[torch]"
|
||||
```
|
||||
|
||||
> **_笔记:_** 从 `huggingface` 渠道安装 `transformers` 已被废弃。
|
||||
如果你需要库中的最新改动或计划参与贡献,可从源码安装(注意:最新版可能不稳定;如遇错误,欢迎在 [issues](https://github.com/huggingface/transformers/issues) 中反馈):
|
||||
|
||||
要通过 conda 安装 Flax、PyTorch 或 TensorFlow 其中之一,请参阅它们各自安装页的说明。
|
||||
```shell
|
||||
git clone https://github.com/huggingface/transformers.git
|
||||
cd transformers
|
||||
|
||||
## 模型架构
|
||||
# pip
|
||||
pip install '.[torch]'
|
||||
|
||||
🤗 Transformers 支持的[**所有的模型检查点**](https://huggingface.co/models)由[用户](https://huggingface.co/users)和[组织](https://huggingface.co/organizations)上传,均与 huggingface.co [model hub](https://huggingface.co) 无缝整合。
|
||||
# uv
|
||||
uv pip install '.[torch]'
|
||||
```
|
||||
|
||||
目前的检查点数量: 
|
||||
## 快速上手
|
||||
|
||||
🤗 Transformers 目前支持如下的架构: 模型概述请阅[这里](https://huggingface.co/docs/transformers/model_summary).
|
||||
使用 [Pipeline](https://huggingface.co/docs/transformers/pipeline_tutorial) API 一步上手。`Pipeline` 是一个高级推理类,支持文本、音频、视觉与多模态任务,负责输入预处理并返回适配的输出。
|
||||
|
||||
要检查某个模型是否已有 Flax、PyTorch 或 TensorFlow 的实现,或其是否在 🤗 Tokenizers 库中有对应词符化器(tokenizer),敬请参阅[此表](https://huggingface.co/docs/transformers/index#supported-frameworks)。
|
||||
实例化一个用于文本生成的 pipeline,指定使用的模型。模型会被下载并缓存,方便复用。最后传入文本作为提示:
|
||||
|
||||
这些实现均已于多个数据集测试(请参看用例脚本)并应于原版实现表现相当。你可以在用例文档的[此节](https://huggingface.co/docs/transformers/examples)中了解表现的细节。
|
||||
```py
|
||||
from transformers import pipeline
|
||||
|
||||
pipeline = pipeline(task="text-generation", model="Qwen/Qwen2.5-1.5B")
|
||||
pipeline("the secret to baking a really good cake is ")
|
||||
[{'generated_text': 'the secret to baking a really good cake is 1) to use the right ingredients and 2) to follow the recipe exactly. the recipe for the cake is as follows: 1 cup of sugar, 1 cup of flour, 1 cup of milk, 1 cup of butter, 1 cup of eggs, 1 cup of chocolate chips. if you want to make 2 cakes, how much sugar do you need? To make 2 cakes, you will need 2 cups of sugar.'}]
|
||||
```
|
||||
|
||||
要与模型进行「聊天」,用法也一致。唯一不同是需要构造一段「聊天历史」(即 `Pipeline` 的输入):
|
||||
|
||||
> [!TIP]
|
||||
> 你也可以直接在命令行与模型聊天:
|
||||
> ```shell
|
||||
> transformers chat Qwen/Qwen2.5-0.5B-Instruct
|
||||
> ```
|
||||
|
||||
```py
|
||||
import torch
|
||||
from transformers import pipeline
|
||||
|
||||
chat = [
|
||||
{"role": "system", "content": "You are a sassy, wise-cracking robot as imagined by Hollywood circa 1986."},
|
||||
{"role": "user", "content": "Hey, can you tell me any fun things to do in New York?"}
|
||||
]
|
||||
|
||||
pipeline = pipeline(task="text-generation", model="meta-llama/Meta-Llama-3-8B-Instruct", dtype=torch.bfloat16, device_map="auto")
|
||||
response = pipeline(chat, max_new_tokens=512)
|
||||
print(response[0]["generated_text"][-1]["content"])
|
||||
```
|
||||
|
||||
展开下方示例,查看 `Pipeline` 在不同模态与任务中的用法。
|
||||
|
||||
<details>
|
||||
<summary>自动语音识别</summary>
|
||||
|
||||
```py
|
||||
from transformers import pipeline
|
||||
|
||||
pipeline = pipeline(task="automatic-speech-recognition", model="openai/whisper-large-v3")
|
||||
pipeline("https://huggingface.co/datasets/Narsil/asr_dummy/resolve/main/mlk.flac")
|
||||
{'text': ' I have a dream that one day this nation will rise up and live out the true meaning of its creed.'}
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>图像分类</summary>
|
||||
|
||||
<h3 align="center">
|
||||
<a><img src="https://huggingface.co/datasets/Narsil/image_dummy/raw/main/parrots.png"></a>
|
||||
</h3>
|
||||
|
||||
```py
|
||||
from transformers import pipeline
|
||||
|
||||
pipeline = pipeline(task="image-classification", model="facebook/dinov2-small-imagenet1k-1-layer")
|
||||
pipeline("https://huggingface.co/datasets/Narsil/image_dummy/raw/main/parrots.png")
|
||||
[{"label": "macaw", "score": 0.997848391532898},
|
||||
{"label": "sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita",
|
||||
"score": 0.0016551691805943847},
|
||||
{"label": "lorikeet", "score": 0.00018523589824326336},
|
||||
{"label": "African grey, African gray, Psittacus erithacus",
|
||||
"score": 7.85409429227002e-05},
|
||||
{"label": "quail", "score": 5.502637941390276e-05}]
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>视觉问答</summary>
|
||||
|
||||
<h3 align="center">
|
||||
<a><img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/idefics-few-shot.jpg"></a>
|
||||
</h3>
|
||||
|
||||
```py
|
||||
from transformers import pipeline
|
||||
|
||||
pipeline = pipeline(task="visual-question-answering", model="Salesforce/blip-vqa-base")
|
||||
pipeline(
|
||||
image="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/idefics-few-shot.jpg",
|
||||
question="What is in the image?",
|
||||
)
|
||||
[{"answer": "statue of liberty"}]
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
## 为什么要用 Transformers?
|
||||
|
||||
1. 易于使用的最先进模型:
|
||||
- 在自然语言理解与生成、计算机视觉、音频、视频与多模态任务上表现优越。
|
||||
- 对研究者、工程师与开发者友好且低门槛。
|
||||
- 少量用户侧抽象,仅需学习三个类。
|
||||
- 统一的 API,使用所有预训练模型体验一致。
|
||||
|
||||
1. 更低计算开销与更小碳足迹:
|
||||
- 共享已训练的模型,而非每次从零开始训练。
|
||||
- 减少计算时间与生产环境成本。
|
||||
- 覆盖数十种模型架构,跨所有模态提供 1M+ 预训练检查点。
|
||||
|
||||
1. 在模型生命周期的每个阶段都可以选用合适的框架:
|
||||
- 3 行代码即可训练最先进模型。
|
||||
- 在 PyTorch/JAX/TF2.0 间自由迁移同一个模型。
|
||||
- 为训练、评估与生产挑选最合适的框架。
|
||||
|
||||
1. 轻松定制模型或用例:
|
||||
- 为每个架构提供示例以复现原论文结果。
|
||||
- 尽可能一致地暴露模型内部。
|
||||
- 模型文件可独立于库使用,便于快速实验。
|
||||
|
||||
<a target="_blank" href="https://huggingface.co/enterprise">
|
||||
<img alt="Hugging Face Enterprise Hub" src="https://github.com/user-attachments/assets/247fb16d-d251-4583-96c4-d3d76dda4925">
|
||||
</a><br>
|
||||
|
||||
## 为什么我不该用 Transformers?
|
||||
|
||||
- 该库不是一个可自由拼搭的神经网络模块化工具箱。模型文件中的代码刻意减少额外抽象,以便研究者能快速在各个模型上迭代,而无需深入更多抽象或文件跳转。
|
||||
- 训练 API 优化用于 Transformers 提供的 PyTorch 模型。若需要通用的机器学习训练循环,请使用其它库,如 [Accelerate](https://huggingface.co/docs/accelerate)。
|
||||
- [示例脚本](https://github.com/huggingface/transformers/tree/main/examples)只是「示例」。它们不一定能直接适配你的具体用例,需要你进行必要的改动。
|
||||
|
||||
|
||||
## 了解更多
|
||||
## 100 个使用 Transformers 的项目
|
||||
|
||||
| 章节 | 描述 |
|
||||
|-|-|
|
||||
| [文档](https://huggingface.co/docs/transformers/) | 完整的 API 文档和教程 |
|
||||
| [任务总结](https://huggingface.co/docs/transformers/task_summary) | 🤗 Transformers 支持的任务 |
|
||||
| [预处理教程](https://huggingface.co/docs/transformers/preprocessing) | 使用 `Tokenizer` 来为模型准备数据 |
|
||||
| [训练和微调](https://huggingface.co/docs/transformers/training) | 在 PyTorch/TensorFlow 的训练循环或 `Trainer` API 中使用 🤗 Transformers 提供的模型 |
|
||||
| [快速上手:微调和用例脚本](https://github.com/huggingface/transformers/tree/main/examples) | 为各种任务提供的用例脚本 |
|
||||
| [模型分享和上传](https://huggingface.co/docs/transformers/model_sharing) | 和社区上传和分享你微调的模型 |
|
||||
| [迁移](https://huggingface.co/docs/transformers/migration) | 从 `pytorch-transformers` 或 `pytorch-pretrained-bert` 迁移到 🤗 Transformers |
|
||||
Transformers 不止是一个使用预训练模型的工具包,它还是围绕 Hugging Face Hub 构建的项目社区。我们希望 Transformers 能助力开发者、研究人员、学生、老师、工程师与任何人构建理想项目。
|
||||
|
||||
为庆祝 Transformers 获得 100,000 颗星,我们制作了 [awesome-transformers](./awesome-transformers.md) 页面,展示了 100 个由社区构建的优秀项目。
|
||||
|
||||
如果你拥有或使用某个项目,认为它应该在列表中出现,欢迎提交 PR 添加它!
|
||||
|
||||
## 示例模型
|
||||
|
||||
你可以直接在它们的 [Hub 模型页](https://huggingface.co/models) 上测试我们的多数模型。
|
||||
|
||||
展开每个模态以查看不同用例中的部分示例模型。
|
||||
|
||||
<details>
|
||||
<summary>音频</summary>
|
||||
|
||||
- 使用 [Whisper](https://huggingface.co/openai/whisper-large-v3-turbo) 进行音频分类
|
||||
- 使用 [Moonshine](https://huggingface.co/UsefulSensors/moonshine) 进行自动语音识别
|
||||
- 使用 [Wav2Vec2](https://huggingface.co/superb/wav2vec2-base-superb-ks) 进行关键词检索
|
||||
- 使用 [Moshi](https://huggingface.co/kyutai/moshiko-pytorch-bf16) 进行语音到语音生成
|
||||
- 使用 [MusicGen](https://huggingface.co/facebook/musicgen-large) 文本到音频生成
|
||||
- 使用 [Bark](https://huggingface.co/suno/bark) 文本到语音生成
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>计算机视觉</summary>
|
||||
|
||||
- 使用 [SAM](https://huggingface.co/facebook/sam-vit-base) 自动生成掩码
|
||||
- 使用 [DepthPro](https://huggingface.co/apple/DepthPro-hf) 进行深度估计
|
||||
- 使用 [DINO v2](https://huggingface.co/facebook/dinov2-base) 进行图像分类
|
||||
- 使用 [SuperPoint](https://huggingface.co/magic-leap-community/superpoint) 进行关键点检测
|
||||
- 使用 [SuperGlue](https://huggingface.co/magic-leap-community/superglue_outdoor) 进行关键点匹配
|
||||
- 使用 [RT-DETRv2](https://huggingface.co/PekingU/rtdetr_v2_r50vd) 进行目标检测
|
||||
- 使用 [VitPose](https://huggingface.co/usyd-community/vitpose-base-simple) 进行姿态估计
|
||||
- 使用 [OneFormer](https://huggingface.co/shi-labs/oneformer_ade20k_swin_large) 进行通用分割
|
||||
- 使用 [VideoMAE](https://huggingface.co/MCG-NJU/videomae-large) 进行视频分类
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>多模态</summary>
|
||||
|
||||
- 使用 [Qwen2-Audio](https://huggingface.co/Qwen/Qwen2-Audio-7B) 实现音频或文本到文本
|
||||
- 使用 [LayoutLMv3](https://huggingface.co/microsoft/layoutlmv3-base) 进行文档问答
|
||||
- 使用 [Qwen-VL](https://huggingface.co/Qwen/Qwen2.5-VL-3B-Instruct) 实现图像或文本到文本
|
||||
- 使用 [BLIP-2](https://huggingface.co/Salesforce/blip2-opt-2.7b) 进行图文描述
|
||||
- 使用 [GOT-OCR2](https://huggingface.co/stepfun-ai/GOT-OCR-2.0-hf) 进行基于 OCR 的文档理解
|
||||
- 使用 [TAPAS](https://huggingface.co/google/tapas-base) 进行表格问答
|
||||
- 使用 [Emu3](https://huggingface.co/BAAI/Emu3-Gen) 进行统一的多模态理解与生成
|
||||
- 使用 [Llava-OneVision](https://huggingface.co/llava-hf/llava-onevision-qwen2-0.5b-ov-hf) 视觉到文本
|
||||
- 使用 [Llava](https://huggingface.co/llava-hf/llava-1.5-7b-hf) 进行视觉问答
|
||||
- 使用 [Kosmos-2](https://huggingface.co/microsoft/kosmos-2-patch14-224) 进行视觉指代表达分割
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>NLP</summary>
|
||||
|
||||
- 使用 [ModernBERT](https://huggingface.co/answerdotai/ModernBERT-base) 进行掩码词填充
|
||||
- 使用 [Gemma](https://huggingface.co/google/gemma-2-2b) 进行命名实体识别(NER)
|
||||
- 使用 [Mixtral](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1) 进行问答
|
||||
- 使用 [BART](https://huggingface.co/facebook/bart-large-cnn) 进行摘要
|
||||
- 使用 [T5](https://huggingface.co/google-t5/t5-base) 进行翻译
|
||||
- 使用 [Llama](https://huggingface.co/meta-llama/Llama-3.2-1B) 进行文本生成
|
||||
- 使用 [Qwen](https://huggingface.co/Qwen/Qwen2.5-0.5B) 进行文本分类
|
||||
|
||||
</details>
|
||||
|
||||
## 引用
|
||||
|
||||
|
||||
@ -14,43 +14,6 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
-->
|
||||
|
||||
<!---
|
||||
A useful guide for English-Traditional Chinese translation of Hugging Face documentation
|
||||
- Add space around English words and numbers when they appear between Chinese characters. E.g., 共 100 多種語言; 使用 transformers 函式庫。
|
||||
- Use square quotes, e.g.,「引用」
|
||||
- Some of terms in the file can be found at National Academy for Educational Research (https://terms.naer.edu.tw/), an official website providing bilingual translations between English and Traditional Chinese.
|
||||
|
||||
Dictionary
|
||||
|
||||
API: API (不翻譯)
|
||||
add: 加入
|
||||
checkpoint: 檢查點
|
||||
code: 程式碼
|
||||
community: 社群
|
||||
confidence: 信賴度
|
||||
dataset: 資料集
|
||||
documentation: 文件
|
||||
example: 基本翻譯為「範例」,或依語意翻為「例子」
|
||||
finetune: 微調
|
||||
Hugging Face: Hugging Face(不翻譯)
|
||||
implementation: 實作
|
||||
inference: 推論
|
||||
library: 函式庫
|
||||
module: 模組
|
||||
NLP/Natural Language Processing: 以 NLP 出現時不翻譯,以 Natural Language Processing 出現時翻譯為自然語言處理
|
||||
online demos: 線上Demo
|
||||
pipeline: pipeline(不翻譯)
|
||||
pretrained/pretrain: 預訓練
|
||||
Python data structures (e.g., list, set, dict): 翻譯為串列,集合,字典,並用括號標註原英文
|
||||
repository: repository(不翻譯)
|
||||
summary: 概覽
|
||||
token-: token-(不翻譯)
|
||||
Trainer: Trainer(不翻譯)
|
||||
transformer: transformer(不翻譯)
|
||||
tutorial: 教學
|
||||
user: 使用者
|
||||
-->
|
||||
|
||||
<p align="center">
|
||||
<picture>
|
||||
<source media="(prefers-color-scheme: dark)" srcset="https://huggingface.co/datasets/huggingface/documentation-images/raw/main/transformers-logo-dark.svg">
|
||||
@ -62,6 +25,7 @@ user: 使用者
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://huggingface.com/models"><img alt="Checkpoints on Hub" src="https://img.shields.io/endpoint?url=https://huggingface.co/api/shields/models&color=brightgreen"></a>
|
||||
<a href="https://circleci.com/gh/huggingface/transformers"><img alt="Build" src="https://img.shields.io/circleci/build/github/huggingface/transformers/main"></a>
|
||||
<a href="https://github.com/huggingface/transformers/blob/main/LICENSE"><img alt="GitHub" src="https://img.shields.io/github/license/huggingface/transformers.svg?color=blue"></a>
|
||||
<a href="https://huggingface.co/docs/transformers/index"><img alt="Documentation" src="https://img.shields.io/website/http/huggingface.co/docs/transformers/index.svg?down_color=red&down_message=offline&up_message=online"></a>
|
||||
@ -72,7 +36,7 @@ user: 使用者
|
||||
|
||||
<h4 align="center">
|
||||
<p>
|
||||
<a href="https://github.com/huggingface/transformers/">English</a> |
|
||||
<a href="https://github.com/huggingface/transformers/blob/main/README.md">English</a> |
|
||||
<a href="https://github.com/huggingface/transformers/blob/main/i18n/README_zh-hans.md">简体中文</a> |
|
||||
<b>繁體中文</b> |
|
||||
<a href="https://github.com/huggingface/transformers/blob/main/i18n/README_ko.md">한국어</a> |
|
||||
@ -80,7 +44,7 @@ user: 使用者
|
||||
<a href="https://github.com/huggingface/transformers/blob/main/i18n/README_ja.md">日本語</a> |
|
||||
<a href="https://github.com/huggingface/transformers/blob/main/i18n/README_hd.md">हिन्दी</a> |
|
||||
<a href="https://github.com/huggingface/transformers/blob/main/i18n/README_ru.md">Русский</a> |
|
||||
<a href="https://github.com/huggingface/transformers/blob/main/i18n/README_pt-br.md">Рortuguês</a> |
|
||||
<a href="https://github.com/huggingface/transformers/blob/main/i18n/README_pt-br.md">Português</a> |
|
||||
<a href="https://github.com/huggingface/transformers/blob/main/i18n/README_te.md">తెలుగు</a> |
|
||||
<a href="https://github.com/huggingface/transformers/blob/main/i18n/README_fr.md">Français</a> |
|
||||
<a href="https://github.com/huggingface/transformers/blob/main/i18n/README_de.md">Deutsch</a> |
|
||||
@ -93,186 +57,261 @@ user: 使用者
|
||||
</h4>
|
||||
|
||||
<h3 align="center">
|
||||
<p>為 Jax、PyTorch 以及 TensorFlow 打造的先進自然語言處理函式庫</p>
|
||||
<p>最先進的預訓練模型,專為推理與訓練而生</p>
|
||||
</h3>
|
||||
|
||||
<h3 align="center">
|
||||
<a href="https://hf.co/course"><img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/course_banner.png"></a>
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/transformers_as_a_model_definition.png"/>
|
||||
</h3>
|
||||
|
||||
🤗 Transformers 提供了數以千計的預訓練模型,支援 100 多種語言的文本分類、資訊擷取、問答、摘要、翻譯、文本生成。它的宗旨是讓最先進的 NLP 技術人人易用。
|
||||
Transformers 是一個為最先進的機器學習模型(涵蓋文字、電腦視覺、音訊、影片及多模態)提供推理和訓練支援的模型定義框架。
|
||||
|
||||
🤗 Transformers 提供了便於快速下載和使用的API,讓你可以將預訓練模型用在給定文本、在你的資料集上微調然後經由 [model hub](https://huggingface.co/models) 與社群共享。同時,每個定義的 Python 模組架構均完全獨立,方便修改和快速研究實驗。
|
||||
它將模型定義集中化,使得該定義在整個生態系中能夠達成共識。`transformers` 是貫穿各個框架的樞紐:如果一個模型定義受到支援,它將與大多數訓練框架(如 Axolotl、Unsloth、DeepSpeed、FSDP、PyTorch-Lightning 等)、推理引擎(如 vLLM、SGLang、TGI 等)以及利用 `transformers` 模型定義的周邊建模函式庫(如 llama.cpp、mlx 等)相容。
|
||||
|
||||
🤗 Transformers 支援三個最熱門的深度學習函式庫: [Jax](https://jax.readthedocs.io/en/latest/), [PyTorch](https://pytorch.org/) 以及 [TensorFlow](https://www.tensorflow.org/) — 並與之完美整合。你可以直接使用其中一個框架訓練你的模型,然後用另一個載入和推論。
|
||||
我們致力於支援最新的頂尖模型,並透過使其模型定義變得簡單、可客製化且高效,來普及它們的應用。
|
||||
|
||||
## 線上Demo
|
||||
在 [Hugging Face Hub](https://huggingface.com/models) 上,有超過 100 萬個 Transformers [模型檢查點](https://huggingface.co/models?library=transformers&sort=trending) 供您使用。
|
||||
|
||||
你可以直接在 [model hub](https://huggingface.co/models) 上測試大多數的模型。我們也提供了 [私有模型託管、模型版本管理以及推論API](https://huggingface.co/pricing)。
|
||||
|
||||
這裡是一些範例:
|
||||
- [用 BERT 做遮蓋填詞](https://huggingface.co/google-bert/bert-base-uncased?text=Paris+is+the+%5BMASK%5D+of+France)
|
||||
- [用 Electra 做專有名詞辨識](https://huggingface.co/dbmdz/electra-large-discriminator-finetuned-conll03-english?text=My+name+is+Sarah+and+I+live+in+London+city)
|
||||
- [用 GPT-2 做文本生成](https://huggingface.co/openai-community/gpt2?text=A+long+time+ago%2C+)
|
||||
- [用 RoBERTa 做自然語言推論](https://huggingface.co/FacebookAI/roberta-large-mnli?text=The+dog+was+lost.+Nobody+lost+any+animal)
|
||||
- [用 BART 做文本摘要](https://huggingface.co/facebook/bart-large-cnn?text=The+tower+is+324+metres+%281%2C063+ft%29+tall%2C+about+the+same+height+as+an+81-storey+building%2C+and+the+tallest+structure+in+Paris.+Its+base+is+square%2C+measuring+125+metres+%28410+ft%29+on+each+side.+During+its+construction%2C+the+Eiffel+Tower+surpassed+the+Washington+Monument+to+become+the+tallest+man-made+structure+in+the+world%2C+a+title+it+held+for+41+years+until+the+Chrysler+Building+in+New+York+City+was+finished+in+1930.+It+was+the+first+structure+to+reach+a+height+of+300+metres.+Due+to+the+addition+of+a+broadcasting+aerial+at+the+top+of+the+tower+in+1957%2C+it+is+now+taller+than+the+Chrysler+Building+by+5.2+metres+%2817+ft%29.+Excluding+transmitters%2C+the+Eiffel+Tower+is+the+second+tallest+free-standing+structure+in+France+after+the+Millau+Viaduct)
|
||||
- [用 DistilBERT 做問答](https://huggingface.co/distilbert/distilbert-base-uncased-distilled-squad?text=Which+name+is+also+used+to+describe+the+Amazon+rainforest+in+English%3F&context=The+Amazon+rainforest+%28Portuguese%3A+Floresta+Amaz%C3%B4nica+or+Amaz%C3%B4nia%3B+Spanish%3A+Selva+Amaz%C3%B3nica%2C+Amazon%C3%ADa+or+usually+Amazonia%3B+French%3A+For%C3%AAt+amazonienne%3B+Dutch%3A+Amazoneregenwoud%29%2C+also+known+in+English+as+Amazonia+or+the+Amazon+Jungle%2C+is+a+moist+broadleaf+forest+that+covers+most+of+the+Amazon+basin+of+South+America.+This+basin+encompasses+7%2C000%2C000+square+kilometres+%282%2C700%2C000+sq+mi%29%2C+of+which+5%2C500%2C000+square+kilometres+%282%2C100%2C000+sq+mi%29+are+covered+by+the+rainforest.+This+region+includes+territory+belonging+to+nine+nations.+The+majority+of+the+forest+is+contained+within+Brazil%2C+with+60%25+of+the+rainforest%2C+followed+by+Peru+with+13%25%2C+Colombia+with+10%25%2C+and+with+minor+amounts+in+Venezuela%2C+Ecuador%2C+Bolivia%2C+Guyana%2C+Suriname+and+French+Guiana.+States+or+departments+in+four+nations+contain+%22Amazonas%22+in+their+names.+The+Amazon+represents+over+half+of+the+planet%27s+remaining+rainforests%2C+and+comprises+the+largest+and+most+biodiverse+tract+of+tropical+rainforest+in+the+world%2C+with+an+estimated+390+billion+individual+trees+divided+into+16%2C000+species)
|
||||
- [用 T5 做翻譯](https://huggingface.co/google-t5/t5-base?text=My+name+is+Wolfgang+and+I+live+in+Berlin)
|
||||
|
||||
**[Write With Transformer](https://transformer.huggingface.co)**,由 Hugging Face 團隊所打造,是一個文本生成的官方 demo。
|
||||
|
||||
## 如果你在尋找由 Hugging Face 團隊所提供的客製化支援服務
|
||||
|
||||
<a target="_blank" href="https://huggingface.co/support">
|
||||
<img alt="HuggingFace Expert Acceleration Program" src="https://huggingface.co/front/thumbnails/support.png" style="max-width: 600px; border: 1px solid #eee; border-radius: 4px; box-shadow: 0 1px 2px 0 rgba(0, 0, 0, 0.05);">
|
||||
</a><br>
|
||||
|
||||
## 快速上手
|
||||
|
||||
我們為快速使用模型提供了 `pipeline` API。 Pipeline 包含了預訓練模型和對應的文本預處理。下面是一個快速使用 pipeline 去判斷正負面情緒的例子:
|
||||
|
||||
```python
|
||||
>>> from transformers import pipeline
|
||||
|
||||
# 使用情緒分析 pipeline
|
||||
>>> classifier = pipeline('sentiment-analysis')
|
||||
>>> classifier('We are very happy to introduce pipeline to the transformers repository.')
|
||||
[{'label': 'POSITIVE', 'score': 0.9996980428695679}]
|
||||
```
|
||||
|
||||
第二行程式碼下載並快取 pipeline 使用的預訓練模型,而第三行程式碼則在給定的文本上進行了評估。這裡的答案“正面” (positive) 具有 99.97% 的信賴度。
|
||||
|
||||
許多的 NLP 任務都有隨選即用的預訓練 `pipeline`。例如,我們可以輕鬆地從給定文本中擷取問題答案:
|
||||
|
||||
``` python
|
||||
>>> from transformers import pipeline
|
||||
|
||||
# 使用問答 pipeline
|
||||
>>> question_answerer = pipeline('question-answering')
|
||||
>>> question_answerer({
|
||||
... 'question': 'What is the name of the repository ?',
|
||||
... 'context': 'Pipeline has been included in the huggingface/transformers repository'
|
||||
... })
|
||||
{'score': 0.30970096588134766, 'start': 34, 'end': 58, 'answer': 'huggingface/transformers'}
|
||||
|
||||
```
|
||||
|
||||
除了提供問題解答,預訓練模型還提供了對應的信賴度分數以及解答在 tokenized 後的文本中開始和結束的位置。你可以從[這個教學](https://huggingface.co/docs/transformers/task_summary)了解更多 `pipeline` API支援的任務。
|
||||
|
||||
要在你的任務中下載和使用任何預訓練模型很簡單,只需三行程式碼。這裡是 PyTorch 版的範例:
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, AutoModel
|
||||
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
|
||||
>>> model = AutoModel.from_pretrained("google-bert/bert-base-uncased")
|
||||
|
||||
>>> inputs = tokenizer("Hello world!", return_tensors="pt")
|
||||
>>> outputs = model(**inputs)
|
||||
```
|
||||
這裡是對應的 TensorFlow 程式碼:
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, TFAutoModel
|
||||
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
|
||||
>>> model = TFAutoModel.from_pretrained("google-bert/bert-base-uncased")
|
||||
|
||||
>>> inputs = tokenizer("Hello world!", return_tensors="tf")
|
||||
>>> outputs = model(**inputs)
|
||||
```
|
||||
|
||||
Tokenizer 為所有的預訓練模型提供了預處理,並可以直接轉換單一字串(比如上面的例子)或串列 (list)。它會輸出一個的字典 (dict) 讓你可以在下游程式碼裡使用或直接藉由 `**` 運算式傳給模型。
|
||||
|
||||
模型本身是一個常規的 [Pytorch `nn.Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) 或 [TensorFlow `tf.keras.Model`](https://www.tensorflow.org/api_docs/python/tf/keras/Model)(取決於你的後端),可依常規方式使用。 [這個教學](https://huggingface.co/transformers/training.html)解釋了如何將這樣的模型整合到一般的 PyTorch 或 TensorFlow 訓練迴圈中,或是如何使用我們的 `Trainer` API 在一個新的資料集上快速進行微調。
|
||||
|
||||
## 為什麼要用 transformers?
|
||||
|
||||
1. 便於使用的先進模型:
|
||||
- NLU 和 NLG 上性能卓越
|
||||
- 對教學和實作友好且低門檻
|
||||
- 高度抽象,使用者只須學習 3 個類別
|
||||
- 對所有模型使用的制式化API
|
||||
|
||||
1. 更低的運算成本,更少的碳排放:
|
||||
- 研究人員可以分享已訓練的模型而非每次從頭開始訓練
|
||||
- 工程師可以減少計算時間以及生產成本
|
||||
- 數十種模型架構、兩千多個預訓練模型、100多種語言支援
|
||||
|
||||
1. 對於模型生命週期的每一個部分都面面俱到:
|
||||
- 訓練先進的模型,只需 3 行程式碼
|
||||
- 模型可以在不同深度學習框架之間任意轉換
|
||||
- 為訓練、評估和生產選擇最適合的框架,並完美銜接
|
||||
|
||||
1. 為你的需求輕鬆客製化專屬模型和範例:
|
||||
- 我們為每種模型架構提供了多個範例來重現原論文結果
|
||||
- 一致的模型內部架構
|
||||
- 模型檔案可單獨使用,便於修改和快速實驗
|
||||
|
||||
## 什麼情況下我不該用 transformers?
|
||||
|
||||
- 本函式庫並不是模組化的神經網絡工具箱。模型文件中的程式碼並未做額外的抽象封裝,以便研究人員快速地翻閱及修改程式碼,而不會深陷複雜的類別包裝之中。
|
||||
- `Trainer` API 並非相容任何模型,它只為本函式庫中的模型最佳化。對於一般的機器學習用途,請使用其他函式庫。
|
||||
- 儘管我們已盡力而為,[examples 目錄](https://github.com/huggingface/transformers/tree/main/examples)中的腳本也僅為範例而已。對於特定問題,它們並不一定隨選即用,可能需要修改幾行程式碼以符合需求。
|
||||
立即探索 [Hub](https://huggingface.com/),尋找合適的模型,並使用 Transformers 幫助您快速上手。
|
||||
|
||||
## 安裝
|
||||
|
||||
### 使用 pip
|
||||
Transformers 支援 Python 3.9+ 和 [PyTorch](https://pytorch.org/get-started/locally/) 2.1+。
|
||||
|
||||
這個 Repository 已在 Python 3.9+、Flax 0.4.1+、PyTorch 2.1+ 和 TensorFlow 2.6+ 下經過測試。
|
||||
使用 [venv](https://docs.python.org/3/library/venv.html) 或基於 Rust 的高速 Python 套件及專案管理器 [uv](https://docs.astral.sh/uv/) 來建立並啟用虛擬環境。
|
||||
|
||||
你可以在[虛擬環境](https://docs.python.org/3/library/venv.html)中安裝 🤗 Transformers。如果你還不熟悉 Python 的虛擬環境,請閱此[使用者指引](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/)。
|
||||
|
||||
首先,用你打算使用的版本的 Python 創建一個虛擬環境並進入。
|
||||
|
||||
然後,你需要安裝 Flax、PyTorch 或 TensorFlow 其中之一。對於該如何在你使用的平台上安裝這些框架,請參閱 [TensorFlow 安裝頁面](https://www.tensorflow.org/install/), [PyTorch 安裝頁面](https://pytorch.org/get-started/locally/#start-locally) 或 [Flax 安裝頁面](https://github.com/google/flax#quick-install)。
|
||||
|
||||
當其中一個後端安裝成功後,🤗 Transformers 可依此安裝:
|
||||
|
||||
```bash
|
||||
pip install transformers
|
||||
```py
|
||||
# venv
|
||||
python -m venv .my-env
|
||||
source .my-env/bin/activate
|
||||
# uv
|
||||
uv venv .my-env
|
||||
source .my-env/bin/activate
|
||||
```
|
||||
|
||||
如果你想要試試範例或者想在正式發布前使用最新開發中的程式碼,你必須[從原始碼安裝](https://huggingface.co/docs/transformers/installation#installing-from-source)。
|
||||
在您的虛擬環境中安裝 Transformers。
|
||||
|
||||
### 使用 conda
|
||||
```py
|
||||
# pip
|
||||
pip install "transformers[torch]"
|
||||
|
||||
🤗 Transformers 可以藉由 conda 依此安裝:
|
||||
|
||||
```shell script
|
||||
conda install conda-forge::transformers
|
||||
# uv
|
||||
uv pip install "transformers[torch]"
|
||||
```
|
||||
|
||||
> **_筆記:_** 從 `huggingface` 頻道安裝 `transformers` 已被淘汰。
|
||||
如果您想使用函式庫的最新變更或有興趣參與貢獻,可以從原始碼安裝 Transformers。然而,*最新*版本可能不穩定。如果您遇到任何錯誤,歡迎隨時提交一個 [issue](https://github.com/huggingface/transformers/issues)。
|
||||
|
||||
要藉由 conda 安裝 Flax、PyTorch 或 TensorFlow 其中之一,請參閱它們各自安裝頁面的說明。
|
||||
```shell
|
||||
git clone https://github.com/huggingface/transformers.git
|
||||
cd transformers
|
||||
|
||||
## 模型架構
|
||||
# pip
|
||||
pip install '.[torch]'
|
||||
|
||||
**🤗 Transformers 支援的[所有的模型檢查點](https://huggingface.co/models)**,由[使用者](https://huggingface.co/users)和[組織](https://huggingface.co/organizations)上傳,均與 huggingface.co [model hub](https://huggingface.co) 完美結合。
|
||||
# uv
|
||||
uv pip install '.[torch]'
|
||||
```
|
||||
|
||||
目前的檢查點數量: 
|
||||
## 快速入門
|
||||
|
||||
🤗 Transformers 目前支援以下的架構: 模型概覽請參閱[這裡](https://huggingface.co/docs/transformers/model_summary).
|
||||
透過 [Pipeline](https://huggingface.co/docs/transformers/pipeline_tutorial) API 快速開始使用 Transformers。`Pipeline` 是一個高階的推理類別,支援文字、音訊、視覺和多模態任務。它負責處理輸入資料的預處理,並回傳適當的輸出。
|
||||
|
||||
要檢查某個模型是否已有 Flax、PyTorch 或 TensorFlow 的實作,或其是否在🤗 Tokenizers 函式庫中有對應的 tokenizer,敬請參閱[此表](https://huggingface.co/docs/transformers/index#supported-frameworks)。
|
||||
實例化一個 pipeline 並指定用於文字生成的模型。該模型會被下載並快取,方便您之後輕鬆複用。最後,傳入一些文字來提示模型。
|
||||
|
||||
這些實作均已於多個資料集測試(請參閱範例腳本)並應與原版實作表現相當。你可以在範例文件的[此節](https://huggingface.co/docs/transformers/examples)中了解實作的細節。
|
||||
```py
|
||||
from transformers import pipeline
|
||||
|
||||
pipeline = pipeline(task="text-generation", model="Qwen/Qwen2.5-1.5B")
|
||||
pipeline("the secret to baking a really good cake is ")
|
||||
[{'generated_text': 'the secret to baking a really good cake is 1) to use the right ingredients and 2) to follow the recipe exactly. the recipe for the cake is as follows: 1 cup of sugar, 1 cup of flour, 1 cup of milk, 1 cup of butter, 1 cup of eggs, 1 cup of chocolate chips. if you want to make 2 cakes, how much sugar do you need? To make 2 cakes, you will need 2 cups of sugar.'}]
|
||||
```
|
||||
|
||||
## 了解更多
|
||||
與模型進行聊天,使用模式是相同的。唯一的區別是您需要建構一個您與系統之間的聊天歷史(作為 `Pipeline` 的輸入)。
|
||||
|
||||
| 章節 | 描述 |
|
||||
|-|-|
|
||||
| [文件](https://huggingface.co/transformers/) | 完整的 API 文件和教學 |
|
||||
| [任務概覽](https://huggingface.co/docs/transformers/task_summary) | 🤗 Transformers 支援的任務 |
|
||||
| [預處理教學](https://huggingface.co/docs/transformers/preprocessing) | 使用 `Tokenizer` 來為模型準備資料 |
|
||||
| [訓練和微調](https://huggingface.co/docs/transformers/training) | 使用 PyTorch/TensorFlow 的內建的訓練方式或於 `Trainer` API 中使用 🤗 Transformers 提供的模型 |
|
||||
| [快速上手:微調和範例腳本](https://github.com/huggingface/transformers/tree/main/examples) | 為各種任務提供的範例腳本 |
|
||||
| [模型分享和上傳](https://huggingface.co/docs/transformers/model_sharing) | 上傳並與社群分享你微調的模型 |
|
||||
| [遷移](https://huggingface.co/docs/transformers/migration) | 從 `pytorch-transformers` 或 `pytorch-pretrained-bert` 遷移到 🤗 Transformers |
|
||||
> [!TIP]
|
||||
> 你也可以直接在命令列中與模型聊天。
|
||||
> ```shell
|
||||
> transformers chat Qwen/Qwen2.5-0.5B-Instruct
|
||||
> ```
|
||||
|
||||
```py
|
||||
import torch
|
||||
from transformers import pipeline
|
||||
|
||||
chat = [
|
||||
{"role": "system", "content": "You are a sassy, wise-cracking robot as imagined by Hollywood circa 1986."},
|
||||
{"role": "user", "content": "Hey, can you tell me any fun things to do in New York?"}
|
||||
]
|
||||
|
||||
pipeline = pipeline(task="text-generation", model="meta-llama/Meta-Llama-3-8B-Instruct", dtype=torch.bfloat16, device_map="auto")
|
||||
response = pipeline(chat, max_new_tokens=512)
|
||||
print(response[0]["generated_text"][-1]["content"])
|
||||
```
|
||||
|
||||
展開下面的範例,查看 `Pipeline` 如何在不同模態和任務上運作。
|
||||
|
||||
<details>
|
||||
<summary>自動語音辨識</summary>
|
||||
|
||||
```py
|
||||
from transformers import pipeline
|
||||
|
||||
pipeline = pipeline(task="automatic-speech-recognition", model="openai/whisper-large-v3")
|
||||
pipeline("https://huggingface.co/datasets/Narsil/asr_dummy/resolve/main/mlk.flac")
|
||||
{'text': ' I have a dream that one day this nation will rise up and live out the true meaning of its creed.'}
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>影像分類</summary>
|
||||
|
||||
<h3 align="center">
|
||||
<a><img src="https://huggingface.co/datasets/Narsil/image_dummy/raw/main/parrots.png"></a>
|
||||
</h3>
|
||||
|
||||
```py
|
||||
from transformers import pipeline
|
||||
|
||||
pipeline = pipeline(task="image-classification", model="facebook/dinov2-small-imagenet1k-1-layer")
|
||||
pipeline("https://huggingface.co/datasets/Narsil/image_dummy/raw/main/parrots.png")
|
||||
[{'label': 'macaw', 'score': 0.997848391532898},
|
||||
{'label': 'sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita',
|
||||
'score': 0.0016551691805943847},
|
||||
{'label': 'lorikeet', 'score': 0.00018523589824326336},
|
||||
{'label': 'African grey, African gray, Psittacus erithacus',
|
||||
'score': 7.85409429227002e-05},
|
||||
{'label': 'quail', 'score': 5.502637941390276e-05}]
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>視覺問答</summary>
|
||||
|
||||
<h3 align="center">
|
||||
<a><img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/idefics-few-shot.jpg"></a>
|
||||
</h3>
|
||||
|
||||
```py
|
||||
from transformers import pipeline
|
||||
|
||||
pipeline = pipeline(task="visual-question-answering", model="Salesforce/blip-vqa-base")
|
||||
pipeline(
|
||||
image="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/idefics-few-shot.jpg",
|
||||
question="What is in the image?",
|
||||
)
|
||||
[{'answer': 'statue of liberty'}]
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
## 為什麼我應該使用 Transformers?
|
||||
|
||||
1. 易於使用的最先進模型:
|
||||
* 在自然語言理解與生成、電腦視覺、音訊、影片和多模態任務上表現卓越。
|
||||
* 為研究人員、工程師與開發者提供了低門檻的入門途徑。
|
||||
* 面向使用者的抽象層級少,只需學習三個核心類別。
|
||||
* 為所有預訓練模型提供了統一的 API 介面。
|
||||
|
||||
2. 更低的運算成本,更小的碳足跡:
|
||||
* 分享訓練好的模型,而不是從零開始訓練。
|
||||
* 減少運算時間和生產成本。
|
||||
* 擁有數十種模型架構和超過100萬個橫跨所有模態的預訓練檢查點。
|
||||
|
||||
3. 為模型的每個生命週期階段選擇合適的框架:
|
||||
* 僅用3行程式碼即可訓練最先進的模型。
|
||||
* 在PyTorch/JAX/TF2.0框架之間輕鬆切換單一模型。
|
||||
* 為訓練、評估和生產選擇最合適的框架。
|
||||
|
||||
4. 輕鬆根據您的需求客製化模型或範例:
|
||||
* 我們為每個架構提供了範例,以重現其原作者發表的結果。
|
||||
* 模型內部結構盡可能保持一致地暴露給使用者。
|
||||
* 模型檔案可以獨立於函式庫使用,便於快速實驗。
|
||||
|
||||
<a target="_blank" href="https://huggingface.co/enterprise">
|
||||
<img alt="Hugging Face Enterprise Hub" src="https://github.com/user-attachments/assets/247fb16d-d251-4583-96c4-d3d76dda4925">
|
||||
</a><br>
|
||||
|
||||
## 為什麼我不應該使用 Transformers?
|
||||
|
||||
- 本函式庫並非一個用於建構神經網路的模組化工具箱。模型檔案中的程式碼為了讓研究人員能快速在模型上迭代,而沒有進行過度的抽象重構,避免了深入額外的抽象層/檔案。
|
||||
- 訓練 API 針對 Transformers 提供的 PyTorch 模型進行了最佳化。對於通用的機器學習迴圈,您應該使用像 [Accelerate](https://huggingface.co/docs/accelerate) 這樣的其他函式庫。
|
||||
- [範例指令稿](https://github.com/huggingface/transformers/tree/main/examples)僅僅是*範例*。它們不一定能在您的特定用例上開箱即用,您可能需要修改程式碼才能使其正常運作。
|
||||
|
||||
## 100個使用 Transformers 的專案
|
||||
|
||||
Transformers 不僅僅是一個使用預訓練模型的工具包,它還是一個圍繞它和 Hugging Face Hub 建構的專案社群。我們希望 Transformers 能夠賦能開發者、研究人員、學生、教授、工程師以及其他任何人,去建構他們夢想中的專案。
|
||||
|
||||
為了慶祝 Transformers 獲得 10 萬顆星標,我們希望透過 [awesome-transformers](./awesome-transformers.md) 頁面來聚焦社群,該頁面列出了100個基於 Transformers 建構的精彩專案。
|
||||
|
||||
如果您擁有或使用一個您認為應該被列入其中的專案,請隨時提交 PR 將其加入!
|
||||
|
||||
## 範例模型
|
||||
|
||||
您可以在我們大多數模型的 [Hub 模型頁面](https://huggingface.co/models) 上直接進行測試。
|
||||
|
||||
展開下面的每個模態,查看一些用於不同用例的範例模型。
|
||||
|
||||
<details>
|
||||
<summary>音訊</summary>
|
||||
|
||||
- Audio classification with [Whisper](https://huggingface.co/openai/whisper-large-v3-turbo)
|
||||
- Automatic speech recognition with [Moonshine](https://huggingface.co/UsefulSensors/moonshine)
|
||||
- Keyword spotting with [Wav2Vec2](https://huggingface.co/superb/wav2vec2-base-superb-ks)
|
||||
- Speech to speech generation with [Moshi](https://huggingface.co/kyutai/moshiko-pytorch-bf16)
|
||||
- Text to audio with [MusicGen](https://huggingface.co/facebook/musicgen-large)
|
||||
- Text to speech with [Bark](https://huggingface.co/suno/bark)
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>電腦視覺</summary>
|
||||
|
||||
- Automatic mask generation with [SAM](https://huggingface.co/facebook/sam-vit-base)
|
||||
- Depth estimation with [DepthPro](https://huggingface.co/apple/DepthPro-hf)
|
||||
- Image classification with [DINO v2](https://huggingface.co/facebook/dinov2-base)
|
||||
- Keypoint detection with [SuperPoint](https://huggingface.co/magic-leap-community/superpoint)
|
||||
- Keypoint matching with [SuperGlue](https://huggingface.co/magic-leap-community/superglue_outdoor)
|
||||
- Object detection with [RT-DETRv2](https://huggingface.co/PekingU/rtdetr_v2_r50vd)
|
||||
- Pose Estimation with [VitPose](https://huggingface.co/usyd-community/vitpose-base-simple)
|
||||
- Universal segmentation with [OneFormer](https://huggingface.co/shi-labs/oneformer_ade20k_swin_large)
|
||||
- Video classification with [VideoMAE](https://huggingface.co/MCG-NJU/videomae-large)
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>多模態</summary>
|
||||
|
||||
- Audio or text to text with [Qwen2-Audio](https://huggingface.co/Qwen/Qwen2-Audio-7B)
|
||||
- Document question answering with [LayoutLMv3](https://huggingface.co/microsoft/layoutlmv3-base)
|
||||
- Image or text to text with [Qwen-VL](https://huggingface.co/Qwen/Qwen2.5-VL-3B-Instruct)
|
||||
- Image captioning [BLIP-2](https://huggingface.co/Salesforce/blip2-opt-2.7b)
|
||||
- OCR-based document understanding with [GOT-OCR2](https://huggingface.co/stepfun-ai/GOT-OCR-2.0-hf)
|
||||
- Table question answering with [TAPAS](https://huggingface.co/google/tapas-base)
|
||||
- Unified multimodal understanding and generation with [Emu3](https://huggingface.co/BAAI/Emu3-Gen)
|
||||
- Vision to text with [Llava-OneVision](https://huggingface.co/llava-hf/llava-onevision-qwen2-0.5b-ov-hf)
|
||||
- Visual question answering with [Llava](https://huggingface.co/llava-hf/llava-1.5-7b-hf)
|
||||
- Visual referring expression segmentation with [Kosmos-2](https://huggingface.co/microsoft/kosmos-2-patch14-224)
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>自然語言處理 (NLP)</summary>
|
||||
|
||||
- Masked word completion with [ModernBERT](https://huggingface.co/answerdotai/ModernBERT-base)
|
||||
- Named entity recognition with [Gemma](https://huggingface.co/google/gemma-2-2b)
|
||||
- Question answering with [Mixtral](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)
|
||||
- Summarization with [BART](https://huggingface.co/facebook/bart-large-cnn)
|
||||
- Translation with [T5](https://huggingface.co/google-t5/t5-base)
|
||||
- Text generation with [Llama](https://huggingface.co/meta-llama/Llama-3.2-1B)
|
||||
- Text classification with [Qwen](https://huggingface.co/Qwen/Qwen2.5-0.5B)
|
||||
|
||||
</details>
|
||||
|
||||
## 引用
|
||||
|
||||
我們已將此函式庫的[論文](https://www.aclweb.org/anthology/2020.emnlp-demos.6/)正式發表。如果你使用了 🤗 Transformers 函式庫,可以引用:
|
||||
現在我們有一篇可供您引用的關於 🤗 Transformers 函式庫的 [論文](https://www.aclweb.org/anthology/2020.emnlp-demos.6/):
|
||||
```bibtex
|
||||
@inproceedings{wolf-etal-2020-transformers,
|
||||
title = "Transformers: State-of-the-Art Natural Language Processing",
|
||||
@ -285,4 +324,4 @@ conda install conda-forge::transformers
|
||||
url = "https://www.aclweb.org/anthology/2020.emnlp-demos.6",
|
||||
pages = "38--45"
|
||||
}
|
||||
```
|
||||
```
|
||||
4
setup.py
4
setup.py
@ -137,8 +137,8 @@ _deps = [
|
||||
"psutil",
|
||||
"pyyaml>=5.1",
|
||||
"pydantic>=2",
|
||||
"pytest>=7.2.0",
|
||||
"pytest-asyncio",
|
||||
"pytest>=7.2.0,<9.0.0",
|
||||
"pytest-asyncio>=1.2.0",
|
||||
"pytest-rerunfailures<16.0",
|
||||
"pytest-timeout",
|
||||
"pytest-xdist",
|
||||
|
||||
@ -302,10 +302,9 @@ class PreTrainedConfig(PushToHubMixin):
|
||||
self.sep_token_id = sep_token_id
|
||||
self.decoder_start_token_id = decoder_start_token_id
|
||||
|
||||
# Retrocompatibility: Parameters for sequence generation. While we will keep the ability to load these
|
||||
# parameters, saving them will be deprecated. In a distant future, we won't need to load them.
|
||||
for parameter_name, default_value in self._get_global_generation_defaults().items():
|
||||
setattr(self, parameter_name, kwargs.pop(parameter_name, default_value))
|
||||
# Parameters for sequence generation saved in the config are popped instead of loading them.
|
||||
for parameter_name in self._get_global_generation_defaults().keys():
|
||||
kwargs.pop(parameter_name, None)
|
||||
|
||||
# Name or path to the pretrained checkpoint
|
||||
self._name_or_path = str(kwargs.pop("name_or_path", ""))
|
||||
@ -445,14 +444,11 @@ class PreTrainedConfig(PushToHubMixin):
|
||||
|
||||
non_default_generation_parameters = self._get_non_default_generation_parameters()
|
||||
if len(non_default_generation_parameters) > 0:
|
||||
# TODO (joao): this should be an exception if the user has modified the loaded config. See #33886
|
||||
warnings.warn(
|
||||
raise ValueError(
|
||||
"Some non-default generation parameters are set in the model config. These should go into either a) "
|
||||
"`model.generation_config` (as opposed to `model.config`); OR b) a GenerationConfig file "
|
||||
"(https://huggingface.co/docs/transformers/generation_strategies#save-a-custom-decoding-strategy-with-your-model)."
|
||||
"This warning will become an exception in the future."
|
||||
f"\nNon-default generation parameters: {str(non_default_generation_parameters)}",
|
||||
UserWarning,
|
||||
)
|
||||
|
||||
os.makedirs(save_directory, exist_ok=True)
|
||||
@ -876,7 +872,7 @@ class PreTrainedConfig(PushToHubMixin):
|
||||
if hasattr(self, "quantization_config"):
|
||||
serializable_config_dict["quantization_config"] = (
|
||||
self.quantization_config.to_dict()
|
||||
if not isinstance(self.quantization_config, dict)
|
||||
if not isinstance(self.quantization_config, dict) and self.quantization_config is not None
|
||||
else self.quantization_config
|
||||
)
|
||||
self.dict_dtype_to_str(serializable_config_dict)
|
||||
@ -910,7 +906,7 @@ class PreTrainedConfig(PushToHubMixin):
|
||||
if hasattr(self, "quantization_config"):
|
||||
output["quantization_config"] = (
|
||||
self.quantization_config.to_dict()
|
||||
if not isinstance(self.quantization_config, dict)
|
||||
if not isinstance(self.quantization_config, dict) and self.quantization_config is not None
|
||||
else self.quantization_config
|
||||
)
|
||||
self.dict_dtype_to_str(output)
|
||||
@ -1101,40 +1097,18 @@ class PreTrainedConfig(PushToHubMixin):
|
||||
non_default_generation_parameters = {}
|
||||
decoder_attribute_name = None
|
||||
|
||||
# Some composite models don't have a default config, use their decoder config as a fallback for default values
|
||||
# If no known pattern is matched, then `default_config = None` -> check against the global generation defaults
|
||||
if not self.has_no_defaults_at_init:
|
||||
default_config = self.__class__()
|
||||
else:
|
||||
decoder_config = self.get_text_config(decoder=True)
|
||||
if decoder_config is not self:
|
||||
default_config = decoder_config.__class__()
|
||||
else:
|
||||
default_config = None
|
||||
|
||||
# If it is a composite model, we want to check the subconfig that will be used for generation
|
||||
self_decoder_config = self if decoder_attribute_name is None else getattr(self, decoder_attribute_name)
|
||||
|
||||
for parameter_name, default_global_value in self._get_global_generation_defaults().items():
|
||||
if hasattr(self_decoder_config, parameter_name):
|
||||
is_default_in_config = is_default_generation_value = None
|
||||
parameter_value = getattr(self_decoder_config, parameter_name)
|
||||
# Three cases in which is okay for the model config to hold generation config parameters:
|
||||
parameter_value = getattr(self_decoder_config, parameter_name, None)
|
||||
# Two cases in which is okay for the model config to hold generation config parameters:
|
||||
# 1. The parameter is set to `None`, effectively delegating its value to the generation config
|
||||
if parameter_value is None:
|
||||
# 2. The parameter is set the global generation defaults
|
||||
if parameter_value is None or parameter_value == default_global_value:
|
||||
continue
|
||||
# 2. If we have a default config, then the instance should hold the same generation defaults
|
||||
if default_config is not None:
|
||||
is_default_in_config = parameter_value == getattr(default_config, parameter_name)
|
||||
# 3. if we don't have a default config, then the instance should hold the global generation defaults
|
||||
else:
|
||||
is_default_generation_value = parameter_value == default_global_value
|
||||
|
||||
is_non_default = (is_default_in_config is False) or (
|
||||
is_default_in_config is None and is_default_generation_value is False
|
||||
)
|
||||
if is_non_default:
|
||||
non_default_generation_parameters[parameter_name] = getattr(self_decoder_config, parameter_name)
|
||||
non_default_generation_parameters[parameter_name] = getattr(self_decoder_config, parameter_name)
|
||||
|
||||
return non_default_generation_parameters
|
||||
|
||||
|
||||
136
src/transformers/conversion_mapping.py
Normal file
136
src/transformers/conversion_mapping.py
Normal file
@ -0,0 +1,136 @@
|
||||
# coding=utf-8
|
||||
# Copyright (C) 2025 the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from copy import deepcopy
|
||||
|
||||
from .core_model_loading import Concatenate, MergeModulelist, WeightConverter
|
||||
from .utils import is_torch_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
def _build_checkpoint_conversion_mapping():
|
||||
mapping = {
|
||||
"mixtral": [
|
||||
WeightConverter(
|
||||
source_keys=[
|
||||
"block_sparse_moe.experts.*.w1.weight",
|
||||
"block_sparse_moe.experts.*.w3.weight",
|
||||
], # you give me a list of 2 keys, I collect a list of a list of tensors
|
||||
target_keys="mlp.experts.gate_up_proj", # target key gets the list of two tensors
|
||||
operations=[
|
||||
MergeModulelist(
|
||||
dim=0
|
||||
), # each process has two lists of tensors, we cat each list. -> we end up with 2 tensors
|
||||
Concatenate(dim=1), # each process has 2 tensors, gate and up, we concat them into gate_up
|
||||
], # we want the loading to add this shard operation here. Though we can't shard after concats and merge, needs to be first
|
||||
),
|
||||
WeightConverter(
|
||||
source_keys=[
|
||||
"block_sparse_moe.experts.*.w2.weight",
|
||||
],
|
||||
target_keys="mlp.experts.down_proj", # target key gets the list of two tensors
|
||||
operations=[
|
||||
MergeModulelist(
|
||||
dim=0
|
||||
), # each process has two lists of tensors, we cat each list. -> we end up with 2 tensors
|
||||
], # we want the loading to add this shard operation here. Though we can't shard after concats and merge, needs to be first
|
||||
),
|
||||
# WeightConverter(
|
||||
# ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"],
|
||||
# "self_attn.qkv_proj",
|
||||
# operations=[Concatenate(dim=0)], # more like stack?
|
||||
# ),
|
||||
WeightConverter("*.block_sparse_moe.", "*.mlp."),
|
||||
],
|
||||
"qwen2_moe": [
|
||||
WeightConverter(
|
||||
source_keys=[
|
||||
"mlp.experts.*.gate_proj.weight",
|
||||
"mlp.experts.*.up_proj.weight",
|
||||
],
|
||||
target_keys="mlp.experts.gate_up_proj",
|
||||
operations=[MergeModulelist(dim=0), Concatenate(dim=1)],
|
||||
),
|
||||
WeightConverter(
|
||||
source_keys=["mlp.experts.*.down_proj.weight"],
|
||||
target_keys="mlp.experts.down_proj",
|
||||
operations=[MergeModulelist(dim=0)],
|
||||
),
|
||||
],
|
||||
"legacy": [
|
||||
WeightConverter(
|
||||
source_keys="LayerNorm.gamma",
|
||||
target_keys="LayerNorm.weight",
|
||||
),
|
||||
WeightConverter(
|
||||
source_keys="LayerNorm.beta",
|
||||
target_keys="LayerNorm.bias",
|
||||
),
|
||||
],
|
||||
}
|
||||
if hasattr(torch.nn.utils.parametrizations, "weight_norm"):
|
||||
mapping["legacy"] += [
|
||||
WeightConverter(
|
||||
source_keys="weight_g",
|
||||
target_keys="parametrizations.weight.original0",
|
||||
),
|
||||
WeightConverter(
|
||||
source_keys="weight_v",
|
||||
target_keys="parametrizations.weight.original1",
|
||||
),
|
||||
]
|
||||
else:
|
||||
mapping["legacy"] += [
|
||||
WeightConverter(
|
||||
source_keys="parametrizations.weight.original0",
|
||||
target_keys="weight_g",
|
||||
),
|
||||
WeightConverter(
|
||||
source_keys="parametrizations.weight.original1",
|
||||
target_keys="weight_v",
|
||||
),
|
||||
]
|
||||
|
||||
mapping["phimoe"] = mapping["mixtral"].copy()
|
||||
mapping["deepseek_v2"] = mapping["qwen2_moe"].copy()
|
||||
mapping["deepseek_v3"] = mapping["qwen2_moe"].copy()
|
||||
mapping["dot1"] = mapping["qwen2_moe"].copy()
|
||||
mapping["ernie_4_5_moe"] = mapping["qwen2_moe"].copy()
|
||||
mapping["glm4_moe"] = mapping["qwen2_moe"].copy()
|
||||
mapping["glm4v_moe"] = mapping["qwen2_moe"].copy()
|
||||
mapping["jamba"] = mapping["qwen2_moe"].copy()
|
||||
mapping["lfm2_moe"] = mapping["mixtral"].copy()
|
||||
mapping["long_cat_flash"] = mapping["qwen2_moe"].copy()
|
||||
mapping["qwen3_moe"] = mapping["qwen2_moe"].copy()
|
||||
mapping["qwen3_omni_moe"] = mapping["qwen2_moe"].copy()
|
||||
mapping["qwen3_next"] = mapping["qwen2_moe"].copy()
|
||||
mapping["qwen3_vl_moe"] = mapping["qwen2_moe"].copy()
|
||||
mapping["hunyuan_v1_moe"] = mapping["qwen2_moe"].copy()
|
||||
mapping["minimax"] = mapping["mixtral"].copy()
|
||||
|
||||
return mapping
|
||||
|
||||
|
||||
_checkpoint_conversion_mapping_cache = None
|
||||
|
||||
|
||||
def get_checkpoint_conversion_mapping(model_type):
|
||||
global _checkpoint_conversion_mapping_cache
|
||||
_checkpoint_conversion_mapping_cache = _build_checkpoint_conversion_mapping()
|
||||
globals()["_checkpoint_conversion_mapping"] = _checkpoint_conversion_mapping_cache
|
||||
return deepcopy(_checkpoint_conversion_mapping_cache.get(model_type, None))
|
||||
@ -1731,10 +1731,8 @@ SLOW_TO_FAST_CONVERTERS = {
|
||||
"OpenAIGPTTokenizer": OpenAIGPTConverter,
|
||||
"PegasusTokenizer": PegasusConverter,
|
||||
"Qwen2Tokenizer": Qwen2Converter,
|
||||
"RealmTokenizer": BertConverter,
|
||||
"ReformerTokenizer": ReformerConverter,
|
||||
"RemBertTokenizer": RemBertConverter,
|
||||
"RetriBertTokenizer": BertConverter,
|
||||
"RobertaTokenizer": RobertaConverter,
|
||||
"RoFormerTokenizer": RoFormerConverter,
|
||||
"SeamlessM4TTokenizer": SeamlessM4TConverter,
|
||||
|
||||
602
src/transformers/core_model_loading.py
Normal file
602
src/transformers/core_model_loading.py
Normal file
@ -0,0 +1,602 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Core helpers for loading model checkpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import itertools
|
||||
import os
|
||||
import re
|
||||
from abc import abstractmethod
|
||||
from collections import defaultdict
|
||||
from collections.abc import MutableMapping, MutableSet, Sequence
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from .integrations.tensor_parallel import ALL_PARALLEL_STYLES, DTensor, Replicate, TensorParallelLayer
|
||||
from .utils import is_torch_greater_or_equal, logging
|
||||
|
||||
|
||||
_torch_distributed_available = torch.distributed.is_available()
|
||||
_is_dtensor_available = _torch_distributed_available and is_torch_greater_or_equal("2.5")
|
||||
if _is_dtensor_available:
|
||||
from torch.distributed.tensor import DTensor
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .modeling_utils import PreTrainedModel
|
||||
from .quantizers import HfQuantizer
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def _glob_to_regex_src(glob: str, *, digits_only: bool = True) -> str:
|
||||
"""
|
||||
Convert a glob with '*' into a regex *source* string. We don't use `glob.translate`
|
||||
'*' matches (\\d+) if digits_only else (.+). Inner groups are non-capturing.
|
||||
"""
|
||||
star = r"(\d+)" if digits_only else r"(.+)"
|
||||
return glob.replace(r"\*", star)
|
||||
|
||||
|
||||
def build_glob_alt(
|
||||
globs: list[str],
|
||||
) -> tuple[re.Pattern, dict[str, str]]:
|
||||
r"""
|
||||
Build one compiled regex alternation with a named group per glob. This allows to run a single
|
||||
re.match and get the correct group name to finally get which pattern matched.
|
||||
Returns (compiled_regex, name->glob map).
|
||||
|
||||
Example:
|
||||
|
||||
```py
|
||||
>>> reg, map_ = build_glob_alt(["mlp.*.w1", "mlp.*.w2"])
|
||||
>>> print(reg)
|
||||
(re.compile(r'(?P<g0>.*mlp\.(\d+)\.w1)|(?P<g1>.*mlp\.(\d+)\.w2)', re.UNICODE),
|
||||
>>> print(map_)
|
||||
{'g0': 'mlp.*.w1', 'g1': 'mlp.*.w2'})
|
||||
>>> match_ = reg.match("model.layers.0.mlp.0.w1.weight")
|
||||
>>> print(match_.lastgroup)
|
||||
'g0'
|
||||
>>> print(map_[match_.lastgroup])
|
||||
mlp.*.w1
|
||||
```
|
||||
"""
|
||||
name_map: dict[str, str] = {}
|
||||
parts: list[str] = []
|
||||
|
||||
for i, g in enumerate(globs):
|
||||
name = f"g{i}"
|
||||
name_map[name] = g
|
||||
pat_src = _glob_to_regex_src(g)
|
||||
prefix_src = ""
|
||||
if pat_src.startswith("*"):
|
||||
prefix_src = "."
|
||||
elif not pat_src.startswith(r"\^") and not pat_src.startswith(r".*"):
|
||||
prefix_src = ".*"
|
||||
|
||||
parts.append(f"(?P<{name}>{prefix_src}{pat_src}.*)")
|
||||
|
||||
alt_src = "|".join(parts).replace("\\^", "^").replace("\\.", r"\.")
|
||||
try:
|
||||
reg = re.compile(alt_src)
|
||||
except re.error as e:
|
||||
logger.error(f"Error compiling regex for alternation: {alt_src}")
|
||||
raise e
|
||||
|
||||
return reg, name_map
|
||||
|
||||
|
||||
def match_glob(key: str, alt: re.Pattern, name_map: dict[str, str]) -> Optional[str]:
|
||||
"""
|
||||
Match the key against the alternation; return the original glob string that matched.
|
||||
"""
|
||||
m = alt.match(key)
|
||||
if not m:
|
||||
return None
|
||||
return name_map.get(m.lastgroup)
|
||||
|
||||
|
||||
class ConversionOps:
|
||||
"""Base class for weight conversion operations."""
|
||||
|
||||
# The inverse operation class, will be used when saving the checkpoint
|
||||
reverse_op: type[ConversionOps]
|
||||
|
||||
@abstractmethod
|
||||
def convert(
|
||||
self, value: Union[dict[str, torch.Tensor], Sequence[torch.Tensor], torch.Tensor], *args, **kwargs
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class Chunk(ConversionOps):
|
||||
"""Split a tensor along ``dim`` into equally sized chunks or using explicit ``sizes``."""
|
||||
|
||||
reverse_op: type[ConversionOps]
|
||||
|
||||
def __init__(self, dim: int = 0, chunks: Optional[int] = None, sizes: Optional[Sequence[int]] = None):
|
||||
if chunks is None and sizes is None:
|
||||
raise ValueError("`chunks` or `sizes` must be provided for Chunk operations.")
|
||||
if chunks is not None and chunks <= 0:
|
||||
raise ValueError("`chunks` must be a strictly positive integer.")
|
||||
self.dim = dim
|
||||
self.chunks = chunks
|
||||
self.sizes = list(sizes) if sizes is not None else None
|
||||
self.reverse_op = Concatenate
|
||||
|
||||
def convert(self, value: torch.Tensor, *args, **kwargs) -> list[torch.Tensor]:
|
||||
# chunk requires a single tensor input
|
||||
if len(value) != 1 or len(value[0]) != 1:
|
||||
raise ValueError("Chunk operation requires a single tensor input.")
|
||||
return list(torch.chunk(value[0][0], self.chunks, dim=self.dim))
|
||||
|
||||
|
||||
class Concatenate(ConversionOps):
|
||||
"""Concatenate tensors along `dim` using a reusable buffer."""
|
||||
|
||||
reverse_op: type[ConversionOps]
|
||||
|
||||
def __init__(self, dim: int = 0):
|
||||
self.dim = dim
|
||||
self.reverse_op = Chunk
|
||||
|
||||
@torch.no_grad
|
||||
def convert(self, value: Sequence[torch.Tensor], *args, **kwargs) -> torch.Tensor:
|
||||
if isinstance(value[0], list):
|
||||
value = [v[0] for v in value]
|
||||
tensors = value
|
||||
if not tensors:
|
||||
raise ValueError("Fuse requires at least one tensor to concatenate.")
|
||||
|
||||
return torch.cat(tuple(tensors), dim=self.dim)
|
||||
|
||||
|
||||
class MergeModulelist(Concatenate):
|
||||
"""
|
||||
Merge a list of tensors into a single tensor along the first dimension.
|
||||
We explicitly define this because for EP or TP you want to make sure you know what you are doing!
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, dim: int = 0):
|
||||
super().__init__(dim=dim)
|
||||
self.reverse_op = SplitModulelist
|
||||
|
||||
@torch.no_grad
|
||||
def convert(self, value: Sequence[torch.Tensor], *args, **kwargs) -> list[torch.Tensor]:
|
||||
merged = []
|
||||
for group in value:
|
||||
if not isinstance(group, Sequence) or len(group) == 0:
|
||||
raise ValueError("MergeModulelist requires non-empty sub-sequences.")
|
||||
group = [k for k in group if k.ndim]
|
||||
merged.append(torch.stack(group, dim=self.dim))
|
||||
return merged
|
||||
|
||||
|
||||
class SplitModulelist(ConversionOps):
|
||||
"""Inverse of :class:`MergeModulelist` using explicit split sizes per group."""
|
||||
|
||||
def __init__(self, sizes: Sequence[Sequence[int]], dim: int = 0):
|
||||
if not isinstance(sizes, Sequence) or not all(isinstance(sub, Sequence) and sub for sub in sizes):
|
||||
raise ValueError("`sizes` must be a sequence of non-empty sequences of integers.")
|
||||
self.sizes = [list(sub) for sub in sizes]
|
||||
self.dim = dim
|
||||
self.reverse_op = MergeModulelist
|
||||
|
||||
@torch.no_grad
|
||||
def convert(self, value: Sequence[torch.Tensor], *, context: dict[str, Any]) -> list[list[torch.Tensor]]:
|
||||
if not isinstance(value, Sequence):
|
||||
raise TypeError("SplitModulelist expects a sequence of tensors.")
|
||||
if len(value) != len(self.sizes):
|
||||
raise ValueError("Number of tensors does not match the provided split specifications.")
|
||||
|
||||
result: list[list[torch.Tensor]] = []
|
||||
for tensor, split_sizes in zip(value, self.sizes):
|
||||
if not isinstance(tensor, torch.Tensor):
|
||||
raise TypeError("SplitModulelist can only split torch.Tensor instances.")
|
||||
splits = torch.split(tensor, split_sizes, dim=self.dim)
|
||||
result.append(list(splits))
|
||||
return result
|
||||
|
||||
|
||||
class PermuteForRope(ConversionOps):
|
||||
"""
|
||||
Applies the permutation required to convert complex RoPE weights to the split sin/cos format.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def _apply(self, tensor: torch.Tensor) -> torch.Tensor:
|
||||
dim1, dim2 = tensor.shape
|
||||
n_heads = self.config.getattr("num_attention_heads", 1)
|
||||
|
||||
tensor = tensor.view(n_heads, dim1 // n_heads // 2, 2, dim2)
|
||||
tensor = tensor.transpose(1, 2).reshape(dim1, dim2)
|
||||
return tensor
|
||||
|
||||
@torch.no_grad
|
||||
def convert(
|
||||
self, value: Union[dict[str, torch.Tensor], Sequence[torch.Tensor], torch.Tensor], config
|
||||
) -> Union[dict[str, torch.Tensor], list[torch.Tensor], torch.Tensor]:
|
||||
self.config = config
|
||||
out = [[self._apply(x) for x in inner] if isinstance(inner, list) else self._apply(inner) for inner in value]
|
||||
return out
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class WeightConverter:
|
||||
r"""
|
||||
A weight convert that acts on a pattern of source keys.
|
||||
The keys need to be collected based on the target keys.
|
||||
|
||||
With wild card, glob patterns are matched, so you have to be detailed with what to match. If you match:
|
||||
`model.layers.*.experts.*` -> it will act on all of them
|
||||
{"model.layers.*.experts.*": []}
|
||||
but
|
||||
`experts.*.mlp` will be layer specific.
|
||||
{"model.layers.1.experts.*": [], }
|
||||
- source_keys: str | list[str] (wildcards '*' match digits)
|
||||
- target_keys: str | list[str] | None
|
||||
- distributed_operation / operations / quantization_operations are ALWAYS lists.
|
||||
|
||||
TODO: for BNB we need to collect model.weight.quant_state_keys
|
||||
"""
|
||||
|
||||
source_keys: Union[str, list[str]]
|
||||
target_keys: Optional[Union[str, list[str]]] = None
|
||||
operations: list[ConversionOps] = field(default_factory=list, repr=False)
|
||||
|
||||
distributed_operation: Optional[TensorParallelLayer] = None
|
||||
quantization_operation: Optional[ConversionOps] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if not isinstance(self.source_keys, list):
|
||||
self.source_keys = [self.source_keys]
|
||||
targets_were_none = False
|
||||
if not isinstance(self.target_keys, list):
|
||||
if self.target_keys is None:
|
||||
self.target_keys = list(self.source_keys)
|
||||
targets_were_none = True
|
||||
else:
|
||||
self.target_keys = [self.target_keys]
|
||||
|
||||
if not targets_were_none and bool(len(self.source_keys) - 1) + bool(len(self.target_keys) - 1) >= 2:
|
||||
raise ValueError(
|
||||
f"source keys={self.source_keys}, target_keys={self.target_keys} but you can only have one to many, one to one or many to one."
|
||||
)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ConversionEntry:
|
||||
weight_converter: WeightConverter
|
||||
collected_tensors: dict = field(default_factory=lambda: defaultdict(dict))
|
||||
|
||||
|
||||
GLOBAL_WORKERS = min(16, (os.cpu_count() or 8) * 2) # NVMe: 8-16; HDD/NFS: 2-4
|
||||
|
||||
|
||||
def _materialize_copy(tensor, device=None, dtype=None):
|
||||
tensor = tensor[...]
|
||||
if dtype is not None or device is not None:
|
||||
tensor = tensor.to(device=device, dtype=dtype)
|
||||
return tensor
|
||||
|
||||
|
||||
def spawn_materialize(thread_pool, tensor, device=None, dtype=None) -> Future:
|
||||
def _job():
|
||||
return _materialize_copy(tensor, device, dtype)
|
||||
|
||||
return thread_pool.submit(_job)
|
||||
|
||||
|
||||
def spawn_tp_materialize(thread_pool, tensor, sharding_method, tensor_idx, dtype=None) -> Future:
|
||||
def _job():
|
||||
return sharding_method.shard_tensor(tensor, param_casting_dtype=dtype, tensor_idx=tensor_idx)[0]
|
||||
|
||||
return thread_pool.submit(_job)
|
||||
|
||||
|
||||
def dot_natural_key(s: str):
|
||||
parts = s.split(".")
|
||||
for i, p in enumerate(parts):
|
||||
# whole-segment digits -> int; otherwise leave as str
|
||||
if p.isdigit():
|
||||
parts[i] = int(p)
|
||||
return parts
|
||||
|
||||
|
||||
@contextmanager
|
||||
def log_to_misc(
|
||||
layer_name: str,
|
||||
misc: MutableMapping[str, str],
|
||||
extras: Any = None,
|
||||
op: Union[list[ConversionOps], ConversionOps, None] = None,
|
||||
):
|
||||
# A simple helper to handle errors with contextual messages.
|
||||
try:
|
||||
yield
|
||||
except Exception as e:
|
||||
|
||||
def _format_op_name(curr_op: Union[list[ConversionOps], ConversionOps, None]) -> Optional[str]:
|
||||
if curr_op is None:
|
||||
return None
|
||||
if isinstance(curr_op, (list, tuple, set)):
|
||||
names = [o.__class__.__name__ for o in curr_op if o is not None]
|
||||
if not names:
|
||||
return None
|
||||
return ", ".join(names)
|
||||
return curr_op.__class__.__name__
|
||||
|
||||
op_name = _format_op_name(op)
|
||||
if isinstance(extras, tuple) and len(extras) == 2:
|
||||
values, target_keys = extras
|
||||
descriptor = f"{op_name} " if op_name else ""
|
||||
misc[layer_name] = (
|
||||
f"{e}\nError: {descriptor}on tensors destined for {target_keys}. Ckpt contains: {len(values[0])}"
|
||||
)
|
||||
elif isinstance(extras, str):
|
||||
suffix = f" via {op_name}" if op_name else ""
|
||||
misc[layer_name] = f"{e}\nError{suffix} when processing parameter {extras}"
|
||||
elif extras is None and op_name:
|
||||
misc[layer_name] = f"{op_name}: {e}"
|
||||
else:
|
||||
misc[layer_name] = f"{extras} |Error: {e}"
|
||||
raise SkipLayer()
|
||||
|
||||
|
||||
def set_param_for_module(
|
||||
model: PreTrainedModel,
|
||||
layer_name: str,
|
||||
param_value: torch.Tensor,
|
||||
mismatch_keys: MutableSet[tuple[str, torch.Size, torch.Size]],
|
||||
missing_keys: MutableSet[str],
|
||||
misc: MutableMapping[str, Any],
|
||||
distributed_operation: Optional[TensorParallelLayer],
|
||||
hf_quantizer: HfQuantizer,
|
||||
):
|
||||
with log_to_misc(layer_name, misc, layer_name):
|
||||
module_path, _, param_name = layer_name.rpartition(".")
|
||||
module_obj = model.get_submodule(module_path) if module_path else model
|
||||
if isinstance(param_value, list):
|
||||
param_value = param_value[0]
|
||||
elif not isinstance(param_value, torch.nn.Parameter):
|
||||
param_value = param_value[...]
|
||||
ref = getattr(module_obj, param_name)
|
||||
|
||||
use_dtensor = hasattr(distributed_operation, "use_dtensor") and distributed_operation.use_dtensor
|
||||
if not isinstance(param_value, torch.nn.Parameter):
|
||||
if distributed_operation is not None:
|
||||
param_value = DTensor.from_local(
|
||||
param_value,
|
||||
distributed_operation.device_mesh,
|
||||
getattr(distributed_operation, "shard", Replicate()),
|
||||
run_check=False,
|
||||
shape=ref.size(),
|
||||
stride=ref.stride(),
|
||||
)
|
||||
if not use_dtensor:
|
||||
# we convert to local
|
||||
param_value = param_value.to_local()
|
||||
if param_name not in module_obj._buffers:
|
||||
param_value = torch.nn.Parameter(param_value, requires_grad=param_value.is_floating_point())
|
||||
|
||||
# Remove from missing keys (it's either mismatched, or all good)
|
||||
missing_keys.discard(layer_name)
|
||||
if ref is not None and ref.shape != param_value.shape and hf_quantizer is None:
|
||||
mismatch_keys.add((layer_name, param_value.shape, ref.shape))
|
||||
module_obj.param_name._is_hf_initialized = False # Needs to be initialized
|
||||
else:
|
||||
param_value._is_hf_initialized = True # super important otherwise _init_weight re-initi if bias is missing
|
||||
setattr(module_obj, param_name, param_value)
|
||||
|
||||
|
||||
class SkipLayer(Exception):
|
||||
"""Control-flow sentinel: abort processing of the current layer only."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def convert_and_load_state_dict_in_model(
|
||||
model: PreTrainedModel,
|
||||
state_dict: dict[str, Any],
|
||||
weight_mapping: dict[str, WeightConverter] | None,
|
||||
tp_plan: dict[str, str] | None,
|
||||
hf_quantizer: HfQuantizer | None,
|
||||
dtype: torch.dtype | None = None,
|
||||
device_map: dict | None = None,
|
||||
dtype_plan: dict | None = None,
|
||||
device_mesh: torch.distributed.device_mesh.DeviceMesh | None = None,
|
||||
):
|
||||
"""
|
||||
Convert a state dict according to a weight mapping (one WeightConverter per glob pattern),
|
||||
collecting tensors per *layer instance* (the concrete indices captured from '*').
|
||||
"""
|
||||
|
||||
prefix = model.base_model_prefix
|
||||
tp_plan = tp_plan or {} # {glob_pattern: plan_obj_or_key}
|
||||
device_map = device_map or {"": "cpu"} # {exact_target_key: device}
|
||||
device_map_regex = re.compile(
|
||||
"|".join(rf"({k})" for k in sorted(device_map.keys(), key=lambda x: x.count("."), reverse=True))
|
||||
)
|
||||
dtype_plan = dtype_plan or {} # {glob_pattern: dtype}
|
||||
weight_mapping = weight_mapping or {} # {glob_pattern: WeightConverter}
|
||||
meta_model_state_dict = model.state_dict()
|
||||
missing_keys = set(meta_model_state_dict.keys())
|
||||
|
||||
misc = {}
|
||||
mismatch_keys = set()
|
||||
unexpected_keys = set()
|
||||
# Global thread_pool
|
||||
thread_pool = ThreadPoolExecutor(max_workers=GLOBAL_WORKERS)
|
||||
|
||||
_patterns = list(itertools.chain.from_iterable([k.source_keys for k in weight_mapping]))
|
||||
source_to_target = {sk: k for k in weight_mapping for sk in k.source_keys}
|
||||
weight_pattern_alt, weight_pattern_by_group_name = build_glob_alt(_patterns)
|
||||
tp_plan_alt, tp_plan_by_group_name = build_glob_alt(list(tp_plan.keys()))
|
||||
dtype_policy_alt, dtype_policy_by_group_name = build_glob_alt(list(dtype_plan.keys()))
|
||||
|
||||
state_dict = sorted(state_dict.items(), key=lambda kv: dot_natural_key(kv[0]))
|
||||
# 1. Create the conversion entries
|
||||
by_conversion_pattern: dict[str, ConversionEntry] = {}
|
||||
for original_key, tensor in state_dict:
|
||||
matched_pattern = match_glob(original_key, weight_pattern_alt, weight_pattern_by_group_name)
|
||||
if matched_pattern is not None:
|
||||
converter = source_to_target[matched_pattern] # TODO make sure its the ref
|
||||
sub_with_extractor = partial(re.sub, matched_pattern.replace("*", r"(\d+)"), string=original_key)
|
||||
entry_key = "|".join(converter.target_keys)
|
||||
target_key = "|".join(map(sub_with_extractor, [k.replace("*", "\\1") for k in converter.target_keys]))
|
||||
entry: ConversionEntry = by_conversion_pattern.setdefault(entry_key, ConversionEntry(converter))
|
||||
converter_key = sub_with_extractor(matched_pattern)
|
||||
else:
|
||||
converter = WeightConverter(original_key)
|
||||
converter_key = entry_key = target_key = original_key
|
||||
entry = by_conversion_pattern.setdefault(converter_key, ConversionEntry(converter))
|
||||
|
||||
_dtype = dtype
|
||||
new_target_key = [] # test_load_with_mismatched_shapes for AutoModel.from_pretrained(AutoForCausal, vocab=10)
|
||||
for t in target_key.split("|"):
|
||||
if t.startswith(prefix) and meta_model_state_dict.get(re.sub(f"^{prefix}.", "", t, count=1)) is not None:
|
||||
t = re.sub(f"^{prefix}.", "", t, count=1)
|
||||
elif meta_model_state_dict.get(f"{prefix}.{t}") is not None:
|
||||
t = f"{prefix}.{t}"
|
||||
new_target_key.append(t)
|
||||
empty_param = meta_model_state_dict.get(t)
|
||||
# If it does not exist, it's unexpected
|
||||
if empty_param is None:
|
||||
unexpected_keys.add(t)
|
||||
continue
|
||||
|
||||
if hf_quantizer is not None and hf_quantizer.param_needs_quantization(model, t):
|
||||
converter.quantization_operation = hf_quantizer.get_quantize_ops()
|
||||
_dtype = dtype
|
||||
matched_dtype_pattern = match_glob(t, dtype_policy_alt, dtype_policy_by_group_name)
|
||||
if matched_dtype_pattern is not None:
|
||||
_dtype = dtype_plan[matched_dtype_pattern]
|
||||
elif empty_param.dtype != _dtype:
|
||||
_dtype = empty_param.dtype
|
||||
|
||||
first_target_key = new_target_key[0]
|
||||
target_key = "|".join(new_target_key)
|
||||
|
||||
future = None
|
||||
if device_mesh:
|
||||
if matched_tp_pattern := match_glob(first_target_key, tp_plan_alt, tp_plan_by_group_name):
|
||||
empty_param = meta_model_state_dict.get(first_target_key)
|
||||
if getattr(converter, "distributed_operation", {}) is None:
|
||||
tp_layer = ALL_PARALLEL_STYLES[model.tp_plan[matched_tp_pattern]].__class__
|
||||
converter.distributed_operation = tp_layer(
|
||||
device_mesh=device_mesh, rank=device_map[""].index, empty_param=empty_param.clone()
|
||||
)
|
||||
# VERY IMPORTANT: this tells us wether we collected stuffs or not.
|
||||
shard_index = len(entry.collected_tensors[target_key].get(converter_key, []))
|
||||
future = spawn_tp_materialize(
|
||||
thread_pool,
|
||||
tensor,
|
||||
_dtype,
|
||||
converter.distributed_operation,
|
||||
shard_index,
|
||||
)
|
||||
|
||||
if future is None: # If not TP, async materialize the tensors. TODO handle disk offload?
|
||||
device_match = device_map_regex.match(first_target_key)
|
||||
param_device = device_map[device_match.group()] if device_match else device_map.get("", "cpu")
|
||||
future = spawn_materialize(thread_pool, tensor, param_device, _dtype)
|
||||
entry.collected_tensors[target_key].setdefault(converter_key, []).append(future)
|
||||
|
||||
# 2. Actually convert the ckpt
|
||||
inverse_converters = {}
|
||||
keys = list(by_conversion_pattern.keys())
|
||||
|
||||
with logging.tqdm(total=len(keys), desc="Loading weights") as pbar:
|
||||
for key in keys[::-1]: # revert to process simple keys first
|
||||
group = by_conversion_pattern.pop(key)
|
||||
converter = group.weight_converter
|
||||
operations = converter.operations if isinstance(converter.operations, list) else [converter.operations]
|
||||
for layer_name, tensors_for_this_layer in group.collected_tensors.items():
|
||||
pbar.update(1)
|
||||
pbar.set_postfix({"Materializing param": layer_name})
|
||||
pbar.refresh()
|
||||
concrete_target_keys = layer_name.split("|")
|
||||
try:
|
||||
if bool(set(concrete_target_keys) - unexpected_keys):
|
||||
with log_to_misc(layer_name, misc):
|
||||
values = [[k.result() for k in inner] for inner in tensors_for_this_layer.values()]
|
||||
|
||||
for op in operations:
|
||||
with log_to_misc(layer_name, misc, (values, concrete_target_keys), operations):
|
||||
values = op.convert(values, model.config)
|
||||
|
||||
values = [values] if not isinstance(values, list) else values
|
||||
with log_to_misc(layer_name, misc, (values, concrete_target_keys), operations):
|
||||
realized_value = {
|
||||
k: t for k, t in zip(concrete_target_keys, values) if k not in unexpected_keys
|
||||
}
|
||||
|
||||
for k in list(realized_value.keys()).copy():
|
||||
if op := converter.quantization_operation:
|
||||
with log_to_misc(layer_name, misc, op=op):
|
||||
realized_value.update(
|
||||
op.convert({k: realized_value.pop(k)}, model=model, missing_keys=missing_keys)
|
||||
)
|
||||
|
||||
for k, output_value in realized_value.items():
|
||||
for src in converter.source_keys: # what should happen to k when we meet k at saving
|
||||
inverse_converters[k] = {src: converter}
|
||||
set_param_for_module(
|
||||
model,
|
||||
k,
|
||||
output_value,
|
||||
mismatch_keys,
|
||||
missing_keys,
|
||||
misc,
|
||||
converter.distributed_operation,
|
||||
hf_quantizer,
|
||||
)
|
||||
|
||||
except SkipLayer:
|
||||
continue
|
||||
del group
|
||||
|
||||
model.inverse_converters = inverse_converters
|
||||
thread_pool.shutdown(wait=False)
|
||||
return missing_keys, unexpected_keys, mismatch_keys, misc
|
||||
|
||||
|
||||
# TODO this is not done yet!
|
||||
def revert_weight_conversion(model, state_dict):
|
||||
mapping = getattr(model, "_checkpoint_conversion_mapping", {}) # IDK why but setting this will fail all llava.
|
||||
reverse_key_mapping = [(v, k) for k, v in mapping.items()]
|
||||
original_state_dict = {}
|
||||
for key, value in state_dict.items():
|
||||
for pattern, inverse_converter in reverse_key_mapping:
|
||||
# TODO FIXME you name it
|
||||
replacement = inverse_converter.lstrip("^") # strip off un-needed chars and patterns
|
||||
replacement = re.sub(r"\(.*\)", "", replacement)
|
||||
key, n_replace = re.subn(pattern, replacement, key)
|
||||
# Early exit of the loop
|
||||
if n_replace > 0:
|
||||
break
|
||||
original_state_dict[key] = value
|
||||
state_dict = original_state_dict
|
||||
return state_dict
|
||||
@ -723,7 +723,7 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
|
||||
|
||||
if self.mask_replace_prob < 1:
|
||||
warnings.warn(
|
||||
"Random token replacement is not supported with whole word masking.",
|
||||
"Random token replacement is not supported with whole word masking. "
|
||||
"Setting mask_replace_prob to 1.",
|
||||
)
|
||||
self.mask_replace_prob = 1
|
||||
|
||||
@ -82,7 +82,7 @@ class GlueDataset(Dataset):
|
||||
cache_dir: Optional[str] = None,
|
||||
):
|
||||
warnings.warn(
|
||||
"This dataset will be removed from the library soon, preprocessing should be handled with the 🤗 Datasets "
|
||||
"This dataset will be removed from the library soon, preprocessing should be handled with the Hugging Face Datasets "
|
||||
"library. You can have a look at this example script for pointers: "
|
||||
"https://github.com/huggingface/transformers/blob/main/examples/pytorch/text-classification/run_glue.py",
|
||||
FutureWarning,
|
||||
|
||||
@ -21,7 +21,7 @@ if is_sklearn_available():
|
||||
|
||||
|
||||
DEPRECATION_WARNING = (
|
||||
"This metric will be removed from the library soon, metrics should be handled with the 🤗 Evaluate "
|
||||
"This metric will be removed from the library soon, metrics should be handled with the Hugging Face Evaluate "
|
||||
"library. You can have a look at this example script for pointers: "
|
||||
"https://github.com/huggingface/transformers/blob/main/examples/pytorch/text-classification/run_glue.py"
|
||||
)
|
||||
|
||||
@ -28,7 +28,7 @@ from .utils import DataProcessor, InputExample, InputFeatures
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
DEPRECATION_WARNING = (
|
||||
"This {0} will be removed from the library soon, preprocessing should be handled with the 🤗 Datasets "
|
||||
"This {0} will be removed from the library soon, preprocessing should be handled with the Hugging Face Datasets "
|
||||
"library. You can have a look at this example script for pointers: "
|
||||
"https://github.com/huggingface/transformers/blob/main/examples/pytorch/text-classification/run_glue.py"
|
||||
)
|
||||
|
||||
@ -47,8 +47,8 @@ deps = {
|
||||
"psutil": "psutil",
|
||||
"pyyaml": "pyyaml>=5.1",
|
||||
"pydantic": "pydantic>=2",
|
||||
"pytest": "pytest>=7.2.0",
|
||||
"pytest-asyncio": "pytest-asyncio",
|
||||
"pytest": "pytest>=7.2.0,<9.0.0",
|
||||
"pytest-asyncio": "pytest-asyncio>=1.2.0",
|
||||
"pytest-rerunfailures": "pytest-rerunfailures<16.0",
|
||||
"pytest-timeout": "pytest-timeout",
|
||||
"pytest-xdist": "pytest-xdist",
|
||||
|
||||
@ -39,6 +39,7 @@ from .utils import (
|
||||
is_torch_dtype,
|
||||
logging,
|
||||
requires_backends,
|
||||
safe_load_json_file,
|
||||
)
|
||||
from .utils.hub import cached_file
|
||||
|
||||
@ -427,35 +428,42 @@ class FeatureExtractionMixin(PushToHubMixin):
|
||||
feature_extractor_file = os.path.join(pretrained_model_name_or_path, FEATURE_EXTRACTOR_NAME)
|
||||
if os.path.isfile(pretrained_model_name_or_path):
|
||||
resolved_feature_extractor_file = pretrained_model_name_or_path
|
||||
resolved_processor_file = None
|
||||
is_local = True
|
||||
elif is_remote_url(pretrained_model_name_or_path):
|
||||
feature_extractor_file = pretrained_model_name_or_path
|
||||
resolved_processor_file = None
|
||||
resolved_feature_extractor_file = download_url(pretrained_model_name_or_path)
|
||||
else:
|
||||
feature_extractor_file = FEATURE_EXTRACTOR_NAME
|
||||
try:
|
||||
# Load from local folder or from cache or download from model Hub and cache
|
||||
resolved_feature_extractor_files = [
|
||||
resolved_file
|
||||
for filename in [feature_extractor_file, PROCESSOR_NAME]
|
||||
if (
|
||||
resolved_file := cached_file(
|
||||
pretrained_model_name_or_path,
|
||||
filename=filename,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
subfolder=subfolder,
|
||||
token=token,
|
||||
user_agent=user_agent,
|
||||
revision=revision,
|
||||
_raise_exceptions_for_missing_entries=False,
|
||||
)
|
||||
)
|
||||
is not None
|
||||
]
|
||||
resolved_feature_extractor_file = resolved_feature_extractor_files[0]
|
||||
resolved_processor_file = cached_file(
|
||||
pretrained_model_name_or_path,
|
||||
filename=PROCESSOR_NAME,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
user_agent=user_agent,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
_raise_exceptions_for_missing_entries=False,
|
||||
)
|
||||
resolved_feature_extractor_file = cached_file(
|
||||
pretrained_model_name_or_path,
|
||||
filename=feature_extractor_file,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
user_agent=user_agent,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
_raise_exceptions_for_missing_entries=False,
|
||||
)
|
||||
except OSError:
|
||||
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to
|
||||
# the original exception.
|
||||
@ -469,19 +477,24 @@ class FeatureExtractionMixin(PushToHubMixin):
|
||||
f" directory containing a {FEATURE_EXTRACTOR_NAME} file"
|
||||
)
|
||||
|
||||
try:
|
||||
# Load feature_extractor dict
|
||||
with open(resolved_feature_extractor_file, encoding="utf-8") as reader:
|
||||
text = reader.read()
|
||||
feature_extractor_dict = json.loads(text)
|
||||
if "audio_processor" in feature_extractor_dict:
|
||||
feature_extractor_dict = feature_extractor_dict["audio_processor"]
|
||||
else:
|
||||
feature_extractor_dict = feature_extractor_dict.get("feature_extractor", feature_extractor_dict)
|
||||
# Load feature_extractor dict. Priority goes as (nested config if found -> image processor config)
|
||||
# We are downloading both configs because almost all models have a `processor_config.json` but
|
||||
# not all of these are nested. We need to check if it was saved recebtly as nested or if it is legacy style
|
||||
feature_extractor_dict = None
|
||||
if resolved_processor_file is not None:
|
||||
processor_dict = safe_load_json_file(resolved_processor_file)
|
||||
if "feature_extractor" in processor_dict or "audio_processor" in processor_dict:
|
||||
feature_extractor_dict = processor_dict.get("feature_extractor", processor_dict.get("audio_processor"))
|
||||
|
||||
except json.JSONDecodeError:
|
||||
if resolved_feature_extractor_file is not None and feature_extractor_dict is None:
|
||||
feature_extractor_dict = safe_load_json_file(resolved_feature_extractor_file)
|
||||
|
||||
if feature_extractor_dict is None:
|
||||
raise OSError(
|
||||
f"It looks like the config file at '{resolved_feature_extractor_file}' is not a valid JSON file."
|
||||
f"Can't load feature extractor for '{pretrained_model_name_or_path}'. If you were trying to load"
|
||||
" it from 'https://huggingface.co/models', make sure you don't have a local directory with the"
|
||||
f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
|
||||
f" directory containing a {feature_extractor_file} file"
|
||||
)
|
||||
|
||||
if is_local:
|
||||
|
||||
@ -918,7 +918,9 @@ class GenerationConfig(PushToHubMixin):
|
||||
else:
|
||||
logger.info(f"loading configuration file {configuration_file} from cache at {resolved_config_file}")
|
||||
|
||||
if kwargs.get("return_unused_kwargs") is True:
|
||||
if kwargs.get("_from_model_config", False):
|
||||
return cls.from_model_config(config_dict)
|
||||
elif kwargs.get("return_unused_kwargs") is True:
|
||||
config, unused_kwargs = cls.from_dict(config_dict, **kwargs)
|
||||
config._original_object_hash = hash(config) # Hash to detect whether the instance was modified
|
||||
return config, unused_kwargs
|
||||
@ -1084,19 +1086,19 @@ class GenerationConfig(PushToHubMixin):
|
||||
writer.write(self.to_json_string(use_diff=use_diff))
|
||||
|
||||
@classmethod
|
||||
def from_model_config(cls, model_config: PreTrainedConfig) -> "GenerationConfig":
|
||||
def from_model_config(cls, model_config: PreTrainedConfig | dict) -> "GenerationConfig":
|
||||
"""
|
||||
Instantiates a [`GenerationConfig`] from a [`PreTrainedConfig`]. This function is useful to convert legacy
|
||||
[`PreTrainedConfig`] objects, which may contain generation parameters, into a stand-alone [`GenerationConfig`].
|
||||
|
||||
Args:
|
||||
model_config (`PreTrainedConfig`):
|
||||
model_config (`PreTrainedConfig | dict`):
|
||||
The model config that will be used to instantiate the generation config.
|
||||
|
||||
Returns:
|
||||
[`GenerationConfig`]: The configuration object instantiated from those parameters.
|
||||
"""
|
||||
config_dict = model_config.to_dict()
|
||||
config_dict = model_config.to_dict() if not isinstance(model_config, dict) else model_config
|
||||
config_dict.pop("_from_model_config", None)
|
||||
|
||||
# Removes all `None` from the model config dict -- this lets the generation config defaults to take hold
|
||||
@ -1106,14 +1108,15 @@ class GenerationConfig(PushToHubMixin):
|
||||
|
||||
# Special case: some models have generation attributes set in the decoder. Use them if still unset in the
|
||||
# generation config (which in turn is defined from the outer attributes of model config).
|
||||
decoder_config = model_config.get_text_config(decoder=True)
|
||||
if decoder_config is not model_config:
|
||||
default_generation_config = GenerationConfig()
|
||||
decoder_config_dict = decoder_config.to_dict()
|
||||
for attr in generation_config.to_dict():
|
||||
is_unset = getattr(generation_config, attr) == getattr(default_generation_config, attr)
|
||||
if attr in decoder_config_dict and is_unset:
|
||||
setattr(generation_config, attr, decoder_config_dict[attr])
|
||||
if not isinstance(model_config, dict):
|
||||
decoder_config = model_config.get_text_config(decoder=True)
|
||||
if decoder_config is not model_config:
|
||||
default_generation_config = GenerationConfig()
|
||||
decoder_config_dict = decoder_config.to_dict()
|
||||
for attr in generation_config.to_dict():
|
||||
is_unset = getattr(generation_config, attr) == getattr(default_generation_config, attr)
|
||||
if attr in decoder_config_dict and is_unset:
|
||||
setattr(generation_config, attr, decoder_config_dict[attr])
|
||||
|
||||
# If any `output_...` flag is set to `True`, we ensure `return_dict_in_generate` is set to `True`.
|
||||
if generation_config.return_dict_in_generate is False:
|
||||
|
||||
@ -12,7 +12,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from collections import deque
|
||||
from math import floor, gcd, sqrt
|
||||
from typing import Optional
|
||||
|
||||
@ -21,8 +20,8 @@ import torch
|
||||
from ...configuration_utils import PreTrainedConfig
|
||||
from ...generation.configuration_utils import GenerationConfig
|
||||
from ...utils.metrics import attach_tracer, traced
|
||||
from .cache_manager import CacheAllocator, FullAttentionCacheAllocator, SlidingAttentionCacheAllocator
|
||||
from .requests import get_device_and_memory_breakdown, logger
|
||||
from .cache_manager import BlockManager, CacheAllocator, FullAttentionCacheAllocator, SlidingAttentionCacheAllocator
|
||||
from .requests import RequestState, get_device_and_memory_breakdown, logger
|
||||
|
||||
|
||||
def group_layers_by_attn_type(config: PreTrainedConfig) -> tuple[list[list[int]], list[str]]:
|
||||
@ -32,7 +31,7 @@ def group_layers_by_attn_type(config: PreTrainedConfig) -> tuple[list[list[int]]
|
||||
- All groups have the same number of layers
|
||||
|
||||
For a model with the following layer types: ["sliding", "full", "full", "sliding", "full", "full", "full", "full"]
|
||||
We would get two groups: [0, 3] and [1, 2], [4,5], [6,7].
|
||||
We would get four groups: [0, 3], [1, 2], [4,5] and [6,7].
|
||||
"""
|
||||
# If the config has no layer_type attribute, it means all layers are the same attention type
|
||||
layer_types = getattr(config, "layer_types", None)
|
||||
@ -116,7 +115,6 @@ class PagedAttentionCache:
|
||||
for the sliding-attention group, although it is not needed.
|
||||
"""
|
||||
|
||||
# TODO: this init is quite long, maybe a refactor is in order
|
||||
def __init__(
|
||||
self,
|
||||
config: PreTrainedConfig,
|
||||
@ -124,8 +122,10 @@ class PagedAttentionCache:
|
||||
device: torch.device,
|
||||
dtype: torch.dtype = torch.float16,
|
||||
tp_size: Optional[int] = None,
|
||||
allow_prefix_sharing: bool = True,
|
||||
) -> None:
|
||||
"""Initialize a paged attention cache for efficient memory usage.
|
||||
"""Initialize a paged attention cache for efficient memory usage. Also turns in prefix sharing if the model has
|
||||
only full attention layers.
|
||||
|
||||
Args:
|
||||
config: Model configuration
|
||||
@ -133,6 +133,7 @@ class PagedAttentionCache:
|
||||
device: Device for the cache tensors
|
||||
dtype: Data type of the cache
|
||||
tp_size: Tensor parallelism size
|
||||
allow_prefix_sharing: A flag to allow prefix sharing if the model has only full attention layers.
|
||||
"""
|
||||
self.config = config
|
||||
self.dtype = dtype
|
||||
@ -173,10 +174,12 @@ class PagedAttentionCache:
|
||||
page_size = self.head_dim * self.num_key_value_heads
|
||||
|
||||
if "flash" in self.config._attn_implementation:
|
||||
num_attention_masks = 1 # only used to compute the default meme args
|
||||
else:
|
||||
num_attention_masks = 0 # only used to compute the default memory footprint args
|
||||
elif "sliding_attention" in group_types:
|
||||
# TODO: when we generalize to allow for block-attn, we can use `num_attention_masks=sum(set(group_types))`
|
||||
num_attention_masks = 2 if "sliding_attention" in group_types else 1
|
||||
num_attention_masks = 2
|
||||
else:
|
||||
num_attention_masks = 1
|
||||
|
||||
memory_handler = PagedAttentionMemoryHandler(
|
||||
block_size=self.block_size,
|
||||
@ -189,7 +192,9 @@ class PagedAttentionCache:
|
||||
num_blocks, max_batch_tokens = memory_handler.infer_num_blocks_and_max_batch_tokens(
|
||||
num_blocks=getattr(generation_config, "num_blocks", None),
|
||||
max_batch_tokens=getattr(generation_config, "max_batch_tokens", None),
|
||||
max_memory_percent=getattr(generation_config, "max_memory", 0.9),
|
||||
max_memory_percent=getattr(
|
||||
generation_config, "max_memory", 0.8
|
||||
), # FIXME: it seems we overcommit memory, was changed from 0.9 which caused OOMs in our benchmarking CI
|
||||
cache_dtype=self.dtype,
|
||||
)
|
||||
|
||||
@ -216,7 +221,6 @@ class PagedAttentionCache:
|
||||
logger.info(f"{self.cache_shape = } {self.key_cache[0].shape = } {self.key_cache[0].numel() = }")
|
||||
|
||||
# Block management data structures
|
||||
self._free_blocks = deque(range(num_blocks))
|
||||
self.group_cache_managers: list[CacheAllocator] = []
|
||||
for i, group_type in enumerate(group_types):
|
||||
if group_type == "full_attention":
|
||||
@ -227,13 +231,19 @@ class PagedAttentionCache:
|
||||
raise ValueError(f"Invalid group type: {group_type}")
|
||||
self.group_cache_managers.append(cm)
|
||||
|
||||
# We only use prefix sharing if the whole model has only full attention layers
|
||||
self.use_prefix_sharing = allow_prefix_sharing and group_types == ["full_attention"]
|
||||
self._block_manager = BlockManager(num_blocks, self.block_size, self.use_prefix_sharing)
|
||||
self.blocks_to_complete: dict[str, int] = {}
|
||||
self._total_prefix_length: int = 0 # a counter to measure the impact of prefix sharing, also used in tests
|
||||
|
||||
@traced
|
||||
def allocate_blocks(self, n_blocks: int, request_id: str) -> int:
|
||||
def allocate_blocks(self, n_blocks: int, state: RequestState) -> int:
|
||||
"""Allocate cache blocks across all layer groups for a given request. Actual allocation is done by the cache
|
||||
managers, and this method only returns the maximum number of blocks actually allocated across all managers."""
|
||||
max_allocated = 0
|
||||
for cm in self.group_cache_managers:
|
||||
allocated = cm.allocate_blocks(n_blocks, request_id, self._free_blocks)
|
||||
allocated = cm.allocate_blocks(n_blocks, state.request_id, self._block_manager)
|
||||
if allocated is None:
|
||||
return None
|
||||
max_allocated = max(max_allocated, allocated)
|
||||
@ -244,11 +254,11 @@ class PagedAttentionCache:
|
||||
"""Free all allocated cache blocks for a given request across all layer groups. Actual deallocation is done
|
||||
by the cache managers."""
|
||||
for cm in self.group_cache_managers:
|
||||
cm.free_blocks(request_id, self._free_blocks)
|
||||
cm.free_blocks(request_id, self._block_manager)
|
||||
|
||||
def get_num_free_blocks(self) -> int:
|
||||
"""Get the current number of unallocated blocks available for new requests."""
|
||||
return len(self._free_blocks)
|
||||
return self._block_manager.num_free_blocks
|
||||
|
||||
@traced
|
||||
def extend_read_indices(
|
||||
@ -335,6 +345,44 @@ class PagedAttentionCache:
|
||||
# Return the new KV values
|
||||
return key_states_with_cache, value_states_with_cache
|
||||
|
||||
def search_prefix_match(self, request_id: str, prompt_ids: list[int]) -> int:
|
||||
"""Searches for a prefix match in the cache for the given (prompts_ids). If one is found, we reference the
|
||||
matching blocks in the (request_id), increase the reference count of the blocks and return the number of blocks
|
||||
that match. If no prefix match is found, we return 0."""
|
||||
current_hash = None
|
||||
allocated_blocks = []
|
||||
for b in range(len(prompt_ids) // self.block_size):
|
||||
tokens = prompt_ids[b * self.block_size : (b + 1) * self.block_size]
|
||||
current_hash = self._block_manager.compute_hash(current_hash, tokens)
|
||||
block_id = self._block_manager._hash_to_id.get(current_hash)
|
||||
if block_id is not None:
|
||||
allocated_blocks.append(block_id)
|
||||
self._block_manager.increase_ref_count(block_id)
|
||||
else:
|
||||
break
|
||||
# If we found a matching prefix, we reference the blocks in the request
|
||||
if allocated_blocks:
|
||||
logger.debug(f"Found prefix match for request {request_id} with {len(allocated_blocks)} blocks")
|
||||
cm = self.group_cache_managers[0]
|
||||
cm.block_table[request_id] = allocated_blocks
|
||||
|
||||
prefix_length = len(allocated_blocks) * self.block_size
|
||||
self._total_prefix_length += prefix_length
|
||||
return prefix_length
|
||||
|
||||
def mark_blocks_as_complete(self, state: RequestState) -> None:
|
||||
"""Marks the blocks that have been computed in the forward pass as complete. If prefix sharing is off, this is
|
||||
a no-op."""
|
||||
num_complete_blocks = 0 if not self.use_prefix_sharing else self.blocks_to_complete.pop(state.request_id)
|
||||
if num_complete_blocks == 0:
|
||||
return None
|
||||
cm = self.group_cache_managers[0] # if prefix sharing is on, there is only one group
|
||||
self._block_manager.mark_blocks_as_complete(
|
||||
num_complete_blocks=num_complete_blocks,
|
||||
allocated_blocks=cm.block_table[state.request_id],
|
||||
prompt_ids=(state.full_prompt_ids + state.static_outputs),
|
||||
)
|
||||
|
||||
|
||||
# TODO: rework computation with the groups and their sizes
|
||||
class PagedAttentionMemoryHandler:
|
||||
@ -414,7 +462,7 @@ class PagedAttentionMemoryHandler:
|
||||
self,
|
||||
num_blocks: Optional[int] = None,
|
||||
max_batch_tokens: Optional[int] = None,
|
||||
max_memory_percent: float = 0.9,
|
||||
max_memory_percent: float = 0.8, # FIXME: it seems we overcommit memory, was changed from 0.9 which caused OOMs in our benchmarking CI
|
||||
cache_dtype: torch.dtype = torch.float16,
|
||||
) -> tuple[int, int]:
|
||||
"""Determine optimal number of blocks and maximum number of tokens per batch based on available memory and
|
||||
@ -454,7 +502,7 @@ class PagedAttentionMemoryHandler:
|
||||
|
||||
def compute_num_blocks_and_max_batch_tokens(
|
||||
self,
|
||||
max_memory_percent: float = 0.9,
|
||||
max_memory_percent: float,
|
||||
cache_dtype: torch.dtype = torch.float16,
|
||||
m: float = 0.01,
|
||||
) -> tuple[int, int]:
|
||||
@ -469,6 +517,8 @@ class PagedAttentionMemoryHandler:
|
||||
2N * (layer_group_size * page_size * cache_dtype + 2 * num_group),
|
||||
m * N * (peak_activation_per_token * activation_dtype + 28 + 4 * num_group),
|
||||
])
|
||||
|
||||
If num_attention_masks is 0, the equation simplifies to a 1st degree polynomial.
|
||||
"""
|
||||
cache_memory = self.get_available_memory(max_memory_percent)
|
||||
logger.info(f"Cache memory: {cache_memory}")
|
||||
@ -480,11 +530,16 @@ class PagedAttentionMemoryHandler:
|
||||
c = -cache_memory
|
||||
logger.debug(f"Coefficients of 2nd degree polynomial: {a = }, {b = }, {c = }")
|
||||
|
||||
# Compute discriminant and greatest solution
|
||||
discriminant = b**2 - 4 * a * c
|
||||
if discriminant < 0:
|
||||
raise ValueError(f"Discriminant is negative: {discriminant = }")
|
||||
greatest_solution = (-b + sqrt(discriminant)) / (2 * a)
|
||||
# If num_attention_masks is 0, the equation simplifies to a 1st degree polynomial
|
||||
if self.num_attention_masks == 0:
|
||||
greatest_solution = -c / b
|
||||
# Otherwise, we solve the quadratic equation
|
||||
else:
|
||||
discriminant = b**2 - 4 * a * c
|
||||
if discriminant < 0:
|
||||
raise ValueError(f"Discriminant is negative: {discriminant = }")
|
||||
greatest_solution = (-b + sqrt(discriminant)) / (2 * a)
|
||||
|
||||
if greatest_solution < 0:
|
||||
raise ValueError(f"Greatest solution is negative: {greatest_solution = }")
|
||||
|
||||
@ -503,7 +558,7 @@ class PagedAttentionMemoryHandler:
|
||||
def compute_max_batch_tokens(
|
||||
self,
|
||||
num_blocks: int,
|
||||
max_memory_percent: float = 0.9,
|
||||
max_memory_percent: float,
|
||||
cache_dtype: torch.dtype = torch.float16,
|
||||
) -> int:
|
||||
"""Calculate maximum batch tokens M given a fixed number of cache blocks. The formula for M is given by:
|
||||
@ -531,7 +586,7 @@ class PagedAttentionMemoryHandler:
|
||||
def compute_num_blocks(
|
||||
self,
|
||||
max_batch_tokens: int,
|
||||
max_memory_percent: float = 0.9,
|
||||
max_memory_percent: float,
|
||||
cache_dtype: torch.dtype = torch.float16,
|
||||
) -> int:
|
||||
"""Calculate number of cache blocks N given a fixed maximum token per token M. The formula for N is given by:
|
||||
|
||||
@ -14,29 +14,211 @@
|
||||
# limitations under the License.
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import deque
|
||||
from collections.abc import Iterator
|
||||
from math import ceil
|
||||
from typing import Optional
|
||||
from typing import Optional, TypeVar
|
||||
|
||||
from .requests import logger
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def reverse_enumerate(xs: list[T]) -> Iterator[tuple[int, T]]:
|
||||
index = len(xs) - 1
|
||||
for x in xs[::-1]:
|
||||
yield index, x
|
||||
index -= 1
|
||||
|
||||
|
||||
class Block:
|
||||
"""A class to represent a block managed by the block manager. We say that a block is complete when the physical KV
|
||||
cache it points to is fully computed. A block can have a parent, which is the block that came before in the
|
||||
sequence. Once a block is complete, it is given a hash, which takes into account the tokens ids of the block and
|
||||
its parent's hash (if there is a parent)."""
|
||||
|
||||
def __init__(self, id_: int, parent_id: int | None) -> None:
|
||||
self.id: int = id_
|
||||
self.parent_id: int | None = parent_id
|
||||
self.hash: int | None = None
|
||||
self.ref_count: int = 1
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"Block(id={self.id}, parent_id={self.parent_id}, hash={self.hash}, ref_count={self.ref_count})"
|
||||
|
||||
@property
|
||||
def is_complete(self) -> bool:
|
||||
return self.hash is not None
|
||||
|
||||
|
||||
class BlockManager:
|
||||
"""A class to manage the number of free blocks and block re-use. If prefix sharing is off, the block manager is a
|
||||
simple FIFO structure where blocks are either free or in use. If prefix sharing is on, blocks can have 3 states:
|
||||
- in use: one or more requests references this block, thus it cannot be written over. The number of requests
|
||||
referencing this block is stored as ref_count in the Block object.
|
||||
- un-initialized: the block points to a space in the KV cache tensor that contains no data yet. Those blocks can
|
||||
be given as free blocks to new requests without any overhead.
|
||||
- initialized: the block is complete and was used by one or more request that are finished. It contains KV cache
|
||||
data and its hash is stored in the hash table. If a new request needs a block with the same hash, we increase
|
||||
the ref_count of the block and remove it from the list of initialized blocks, because it is now in use.
|
||||
Still, the block can be freed if no un-initialized blocks are left. In that case, we remove its hash from the
|
||||
hash table.
|
||||
There is no structure to keep track of the blocks in use: if a block is neither un-initialized nor initialized,
|
||||
it is in use.
|
||||
"""
|
||||
|
||||
def __init__(self, num_blocks: int, block_size: int, use_prefix_sharing: bool) -> None:
|
||||
"""Initializes the block manager with a given number of blocks (num_blocks) of size (block_size). Prefix sharing
|
||||
can be turned on with the (use_prefix_sharing) flag, which only happens if the model has only full attention
|
||||
layers."""
|
||||
self.num_blocks = num_blocks
|
||||
self.block_size = block_size
|
||||
self._uninit_block_ids = deque(range(num_blocks))
|
||||
self._init_block_ids: dict[int, None] = {} # effectively act as an ordered set
|
||||
self._use_prefix_sharing = use_prefix_sharing
|
||||
self._hash_to_id: dict[int, int] = {}
|
||||
self._id_to_block: dict[int, Block] = {}
|
||||
|
||||
@property
|
||||
def num_free_blocks(self) -> int:
|
||||
"""Returns the number of free blocks left. Both initialized and uninitialized blocks are considered free."""
|
||||
return len(self._uninit_block_ids) + len(self._init_block_ids)
|
||||
|
||||
def has_enough_free_blocks(self, n_blocks: int) -> bool:
|
||||
"""Checks if there are enough free blocks to allocate the requested number of blocks (n_blocks). If there are
|
||||
not enough uninitialized blocks, we uninitialize the required number of initialized blocks."""
|
||||
# Exit early if there are enough uninitialized blocks
|
||||
if len(self._uninit_block_ids) >= n_blocks:
|
||||
return True
|
||||
# Exit early if even after uninitializing all initialized blocks, there are not enough free blocks
|
||||
block_to_unintialize = n_blocks - len(self._uninit_block_ids)
|
||||
if len(self._init_block_ids) < block_to_unintialize:
|
||||
return False
|
||||
# Uninitialize the required amount of blocks
|
||||
for _ in range(block_to_unintialize):
|
||||
id_to_unintialize = self._init_block_ids.popitem()[0]
|
||||
block = self._id_to_block[id_to_unintialize]
|
||||
self._hash_to_id.pop(block.hash)
|
||||
self._uninit_block_ids.append(id_to_unintialize)
|
||||
return True
|
||||
|
||||
def get_free_blocks(self, n_blocks: int, last_block_id: int | None) -> list[int] | None:
|
||||
"""Returns a list of (n_blocks) free block and mark them as no longuer free in the internal data structures. One
|
||||
can also pass a (last_block_id) to indicate the last block id in the sequence, which is used to keep track of
|
||||
the parent block. If the manager cannot find enough free blocks, it returns None."""
|
||||
if not self.has_enough_free_blocks(n_blocks):
|
||||
return None
|
||||
allocated_block_ids = [self._uninit_block_ids.popleft() for _ in range(n_blocks)]
|
||||
# If we use prefix caching, we keep track of the allocated blocks as partial blocks
|
||||
if self._use_prefix_sharing:
|
||||
for block_id in allocated_block_ids:
|
||||
block = Block(block_id, last_block_id)
|
||||
self._id_to_block[block_id] = block
|
||||
last_block_id = block_id
|
||||
# In both cases, we return the allocated block ids
|
||||
return allocated_block_ids
|
||||
|
||||
def increase_ref_count(self, block_id: int) -> None:
|
||||
"""Increases the reference count of a given (block_id)."""
|
||||
block = self._id_to_block[block_id]
|
||||
block.ref_count += 1
|
||||
if block.ref_count == 1:
|
||||
self._init_block_ids.pop(block_id)
|
||||
|
||||
def decrease_ref_count(self, block_id: int) -> None:
|
||||
"""Decreases the reference count of a given (block_id). If the reference count reaches 0, the block is no longer
|
||||
in use, and becomes initialized (if it was complete) or uninitialized (if it was incomplete)."""
|
||||
block = self._id_to_block[block_id]
|
||||
block.ref_count -= 1
|
||||
if block.ref_count == 0:
|
||||
if block.is_complete:
|
||||
self._init_block_ids[block_id] = None
|
||||
else:
|
||||
self._id_to_block.pop(block_id)
|
||||
self._uninit_block_ids.append(block_id)
|
||||
|
||||
def free_blocks(self, blocks: list[int]) -> None:
|
||||
"""Marks a list of (blocks) as free. If there is no prefix sharing, we simply add them to the uninitialized
|
||||
blocks queue. Otherwise, their new state depends on whether they are complete."""
|
||||
if self._use_prefix_sharing:
|
||||
for block_id in blocks:
|
||||
self.decrease_ref_count(block_id)
|
||||
else:
|
||||
self._uninit_block_ids.extend(blocks)
|
||||
|
||||
def mark_blocks_as_complete(
|
||||
self, num_complete_blocks: int, allocated_blocks: list[int], prompt_ids: list[int]
|
||||
) -> None:
|
||||
"""Among the list of (allocated_blocks), mark (num_complete_blocks) incomplete blocks as now complete. The list
|
||||
of (prompt_ids) is used to compute the hash of the new block."""
|
||||
# Look for the first complete block, starting from the last block in the sequence
|
||||
parent_hash = None
|
||||
incomplete_blocks: list[Block] = []
|
||||
for i, block_id in reverse_enumerate(allocated_blocks):
|
||||
block = self._id_to_block[block_id]
|
||||
if block.is_complete:
|
||||
parent_hash = block.hash
|
||||
break
|
||||
incomplete_blocks.append((i, block))
|
||||
|
||||
# Now go through the incomplete blocks and updated them
|
||||
new_parent_id = None
|
||||
while incomplete_blocks:
|
||||
i, block = incomplete_blocks.pop()
|
||||
|
||||
# If the parent id has been updated, we apply the change
|
||||
if new_parent_id is not None:
|
||||
block.parent_id = new_parent_id
|
||||
new_parent_id = None
|
||||
|
||||
# If we have set the hash for all complete blocks, we can stop
|
||||
if num_complete_blocks == 0:
|
||||
break
|
||||
|
||||
# Otherwise, we compute the hash
|
||||
num_complete_blocks -= 1
|
||||
tokens = prompt_ids[i * self.block_size : (i + 1) * self.block_size]
|
||||
block.hash = self.compute_hash(parent_hash, tokens)
|
||||
|
||||
existing_block_id = self._hash_to_id.get(block.hash)
|
||||
# If the block hash is already in the hash to id mapping, we reference the existing block instead
|
||||
if existing_block_id is not None:
|
||||
logger.debug(f"Found existing block {existing_block_id} for block {block.id}")
|
||||
allocated_blocks[i] = existing_block_id
|
||||
self._id_to_block[existing_block_id].ref_count += 1
|
||||
new_parent_id = existing_block_id
|
||||
self.free_blocks([block.id])
|
||||
|
||||
# Otherwise, we add the completed block to the hash table
|
||||
else:
|
||||
self._hash_to_id[block.hash] = block.id
|
||||
|
||||
# Update loop variables
|
||||
parent_hash = block.hash
|
||||
|
||||
def compute_hash(self, parent_hash: int | None, tokens: list[int]) -> int:
|
||||
"""Computes the hash of a block containing the given (tokens) with a given (parent_hash). If the block has no
|
||||
parent, the parent hash is None."""
|
||||
return hash((parent_hash, tuple(tokens)))
|
||||
|
||||
|
||||
class CacheAllocator(ABC):
|
||||
"""Abstract base class for cache managers. Cache managers keep track of per-request cache allocations, determine
|
||||
when a new physical block needs to be allocated and compute physical indices for reading or writing to the cache."""
|
||||
|
||||
_index: int
|
||||
_block_table: dict[str, list[int]] # request_id -> list of block_ids allocated to the request
|
||||
block_table: dict[str, list[int]] # request_id -> list of block_ids allocated to the request
|
||||
|
||||
@abstractmethod
|
||||
def allocate_blocks(self, n_blocks: int, request_id: str, free_blocks: deque[int]) -> Optional[int]:
|
||||
"""Allocates n_blocks for a given request_id. Returns the num of blocks allocated if successful and None
|
||||
otherwise."""
|
||||
def allocate_blocks(self, n_blocks: int, request_id: str, block_manager: BlockManager) -> Optional[int]:
|
||||
"""Allocates (n_blocks) for a given (request_id) using the (block_manager). Returns the num of blocks allocated
|
||||
if successful and None otherwise."""
|
||||
|
||||
def free_blocks(self, request_id: str, free_blocks: deque[int]) -> None:
|
||||
"""Frees all blocks associated with a request_id."""
|
||||
if request_id in self._block_table:
|
||||
blocks_to_free = self._block_table.pop(request_id)
|
||||
free_blocks.extend(blocks_to_free)
|
||||
def free_blocks(self, request_id: str, block_manager: BlockManager) -> None:
|
||||
"""Frees all blocks associated with a (request_id) using the (block_manager)."""
|
||||
if request_id in self.block_table:
|
||||
blocks_to_free = self.block_table.pop(request_id)
|
||||
block_manager.free_blocks(blocks_to_free)
|
||||
else:
|
||||
logger.warning(
|
||||
f"CacheAllocator {self._index} attempted to free blocks for non-existent request_id: {request_id}"
|
||||
@ -66,23 +248,30 @@ class FullAttentionCacheAllocator(CacheAllocator):
|
||||
"""
|
||||
self._index = index
|
||||
self.block_size = block_size
|
||||
self._block_table = {}
|
||||
self.block_table = {}
|
||||
|
||||
def allocate_blocks(self, n_blocks: int, request_id: str, free_blocks: deque[int]) -> Optional[int]:
|
||||
"""Allocate blocks for a given request_id. Returns the number of blocks allocated if successful and None
|
||||
otherwise. For group of full attention layers, we always allocate the number of requested blocks."""
|
||||
if len(free_blocks) < n_blocks:
|
||||
def allocate_blocks(self, n_blocks: int, request_id: str, block_manager: BlockManager) -> Optional[int]:
|
||||
"""Allocate (n_blocks) for a given (request_id) using the (block_manager). Returns the number of blocks
|
||||
allocated if successful and None otherwise. For group of full attention layers, we always allocate the number of
|
||||
requested blocks."""
|
||||
# Make sure the request_id is in the block table and get the first block id
|
||||
if request_id not in self.block_table:
|
||||
self.block_table[request_id] = [] # TODO: check the impact of making this a deque
|
||||
last_block_id = None
|
||||
else:
|
||||
last_block_id = self.block_table[request_id][-1]
|
||||
# Actual allocation, return early if failed
|
||||
allocated_blocks = block_manager.get_free_blocks(n_blocks, last_block_id)
|
||||
if allocated_blocks is None:
|
||||
return None
|
||||
if request_id not in self._block_table:
|
||||
self._block_table[request_id] = []
|
||||
self._block_table[request_id].extend(free_blocks.popleft() for _ in range(n_blocks))
|
||||
self.block_table[request_id].extend(allocated_blocks)
|
||||
return n_blocks
|
||||
|
||||
def get_read_indices(self, request_id: str, past_length: int, query_length: int) -> list[int]:
|
||||
"""Returns the physical indices of where to read request_id's cache. For a group of full attention layers, we
|
||||
first write the new cache to the cache tensor and then read the entire cache from the beginning to the end."""
|
||||
# Retrieve the block table for the request and raise an error if it doesn't exist
|
||||
block_table = self._block_table.get(request_id)
|
||||
block_table = self.block_table.get(request_id)
|
||||
if block_table is None:
|
||||
raise ValueError(f"No block table found for request {request_id}")
|
||||
# Compute the physical indices
|
||||
@ -97,7 +286,7 @@ class FullAttentionCacheAllocator(CacheAllocator):
|
||||
def get_write_indices(self, request_id: str, past_length: int, query_length: int) -> list[int]:
|
||||
"""Returns the physical indices for writing to the cache. For a group of full attention layers, we write the new
|
||||
cache as a continuation of the existing cache for the same request."""
|
||||
block_table = self._block_table.get(request_id)
|
||||
block_table = self.block_table.get(request_id)
|
||||
if block_table is None:
|
||||
raise ValueError(f"No block table found for request {request_id}")
|
||||
# Compute the physical indices
|
||||
@ -129,25 +318,26 @@ class SlidingAttentionCacheAllocator(CacheAllocator):
|
||||
self.block_size = block_size
|
||||
self.sliding_window = sliding_window
|
||||
self._max_blocks_per_request = ceil(self.sliding_window / self.block_size)
|
||||
self._block_table = {}
|
||||
self.block_table = {}
|
||||
|
||||
def allocate_blocks(self, n_blocks: int, request_id: str, free_blocks: deque[int]) -> Optional[int]:
|
||||
"""Allocate blocks for a given request_id. Returns the number of blocks allocated if successful and None
|
||||
otherwise. For group of sliding window attention layers, we only allocate up to the point where we can fit an
|
||||
entire sliding window in the cache tensor."""
|
||||
if request_id not in self._block_table:
|
||||
self._block_table[request_id] = []
|
||||
def allocate_blocks(self, n_blocks: int, request_id: str, block_manager: BlockManager) -> Optional[int]:
|
||||
"""Allocate (n_blocks) for a given (request_id) using the (block_manager). Returns the number of blocks
|
||||
allocated otherwise. For group of sliding window attention layers, we only allocate up to the point where we can
|
||||
fit an entire sliding window in the cache tensor."""
|
||||
if request_id not in self.block_table:
|
||||
self.block_table[request_id] = []
|
||||
# Early return if we are already at the max number of blocks per request
|
||||
already_allocated = len(self._block_table[request_id])
|
||||
already_allocated = len(self.block_table[request_id])
|
||||
if already_allocated == self._max_blocks_per_request:
|
||||
return 0
|
||||
# Compute actual number of blocks to allocate
|
||||
after_allocation = min(already_allocated + n_blocks, self._max_blocks_per_request)
|
||||
actual_n_blocks = after_allocation - already_allocated
|
||||
# Classic allocation
|
||||
if len(free_blocks) < actual_n_blocks:
|
||||
allocated_blocks = block_manager.get_free_blocks(actual_n_blocks, None) # no prefix caching w/ sliding window
|
||||
if allocated_blocks is None:
|
||||
return None
|
||||
self._block_table[request_id].extend(free_blocks.popleft() for _ in range(actual_n_blocks))
|
||||
self.block_table[request_id].extend(allocated_blocks)
|
||||
return actual_n_blocks
|
||||
|
||||
def get_read_indices(self, request_id: str, past_length: int, query_length: int) -> list[int]:
|
||||
@ -157,7 +347,7 @@ class SlidingAttentionCacheAllocator(CacheAllocator):
|
||||
sliding_window - 1 cache page and then manually add the new key / values states after. Hence the -1 indices
|
||||
which indicate where to store the new key or values indices."""
|
||||
# Retrieve the block table for the request and raise an error if it doesn't exist
|
||||
block_table = self._block_table.get(request_id)
|
||||
block_table = self.block_table.get(request_id)
|
||||
if block_table is None:
|
||||
raise ValueError(f"No block table found for request {request_id}")
|
||||
# Apply sliding window
|
||||
@ -178,7 +368,7 @@ class SlidingAttentionCacheAllocator(CacheAllocator):
|
||||
sliding window attention layers, we write the new cache in rolling-buffer kind of way: if we reach the end of
|
||||
the allocated physical cache, we start writing from the beginning of the physical cache again."""
|
||||
# Retrieve the block table for the request and raise an error if it doesn't exist
|
||||
block_table = self._block_table.get(request_id)
|
||||
block_table = self.block_table.get(request_id)
|
||||
if block_table is None:
|
||||
raise ValueError(f"No block table found for request {request_id}")
|
||||
# Apply sliding window
|
||||
@ -201,22 +391,3 @@ class SlidingAttentionCacheAllocator(CacheAllocator):
|
||||
"""Returns the attention type of the cache allocator and the key sequence length for the given request_id."""
|
||||
seqlens_k = query_length + min(past_length, self.sliding_window - 1)
|
||||
return "sliding_attention", seqlens_k
|
||||
|
||||
|
||||
# TODO: test the impact of this
|
||||
# def get_read_indices(self, request_id: str, past_length: int) -> list[int]:
|
||||
# # Retrieve the block table for the request and raise an error if it doesn't exist
|
||||
# block_table = self._block_table.get(request_id)
|
||||
# if block_table is None:
|
||||
# raise ValueError(f"No block table found for request {request_id}")
|
||||
# # Compute the physical indices
|
||||
# physical_indices = []
|
||||
# n_left = past_length
|
||||
# for block_idx in block_table:
|
||||
# block_physical_index = block_idx * self.block_size
|
||||
# pages_used = min(self.block_size, n_left)
|
||||
# physical_indices.extend(block_physical_index + i for i in range(pages_used))
|
||||
# n_left -= pages_used
|
||||
# if n_left == 0:
|
||||
# return physical_indices
|
||||
# raise ValueError(f"Request {request_id} required too many indices: {past_length = } and {len(block_table) = }")
|
||||
|
||||
@ -16,12 +16,13 @@
|
||||
import queue
|
||||
import threading
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from itertools import count
|
||||
from math import ceil
|
||||
from time import perf_counter
|
||||
from typing import Optional, Union
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -446,10 +447,7 @@ class ContinuousBatchProcessor:
|
||||
cumulative_seqlens_q = [0]
|
||||
logits_indices = []
|
||||
|
||||
if isinstance(self.cumulative_seqlens_k, dict):
|
||||
cumulative_seqlens_k = {layer_type: [0] for layer_type in self.cumulative_seqlens_k}
|
||||
else:
|
||||
cumulative_seqlens_k = [0]
|
||||
cumulative_seqlens_k = {layer_type: [0] for layer_type in self.cumulative_seqlens_k}
|
||||
|
||||
read_index = [[] for _ in range(self.cache.num_groups)]
|
||||
write_index = [[] for _ in range(self.cache.num_groups)]
|
||||
@ -498,10 +496,7 @@ class ContinuousBatchProcessor:
|
||||
self.metrics.record_kv_cache_memory_metrics(self.cache)
|
||||
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
if isinstance(self.cumulative_seqlens_k, dict):
|
||||
ck = max(cumulative_seqlens_k[layer_type][-1] for layer_type in self.cumulative_seqlens_k)
|
||||
else:
|
||||
ck = cumulative_seqlens_k[-1]
|
||||
ck = max(cumulative_seqlens_k[layer_type][-1] for layer_type in self.cumulative_seqlens_k)
|
||||
logger.debug(
|
||||
f"Scheduled: {len(self.requests_in_batch)}, Waiting: {len(self.scheduler.waiting_requests)}, "
|
||||
f"Active: {len(self.scheduler.active_requests)}. cum Q: {cumulative_seqlens_q[-1]}. "
|
||||
@ -517,7 +512,7 @@ class ContinuousBatchProcessor:
|
||||
read_index: list[list[int]],
|
||||
write_index: list[list[int]],
|
||||
cumulative_seqlens_q: list[int],
|
||||
cumulative_seqlens_k: Union[list[int], dict[str, list[int]]],
|
||||
cumulative_seqlens_k: dict[str, list[int]],
|
||||
logits_indices: list[int],
|
||||
) -> None:
|
||||
"""Builds the actual tensors for the current batch, by modifying the already allocated tensors in place."""
|
||||
@ -561,9 +556,7 @@ class ContinuousBatchProcessor:
|
||||
@traced
|
||||
def _maybe_send_output(self, state: RequestState) -> None:
|
||||
"""Send output to the queue based on streaming mode and request state."""
|
||||
if state.streaming:
|
||||
self.output_queue.put(state.to_generation_output())
|
||||
elif state.status == RequestStatus.FINISHED:
|
||||
if state.streaming or state.status == RequestStatus.FINISHED:
|
||||
self.output_queue.put(state.to_generation_output())
|
||||
|
||||
@traced
|
||||
@ -571,17 +564,27 @@ class ContinuousBatchProcessor:
|
||||
"""Update request states based on generated tokens."""
|
||||
out_tokens = self._sync()
|
||||
for i, state in enumerate(self.requests_in_batch):
|
||||
# If the request has no remaining prompt ids, it means prefill has already ended or just finished
|
||||
if len(state.remaining_prompt_ids) == 0:
|
||||
self.metrics.record_ttft_metric(state.created_time, state.request_id)
|
||||
state.status = RequestStatus.DECODING
|
||||
token = out_tokens[self.logits_indices[i]]
|
||||
state.prompt_ids = [token]
|
||||
if state.update_with_token(token):
|
||||
# Update the request and stop if it is complete
|
||||
is_finished = state.update_and_check_completion(token)
|
||||
# We mark the completed blocks as such
|
||||
self.cache.mark_blocks_as_complete(state)
|
||||
if is_finished:
|
||||
self.metrics.record_request_completion(state.created_time, state.request_id)
|
||||
self.scheduler.finish_request(state.request_id, evict_from_cache=(not self.manual_eviction))
|
||||
self._maybe_send_output(state)
|
||||
# Otherwise, the request is still prefilling, but the prefill has been split
|
||||
elif state.status == RequestStatus.PREFILLING_SPLIT:
|
||||
self.cache.mark_blocks_as_complete(state)
|
||||
state.status = RequestStatus.SPLIT_PENDING_REMAINDER
|
||||
else:
|
||||
raise ValueError(f"Request {state.request_id} is in an unexpected state: {state.status}")
|
||||
|
||||
if self.cache.get_num_free_blocks() == 0:
|
||||
raise ValueError("No more free blocks")
|
||||
|
||||
@ -726,6 +729,7 @@ class ContinuousBatchingManager:
|
||||
max_queue_size: int = 0,
|
||||
num_q_cuda_graphs: int = 0,
|
||||
num_kv_cuda_graphs: int = 0,
|
||||
allow_prefix_sharing: bool = True,
|
||||
) -> None:
|
||||
"""Initialize the continuous batching manager.
|
||||
|
||||
@ -735,6 +739,7 @@ class ContinuousBatchingManager:
|
||||
max_queue_size: Maximum size of the request queue (0 = unlimited)
|
||||
num_q_cuda_graphs: (optional) Number of CUDA graphs to use for the query dimension
|
||||
num_kv_cuda_graphs: (optional) Number of CUDA graphs to use for the keys/values dimension
|
||||
allow_prefix_sharing: (optional) Whether to allow prefix sharing if the model has only full attention layers
|
||||
"""
|
||||
if "paged|" not in model.config._attn_implementation:
|
||||
attn_implementation = f"paged|{model.config._attn_implementation}"
|
||||
@ -767,6 +772,8 @@ class ContinuousBatchingManager:
|
||||
self.manual_eviction = manual_eviction
|
||||
self.batch_processor: Optional[ContinuousBatchProcessor] = None
|
||||
|
||||
self._allow_prefix_sharing = allow_prefix_sharing
|
||||
|
||||
# If a number of cuda graphs was specified for either Q or KV, we activate cuda graphs
|
||||
if num_q_cuda_graphs > 0 or num_kv_cuda_graphs > 0:
|
||||
self.use_cuda_graph = True
|
||||
@ -799,7 +806,6 @@ class ContinuousBatchingManager:
|
||||
logger.warning("Manager thread is already running.")
|
||||
return
|
||||
|
||||
self._result_queue = queue.Queue()
|
||||
self._generation_thread = threading.Thread(target=self._run_generation_loop)
|
||||
self._generation_thread.start()
|
||||
|
||||
@ -807,25 +813,38 @@ class ContinuousBatchingManager:
|
||||
"""Check if the background generation thread is running."""
|
||||
return self._generation_thread is not None and self._generation_thread.is_alive()
|
||||
|
||||
def stop(self, block: bool = False, timeout: Optional[float] = None) -> None:
|
||||
def stop(self, block: bool = True, timeout: Optional[float] = None) -> None:
|
||||
"""Signal the background thread to stop.
|
||||
|
||||
Args:
|
||||
block: Whether to wait for the thread to stop
|
||||
timeout: Maximum time to wait for the thread to stop
|
||||
"""
|
||||
if self.batch_processor is None:
|
||||
logger.warning("\nBatch processor was not initialized.")
|
||||
else:
|
||||
if self.batch_processor.cache.use_prefix_sharing:
|
||||
logger.warning(
|
||||
f"\nPrefix sharing was on. Total prefix length: {self.batch_processor.cache._total_prefix_length}"
|
||||
)
|
||||
else:
|
||||
logger.warning("\nPrefix sharing was off.")
|
||||
|
||||
if self._generation_thread is None:
|
||||
logger.warning("Manager not started.")
|
||||
return
|
||||
|
||||
stop_trigger_time = perf_counter()
|
||||
if not self.stop_event.is_set():
|
||||
self.stop_event.set()
|
||||
logger.info("Stopping continuous batching manager...")
|
||||
|
||||
if block:
|
||||
self.join(timeout)
|
||||
self.join(stop_trigger_time, timeout)
|
||||
|
||||
def join(self, timeout: Optional[float] = None) -> None:
|
||||
self.batch_processor = None
|
||||
|
||||
def join(self, stop_trigger_time: float, timeout: Optional[float] = None) -> None:
|
||||
"""Wait for the background thread to finish.
|
||||
|
||||
Args:
|
||||
@ -834,9 +853,10 @@ class ContinuousBatchingManager:
|
||||
if self._generation_thread is not None:
|
||||
self._generation_thread.join(timeout=timeout)
|
||||
if self._generation_thread.is_alive():
|
||||
logger.warning("Generation thread did not exit after join timeout.")
|
||||
logger.warning(f"Generation thread did not exit after join timeout ({timeout}).")
|
||||
else:
|
||||
logger.info("Continuous Batching Manager stopped.")
|
||||
end = perf_counter()
|
||||
logger.info(f"Continuous Batching Manager stopped after {end - stop_trigger_time:.2f}s.")
|
||||
self._generation_thread = None
|
||||
|
||||
def add_request(
|
||||
@ -877,9 +897,11 @@ class ContinuousBatchingManager:
|
||||
self.input_queue.put(state, block=True, timeout=10) # XXX: pass timeout as fn arg?
|
||||
return request_id
|
||||
|
||||
def add_requests(self, inputs: list[list[int]], max_new_tokens: Optional[int] = None) -> None:
|
||||
def add_requests(
|
||||
self, inputs: list[list[int]], max_new_tokens: Optional[int] = None, streaming: bool = False
|
||||
) -> None:
|
||||
for input_ids in inputs:
|
||||
self.add_request(input_ids, max_new_tokens=max_new_tokens)
|
||||
self.add_request(input_ids, max_new_tokens=max_new_tokens, streaming=streaming)
|
||||
|
||||
def cancel_request(self, request_id: str) -> None:
|
||||
"""Cancel a request by its ID.
|
||||
@ -890,6 +912,7 @@ class ContinuousBatchingManager:
|
||||
if self.batch_processor is not None:
|
||||
self.batch_processor.scheduler.set_request_cancellation(request_id)
|
||||
|
||||
# TODO:handle benchmarking properly when updating / fixing the requeue logic
|
||||
def get_result(
|
||||
self, request_id: Optional[str] = None, timeout: Optional[float] = None
|
||||
) -> Optional[GenerationOutput]:
|
||||
@ -905,6 +928,7 @@ class ContinuousBatchingManager:
|
||||
return None
|
||||
try:
|
||||
result = self.output_queue.get(block=True, timeout=timeout)
|
||||
# NOTE: requeue logic here
|
||||
if request_id is not None and result.request_id != request_id:
|
||||
self.output_queue.put(result)
|
||||
return None
|
||||
@ -931,20 +955,6 @@ class ContinuousBatchingManager:
|
||||
request_cancelled = self.batch_processor.scheduler.request_is_cancelled(request_id)
|
||||
|
||||
@traced
|
||||
def warmup(self, batch_processor: ContinuousBatchProcessor) -> None:
|
||||
stream = torch.cuda.Stream(device=self.model.device)
|
||||
stream.wait_stream(torch.cuda.current_stream())
|
||||
with torch.cuda.stream(stream):
|
||||
# Warmup the model with a dummy forward pass
|
||||
self._generation_step(batch_processor)
|
||||
torch.cuda.current_stream().wait_stream(stream)
|
||||
|
||||
self.graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(self.graph, stream=stream):
|
||||
self._generation_step(batch_processor)
|
||||
|
||||
@traced
|
||||
# @torch.compile
|
||||
def _generation_step(self) -> None:
|
||||
"""Perform a single generation step. This is cuda graphed"""
|
||||
self.batch_processor._generation_step(self.model, self.logit_processor, self.do_sample)
|
||||
@ -960,6 +970,7 @@ class ContinuousBatchingManager:
|
||||
self.model.device,
|
||||
self.model.dtype,
|
||||
tp_size=getattr(self.model, "_tp_size", None), # Use model's actual TP setting
|
||||
allow_prefix_sharing=self._allow_prefix_sharing,
|
||||
)
|
||||
logger.debug(f"PagedAttentionCache created in {perf_counter() - t0} seconds")
|
||||
|
||||
@ -1051,6 +1062,15 @@ class ContinuousBatchingManager:
|
||||
class ContinuousMixin:
|
||||
"""Mixin class for models to add continuous batching capabilities."""
|
||||
|
||||
@contextmanager
|
||||
def continuous_batching_context_manager(self, **kwargs) -> Generator[ContinuousBatchingManager]:
|
||||
manager = self.init_continuous_batching(**kwargs)
|
||||
manager.start()
|
||||
try:
|
||||
yield manager
|
||||
finally:
|
||||
manager.stop(block=True)
|
||||
|
||||
def init_continuous_batching(
|
||||
self,
|
||||
generation_config: Optional[GenerationConfig] = None,
|
||||
@ -1058,6 +1078,7 @@ class ContinuousMixin:
|
||||
max_queue_size: int = 0,
|
||||
num_q_cuda_graphs: int = 0,
|
||||
num_kv_cuda_graphs: int = 0,
|
||||
allow_prefix_sharing: bool = True,
|
||||
) -> ContinuousBatchingManager:
|
||||
"""Initialize a manager for continuous batching inference.
|
||||
|
||||
@ -1090,8 +1111,10 @@ class ContinuousMixin:
|
||||
max_queue_size=max_queue_size,
|
||||
num_q_cuda_graphs=num_q_cuda_graphs,
|
||||
num_kv_cuda_graphs=num_kv_cuda_graphs,
|
||||
allow_prefix_sharing=allow_prefix_sharing,
|
||||
)
|
||||
|
||||
# TODO: support streaming
|
||||
@traced
|
||||
@torch.inference_mode()
|
||||
def generate_batch(
|
||||
@ -1148,7 +1171,7 @@ class ContinuousMixin:
|
||||
result = manager.get_result(timeout=1)
|
||||
if result:
|
||||
req_id = result.request_id
|
||||
if result.status == RequestStatus.FINISHED:
|
||||
if result.is_finished():
|
||||
results[req_id] = result
|
||||
finished_count += 1
|
||||
pbar.update(1)
|
||||
@ -1160,5 +1183,6 @@ class ContinuousMixin:
|
||||
except Exception as e:
|
||||
logger.error(f"Error during batch generation: {e}", exc_info=True)
|
||||
finally:
|
||||
logger.debug("Generate batch is finished.") # a dummy log needed for the logs of stop to show. Won't show.
|
||||
manager.stop(block=True, timeout=5.0)
|
||||
return results
|
||||
|
||||
@ -19,6 +19,7 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from ...utils import is_torch_xpu_available
|
||||
from ...utils.logging import logging
|
||||
from ...utils.metrics import traced
|
||||
|
||||
@ -35,6 +36,13 @@ def get_device_and_memory_breakdown() -> tuple[torch.device, int, int, int]:
|
||||
total_memory = torch.cuda.get_device_properties(device).total_memory
|
||||
reserved_memory = torch.cuda.memory_reserved(device)
|
||||
allocated_memory = torch.cuda.memory_allocated(device)
|
||||
elif is_torch_xpu_available():
|
||||
device = torch.device("xpu")
|
||||
torch.xpu.empty_cache()
|
||||
torch.xpu.synchronize()
|
||||
total_memory = torch.xpu.get_device_properties(device).total_memory
|
||||
reserved_memory = torch.xpu.memory_reserved(device)
|
||||
allocated_memory = torch.xpu.memory_allocated(device)
|
||||
elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
|
||||
device = torch.device("mps")
|
||||
# MPS memory reporting (PyTorch 2.0+)
|
||||
@ -83,6 +91,9 @@ class GenerationOutput:
|
||||
status: RequestStatus = RequestStatus.PENDING
|
||||
created_time: float = field(default_factory=time.time)
|
||||
|
||||
def is_finished(self) -> bool:
|
||||
return self.status == RequestStatus.FINISHED
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestState:
|
||||
@ -105,10 +116,10 @@ class RequestState:
|
||||
error (Optional[str]): Any error message associated with the request. When None, has had no error yet.
|
||||
"""
|
||||
|
||||
# Required fields
|
||||
# Required fields # TODO: come up with better names / not sure prompt_ids and such are not redundant
|
||||
request_id: str
|
||||
full_prompt_ids: Optional[list[int]] = None # Full initial prompt
|
||||
prompt_ids: Optional[list[int]] = None # Tokens IDs currently being processed (initial + generated)
|
||||
prompt_ids: Optional[list[int]] = None # Tokens IDs currently being processed
|
||||
remaining_prompt_ids: list[int] = field(default_factory=list) # For split requests, prefill left to process
|
||||
static_outputs: list[int] = field(default_factory=list) # Generated tokens
|
||||
allocated_blocks: int = 0 # Number of blocks allocated to the request
|
||||
@ -153,7 +164,7 @@ class RequestState:
|
||||
|
||||
# TODO: this logic seems one token off, check it out
|
||||
@traced
|
||||
def update_with_token(self, token_id: int) -> bool:
|
||||
def update_and_check_completion(self, token_id: int) -> bool:
|
||||
"""Update the request with a newly generated token and check for completion.
|
||||
|
||||
Args:
|
||||
|
||||
@ -104,7 +104,7 @@ class Scheduler(ABC):
|
||||
)
|
||||
|
||||
@traced
|
||||
def _allocate_blocks_if_needed(self, state: RequestState, len_next_tokens: int) -> bool:
|
||||
def _allocate_blocks_if_needed(self, state: RequestState) -> bool:
|
||||
"""Allocate additional cache blocks for a request if the currently allocated blocks are insufficient to
|
||||
accommodate the next tokens. It calculates how many blocks are needed based on the request's current
|
||||
cache occupancy and the number of tokens to be processed. The allocation itself is done by the CacheAllocator
|
||||
@ -113,10 +113,11 @@ class Scheduler(ABC):
|
||||
# 1. we check that the occupancy is less than the requested length
|
||||
# 2. we allocate enough blocks to cover the requested length
|
||||
current_len = state.current_len()
|
||||
len_next_tokens = len(state.prompt_ids)
|
||||
occupancy = state.allocated_blocks * self.cache.block_size - current_len
|
||||
if occupancy < len_next_tokens or state.allocated_blocks == 0:
|
||||
blocks_needed = ((len_next_tokens - occupancy + 1) // self.cache.block_size) + 1
|
||||
allocated = self.cache.allocate_blocks(blocks_needed, state.request_id)
|
||||
allocated = self.cache.allocate_blocks(blocks_needed, state)
|
||||
if allocated is None:
|
||||
return False
|
||||
state.allocated_blocks += allocated
|
||||
@ -125,11 +126,29 @@ class Scheduler(ABC):
|
||||
@traced(span_name="prepare_request")
|
||||
def _prepare_request_for_processing(
|
||||
self, state: RequestState, token_budget: int, request_ids_to_remove_from_waiting: set[str]
|
||||
):
|
||||
"""Prepares a request for processing in the current batch."""
|
||||
request_tokens = (
|
||||
state.remaining_prompt_ids if state.status == RequestStatus.SPLIT_PENDING_REMAINDER else state.prompt_ids
|
||||
)
|
||||
) -> None:
|
||||
"""Prepares a request for processing in the current batch. If prefix sharing is enabled, and the request was
|
||||
pending, this is where we look for a prefix match and split the request if found."""
|
||||
# If prefix sharing is enabled, we look for a prefix match and split the request if found
|
||||
if self.cache.use_prefix_sharing and state.status == RequestStatus.PENDING:
|
||||
prefill_length = self.cache.search_prefix_match(state.request_id, state.prompt_ids)
|
||||
if prefill_length > 0:
|
||||
self.active_requests[state.request_id] = state
|
||||
request_ids_to_remove_from_waiting.add(state.request_id)
|
||||
state.status = RequestStatus.SPLIT_PENDING_REMAINDER
|
||||
# Even if we match the whole request, we keep at least 1 token to start decoding
|
||||
prefill_length = min(prefill_length, len(state.prompt_ids) - 1)
|
||||
state.remaining_prompt_ids = state.prompt_ids[prefill_length:]
|
||||
state.prompt_ids = state.prompt_ids[prefill_length:]
|
||||
state.position_offset += prefill_length
|
||||
|
||||
# If the request has a split prefill, the tokens to process are the remaining prompt ids
|
||||
if state.status == RequestStatus.SPLIT_PENDING_REMAINDER:
|
||||
request_tokens = state.remaining_prompt_ids
|
||||
# Otherwise, the tokens to process are the prompt ids, which are the full prompt or the last predicted tokens
|
||||
else:
|
||||
request_tokens = state.prompt_ids
|
||||
|
||||
if len(request_tokens) < token_budget:
|
||||
# Can process the entire prompt/remainder
|
||||
if state.status == RequestStatus.PENDING:
|
||||
@ -152,6 +171,7 @@ class Scheduler(ABC):
|
||||
state.prompt_ids = request_tokens[:token_budget]
|
||||
|
||||
|
||||
# TODO: further common-ize the two classes
|
||||
@attach_tracer()
|
||||
class FIFOScheduler(Scheduler):
|
||||
"""This scheduler processes requests in the order they arrive, meaning decoding requests has priority over
|
||||
@ -195,30 +215,31 @@ class FIFOScheduler(Scheduler):
|
||||
|
||||
self._prepare_request_for_processing(state, token_budget, request_ids_to_remove_from_waiting)
|
||||
request_len = len(state.prompt_ids)
|
||||
if not self._allocate_blocks_if_needed(
|
||||
state, len(state.prompt_ids)
|
||||
): # don't schedule if we can't allocate blocks
|
||||
if len(self.cache._free_blocks) == 0:
|
||||
# If we can't allocate blocks, do not schedule the request and break if the cache is full
|
||||
if not self._allocate_blocks_if_needed(state):
|
||||
if self.cache.get_num_free_blocks() == 0:
|
||||
break
|
||||
continue
|
||||
|
||||
@traced
|
||||
def _add_to_scheduled_requests(state: RequestState):
|
||||
scheduled_requests.append(state)
|
||||
|
||||
_add_to_scheduled_requests(state)
|
||||
# Add the request to the scheduled requests
|
||||
scheduled_requests.append(state)
|
||||
|
||||
# Update the token budget
|
||||
token_budget -= request_len
|
||||
# If using prefix sharing, we make note of the blocks that will be computed in the forward pass
|
||||
if self.cache.use_prefix_sharing:
|
||||
tokens_in_current_block = state.current_len() % self.cache.block_size
|
||||
tokens_after_forward = tokens_in_current_block + request_len
|
||||
complete_blocks = tokens_after_forward // self.cache.block_size
|
||||
self.cache.blocks_to_complete[state.request_id] = complete_blocks
|
||||
|
||||
@traced
|
||||
def _remove_from_waiting_requests(state: RequestState):
|
||||
req_id = state.request_id
|
||||
if req_id in self.waiting_requests:
|
||||
del self.waiting_requests[req_id]
|
||||
request_ids_to_remove_from_waiting.add(req_id)
|
||||
|
||||
_remove_from_waiting_requests(state)
|
||||
# Remove the request from the waiting queue and mark it as removed
|
||||
req_id = state.request_id
|
||||
was_waiting = self.waiting_requests.pop(req_id, None) is not None
|
||||
if was_waiting:
|
||||
request_ids_to_remove_from_waiting.add(req_id)
|
||||
|
||||
# Early exit of the loop if we have no token budget left
|
||||
if token_budget == 0:
|
||||
break
|
||||
|
||||
@ -249,6 +270,7 @@ class PrefillFirstScheduler(Scheduler):
|
||||
elif state.status == RequestStatus.DECODING:
|
||||
second_priority_states.append(state)
|
||||
|
||||
# Add waiting requests to second priority
|
||||
for req_id in self.waiting_requests_order:
|
||||
second_priority_states.append(self.waiting_requests[req_id])
|
||||
|
||||
@ -259,30 +281,31 @@ class PrefillFirstScheduler(Scheduler):
|
||||
for state in candidates:
|
||||
self._prepare_request_for_processing(state, token_budget, request_ids_to_remove_from_waiting)
|
||||
request_len = len(state.prompt_ids)
|
||||
if not self._allocate_blocks_if_needed(
|
||||
state, len(state.prompt_ids)
|
||||
): # don't schedule if we can't allocate blocks
|
||||
if len(self.cache._free_blocks) == 0:
|
||||
# If we can't allocate blocks, do not schedule the request and break if the cache is full
|
||||
if not self._allocate_blocks_if_needed(state):
|
||||
if self.cache.get_num_free_blocks() == 0:
|
||||
break
|
||||
continue
|
||||
|
||||
@traced
|
||||
def _add_to_scheduled_requests(state: RequestState):
|
||||
scheduled_requests.append(state)
|
||||
|
||||
_add_to_scheduled_requests(state)
|
||||
# Add the request to the scheduled requests
|
||||
scheduled_requests.append(state)
|
||||
|
||||
# Update the token budget
|
||||
token_budget -= request_len
|
||||
# If using prefix sharing, we make note of the blocks that will be computed in the forward pass
|
||||
if self.cache.use_prefix_sharing:
|
||||
tokens_in_current_block = state.current_len() % self.cache.block_size
|
||||
tokens_after_forward = tokens_in_current_block + request_len
|
||||
complete_blocks = tokens_after_forward // self.cache.block_size
|
||||
self.cache.blocks_to_complete[state.request_id] = complete_blocks
|
||||
|
||||
@traced
|
||||
def _remove_from_waiting_requests(state: RequestState):
|
||||
req_id = state.request_id
|
||||
if req_id in self.waiting_requests:
|
||||
del self.waiting_requests[req_id]
|
||||
request_ids_to_remove_from_waiting.add(req_id)
|
||||
|
||||
_remove_from_waiting_requests(state)
|
||||
# Remove the request from the waiting queue and mark it as removed
|
||||
req_id = state.request_id
|
||||
if req_id in self.waiting_requests:
|
||||
del self.waiting_requests[req_id]
|
||||
request_ids_to_remove_from_waiting.add(req_id)
|
||||
|
||||
# Early exit of the loop if we have no token budget left
|
||||
if token_budget == 0:
|
||||
break
|
||||
|
||||
|
||||
@ -410,8 +410,16 @@ class GenerationMixin(ContinuousMixin):
|
||||
logger.info(
|
||||
"Generation config file not found, using a generation config created from the model config."
|
||||
)
|
||||
self.generation_config = GenerationConfig.from_pretrained(
|
||||
pretrained_model_name_or_path,
|
||||
config_file_name="config.json",
|
||||
_from_auto=from_auto_class,
|
||||
_from_pipeline=from_pipeline,
|
||||
_from_model_config=True,
|
||||
**repo_loading_kwargs,
|
||||
)
|
||||
# Load custom generate function if `pretrained_model_name_or_path` defines it (and override `generate`)
|
||||
if hasattr(self, "load_custom_generate"):
|
||||
if hasattr(self, "load_custom_generate") and trust_remote_code:
|
||||
try:
|
||||
custom_generate = self.load_custom_generate(
|
||||
pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **repo_loading_kwargs
|
||||
@ -608,7 +616,7 @@ class GenerationMixin(ContinuousMixin):
|
||||
use_cache = kwargs.get("use_cache")
|
||||
if use_cache is None:
|
||||
use_cache = getattr(self.config, "use_cache", False)
|
||||
if past_key_values is None or use_cache:
|
||||
if past_key_values is not None or use_cache:
|
||||
# TODO (joao): handle the case where cache length == input_ids length. The function below results in an
|
||||
# exception because we get empty input_ids after slicing. In essence, we need to roll back the cache 1
|
||||
# token to recompute the logits for the first token to be generated (but not all caches support roll backs)
|
||||
@ -1635,7 +1643,12 @@ class GenerationMixin(ContinuousMixin):
|
||||
|
||||
# TransformersKwargs are model-agnostic attention and generation arguments such as 'output_attentions'
|
||||
for key, value in model_kwargs.items():
|
||||
if value is not None and key not in model_args and key not in TransformersKwargs.__optional_keys__:
|
||||
if (
|
||||
value is not None
|
||||
and key not in model_args
|
||||
and key not in TransformersKwargs.__optional_keys__
|
||||
and key != "debug_io"
|
||||
):
|
||||
unused_model_args.append(key)
|
||||
|
||||
if unused_model_args:
|
||||
@ -1773,14 +1786,12 @@ class GenerationMixin(ContinuousMixin):
|
||||
):
|
||||
new_generation_config = GenerationConfig.from_model_config(self.config)
|
||||
if new_generation_config != self.generation_config: # 4)
|
||||
warnings.warn(
|
||||
"You have modified the pretrained model configuration to control generation. This is a"
|
||||
" deprecated strategy to control generation and will be removed in v5."
|
||||
raise ValueError(
|
||||
"You have modified the pretrained model configuration to control generation."
|
||||
" This strategy to control generation is not supported anymore. "
|
||||
" Please use and modify the model generation configuration (see"
|
||||
" https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration )",
|
||||
UserWarning,
|
||||
)
|
||||
self.generation_config = new_generation_config
|
||||
|
||||
generation_config = self.generation_config
|
||||
using_model_generation_config = True
|
||||
@ -2170,7 +2181,7 @@ class GenerationMixin(ContinuousMixin):
|
||||
return False
|
||||
|
||||
# Base logic
|
||||
valid_hardware = self.device.type == "cuda" or bool(
|
||||
valid_hardware = self.device.type in ["cuda", "xpu"] or bool(
|
||||
generation_config.compile_config is not None and generation_config.compile_config._compile_all_devices
|
||||
)
|
||||
using_compilable_cache = (
|
||||
|
||||
@ -23,6 +23,7 @@ import torch
|
||||
from torch import nn
|
||||
from torch.nn import BCELoss
|
||||
|
||||
from .. import initialization as init
|
||||
from ..modeling_utils import PreTrainedModel
|
||||
from ..utils import ModelOutput, logging
|
||||
from .configuration_utils import PreTrainedConfig, WatermarkingConfig
|
||||
@ -383,10 +384,11 @@ class BayesianDetectorModel(PreTrainedModel):
|
||||
)
|
||||
self.prior = torch.nn.Parameter(torch.tensor([self.base_rate]))
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights."""
|
||||
if isinstance(module, nn.Parameter):
|
||||
module.weight.data.normal_(mean=0.0, std=0.02)
|
||||
init.normal_(module.weight, mean=0.0, std=0.02)
|
||||
|
||||
def _compute_posterior(
|
||||
self,
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user