mirror of
https://github.com/huggingface/transformers.git
synced 2025-11-15 07:04:49 +08:00
Compare commits
55 Commits
vb/fix-art
...
flaky_gene
| Author | SHA1 | Date | |
|---|---|---|---|
| 8af4ddf5ad | |||
| 5c6d6bed4d | |||
| 80134e6e66 | |||
| ce40ca0d4c | |||
| 6408d3b01a | |||
| f40ef03214 | |||
| 5150dac727 | |||
| 27c3807991 | |||
| ffb35fe142 | |||
| 1fd63dd532 | |||
| 240d19f4a3 | |||
| ba938fa590 | |||
| 6744ebe745 | |||
| 1709ed96e4 | |||
| fd36275be2 | |||
| 922e85487b | |||
| f9e668abf3 | |||
| 7951105d69 | |||
| 58a3f8caac | |||
| fcea1e1fe0 | |||
| 563f2ffb21 | |||
| 6f479d5d75 | |||
| d012f34e0d | |||
| e76364d5c1 | |||
| 2b8068c306 | |||
| 33c60a5254 | |||
| fa22b56903 | |||
| f30c22500b | |||
| 496c283615 | |||
| df45a92cea | |||
| 3ff0e69f84 | |||
| 31839d741a | |||
| 2072f3059e | |||
| 3760afb21c | |||
| 3c0b2b101e | |||
| e869e9df54 | |||
| 37d48bbb48 | |||
| 21913b2e10 | |||
| f028e9340c | |||
| 4dd4a8fafe | |||
| 03538a80be | |||
| 700c48a29f | |||
| 18a19dea61 | |||
| dba6aeb1e3 | |||
| 1c9077f66d | |||
| 756742354b | |||
| 926c37aaf4 | |||
| f5630f9b1a | |||
| e8a6eb3304 | |||
| 370fc65ee5 | |||
| f065e402fc | |||
| 91d250efb1 | |||
| 7cb4280112 | |||
| 144c8ce280 | |||
| 069684ef87 |
7
.github/workflows/benchmark.yml
vendored
7
.github/workflows/benchmark.yml
vendored
@ -32,16 +32,15 @@ jobs:
|
||||
options: --gpus all --privileged --ipc host
|
||||
steps:
|
||||
- name: Get repo
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v5
|
||||
with:
|
||||
ref: ${{ github.event.pull_request.head.sha || github.sha }}
|
||||
fetch-depth: 1
|
||||
|
||||
- name: Install benchmark script dependencies
|
||||
run: python3 -m pip install -r benchmark_v2/requirements.txt kernels
|
||||
|
||||
- name: Reinstall transformers in edit mode (remove the one installed during docker image build)
|
||||
working-directory: /transformers
|
||||
run: python3 -m pip uninstall -y transformers && python3 -m pip install -e ".[torch]" && python3 -m pip uninstall -y torchvision # temp fix
|
||||
run: python3 -m pip uninstall -y transformers && python3 -m pip install -e ".[torch]"
|
||||
|
||||
- name: Run benchmark
|
||||
run: |
|
||||
|
||||
@ -20,4 +20,4 @@ jobs:
|
||||
contents: read
|
||||
with:
|
||||
workflow_name: ${{ inputs.workflow_name }}
|
||||
run_count: ${{ fromJSON(inputs.run_count) }}
|
||||
run_count: ${{ fromJSON(inputs.run_count) }}
|
||||
3
.github/workflows/get-pr-info.yml
vendored
3
.github/workflows/get-pr-info.yml
vendored
@ -87,9 +87,6 @@ jobs:
|
||||
PR_FILES: ${{ steps.pr_info.outputs.files }}
|
||||
if: ${{ inputs.pr_number != '' }}
|
||||
steps:
|
||||
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
|
||||
with:
|
||||
config: ${{ vars.PERMISSIONS_CONFIG }}
|
||||
- name: Extract PR details
|
||||
id: pr_info
|
||||
uses: actions/github-script@v6
|
||||
|
||||
3
.github/workflows/get-pr-number.yml
vendored
3
.github/workflows/get-pr-number.yml
vendored
@ -13,9 +13,6 @@ jobs:
|
||||
outputs:
|
||||
PR_NUMBER: ${{ steps.set_pr_number.outputs.PR_NUMBER }}
|
||||
steps:
|
||||
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
|
||||
with:
|
||||
config: ${{ vars.PERMISSIONS_CONFIG }}
|
||||
- name: Get PR number
|
||||
shell: bash
|
||||
env:
|
||||
|
||||
@ -13,9 +13,6 @@ jobs:
|
||||
name: Notify new model
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
|
||||
with:
|
||||
config: ${{ vars.PERMISSIONS_CONFIG }}
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
12
.github/workflows/pr_build_doc_with_comment.yml
vendored
12
.github/workflows/pr_build_doc_with_comment.yml
vendored
@ -35,9 +35,6 @@ jobs:
|
||||
PR_MERGE_COMMIT_DATE: ${{ needs.get-pr-info.outputs.PR_MERGE_COMMIT_DATE }}
|
||||
PR_MERGE_COMMIT_TIMESTAMP: ${{ needs.get-pr-info.outputs.PR_MERGE_COMMIT_TIMESTAMP }}
|
||||
steps:
|
||||
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
|
||||
with:
|
||||
config: ${{ vars.PERMISSIONS_CONFIG }}
|
||||
- run: |
|
||||
COMMENT_TIMESTAMP=$(date -d "${COMMENT_DATE}" +"%s")
|
||||
echo "COMMENT_DATE: $COMMENT_DATE"
|
||||
@ -57,9 +54,6 @@ jobs:
|
||||
statuses: write
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
|
||||
with:
|
||||
config: ${{ vars.PERMISSIONS_CONFIG }}
|
||||
- name: Create Run
|
||||
id: create_run
|
||||
env:
|
||||
@ -83,9 +77,6 @@ jobs:
|
||||
pull-requests: write
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
|
||||
with:
|
||||
config: ${{ vars.PERMISSIONS_CONFIG }}
|
||||
- name: Reply to the comment
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
@ -121,9 +112,6 @@ jobs:
|
||||
GITHUB_RUN_URL: https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}
|
||||
STATUS_OK: ${{ contains(fromJSON('["skipped", "success"]'), needs.create_run.result) }}
|
||||
steps:
|
||||
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
|
||||
with:
|
||||
config: ${{ vars.PERMISSIONS_CONFIG }}
|
||||
- name: Get `build-doc` job status
|
||||
run: |
|
||||
echo "${{ needs.build-doc.result }}"
|
||||
|
||||
8
.github/workflows/pr_slow_ci_suggestion.yml
vendored
8
.github/workflows/pr_slow_ci_suggestion.yml
vendored
@ -23,10 +23,6 @@ jobs:
|
||||
outputs:
|
||||
jobs: ${{ steps.get_jobs.outputs.jobs_to_run }}
|
||||
steps:
|
||||
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
|
||||
with:
|
||||
config: ${{ vars.PERMISSIONS_CONFIG }}
|
||||
|
||||
# This checkout to the main branch
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
@ -93,10 +89,6 @@ jobs:
|
||||
pull-requests: write
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
|
||||
with:
|
||||
config: ${{ vars.PERMISSIONS_CONFIG }}
|
||||
|
||||
- name: Check and update comment if needed
|
||||
uses: actions/github-script@v7
|
||||
env:
|
||||
|
||||
4
.github/workflows/push-important-models.yml
vendored
4
.github/workflows/push-important-models.yml
vendored
@ -11,10 +11,6 @@ jobs:
|
||||
outputs:
|
||||
matrix: ${{ steps.set-matrix.outputs.matrix }}
|
||||
steps:
|
||||
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
|
||||
with:
|
||||
config: ${{ vars.PERMISSIONS_CONFIG }}
|
||||
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
|
||||
4
.github/workflows/release-conda.yml
vendored
4
.github/workflows/release-conda.yml
vendored
@ -18,10 +18,6 @@ jobs:
|
||||
shell: bash -l {0}
|
||||
|
||||
steps:
|
||||
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
|
||||
with:
|
||||
config: ${{ vars.PERMISSIONS_CONFIG }}
|
||||
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
|
||||
24
.github/workflows/self-comment-ci.yml
vendored
24
.github/workflows/self-comment-ci.yml
vendored
@ -46,10 +46,6 @@ jobs:
|
||||
PR_HEAD_SHA: ${{ needs.get-pr-info.outputs.PR_HEAD_SHA }}
|
||||
PR_MERGE_SHA: ${{ needs.get-pr-info.outputs.PR_MERGE_COMMIT_SHA }}
|
||||
steps:
|
||||
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
|
||||
with:
|
||||
config: ${{ vars.PERMISSIONS_CONFIG }}
|
||||
|
||||
- name: Verify `merge_commit` timestamp is older than the issue comment timestamp
|
||||
env:
|
||||
COMMENT_DATE: ${{ github.event.comment.created_at }}
|
||||
@ -71,10 +67,6 @@ jobs:
|
||||
models: ${{ steps.models_to_run.outputs.models }}
|
||||
quantizations: ${{ steps.models_to_run.outputs.quantizations }}
|
||||
steps:
|
||||
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
|
||||
with:
|
||||
config: ${{ vars.PERMISSIONS_CONFIG }}
|
||||
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: "0"
|
||||
@ -117,10 +109,6 @@ jobs:
|
||||
pull-requests: write
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
|
||||
with:
|
||||
config: ${{ vars.PERMISSIONS_CONFIG }}
|
||||
|
||||
- name: Reply to the comment
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
@ -143,10 +131,6 @@ jobs:
|
||||
pull-requests: write
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
|
||||
with:
|
||||
config: ${{ vars.PERMISSIONS_CONFIG }}
|
||||
|
||||
- name: Reply to the comment
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
@ -168,10 +152,6 @@ jobs:
|
||||
statuses: write
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
|
||||
with:
|
||||
config: ${{ vars.PERMISSIONS_CONFIG }}
|
||||
|
||||
- name: Create Run
|
||||
id: create_run
|
||||
env:
|
||||
@ -230,10 +210,6 @@ jobs:
|
||||
if: ${{ always() && needs.create_run.result == 'success' }}
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
|
||||
with:
|
||||
config: ${{ vars.PERMISSIONS_CONFIG }}
|
||||
|
||||
- name: Show reports from jobs
|
||||
env:
|
||||
MODEL_REPORT: ${{ needs.model-ci.outputs.report }}
|
||||
|
||||
4
.github/workflows/self-nightly-caller.yml
vendored
4
.github/workflows/self-nightly-caller.yml
vendored
@ -30,10 +30,6 @@ jobs:
|
||||
name: Setup
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
|
||||
with:
|
||||
config: ${{ vars.PERMISSIONS_CONFIG }}
|
||||
|
||||
- name: Setup
|
||||
run: |
|
||||
mkdir "setup_values"
|
||||
|
||||
@ -14,9 +14,6 @@ jobs:
|
||||
outputs:
|
||||
run_number: ${{ steps.get_number.outputs.run_number }}
|
||||
steps:
|
||||
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
|
||||
with:
|
||||
config: ${{ vars.PERMISSIONS_CONFIG }}
|
||||
- name: Get number
|
||||
id: get_number
|
||||
run: |
|
||||
|
||||
@ -10,9 +10,5 @@ jobs:
|
||||
runs-on: ubuntu-22.04
|
||||
if: ${{ always() }}
|
||||
steps:
|
||||
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
|
||||
with:
|
||||
config: ${{ vars.PERMISSIONS_CONFIG }}
|
||||
|
||||
- name: Trigger scheduled AMD CI via workflow_run
|
||||
run: echo "Trigger scheduled AMD CI via workflow_run"
|
||||
|
||||
3
.github/workflows/self-scheduled-caller.yml
vendored
3
.github/workflows/self-scheduled-caller.yml
vendored
@ -32,9 +32,6 @@ jobs:
|
||||
name: Setup
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
|
||||
with:
|
||||
config: ${{ vars.PERMISSIONS_CONFIG }}
|
||||
- name: Setup
|
||||
env:
|
||||
prev_workflow_run_id: ${{ inputs.prev_workflow_run_id || env.prev_workflow_run_id }}
|
||||
|
||||
@ -32,10 +32,6 @@ jobs:
|
||||
name: Setup
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
|
||||
with:
|
||||
config: ${{ vars.PERMISSIONS_CONFIG }}
|
||||
|
||||
- name: Setup
|
||||
run: |
|
||||
mkdir "setup_values"
|
||||
|
||||
16
.github/workflows/self-scheduled-intel-gaudi.yml
vendored
16
.github/workflows/self-scheduled-intel-gaudi.yml
vendored
@ -38,10 +38,6 @@ jobs:
|
||||
folder_slices: ${{ steps.set-matrix.outputs.folder_slices }}
|
||||
quantization_matrix: ${{ steps.set-matrix.outputs.quantization_matrix }}
|
||||
steps:
|
||||
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
|
||||
with:
|
||||
config: ${{ vars.PERMISSIONS_CONFIG }}
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
@ -126,10 +122,6 @@ jobs:
|
||||
--cap-add=sys_nice
|
||||
--shm-size=64G
|
||||
steps:
|
||||
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
|
||||
with:
|
||||
config: ${{ vars.PERMISSIONS_CONFIG }}
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
@ -199,10 +191,6 @@ jobs:
|
||||
--cap-add=sys_nice
|
||||
--shm-size=64G
|
||||
steps:
|
||||
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
|
||||
with:
|
||||
config: ${{ vars.PERMISSIONS_CONFIG }}
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
@ -275,10 +263,6 @@ jobs:
|
||||
--cap-add=sys_nice
|
||||
--shm-size=64G
|
||||
steps:
|
||||
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
|
||||
with:
|
||||
config: ${{ vars.PERMISSIONS_CONFIG }}
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
|
||||
21
.github/workflows/self-scheduled.yml
vendored
21
.github/workflows/self-scheduled.yml
vendored
@ -78,9 +78,6 @@ jobs:
|
||||
slice_ids: ${{ steps.set-matrix.outputs.slice_ids }}
|
||||
quantization_matrix: ${{ steps.set-matrix-quantization.outputs.quantization_matrix }}
|
||||
steps:
|
||||
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
|
||||
with:
|
||||
config: ${{ vars.PERMISSIONS_CONFIG }}
|
||||
- name: Update clone
|
||||
working-directory: /transformers
|
||||
env:
|
||||
@ -187,9 +184,6 @@ jobs:
|
||||
image: huggingface/transformers-all-latest-gpu
|
||||
options: --gpus all --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/
|
||||
steps:
|
||||
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
|
||||
with:
|
||||
config: ${{ vars.PERMISSIONS_CONFIG }}
|
||||
- name: Update clone
|
||||
working-directory: /transformers
|
||||
env:
|
||||
@ -262,9 +256,6 @@ jobs:
|
||||
image: huggingface/transformers-all-latest-gpu
|
||||
options: --gpus all --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/
|
||||
steps:
|
||||
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
|
||||
with:
|
||||
config: ${{ vars.PERMISSIONS_CONFIG }}
|
||||
- name: Update clone
|
||||
working-directory: /transformers
|
||||
env:
|
||||
@ -338,9 +329,6 @@ jobs:
|
||||
image: ${{ inputs.docker }}
|
||||
options: --gpus all --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/
|
||||
steps:
|
||||
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
|
||||
with:
|
||||
config: ${{ vars.PERMISSIONS_CONFIG }}
|
||||
- name: Update clone
|
||||
working-directory: ${{ inputs.working-directory-prefix }}/transformers
|
||||
env:
|
||||
@ -446,9 +434,6 @@ jobs:
|
||||
image: huggingface/transformers-quantization-latest-gpu
|
||||
options: --gpus all --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/
|
||||
steps:
|
||||
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
|
||||
with:
|
||||
config: ${{ vars.PERMISSIONS_CONFIG }}
|
||||
- name: Echo folder ${{ matrix.folders }}
|
||||
shell: bash
|
||||
env:
|
||||
@ -533,9 +518,6 @@ jobs:
|
||||
image: ${{ inputs.docker }}
|
||||
options: --gpus all --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/
|
||||
steps:
|
||||
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
|
||||
with:
|
||||
config: ${{ vars.PERMISSIONS_CONFIG }}
|
||||
- name: Update clone
|
||||
working-directory: /transformers
|
||||
env:
|
||||
@ -606,9 +588,6 @@ jobs:
|
||||
steps:
|
||||
# Checkout in order to run `utils/extract_warnings.py`. Avoid **explicit** checkout (i.e. don't specify `ref`) for
|
||||
# security reason.
|
||||
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
|
||||
with:
|
||||
config: ${{ vars.PERMISSIONS_CONFIG }}
|
||||
- name: Checkout transformers
|
||||
uses: actions/checkout@v4
|
||||
|
||||
|
||||
4
.github/workflows/slack-report.yml
vendored
4
.github/workflows/slack-report.yml
vendored
@ -38,10 +38,6 @@ jobs:
|
||||
runs-on: ubuntu-22.04
|
||||
if: always() && !cancelled()
|
||||
steps:
|
||||
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
|
||||
with:
|
||||
config: ${{ vars.PERMISSIONS_CONFIG }}
|
||||
|
||||
- name: Preliminary job status
|
||||
shell: bash
|
||||
# For the meaning of these environment variables, see the job `Setup`
|
||||
|
||||
8
.github/workflows/ssh-runner.yml
vendored
8
.github/workflows/ssh-runner.yml
vendored
@ -30,10 +30,6 @@ jobs:
|
||||
outputs:
|
||||
RUNNER: ${{ steps.set_runner.outputs.RUNNER }}
|
||||
steps:
|
||||
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
|
||||
with:
|
||||
config: ${{ vars.PERMISSIONS_CONFIG }}
|
||||
|
||||
- name: Get runner to use
|
||||
shell: bash
|
||||
env:
|
||||
@ -62,10 +58,6 @@ jobs:
|
||||
container:
|
||||
image: ${{ github.event.inputs.docker_image }}
|
||||
steps:
|
||||
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
|
||||
with:
|
||||
config: ${{ vars.PERMISSIONS_CONFIG }}
|
||||
|
||||
- name: Update clone
|
||||
working-directory: /transformers
|
||||
env:
|
||||
|
||||
27
.github/workflows/stale.yml
vendored
27
.github/workflows/stale.yml
vendored
@ -14,21 +14,16 @@ jobs:
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
steps:
|
||||
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
|
||||
with:
|
||||
config: ${{ vars.PERMISSIONS_CONFIG }}
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: 3.8
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: 3.8
|
||||
|
||||
- name: Install requirements
|
||||
run: |
|
||||
pip install PyGithub
|
||||
- name: Close stale issues
|
||||
run: |
|
||||
python scripts/stale.py
|
||||
- name: Install requirements
|
||||
run: |
|
||||
pip install PyGithub
|
||||
- name: Close stale issues
|
||||
run: |
|
||||
python scripts/stale.py
|
||||
|
||||
4
.github/workflows/trufflehog.yml
vendored
4
.github/workflows/trufflehog.yml
vendored
@ -10,10 +10,6 @@ jobs:
|
||||
trufflehog:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
|
||||
with:
|
||||
config: ${{ vars.PERMISSIONS_CONFIG }}
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
|
||||
4
.github/workflows/update_metdata.yml
vendored
4
.github/workflows/update_metdata.yml
vendored
@ -14,10 +14,6 @@ jobs:
|
||||
shell: bash -l {0}
|
||||
|
||||
steps:
|
||||
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
|
||||
with:
|
||||
config: ${{ vars.PERMISSIONS_CONFIG }}
|
||||
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Setup environment
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
gpustat==1.1.1
|
||||
psutil==6.0.0
|
||||
psycopg2==2.9.9
|
||||
torch>=2.4.0
|
||||
hf_xet
|
||||
pandas>=1.5.0
|
||||
pandas>=1.5.0
|
||||
|
||||
@ -36,6 +36,7 @@ class BenchmarkConfig:
|
||||
warmup_iterations: int = 5,
|
||||
measurement_iterations: int = 20,
|
||||
gpu_monitoring: bool = True, # NOTE: you may want to disable this at times as we have obsvered it could heavily slow down benchmarks on AMD
|
||||
continuous_batching: bool = False,
|
||||
batch_size: int = 1,
|
||||
sequence_length: int = 128,
|
||||
num_tokens_to_generate: int = 128,
|
||||
@ -51,6 +52,7 @@ class BenchmarkConfig:
|
||||
self.warmup_iterations = warmup_iterations
|
||||
self.measurement_iterations = measurement_iterations
|
||||
self.gpu_monitoring = gpu_monitoring
|
||||
self.continuous_batching = continuous_batching
|
||||
# Input parameters
|
||||
self.batch_size = batch_size
|
||||
self.sequence_length = sequence_length
|
||||
@ -85,6 +87,22 @@ class BenchmarkConfig:
|
||||
if is_fa:
|
||||
logger.warning("Flash attention does not support compile mode. Turning off compile mode.")
|
||||
self.compile_mode = None
|
||||
# Handle SDPA backend if not determined by the config (needs to be done before skipping duplicates)
|
||||
if self.attn_implementation == "sdpa" and self.sdpa_backend is None:
|
||||
default_backend = "flash_attention" # FIXME: torch has a _cur_sdpa_kernel_backends but it fails
|
||||
logger.warning(f"No SDPA backend provided, using {default_backend} instead.")
|
||||
self.sdpa_backend = default_backend
|
||||
if self.continuous_batching:
|
||||
if self.attn_implementation == "flex_attention":
|
||||
logger.error(
|
||||
"disabling continuous batching because of invalid configuration: flex attention is not supported"
|
||||
)
|
||||
self.continuous_batching = False
|
||||
elif self.attn_implementation == "sdpa" and self.sdpa_backend is not None:
|
||||
logger.warning(
|
||||
"when continuous batching is enabled, sdpa_backend must be None because of the attention mask, setting it to None"
|
||||
)
|
||||
self.sdpa_backend = "math"
|
||||
|
||||
@property
|
||||
def hash(self) -> str:
|
||||
@ -100,6 +118,7 @@ class BenchmarkConfig:
|
||||
attn_code += f"_{self.sdpa_backend}" if self.attn_implementation == "sdpa" else ""
|
||||
compile_str = f"compiled_{self.compile_mode}" if self.compile_mode is not None else "uncompiled"
|
||||
kernelize_str = "kernelized" if self.kernelize else "unkernelized"
|
||||
continuous_batching_str = "cb" if self.continuous_batching else "generate"
|
||||
sep = "-"
|
||||
else:
|
||||
iter_str = f"{self.warmup_iterations} warmup, {self.measurement_iterations} iterations"
|
||||
@ -109,8 +128,11 @@ class BenchmarkConfig:
|
||||
attn_code += f" with {self.sdpa_backend} backend" if self.attn_implementation == "sdpa" else ""
|
||||
compile_str = "compiled" if self.compile_mode is not None else "not compiled"
|
||||
kernelize_str = "kernelized" if self.kernelize else "not kernelized"
|
||||
continuous_batching_str = "continuous batching" if self.continuous_batching else "regular generate"
|
||||
sep = ", "
|
||||
return sep.join([iter_str, gpu_monitor_str, dimensions_str, attn_code, compile_str, kernelize_str])
|
||||
return sep.join(
|
||||
[iter_str, gpu_monitor_str, dimensions_str, attn_code, compile_str, kernelize_str, continuous_batching_str]
|
||||
)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
@ -118,6 +140,7 @@ class BenchmarkConfig:
|
||||
"warmup_iterations": self.warmup_iterations,
|
||||
"measurement_iterations": self.measurement_iterations,
|
||||
"gpu_monitoring": self.gpu_monitoring,
|
||||
"continuous_batching": self.continuous_batching,
|
||||
"batch_size": self.batch_size,
|
||||
"sequence_length": self.sequence_length,
|
||||
"num_tokens_to_generate": self.num_tokens_to_generate,
|
||||
@ -134,6 +157,7 @@ class BenchmarkConfig:
|
||||
warmup_iterations=data.get("warmup_iterations", 5),
|
||||
measurement_iterations=data.get("measurement_iterations", 20),
|
||||
gpu_monitoring=data.get("gpu_monitoring", False),
|
||||
continuous_batching=data.get("continuous_batching", False),
|
||||
batch_size=data.get("batch_size", 1),
|
||||
sequence_length=data.get("sequence_length", 128),
|
||||
num_tokens_to_generate=data.get("num_tokens_to_generate", 128),
|
||||
@ -191,15 +215,17 @@ def get_config_by_level(level: int) -> list[BenchmarkConfig]:
|
||||
# Usually there is not much to gain by compiling with other modes, but we allow it for level 4
|
||||
compile_modes = BenchmarkConfig.all_compiled_modes if level >= 4 else [None, "default"]
|
||||
for cm in compile_modes:
|
||||
for kernelize_on in [False, KERNELIZATION_AVAILABLE]:
|
||||
configs.append(
|
||||
BenchmarkConfig(
|
||||
attn_implementation=attn_implementation,
|
||||
sdpa_backend=sdpa_backend,
|
||||
compile_mode=cm,
|
||||
kernelize=kernelize_on,
|
||||
for kernelize_on in {False, KERNELIZATION_AVAILABLE}:
|
||||
for cb_on in [False, True]:
|
||||
configs.append(
|
||||
BenchmarkConfig(
|
||||
attn_implementation=attn_implementation,
|
||||
sdpa_backend=sdpa_backend,
|
||||
compile_mode=cm,
|
||||
kernelize=kernelize_on,
|
||||
continuous_batching=cb_on,
|
||||
)
|
||||
)
|
||||
)
|
||||
return configs
|
||||
# Otherwise, we add the configs for the given level
|
||||
if level >= 0:
|
||||
@ -207,8 +233,10 @@ def get_config_by_level(level: int) -> list[BenchmarkConfig]:
|
||||
if level >= 1:
|
||||
configs.append(BenchmarkConfig(attn_implementation="flash_attention_2"))
|
||||
configs.append(BenchmarkConfig(attn_implementation="eager", compile_mode="default"))
|
||||
configs.append(BenchmarkConfig(attn_implementation="flash_attention_2", continuous_batching=True))
|
||||
if level >= 2:
|
||||
configs.append(BenchmarkConfig(attn_implementation="sdpa", compile_mode="default"))
|
||||
configs.append(BenchmarkConfig(attn_implementation="flex_attention", compile_mode="default", kernelize=True))
|
||||
configs.append(BenchmarkConfig(attn_implementation="flash_attention_2", kernelize=True))
|
||||
configs.append(BenchmarkConfig(attn_implementation="paged|sdpa", continuous_batching=True))
|
||||
return configs
|
||||
|
||||
@ -117,8 +117,6 @@ def flush_memory():
|
||||
# Clear CUDA cache
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_max_memory_allocated()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
torch.cuda.synchronize()
|
||||
gc.collect()
|
||||
|
||||
@ -234,8 +232,9 @@ class BenchmarkRunner:
|
||||
self.logger.info(f"Running benchmark scenario: {config.name}")
|
||||
|
||||
# Quick validation: try one measurement first to see if this scenario works
|
||||
generate_fn = self.time_generate_batch if config.continuous_batching else self.time_generate
|
||||
flush_memory()
|
||||
e2e_latency, token_generation_times, shape_and_decoded_output, gpu_metrics = self.time_generate(
|
||||
e2e_latency, token_generation_times, shape_and_decoded_output, gpu_metrics = generate_fn(
|
||||
max_new_tokens=1, gpu_monitor=None
|
||||
)
|
||||
if e2e_latency < 0:
|
||||
@ -245,14 +244,14 @@ class BenchmarkRunner:
|
||||
# Warmup runs
|
||||
self.logger.info(f"Warming up with {config.warmup_iterations} iterations...")
|
||||
for _ in trange(config.warmup_iterations):
|
||||
_ = self.time_generate(max_new_tokens=config.num_tokens_to_generate)
|
||||
_ = generate_fn(max_new_tokens=config.num_tokens_to_generate)
|
||||
self.logger.info("Warmup over.")
|
||||
|
||||
# Measurement runs
|
||||
result = BenchmarkResult()
|
||||
self.logger.info(f"Benchmarking with {config.measurement_iterations} iterations.")
|
||||
for _ in trange(config.measurement_iterations):
|
||||
e2e_latency, token_generation_times, shape_and_decoded_output, gpu_metrics = self.time_generate(
|
||||
e2e_latency, token_generation_times, shape_and_decoded_output, gpu_metrics = generate_fn(
|
||||
max_new_tokens=config.num_tokens_to_generate,
|
||||
gpu_monitor=(GPUMonitor(logger=self.logger) if config.gpu_monitoring else None),
|
||||
)
|
||||
@ -274,6 +273,58 @@ class BenchmarkRunner:
|
||||
"config": config,
|
||||
}
|
||||
|
||||
# TODO: refactor `generate_batch` to handle streaming so we can use it here
|
||||
def time_generate_batch(
|
||||
self,
|
||||
max_new_tokens: int,
|
||||
gpu_monitor: GPUMonitor | None = None,
|
||||
) -> tuple[float, list[float], str, GPURawMetrics | None]:
|
||||
if gpu_monitor is not None:
|
||||
gpu_monitor.start()
|
||||
config = GenerationConfig(
|
||||
max_new_tokens=max_new_tokens,
|
||||
eos_token_id=self.tokenizer.eos_token_id,
|
||||
pad_token_id=self.tokenizer.pad_token_id,
|
||||
do_sample=True,
|
||||
)
|
||||
manager = self.model.init_continuous_batching(config)
|
||||
manager.start()
|
||||
try:
|
||||
first_req_results = []
|
||||
timestamps = []
|
||||
wall_time_0 = time.perf_counter()
|
||||
inputs = self.inputs["input_ids"].tolist()
|
||||
manager.add_requests(inputs, max_new_tokens=max_new_tokens, streaming=True)
|
||||
first_req_id = None
|
||||
num_requests = len(inputs)
|
||||
finished_requests = 0
|
||||
while finished_requests < num_requests:
|
||||
# NOTE: I don't like having the extra if stmt here, but hopefully won't degrade perf too much
|
||||
result = manager.get_result()
|
||||
if result:
|
||||
timestamps.append(time.perf_counter() - wall_time_0)
|
||||
if result.is_finished():
|
||||
finished_requests += 1
|
||||
if first_req_id is None:
|
||||
first_req_id = result.request_id
|
||||
if result.request_id == first_req_id:
|
||||
first_req_results.append(result)
|
||||
else:
|
||||
if not manager.is_running():
|
||||
raise RuntimeError("Generation thread exited unexpectedly")
|
||||
wall_time_1 = time.perf_counter()
|
||||
gpu_metrics = gpu_monitor.stop_and_collect() if gpu_monitor is not None else None
|
||||
decoded_output = self.tokenizer.decode(
|
||||
[res.generated_tokens[0] for res in first_req_results], skip_special_tokens=True
|
||||
)
|
||||
shape_and_decoded_output = f"{(1, len(first_req_results))} | {decoded_output}"
|
||||
e2e_latency = wall_time_1 - wall_time_0
|
||||
return e2e_latency, timestamps, shape_and_decoded_output, gpu_metrics
|
||||
except Exception as e:
|
||||
raise e
|
||||
finally:
|
||||
manager.stop()
|
||||
|
||||
def time_generate(
|
||||
self,
|
||||
max_new_tokens: int,
|
||||
@ -339,12 +390,6 @@ class BenchmarkRunner:
|
||||
|
||||
n_configs = len(benchmark_configs)
|
||||
for i, config in enumerate(benchmark_configs):
|
||||
# Handle SDPA backend if not determined by the config (needs to be done before skipping duplicates)
|
||||
if config.attn_implementation == "sdpa" and config.sdpa_backend is None:
|
||||
default_backend = "flash_attention" # FIXME: torch has a _cur_sdpa_kernel_backends but it fails
|
||||
self.logger.warning(f"No SDPA backend provided, using {default_backend} instead.")
|
||||
config.sdpa_backend = default_backend
|
||||
|
||||
# Skip if already run
|
||||
if config.hash in all_results:
|
||||
self.logger.info(f"Skipping duplicate config {config.name} for model {model_id} ({i + 1}/{n_configs})")
|
||||
@ -368,21 +413,27 @@ class BenchmarkRunner:
|
||||
self.cleanup()
|
||||
self.save_results(model_id, all_results, timestamp=timestamp)
|
||||
|
||||
if len(all_results) < 1:
|
||||
raise RuntimeError("No benchmark was run succesfully")
|
||||
|
||||
if pretty_print_summary:
|
||||
print()
|
||||
print("=" * 100)
|
||||
print(f"Finished benchmarks in {time.perf_counter() - start_time:.2f} seconds")
|
||||
print(f"Total number of benchmarks: {len(all_results)}")
|
||||
if len(all_results) > 0:
|
||||
print("First run metadata:")
|
||||
first_key = list(all_results.keys())[0]
|
||||
first_metadata = all_results[first_key]["metadata"].to_dict()
|
||||
hardware_info = first_metadata.pop("hardware_info")
|
||||
pretty_print_dict(first_metadata | hardware_info, tabs=1)
|
||||
print("First run metadata:")
|
||||
first_key = list(all_results.keys())[0]
|
||||
first_metadata = all_results[first_key]["metadata"].to_dict()
|
||||
hardware_info = first_metadata.pop("hardware_info")
|
||||
pretty_print_dict(first_metadata | hardware_info, tabs=1)
|
||||
for result in all_results.values():
|
||||
print("=" * 100)
|
||||
print(f"Config: {result['config'].infer_name(compact=False)}\n")
|
||||
result["measurements"].pprint(batch_size=result["config"].batch_size, tabs=1)
|
||||
result["measurements"].pprint(
|
||||
batch_size=result["config"].batch_size,
|
||||
num_generated_tokens=result["config"].num_tokens_to_generate,
|
||||
tabs=1,
|
||||
)
|
||||
print("=" * 100)
|
||||
|
||||
return (timestamp, all_results)
|
||||
|
||||
@ -36,16 +36,17 @@ def add_unit_to_duration(stats: dict[str, float]) -> dict[str, str]:
|
||||
return stats
|
||||
|
||||
|
||||
def equalize_lengths_and_collate(stats: list[dict[str, str]]) -> list[str]:
|
||||
def equalize_lengths_and_collate(stats: dict[str, dict[str, str]]) -> dict[str, str]:
|
||||
"""Note: This operation is destructive as it will update values in place before returning a new correctly formatted dict"""
|
||||
keys = ["avg", "std", "min", "med", "max", "p95"]
|
||||
for key in keys:
|
||||
max_length = max(len(stat[key]) for stat in stats)
|
||||
for stat in stats:
|
||||
max_length = max(len(stat[key]) for stat in stats.values())
|
||||
for stat in stats.values():
|
||||
stat[key] = stat[key].ljust(max_length, " ")
|
||||
return [" ".join([f"{key}={stat[key]}" for key in keys]) for stat in stats]
|
||||
return {name: " ".join([f"{key}={stat[key]}" for key in keys]) for name, stat in stats.items()}
|
||||
|
||||
|
||||
def pretty_print_dict(data: dict[str, Any], tabs: int = 0) -> None:
|
||||
def pretty_print_dict(data: dict[str, str], tabs: int = 0) -> None:
|
||||
max_key_length = max([len(key) for key in data.keys()])
|
||||
for key, value in data.items():
|
||||
tabs_str = " " * tabs
|
||||
@ -141,27 +142,19 @@ class BenchmarkResult:
|
||||
def get_measured_itl(self) -> list[float]:
|
||||
return [(dt[-1] - dt[0]) / (len(dt) - 1) for dt in self.token_generation_times if len(dt) > 1]
|
||||
|
||||
def get_throughput(self, batch_size: int) -> float:
|
||||
return [
|
||||
batch_size * len(dt) / e2e_latency
|
||||
for e2e_latency, dt in zip(self.e2e_latency, self.token_generation_times)
|
||||
]
|
||||
def get_throughput(self, total_generated_tokens: int) -> list[float]:
|
||||
return [total_generated_tokens / e2e_latency for e2e_latency in self.e2e_latency]
|
||||
|
||||
def pprint(self, batch_size: int = 0, tabs: int = 0) -> None:
|
||||
stats_to_collate = [
|
||||
add_unit_to_duration(compute_basic_statistics(self.e2e_latency)),
|
||||
add_unit_to_duration(compute_basic_statistics(self.get_measured_ttft())),
|
||||
add_unit_to_duration(compute_basic_statistics(self.get_measured_itl())),
|
||||
]
|
||||
if batch_size > 0:
|
||||
throughput_stats = compute_basic_statistics(self.get_throughput(batch_size))
|
||||
stats_to_collate.append({key: f"{value:.2f}tok/s" for key, value in throughput_stats.items()})
|
||||
collated_stats = equalize_lengths_and_collate(stats_to_collate)
|
||||
dict_to_pprint = {
|
||||
"E2E Latency": collated_stats[0],
|
||||
"Time to First Token": collated_stats[1],
|
||||
"Inter-Token Latency": collated_stats[2],
|
||||
def pprint(self, batch_size: int = 0, num_generated_tokens: int = 0, tabs: int = 0) -> None:
|
||||
measurements = {
|
||||
"E2E Latency": add_unit_to_duration(compute_basic_statistics(self.e2e_latency)),
|
||||
"Time to First Token": add_unit_to_duration(compute_basic_statistics(self.get_measured_ttft())),
|
||||
}
|
||||
itl_values = self.get_measured_itl()
|
||||
if len(itl_values) > 0:
|
||||
measurements["Inter-Token Latency"] = add_unit_to_duration(compute_basic_statistics(itl_values))
|
||||
if batch_size > 0:
|
||||
dict_to_pprint["Throughput"] = collated_stats[3]
|
||||
throughput_stats = compute_basic_statistics(self.get_throughput(batch_size * num_generated_tokens))
|
||||
measurements["Throughput"] = {key: f"{value:.2f}tok/s" for key, value in throughput_stats.items()}
|
||||
dict_to_pprint = equalize_lengths_and_collate(measurements)
|
||||
pretty_print_dict(dict_to_pprint, tabs=tabs)
|
||||
|
||||
@ -2,6 +2,5 @@ numpy>=1.21.0
|
||||
psutil>=5.8.0
|
||||
gpustat>=1.0.0
|
||||
torch>=2.0.0
|
||||
transformers>=4.30.0
|
||||
datasets>=2.10.0
|
||||
huggingface_hub>=0.16.0
|
||||
|
||||
@ -80,6 +80,10 @@ if __name__ == "__main__":
|
||||
logger.info(f"Benchmark run UUID: {benchmark_run_uuid}")
|
||||
logger.info(f"Output directory: {args.output_dir}")
|
||||
|
||||
# We cannot compute ITL if we don't have at least two measurements
|
||||
if any(n <= 1 for n in args.num_tokens_to_generate):
|
||||
raise ValueError("--num_tokens_to_generate arguments should be larger than 1")
|
||||
|
||||
# Error out if one of the arguments is not provided
|
||||
if len(args.batch_size) * len(args.sequence_length) * len(args.num_tokens_to_generate) == 0:
|
||||
raise ValueError(
|
||||
|
||||
@ -1008,6 +1008,8 @@
|
||||
title: AltCLIP
|
||||
- local: model_doc/aria
|
||||
title: Aria
|
||||
- local: model_doc/audioflamingo3
|
||||
title: AudioFlamingo3
|
||||
- local: model_doc/aya_vision
|
||||
title: AyaVision
|
||||
- local: model_doc/blip
|
||||
|
||||
402
docs/source/en/model_doc/audioflamingo3.md
Normal file
402
docs/source/en/model_doc/audioflamingo3.md
Normal file
@ -0,0 +1,402 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
*This model was released on 2025-07-10 and added to Hugging Face Transformers on 2025-11-11.*
|
||||
|
||||
# Audio Flamingo 3
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
|
||||
## Overview
|
||||
|
||||
Audio Flamingo 3 (AF3) is a fully open large audio–language model designed for robust understanding and reasoning over speech, environmental sounds, and music. AF3 pairs a Whisper-style audio encoder with a causal language model and performs replace-in-place audio–text fusion: the processor aligns post-pool audio frames to a dedicated placeholder token and the model replaces those token slots with projected audio embeddings during the forward pass.
|
||||
|
||||
The model checkpoint is available at: [nvidia/audio-flamingo-3-hf](https://huggingface.co/nvidia/audio-flamingo-3-hf)
|
||||
|
||||
Highlights:
|
||||
|
||||
- Unified audio encoder across speech, sound, and music.
|
||||
- **Long-audio support via windowing and post-pool alignment (up to 10 minutes maximum).** The model processes audio in 30-second windows with a hard limit of 20 windows (10 minutes total). Audio longer than 10 minutes will be truncated.
|
||||
- Deterministic fusion that preserves sequence length by replacing audio placeholder tokens with audio embeddings.
|
||||
|
||||
This model was contributed by [Lasha Koroshinadze](https://huggingface.co/lashahub) and [Eric Bezzam](https://huggingface.co/bezzam).
|
||||
|
||||
### Paper
|
||||
|
||||
[Audio Flamingo 3](https://huggingface.co/papers/2507.08128): Advancing Audio Intelligence with Fully Open Large Audio Language Models
|
||||
A. Goel, S. Ghosh, J. Kim, S. Kumar, Z. Kong, S. Lee, C.-H. H. Yang, R. Duraiswami, D. Manocha, R. Valle, B. Catanzaro
|
||||
NVIDIA and University of Maryland
|
||||
Project: https://research.nvidia.com/labs/adlr/AF3/
|
||||
|
||||
## Usage
|
||||
|
||||
### Audio Instruct Mode
|
||||
|
||||
The model supports audio-text instructions, including multi-turn interactions, all processed in batches.
|
||||
|
||||
➡️ audio + text instruction
|
||||
|
||||
```python
|
||||
from transformers import AudioFlamingo3ForConditionalGeneration, AutoProcessor
|
||||
|
||||
model_id = "nvidia/audio-flamingo-3-hf"
|
||||
processor = AutoProcessor.from_pretrained(model_id)
|
||||
model = AudioFlamingo3ForConditionalGeneration.from_pretrained(model_id, device_map="auto")
|
||||
|
||||
conversation = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "Transcribe the input speech."},
|
||||
{"type": "audio", "path": "https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/WhDJDIviAOg_120_10.mp3"},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
inputs = processor.apply_chat_template(
|
||||
conversation,
|
||||
tokenize=True,
|
||||
add_generation_prompt=True,
|
||||
return_dict=True,
|
||||
).to(model.device)
|
||||
|
||||
outputs = model.generate(**inputs, max_new_tokens=500)
|
||||
|
||||
decoded_outputs = processor.batch_decode(outputs[:, inputs.input_ids.shape[1]:], skip_special_tokens=True)
|
||||
print(decoded_outputs)
|
||||
```
|
||||
|
||||
➡️ multi-turn:
|
||||
|
||||
```python
|
||||
from transformers import AudioFlamingo3ForConditionalGeneration, AutoProcessor
|
||||
|
||||
model_id = "nvidia/audio-flamingo-3-hf"
|
||||
processor = AutoProcessor.from_pretrained(model_id)
|
||||
model = AudioFlamingo3ForConditionalGeneration.from_pretrained(model_id, device_map="auto")
|
||||
|
||||
conversation = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Instruction: How does the tone of female speech change throughout the audio? Choose the correct option among the options below: (A) Sad to happy (B) Happy to sad (C) Neutral to happy (D) Happy to neutral.",
|
||||
},
|
||||
{"type": "audio", "path": "https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/000000786159.31.wav"},
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": "(A) Sad to happy"}],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "Why do you think so?"},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
inputs = processor.apply_chat_template(
|
||||
conversation,
|
||||
tokenize=True,
|
||||
add_generation_prompt=True,
|
||||
return_dict=True,
|
||||
).to(model.device)
|
||||
|
||||
outputs = model.generate(**inputs, max_new_tokens=500)
|
||||
|
||||
decoded_outputs = processor.batch_decode(outputs[:, inputs.input_ids.shape[1]:], skip_special_tokens=True)
|
||||
print(decoded_outputs)
|
||||
```
|
||||
|
||||
➡️ text only:
|
||||
|
||||
```python
|
||||
from transformers import AudioFlamingo3ForConditionalGeneration, AutoProcessor
|
||||
|
||||
model_id = "nvidia/audio-flamingo-3-hf"
|
||||
processor = AutoProcessor.from_pretrained(model_id)
|
||||
model = AudioFlamingo3ForConditionalGeneration.from_pretrained(model_id, device_map="auto")
|
||||
|
||||
conversation = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What is the capital of France?"},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
inputs = processor.apply_chat_template(
|
||||
conversation,
|
||||
tokenize=True,
|
||||
add_generation_prompt=True,
|
||||
return_dict=True,
|
||||
).to(model.device)
|
||||
|
||||
outputs = model.generate(**inputs, max_new_tokens=500)
|
||||
|
||||
decoded_outputs = processor.batch_decode(outputs[:, inputs.input_ids.shape[1]:], skip_special_tokens=True)
|
||||
print(decoded_outputs)
|
||||
```
|
||||
|
||||
➡️ audio only:
|
||||
|
||||
```python
|
||||
from transformers import AudioFlamingo3ForConditionalGeneration, AutoProcessor
|
||||
|
||||
model_id = "nvidia/audio-flamingo-3-hf"
|
||||
processor = AutoProcessor.from_pretrained(model_id)
|
||||
model = AudioFlamingo3ForConditionalGeneration.from_pretrained(model_id, device_map="auto")
|
||||
|
||||
conversation = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "audio", "path": "https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/WhDJDIviAOg_120_10.mp3"},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
inputs = processor.apply_chat_template(
|
||||
conversation,
|
||||
tokenize=True,
|
||||
add_generation_prompt=True,
|
||||
return_dict=True,
|
||||
).to(model.device)
|
||||
|
||||
outputs = model.generate(**inputs, max_new_tokens=500)
|
||||
|
||||
decoded_outputs = processor.batch_decode(outputs[:, inputs.input_ids.shape[1]:], skip_special_tokens=True)
|
||||
print(decoded_outputs)
|
||||
```
|
||||
|
||||
➡️ batched inference!
|
||||
|
||||
```python
|
||||
from transformers import AudioFlamingo3ForConditionalGeneration, AutoProcessor
|
||||
|
||||
model_id = "nvidia/audio-flamingo-3-hf"
|
||||
processor = AutoProcessor.from_pretrained(model_id)
|
||||
model = AudioFlamingo3ForConditionalGeneration.from_pretrained(model_id, device_map="auto")
|
||||
|
||||
conversations = [
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "Transcribe the input speech."},
|
||||
{
|
||||
"type": "audio",
|
||||
"path": "https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/t_837b89f2-26aa-4ee2-bdf6-f73f0dd59b26.wav",
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "This track feels really peaceful and introspective. What elements make it feel so calming and meditative?",
|
||||
},
|
||||
{"type": "audio", "path": "https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/FPSbCAANfbJLVSwD.mp3"},
|
||||
],
|
||||
}
|
||||
],
|
||||
]
|
||||
|
||||
inputs = processor.apply_chat_template(
|
||||
conversations,
|
||||
tokenize=True,
|
||||
add_generation_prompt=True,
|
||||
return_dict=True,
|
||||
).to(model.device)
|
||||
|
||||
outputs = model.generate(**inputs, max_new_tokens=500)
|
||||
|
||||
decoded_outputs = processor.batch_decode(outputs[:, inputs.input_ids.shape[1]:], skip_special_tokens=True)
|
||||
print(decoded_outputs)
|
||||
```
|
||||
|
||||
➡️ Training:
|
||||
|
||||
```python
|
||||
from transformers import AudioFlamingo3ForConditionalGeneration, AutoProcessor
|
||||
|
||||
model_id = "nvidia/audio-flamingo-3-hf"
|
||||
processor = AutoProcessor.from_pretrained(model_id)
|
||||
model = AudioFlamingo3ForConditionalGeneration.from_pretrained(model_id, device_map="auto")
|
||||
model.train()
|
||||
|
||||
conversation = [
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "Transcribe the input speech."},
|
||||
{"type": "audio", "path": "https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/WhDJDIviAOg_120_10.mp3"},
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": "The transcription of the audio is 'summer follows spring the days grow longer and the nights are warm'."}],
|
||||
}
|
||||
],
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "This track feels really peaceful and introspective. What elements make it feel so calming and meditative?",
|
||||
},
|
||||
{"type": "audio", "path": "https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/FPSbCAANfbJLVSwD.mp3"},
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": "The transcription of the audio is 'some transcription of the audio'."}],
|
||||
}
|
||||
|
||||
]
|
||||
]
|
||||
|
||||
inputs = processor.apply_chat_template(
|
||||
conversation,
|
||||
tokenize=True,
|
||||
add_generation_prompt=True,
|
||||
return_dict=True,
|
||||
output_labels=True,
|
||||
).to(model.device)
|
||||
|
||||
loss = model(**inputs).loss
|
||||
loss.backward()
|
||||
```
|
||||
|
||||
➡️ transcription shortcut
|
||||
|
||||
```python
|
||||
from transformers import AudioFlamingo3ForConditionalGeneration, AutoProcessor
|
||||
|
||||
model_id = "nvidia/audio-flamingo-3-hf"
|
||||
processor = AutoProcessor.from_pretrained(model_id)
|
||||
model = AudioFlamingo3ForConditionalGeneration.from_pretrained(model_id, device_map="auto")
|
||||
|
||||
inputs = processor.apply_transcription_request(audio="https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/t_837b89f2-26aa-4ee2-bdf6-f73f0dd59b26.wav").to(model.device)
|
||||
|
||||
outputs = model.generate(**inputs, max_new_tokens=500)
|
||||
decoded_outputs = processor.batch_decode(outputs[:, inputs.input_ids.shape[1]:], skip_special_tokens=True, strip_prefix=True)
|
||||
|
||||
print(decoded_outputs)
|
||||
```
|
||||
|
||||
The model is trained to emit transcriptions prefixed with assistant framing such as `The spoken content of the audio is "<text>".`. Use `strip_prefix=True` (as shown above) to remove the fixed assistant sentence and surrounding quotes so that only the transcription remains.
|
||||
|
||||
## How the model works
|
||||
|
||||
### Architecture
|
||||
|
||||
* **AudioFlamingo3Encoder**
|
||||
Whisper-style feature extractor + encoder → average-pool over time (stride 2) → LayerNorm.
|
||||
Produces per-frame hidden states at the post-pool rate.
|
||||
|
||||
* **AudioFlamingo3MultiModalProjector**
|
||||
A small MLP that maps encoder features to the language model’s hidden size.
|
||||
|
||||
* **AudioFlamingo3ForConditionalGeneration**
|
||||
A causal language model that accepts text embeddings where each audio placeholder token slot is replaced, in place, by an audio frame embedding. No sequence-length change is introduced by fusion.
|
||||
|
||||
### Processor-level alignment
|
||||
|
||||
1. Each raw waveform is split into fixed-length windows based on the feature extractor’s `chunk_length` (seconds) and `sampling_rate` (Hz).
|
||||
2. For each window, the processor computes the number of post-pool frames `post_pool_len` that the encoder will output (matching the conv/pool schedule).
|
||||
3. The processor expands the audio placeholder token by the total number of post-pool frames across all windows.
|
||||
4. The model later replaces those token positions with the corresponding projected audio embeddings.
|
||||
|
||||
## Usage patterns
|
||||
|
||||
### Transcription shortcut
|
||||
|
||||
For automatic speech recognition you can skip writing the default instruction each time and call
|
||||
[`~transformers.AudioFlamingo3Processor.apply_transcription_request`]:
|
||||
|
||||
```python
|
||||
inputs = processor.apply_transcription_request(audio=audio_array)
|
||||
```
|
||||
|
||||
Pass `prompt="Transcribe the input speech."` (or a list of prompts for batch audio) to customize the instruction while
|
||||
keeping the audio placeholder handling.
|
||||
|
||||
`audio` accepts in-memory arrays, local file paths, or URLs. Any processor kwargs (`text_kwargs`, `audio_kwargs`, etc.)
|
||||
are forwarded, so you can tweak padding or tensor formats just like when calling `processor(...)`.
|
||||
|
||||
## Long audio and windowing
|
||||
|
||||
**Important: Maximum audio length is 10 minutes.** Audio longer than this will be truncated.
|
||||
|
||||
* The default setup processes 30-second windows at 16 kHz mono.
|
||||
* **The processor enforces a hard limit of 20 windows per sample, resulting in a maximum of 10 minutes of audio (20 windows × 30 seconds).**
|
||||
* For each window:
|
||||
|
||||
* `mel_len` is the padded mel length.
|
||||
* A conv stack reduces time as `conv_output_len = (mel_len - 1) // 2 + 1`.
|
||||
* Post-pool frames per window: `post_pool_len = (conv_output_len - 2) // 2 + 1`.
|
||||
* An audio placeholder token is expanded to the sum of `post_pool_len` across all windows.
|
||||
|
||||
## Padding, attention, and caching
|
||||
|
||||
* **Left padding vs right padding**
|
||||
For generation with mixed prompt lengths in a batch, left padding is usually preferable.
|
||||
For training, right padding is common; AF3’s fusion mechanism itself is padding-agnostic because it replaces in place.
|
||||
* **Attention masks**
|
||||
The processor returns `attention_mask` (text) and `input_features_mask` (audio). The model builds an internal 4-D mask on the encoder’s pre-pool axis with negative infinity at pad positions.
|
||||
* **Caching**
|
||||
During generation, `input_features` and `input_features_mask` are only passed on the first step. Subsequent steps use cached keys/values from the language model.
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
* Empty or truncated outputs when batching
|
||||
Use left padding for batched generation and decode only the new tokens after the prompt length, as shown in the quickstart.
|
||||
|
||||
## AudioFlamingo3Config
|
||||
|
||||
[[autodoc]] AudioFlamingo3Config
|
||||
|
||||
## AudioFlamingo3EncoderConfig
|
||||
|
||||
[[autodoc]] AudioFlamingo3EncoderConfig
|
||||
|
||||
## AudioFlamingo3Processor
|
||||
|
||||
[[autodoc]] AudioFlamingo3Processor
|
||||
|
||||
## AudioFlamingo3Encoder
|
||||
|
||||
[[autodoc]] AudioFlamingo3Encoder
|
||||
- forward
|
||||
|
||||
## AudioFlamingo3ForConditionalGeneration
|
||||
|
||||
[[autodoc]] AudioFlamingo3ForConditionalGeneration
|
||||
- forward
|
||||
@ -169,6 +169,9 @@ print("Pooled output shape:", pooled_output.shape)
|
||||
[[autodoc]] DINOv3ViTModel
|
||||
- forward
|
||||
|
||||
## DINOv3ViTBackbone
|
||||
[[autodoc]] DINOv3ViTBackbone
|
||||
|
||||
## DINOv3ConvNextModel
|
||||
|
||||
[[autodoc]] DINOv3ConvNextModel
|
||||
|
||||
@ -159,7 +159,7 @@ conversation3 = [
|
||||
|
||||
conversations = [conversation1, conversation2, conversation3]
|
||||
inputs = processor.apply_chat_template(
|
||||
conversation,
|
||||
conversations,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
|
||||
@ -329,7 +329,7 @@ from torchao.dtypes import Int4XPULayout
|
||||
from torchao.quantization.quant_primitives import ZeroPointDomain
|
||||
|
||||
|
||||
quant_config = Int4WeightOnlyConfig(group_size=128, layout=Int4XPULayout(), zero_point_domain=ZeroPointDomain.INT)
|
||||
quant_config = Int4WeightOnlyConfig(group_size=128, layout=Int4XPULayout(), zero_point_domain=ZeroPointDomain.INT, int4_packing_format="plain_int32")
|
||||
quantization_config = TorchAoConfig(quant_type=quant_config)
|
||||
|
||||
# Load and quantize the model
|
||||
@ -342,7 +342,7 @@ quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
|
||||
input_text = "What are we having for dinner?"
|
||||
input_ids = tokenizer(input_text, return_tensors="pt").to(model.device)
|
||||
input_ids = tokenizer(input_text, return_tensors="pt").to(quantized_model.device)
|
||||
|
||||
# auto-compile the quantized model with `cache_implementation="static"` to get speed up
|
||||
output = quantized_model.generate(**input_ids, max_new_tokens=10, cache_implementation="static")
|
||||
@ -395,7 +395,7 @@ from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
|
||||
from torchao.quantization import Int4WeightOnlyConfig
|
||||
from torchao.dtypes import Int4CPULayout
|
||||
|
||||
quant_config = Int4WeightOnlyConfig(group_size=128, layout=Int4CPULayout())
|
||||
quant_config = Int4WeightOnlyConfig(group_size=128, layout=Int4CPULayout(), int4_packing_format="opaque")
|
||||
quantization_config = TorchAoConfig(quant_type=quant_config)
|
||||
|
||||
# Load and quantize the model
|
||||
@ -422,7 +422,7 @@ print(tokenizer.decode(output[0], skip_special_tokens=True))
|
||||
|
||||
#### 1. Skip quantization for certain layers
|
||||
|
||||
With `ModuleFqnToConfig` we can specify a default configuration for all layers while skipping quantization for certain layers.
|
||||
With `FqnToConfig` we can specify a default configuration for all layers while skipping quantization for certain layers.
|
||||
|
||||
```py
|
||||
import torch
|
||||
@ -430,11 +430,11 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig
|
||||
|
||||
model_id = "meta-llama/Llama-3.1-8B-Instruct"
|
||||
|
||||
from torchao.quantization import Int4WeightOnlyConfig, ModuleFqnToConfig
|
||||
from torchao.quantization import Int4WeightOnlyConfig, FqnToConfig
|
||||
config = Int4WeightOnlyConfig(group_size=128)
|
||||
|
||||
# set default to int4 (for linears), and skip quantizing `model.layers.0.self_attn.q_proj`
|
||||
quant_config = ModuleFqnToConfig({"_default": config, "model.layers.0.self_attn.q_proj": None})
|
||||
quant_config = FqnToConfig({"_default": config, "model.layers.0.self_attn.q_proj": None})
|
||||
quantization_config = TorchAoConfig(quant_type=quant_config)
|
||||
quantized_model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", dtype=torch.bfloat16, quantization_config=quantization_config)
|
||||
# lm_head is not quantized and model.layers.0.self_attn.q_proj is not quantized
|
||||
@ -459,7 +459,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig
|
||||
|
||||
model_id = "facebook/opt-125m"
|
||||
|
||||
from torchao.quantization import Int4WeightOnlyConfig, ModuleFqnToConfig, Int8DynamicActivationInt4WeightConfig, IntxWeightOnlyConfig, PerAxis, MappingType
|
||||
from torchao.quantization import Int4WeightOnlyConfig, FqnToConfig, Int8DynamicActivationInt4WeightConfig, IntxWeightOnlyConfig, PerAxis, MappingType
|
||||
|
||||
weight_dtype = torch.int8
|
||||
granularity = PerAxis(0)
|
||||
@ -470,7 +470,7 @@ embedding_config = IntxWeightOnlyConfig(
|
||||
mapping_type=mapping_type,
|
||||
)
|
||||
linear_config = Int8DynamicActivationInt4WeightConfig(group_size=128)
|
||||
quant_config = ModuleFqnToConfig({"_default": linear_config, "model.decoder.embed_tokens": embedding_config, "model.decoder.embed_positions": None})
|
||||
quant_config = FqnToConfig({"_default": linear_config, "model.decoder.embed_tokens": embedding_config, "model.decoder.embed_positions": None})
|
||||
# set `include_embedding` to True in order to include embedding in quantization
|
||||
# when `include_embedding` is True, we'll remove input embedding from `modules_not_to_convert` as well
|
||||
quantization_config = TorchAoConfig(quant_type=quant_config, include_embedding=True)
|
||||
@ -521,7 +521,7 @@ from torchao.quantization import (
|
||||
IntxWeightOnlyConfig,
|
||||
PerRow,
|
||||
PerAxis,
|
||||
ModuleFqnToConfig,
|
||||
FqnToConfig,
|
||||
Float8Tensor,
|
||||
Int4TilePackedTo4dTensor,
|
||||
IntxUnpackedToInt8Tensor,
|
||||
@ -550,7 +550,7 @@ qconfig_dict = {
|
||||
|
||||
"_default": intxwo,
|
||||
}
|
||||
quant_config = ModuleFqnToConfig(qconfig_dict)
|
||||
quant_config = FqnToConfig(qconfig_dict)
|
||||
quantization_config = TorchAoConfig(quant_type=quant_config)
|
||||
quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id,
|
||||
|
||||
@ -127,7 +127,7 @@ def parse_args():
|
||||
parser.add_argument(
|
||||
"--use_slow_tokenizer",
|
||||
action="store_true",
|
||||
help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).",
|
||||
help="If passed, will use a slow tokenizer (not backed by the Hugging Face Tokenizers library).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--per_device_train_batch_size",
|
||||
|
||||
@ -132,7 +132,7 @@ def parse_args():
|
||||
parser.add_argument(
|
||||
"--use_slow_tokenizer",
|
||||
action="store_true",
|
||||
help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).",
|
||||
help="If passed, will use a slow tokenizer (not backed by the Tokenizers library).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--per_device_train_batch_size",
|
||||
|
||||
@ -130,7 +130,7 @@ def parse_args():
|
||||
parser.add_argument(
|
||||
"--use_slow_tokenizer",
|
||||
action="store_true",
|
||||
help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).",
|
||||
help="If passed, will use a slow tokenizer (not backed by the Hugging Face Tokenizers library).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--per_device_train_batch_size",
|
||||
|
||||
@ -128,7 +128,7 @@ def parse_args():
|
||||
parser.add_argument(
|
||||
"--use_slow_tokenizer",
|
||||
action="store_true",
|
||||
help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).",
|
||||
help="If passed, will use a slow tokenizer (not backed by the HuggingFace Tokenizers library).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--per_device_train_batch_size",
|
||||
|
||||
@ -151,7 +151,7 @@ def parse_args():
|
||||
parser.add_argument(
|
||||
"--use_slow_tokenizer",
|
||||
action="store_true",
|
||||
help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).",
|
||||
help="If passed, will use a slow tokenizer (not backed by the Hugging Face Tokenizers library).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--per_device_train_batch_size",
|
||||
|
||||
@ -223,7 +223,7 @@ def parse_args():
|
||||
parser.add_argument(
|
||||
"--use_slow_tokenizer",
|
||||
action="store_true",
|
||||
help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).",
|
||||
help="If passed, will use a slow tokenizer (not backed by the Hugging Face Tokenizers library).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--per_device_train_batch_size",
|
||||
|
||||
@ -74,6 +74,7 @@ class ExamplesTestsNoTrainer(TestCasePlus):
|
||||
def tearDownClass(cls):
|
||||
shutil.rmtree(cls.tmpdir)
|
||||
|
||||
@slow
|
||||
@mock.patch.dict(os.environ, {"WANDB_MODE": "offline", "DVCLIVE_TEST": "true"})
|
||||
def test_run_glue_no_trainer(self):
|
||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||
@ -147,6 +148,7 @@ class ExamplesTestsNoTrainer(TestCasePlus):
|
||||
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
|
||||
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "mlm_no_trainer")))
|
||||
|
||||
@slow
|
||||
@mock.patch.dict(os.environ, {"WANDB_MODE": "offline", "DVCLIVE_TEST": "true"})
|
||||
def test_run_ner_no_trainer(self):
|
||||
# with so little data distributed training needs more epochs to get the score on par with 0/1 gpu
|
||||
@ -175,6 +177,7 @@ class ExamplesTestsNoTrainer(TestCasePlus):
|
||||
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
|
||||
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "ner_no_trainer")))
|
||||
|
||||
@slow
|
||||
@mock.patch.dict(os.environ, {"WANDB_MODE": "offline", "DVCLIVE_TEST": "true"})
|
||||
def test_run_squad_no_trainer(self):
|
||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||
@ -203,6 +206,7 @@ class ExamplesTestsNoTrainer(TestCasePlus):
|
||||
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
|
||||
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "qa_no_trainer")))
|
||||
|
||||
@slow
|
||||
@mock.patch.dict(os.environ, {"WANDB_MODE": "offline", "DVCLIVE_TEST": "true"})
|
||||
def test_run_swag_no_trainer(self):
|
||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||
@ -305,6 +309,7 @@ class ExamplesTestsNoTrainer(TestCasePlus):
|
||||
result = get_results(tmp_dir)
|
||||
self.assertGreaterEqual(result["eval_overall_accuracy"], 0.10)
|
||||
|
||||
@slow
|
||||
@mock.patch.dict(os.environ, {"WANDB_MODE": "offline", "DVCLIVE_TEST": "true"})
|
||||
def test_run_image_classification_no_trainer(self):
|
||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||
|
||||
@ -374,6 +374,7 @@ class ExamplesTests(TestCasePlus):
|
||||
result = get_results(tmp_dir)
|
||||
self.assertGreaterEqual(result["eval_bleu"], 30)
|
||||
|
||||
@slow
|
||||
def test_run_image_classification(self):
|
||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||
testargs = f"""
|
||||
@ -403,6 +404,7 @@ class ExamplesTests(TestCasePlus):
|
||||
result = get_results(tmp_dir)
|
||||
self.assertGreaterEqual(result["eval_accuracy"], 0.8)
|
||||
|
||||
@slow
|
||||
def test_run_speech_recognition_ctc(self):
|
||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||
testargs = f"""
|
||||
@ -573,6 +575,7 @@ class ExamplesTests(TestCasePlus):
|
||||
model = ViTMAEForPreTraining.from_pretrained(tmp_dir)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
@slow
|
||||
def test_run_semantic_segmentation(self):
|
||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||
testargs = f"""
|
||||
@ -597,6 +600,7 @@ class ExamplesTests(TestCasePlus):
|
||||
result = get_results(tmp_dir)
|
||||
self.assertGreaterEqual(result["eval_overall_accuracy"], 0.1)
|
||||
|
||||
@slow
|
||||
@patch.dict(os.environ, {"WANDB_DISABLED": "true"})
|
||||
def test_run_object_detection(self):
|
||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||
@ -624,6 +628,7 @@ class ExamplesTests(TestCasePlus):
|
||||
result = get_results(tmp_dir)
|
||||
self.assertGreaterEqual(result["test_map"], 0.1)
|
||||
|
||||
@slow
|
||||
@patch.dict(os.environ, {"WANDB_DISABLED": "true"})
|
||||
def test_run_instance_segmentation(self):
|
||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||
|
||||
@ -120,7 +120,7 @@ def parse_args():
|
||||
parser.add_argument(
|
||||
"--use_slow_tokenizer",
|
||||
action="store_true",
|
||||
help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).",
|
||||
help="If passed, will use a slow tokenizer (not backed by the Hugging Face Tokenizers library).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--per_device_train_batch_size",
|
||||
|
||||
@ -212,7 +212,7 @@ def parse_args():
|
||||
parser.add_argument(
|
||||
"--use_slow_tokenizer",
|
||||
action="store_true",
|
||||
help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).",
|
||||
help="If passed, will use a slow tokenizer (not backed by the Hugging Face Tokenizers library).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--per_device_train_batch_size",
|
||||
|
||||
@ -50,6 +50,7 @@ checkpoint: 检查点
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://huggingface.co/models"><img alt="Checkpoints on Hub" src="https://img.shields.io/endpoint?url=https://huggingface.co/api/shields/models&color=brightgreen"></a>
|
||||
<a href="https://circleci.com/gh/huggingface/transformers"><img alt="Build" src="https://img.shields.io/circleci/build/github/huggingface/transformers/main"></a>
|
||||
<a href="https://github.com/huggingface/transformers/blob/main/LICENSE"><img alt="GitHub" src="https://img.shields.io/github/license/huggingface/transformers.svg?color=blue"></a>
|
||||
<a href="https://huggingface.co/docs/transformers/index"><img alt="Documentation" src="https://img.shields.io/website/http/huggingface.co/docs/transformers/index.svg?down_color=red&down_message=offline&up_message=online"></a>
|
||||
@ -60,7 +61,7 @@ checkpoint: 检查点
|
||||
|
||||
<h4 align="center">
|
||||
<p>
|
||||
<a href="https://github.com/huggingface/transformers/">English</a> |
|
||||
<a href="https://github.com/huggingface/transformers/blob/main/README.md">English</a> |
|
||||
<b>简体中文</b> |
|
||||
<a href="https://github.com/huggingface/transformers/blob/main/i18n/README_zh-hant.md">繁體中文</a> |
|
||||
<a href="https://github.com/huggingface/transformers/blob/main/i18n/README_ko.md">한국어</a> |
|
||||
@ -68,7 +69,7 @@ checkpoint: 检查点
|
||||
<a href="https://github.com/huggingface/transformers/blob/main/i18n/README_ja.md">日本語</a> |
|
||||
<a href="https://github.com/huggingface/transformers/blob/main/i18n/README_hd.md">हिन्दी</a> |
|
||||
<a href="https://github.com/huggingface/transformers/blob/main/i18n/README_ru.md">Русский</a> |
|
||||
<a href="https://github.com/huggingface/transformers/blob/main/i18n/README_pt-br.md">Рortuguês</a> |
|
||||
<a href="https://github.com/huggingface/transformers/blob/main/i18n/README_pt-br.md">Português</a> |
|
||||
<a href="https://github.com/huggingface/transformers/blob/main/i18n/README_te.md">తెలుగు</a> |
|
||||
<a href="https://github.com/huggingface/transformers/blob/main/i18n/README_fr.md">Français</a> |
|
||||
<a href="https://github.com/huggingface/transformers/blob/main/i18n/README_de.md">Deutsch</a> |
|
||||
@ -81,182 +82,258 @@ checkpoint: 检查点
|
||||
</h4>
|
||||
|
||||
<h3 align="center">
|
||||
<p>为 Jax、PyTorch 和 TensorFlow 打造的先进的自然语言处理函数库</p>
|
||||
<p>为文本、视觉、音频、视频与多模态提供推理与训练的先进预训练模型</p>
|
||||
</h3>
|
||||
|
||||
<h3 align="center">
|
||||
<a href="https://hf.co/course"><img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/course_banner.png"></a>
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/transformers_as_a_model_definition.png"/>
|
||||
</h3>
|
||||
|
||||
🤗 Transformers 提供了数以千计的预训练模型,支持 100 多种语言的文本分类、信息抽取、问答、摘要、翻译、文本生成。它的宗旨是让最先进的 NLP 技术人人易用。
|
||||
Transformers 充当跨文本、计算机视觉、音频、视频与多模态的最先进机器学习模型的「模型定义框架」,同时覆盖推理与训练。
|
||||
|
||||
🤗 Transformers 提供了便于快速下载和使用的API,让你可以把预训练模型用在给定文本、在你的数据集上微调然后通过 [model hub](https://huggingface.co/models) 与社区共享。同时,每个定义的 Python 模块都是完全独立的,便于修改和快速进行研究实验。
|
||||
它将模型的定义集中化,使整个生态系统对该定义达成一致。`transformers` 是跨框架的枢纽:一旦某模型定义被支持,它通常就能兼容多数训练框架(如 Axolotl、Unsloth、DeepSpeed、FSDP、PyTorch‑Lightning 等)、推理引擎(如 vLLM、SGLang、TGI 等),以及依赖 `transformers` 模型定义的相关库(如 llama.cpp、mlx 等)。
|
||||
|
||||
🤗 Transformers 支持三个最热门的深度学习库: [Jax](https://jax.readthedocs.io/en/latest/), [PyTorch](https://pytorch.org/) 以及 [TensorFlow](https://www.tensorflow.org/) — 并与之无缝整合。你可以直接使用一个框架训练你的模型然后用另一个加载和推理。
|
||||
我们的目标是持续支持新的最先进模型,并通过让模型定义保持简单、可定制且高效来普及其使用。
|
||||
|
||||
## 在线演示
|
||||
|
||||
你可以直接在模型页面上测试大多数 [model hub](https://huggingface.co/models) 上的模型。 我们也提供了 [私有模型托管、模型版本管理以及推理API](https://huggingface.co/pricing)。
|
||||
|
||||
这里是一些例子:
|
||||
- [用 BERT 做掩码填词](https://huggingface.co/google-bert/bert-base-uncased?text=Paris+is+the+%5BMASK%5D+of+France)
|
||||
- [用 Electra 做命名实体识别](https://huggingface.co/dbmdz/electra-large-discriminator-finetuned-conll03-english?text=My+name+is+Sarah+and+I+live+in+London+city)
|
||||
- [用 GPT-2 做文本生成](https://huggingface.co/openai-community/gpt2?text=A+long+time+ago%2C+)
|
||||
- [用 RoBERTa 做自然语言推理](https://huggingface.co/FacebookAI/roberta-large-mnli?text=The+dog+was+lost.+Nobody+lost+any+animal)
|
||||
- [用 BART 做文本摘要](https://huggingface.co/facebook/bart-large-cnn?text=The+tower+is+324+metres+%281%2C063+ft%29+tall%2C+about+the+same+height+as+an+81-storey+building%2C+and+the+tallest+structure+in+Paris.+Its+base+is+square%2C+measuring+125+metres+%28410+ft%29+on+each+side.+During+its+construction%2C+the+Eiffel+Tower+surpassed+the+Washington+Monument+to+become+the+tallest+man-made+structure+in+the+world%2C+a+title+it+held+for+41+years+until+the+Chrysler+Building+in+New+York+City+was+finished+in+1930.+It+was+the+first+structure+to+reach+a+height+of+300+metres.+Due+to+the+addition+of+a+broadcasting+aerial+at+the+top+of+the+tower+in+1957%2C+it+is+now+taller+than+the+Chrysler+Building+by+5.2+metres+%2817+ft%29.+Excluding+transmitters%2C+the+Eiffel+Tower+is+the+second+tallest+free-standing+structure+in+France+after+the+Millau+Viaduct)
|
||||
- [用 DistilBERT 做问答](https://huggingface.co/distilbert/distilbert-base-uncased-distilled-squad?text=Which+name+is+also+used+to+describe+the+Amazon+rainforest+in+English%3F&context=The+Amazon+rainforest+%28Portuguese%3A+Floresta+Amaz%C3%B4nica+or+Amaz%C3%B4nia%3B+Spanish%3A+Selva+Amaz%C3%B3nica%2C+Amazon%C3%ADa+or+usually+Amazonia%3B+French%3A+For%C3%AAt+amazonienne%3B+Dutch%3A+Amazoneregenwoud%29%2C+also+known+in+English+as+Amazonia+or+the+Amazon+Jungle%2C+is+a+moist+broadleaf+forest+that+covers+most+of+the+Amazon+basin+of+South+America.+This+basin+encompasses+7%2C000%2C000+square+kilometres+%282%2C700%2C000+sq+mi%29%2C+of+which+5%2C500%2C000+square+kilometres+%282%2C100%2C000+sq+mi%29+are+covered+by+the+rainforest.+This+region+includes+territory+belonging+to+nine+nations.+The+majority+of+the+forest+is+contained+within+Brazil%2C+with+60%25+of+the+rainforest%2C+followed+by+Peru+with+13%25%2C+Colombia+with+10%25%2C+and+with+minor+amounts+in+Venezuela%2C+Ecuador%2C+Bolivia%2C+Guyana%2C+Suriname+and+French+Guiana.+States+or+departments+in+four+nations+contain+%22Amazonas%22+in+their+names.+The+Amazon+represents+over+half+of+the+planet%27s+remaining+rainforests%2C+and+comprises+the+largest+and+most+biodiverse+tract+of+tropical+rainforest+in+the+world%2C+with+an+estimated+390+billion+individual+trees+divided+into+16%2C000+species)
|
||||
- [用 T5 做翻译](https://huggingface.co/google-t5/t5-base?text=My+name+is+Wolfgang+and+I+live+in+Berlin)
|
||||
|
||||
**[Write With Transformer](https://transformer.huggingface.co)**,由 Hugging Face 团队打造,是一个文本生成的官方 demo。
|
||||
|
||||
## 如果你在寻找由 Hugging Face 团队提供的定制化支持服务
|
||||
|
||||
<a target="_blank" href="https://huggingface.co/support">
|
||||
<img alt="HuggingFace Expert Acceleration Program" src="https://huggingface.co/front/thumbnails/support.png" style="max-width: 600px; border: 1px solid #eee; border-radius: 4px; box-shadow: 0 1px 2px 0 rgba(0, 0, 0, 0.05);">
|
||||
</a><br>
|
||||
|
||||
## 快速上手
|
||||
|
||||
我们为快速使用模型提供了 `pipeline` API。Pipeline 聚合了预训练模型和对应的文本预处理。下面是一个快速使用 pipeline 去判断正负面情绪的例子:
|
||||
|
||||
```python
|
||||
>>> from transformers import pipeline
|
||||
|
||||
# 使用情绪分析 pipeline
|
||||
>>> classifier = pipeline('sentiment-analysis')
|
||||
>>> classifier('We are very happy to introduce pipeline to the transformers repository.')
|
||||
[{'label': 'POSITIVE', 'score': 0.9996980428695679}]
|
||||
```
|
||||
|
||||
第二行代码下载并缓存了 pipeline 使用的预训练模型,而第三行代码则在给定的文本上进行了评估。这里的答案"正面" (positive) 具有 99 的置信度。
|
||||
|
||||
许多的 NLP 任务都有开箱即用的预训练 `pipeline`。比如说,我们可以轻松的从给定文本中抽取问题答案:
|
||||
|
||||
``` python
|
||||
>>> from transformers import pipeline
|
||||
|
||||
# 使用问答 pipeline
|
||||
>>> question_answerer = pipeline('question-answering')
|
||||
>>> question_answerer({
|
||||
... 'question': 'What is the name of the repository ?',
|
||||
... 'context': 'Pipeline has been included in the huggingface/transformers repository'
|
||||
... })
|
||||
{'score': 0.30970096588134766, 'start': 34, 'end': 58, 'answer': 'huggingface/transformers'}
|
||||
|
||||
```
|
||||
|
||||
除了给出答案,预训练模型还给出了对应的置信度分数、答案在词符化 (tokenized) 后的文本中开始和结束的位置。你可以从[这个教程](https://huggingface.co/docs/transformers/task_summary)了解更多 `pipeline` API 支持的任务。
|
||||
|
||||
要在你的任务上下载和使用任意预训练模型也很简单,只需三行代码。这里是 PyTorch 版的示例:
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, AutoModel
|
||||
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
|
||||
>>> model = AutoModel.from_pretrained("google-bert/bert-base-uncased")
|
||||
|
||||
>>> inputs = tokenizer("Hello world!", return_tensors="pt")
|
||||
>>> outputs = model(**inputs)
|
||||
```
|
||||
这里是等效的 TensorFlow 代码:
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, TFAutoModel
|
||||
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
|
||||
>>> model = TFAutoModel.from_pretrained("google-bert/bert-base-uncased")
|
||||
|
||||
>>> inputs = tokenizer("Hello world!", return_tensors="tf")
|
||||
>>> outputs = model(**inputs)
|
||||
```
|
||||
|
||||
词符化器 (tokenizer) 为所有的预训练模型提供了预处理,并可以直接对单个字符串进行调用(比如上面的例子)或对列表 (list) 调用。它会输出一个你可以在下游代码里使用或直接通过 `**` 解包表达式传给模型的词典 (dict)。
|
||||
|
||||
模型本身是一个常规的 [Pytorch `nn.Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) 或 [TensorFlow `tf.keras.Model`](https://www.tensorflow.org/api_docs/python/tf/keras/Model)(取决于你的后端),可以常规方式使用。 [这个教程](https://huggingface.co/transformers/training.html)解释了如何将这样的模型整合到经典的 PyTorch 或 TensorFlow 训练循环中,或是如何使用我们的 `Trainer` (训练器)API 来在一个新的数据集上快速微调。
|
||||
|
||||
## 为什么要用 transformers?
|
||||
|
||||
1. 便于使用的先进模型:
|
||||
- NLU 和 NLG 上表现优越
|
||||
- 对教学和实践友好且低门槛
|
||||
- 高级抽象,只需了解三个类
|
||||
- 对所有模型统一的API
|
||||
|
||||
1. 更低计算开销,更少的碳排放:
|
||||
- 研究人员可以分享已训练的模型而非每次从头开始训练
|
||||
- 工程师可以减少计算用时和生产环境开销
|
||||
- 数十种模型架构、两千多个预训练模型、100多种语言支持
|
||||
|
||||
1. 对于模型生命周期的每一个部分都面面俱到:
|
||||
- 训练先进的模型,只需 3 行代码
|
||||
- 模型在不同深度学习框架间任意转移,随你心意
|
||||
- 为训练、评估和生产选择最适合的框架,衔接无缝
|
||||
|
||||
1. 为你的需求轻松定制专属模型和用例:
|
||||
- 我们为每种模型架构提供了多个用例来复现原论文结果
|
||||
- 模型内部结构保持透明一致
|
||||
- 模型文件可单独使用,方便修改和快速实验
|
||||
|
||||
## 什么情况下我不该用 transformers?
|
||||
|
||||
- 本库并不是模块化的神经网络工具箱。模型文件中的代码特意呈若璞玉,未经额外抽象封装,以便研究人员快速迭代修改而不致溺于抽象和文件跳转之中。
|
||||
- `Trainer` API 并非兼容任何模型,只为本库之模型优化。若是在寻找适用于通用机器学习的训练循环实现,请另觅他库。
|
||||
- 尽管我们已尽力而为,[examples 目录](https://github.com/huggingface/transformers/tree/main/examples)中的脚本也仅为用例而已。对于你的特定问题,它们并不一定开箱即用,可能需要改几行代码以适之。
|
||||
目前在 [Hugging Face Hub](https://huggingface.com/models) 上有超过 1M+ 使用 `transformers` 的[模型检查点](https://huggingface.co/models?library=transformers&sort=trending),可随取随用。
|
||||
|
||||
今天就去探索 Hub,找到一个模型,并用 Transformers 立刻开始吧。
|
||||
|
||||
## 安装
|
||||
|
||||
### 使用 pip
|
||||
Transformers 支持 Python 3.9+,以及 [PyTorch](https://pytorch.org/get-started/locally/) 2.1+。
|
||||
|
||||
这个仓库已在 Python 3.9+、Flax 0.4.1+、PyTorch 2.1+ 和 TensorFlow 2.6+ 下经过测试。
|
||||
使用 [venv](https://docs.python.org/3/library/venv.html) 或 [uv](https://docs.astral.sh/uv/)(一个基于 Rust 的快速 Python 包与项目管理器)创建并激活虚拟环境:
|
||||
|
||||
你可以在[虚拟环境](https://docs.python.org/3/library/venv.html)中安装 🤗 Transformers。如果你还不熟悉 Python 的虚拟环境,请阅此[用户说明](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/)。
|
||||
|
||||
首先,用你打算使用的版本的 Python 创建一个虚拟环境并激活。
|
||||
|
||||
然后,你需要安装 Flax、PyTorch 或 TensorFlow 其中之一。关于在你使用的平台上安装这些框架,请参阅 [TensorFlow 安装页](https://www.tensorflow.org/install/), [PyTorch 安装页](https://pytorch.org/get-started/locally/#start-locally) 或 [Flax 安装页](https://github.com/google/flax#quick-install)。
|
||||
|
||||
当这些后端之一安装成功后, 🤗 Transformers 可依此安装:
|
||||
|
||||
```bash
|
||||
pip install transformers
|
||||
```py
|
||||
# venv
|
||||
python -m venv .my-env
|
||||
source .my-env/bin/activate
|
||||
# uv
|
||||
uv venv .my-env
|
||||
source .my-env/bin/activate
|
||||
```
|
||||
|
||||
如果你想要试试用例或者想在正式发布前使用最新的开发中代码,你得[从源代码安装](https://huggingface.co/docs/transformers/installation#installing-from-source)。
|
||||
在虚拟环境中安装 Transformers:
|
||||
|
||||
### 使用 conda
|
||||
```py
|
||||
# pip
|
||||
pip install "transformers[torch]"
|
||||
|
||||
🤗 Transformers 可以通过 conda 依此安装:
|
||||
|
||||
```shell script
|
||||
conda install conda-forge::transformers
|
||||
# uv
|
||||
uv pip install "transformers[torch]"
|
||||
```
|
||||
|
||||
> **_笔记:_** 从 `huggingface` 渠道安装 `transformers` 已被废弃。
|
||||
如果你需要库中的最新改动或计划参与贡献,可从源码安装(注意:最新版可能不稳定;如遇错误,欢迎在 [issues](https://github.com/huggingface/transformers/issues) 中反馈):
|
||||
|
||||
要通过 conda 安装 Flax、PyTorch 或 TensorFlow 其中之一,请参阅它们各自安装页的说明。
|
||||
```shell
|
||||
git clone https://github.com/huggingface/transformers.git
|
||||
cd transformers
|
||||
|
||||
## 模型架构
|
||||
# pip
|
||||
pip install '.[torch]'
|
||||
|
||||
🤗 Transformers 支持的[**所有的模型检查点**](https://huggingface.co/models)由[用户](https://huggingface.co/users)和[组织](https://huggingface.co/organizations)上传,均与 huggingface.co [model hub](https://huggingface.co) 无缝整合。
|
||||
# uv
|
||||
uv pip install '.[torch]'
|
||||
```
|
||||
|
||||
目前的检查点数量: 
|
||||
## 快速上手
|
||||
|
||||
🤗 Transformers 目前支持如下的架构: 模型概述请阅[这里](https://huggingface.co/docs/transformers/model_summary).
|
||||
使用 [Pipeline](https://huggingface.co/docs/transformers/pipeline_tutorial) API 一步上手。`Pipeline` 是一个高级推理类,支持文本、音频、视觉与多模态任务,负责输入预处理并返回适配的输出。
|
||||
|
||||
要检查某个模型是否已有 Flax、PyTorch 或 TensorFlow 的实现,或其是否在 🤗 Tokenizers 库中有对应词符化器(tokenizer),敬请参阅[此表](https://huggingface.co/docs/transformers/index#supported-frameworks)。
|
||||
实例化一个用于文本生成的 pipeline,指定使用的模型。模型会被下载并缓存,方便复用。最后传入文本作为提示:
|
||||
|
||||
这些实现均已于多个数据集测试(请参看用例脚本)并应于原版实现表现相当。你可以在用例文档的[此节](https://huggingface.co/docs/transformers/examples)中了解表现的细节。
|
||||
```py
|
||||
from transformers import pipeline
|
||||
|
||||
pipeline = pipeline(task="text-generation", model="Qwen/Qwen2.5-1.5B")
|
||||
pipeline("the secret to baking a really good cake is ")
|
||||
[{'generated_text': 'the secret to baking a really good cake is 1) to use the right ingredients and 2) to follow the recipe exactly. the recipe for the cake is as follows: 1 cup of sugar, 1 cup of flour, 1 cup of milk, 1 cup of butter, 1 cup of eggs, 1 cup of chocolate chips. if you want to make 2 cakes, how much sugar do you need? To make 2 cakes, you will need 2 cups of sugar.'}]
|
||||
```
|
||||
|
||||
要与模型进行「聊天」,用法也一致。唯一不同是需要构造一段「聊天历史」(即 `Pipeline` 的输入):
|
||||
|
||||
> [!TIP]
|
||||
> 你也可以直接在命令行与模型聊天:
|
||||
> ```shell
|
||||
> transformers chat Qwen/Qwen2.5-0.5B-Instruct
|
||||
> ```
|
||||
|
||||
```py
|
||||
import torch
|
||||
from transformers import pipeline
|
||||
|
||||
chat = [
|
||||
{"role": "system", "content": "You are a sassy, wise-cracking robot as imagined by Hollywood circa 1986."},
|
||||
{"role": "user", "content": "Hey, can you tell me any fun things to do in New York?"}
|
||||
]
|
||||
|
||||
pipeline = pipeline(task="text-generation", model="meta-llama/Meta-Llama-3-8B-Instruct", dtype=torch.bfloat16, device_map="auto")
|
||||
response = pipeline(chat, max_new_tokens=512)
|
||||
print(response[0]["generated_text"][-1]["content"])
|
||||
```
|
||||
|
||||
展开下方示例,查看 `Pipeline` 在不同模态与任务中的用法。
|
||||
|
||||
<details>
|
||||
<summary>自动语音识别</summary>
|
||||
|
||||
```py
|
||||
from transformers import pipeline
|
||||
|
||||
pipeline = pipeline(task="automatic-speech-recognition", model="openai/whisper-large-v3")
|
||||
pipeline("https://huggingface.co/datasets/Narsil/asr_dummy/resolve/main/mlk.flac")
|
||||
{'text': ' I have a dream that one day this nation will rise up and live out the true meaning of its creed.'}
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>图像分类</summary>
|
||||
|
||||
<h3 align="center">
|
||||
<a><img src="https://huggingface.co/datasets/Narsil/image_dummy/raw/main/parrots.png"></a>
|
||||
</h3>
|
||||
|
||||
```py
|
||||
from transformers import pipeline
|
||||
|
||||
pipeline = pipeline(task="image-classification", model="facebook/dinov2-small-imagenet1k-1-layer")
|
||||
pipeline("https://huggingface.co/datasets/Narsil/image_dummy/raw/main/parrots.png")
|
||||
[{"label": "macaw", "score": 0.997848391532898},
|
||||
{"label": "sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita",
|
||||
"score": 0.0016551691805943847},
|
||||
{"label": "lorikeet", "score": 0.00018523589824326336},
|
||||
{"label": "African grey, African gray, Psittacus erithacus",
|
||||
"score": 7.85409429227002e-05},
|
||||
{"label": "quail", "score": 5.502637941390276e-05}]
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>视觉问答</summary>
|
||||
|
||||
<h3 align="center">
|
||||
<a><img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/idefics-few-shot.jpg"></a>
|
||||
</h3>
|
||||
|
||||
```py
|
||||
from transformers import pipeline
|
||||
|
||||
pipeline = pipeline(task="visual-question-answering", model="Salesforce/blip-vqa-base")
|
||||
pipeline(
|
||||
image="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/idefics-few-shot.jpg",
|
||||
question="What is in the image?",
|
||||
)
|
||||
[{"answer": "statue of liberty"}]
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
## 为什么要用 Transformers?
|
||||
|
||||
1. 易于使用的最先进模型:
|
||||
- 在自然语言理解与生成、计算机视觉、音频、视频与多模态任务上表现优越。
|
||||
- 对研究者、工程师与开发者友好且低门槛。
|
||||
- 少量用户侧抽象,仅需学习三个类。
|
||||
- 统一的 API,使用所有预训练模型体验一致。
|
||||
|
||||
1. 更低计算开销与更小碳足迹:
|
||||
- 共享已训练的模型,而非每次从零开始训练。
|
||||
- 减少计算时间与生产环境成本。
|
||||
- 覆盖数十种模型架构,跨所有模态提供 1M+ 预训练检查点。
|
||||
|
||||
1. 在模型生命周期的每个阶段都可以选用合适的框架:
|
||||
- 3 行代码即可训练最先进模型。
|
||||
- 在 PyTorch/JAX/TF2.0 间自由迁移同一个模型。
|
||||
- 为训练、评估与生产挑选最合适的框架。
|
||||
|
||||
1. 轻松定制模型或用例:
|
||||
- 为每个架构提供示例以复现原论文结果。
|
||||
- 尽可能一致地暴露模型内部。
|
||||
- 模型文件可独立于库使用,便于快速实验。
|
||||
|
||||
<a target="_blank" href="https://huggingface.co/enterprise">
|
||||
<img alt="Hugging Face Enterprise Hub" src="https://github.com/user-attachments/assets/247fb16d-d251-4583-96c4-d3d76dda4925">
|
||||
</a><br>
|
||||
|
||||
## 为什么我不该用 Transformers?
|
||||
|
||||
- 该库不是一个可自由拼搭的神经网络模块化工具箱。模型文件中的代码刻意减少额外抽象,以便研究者能快速在各个模型上迭代,而无需深入更多抽象或文件跳转。
|
||||
- 训练 API 优化用于 Transformers 提供的 PyTorch 模型。若需要通用的机器学习训练循环,请使用其它库,如 [Accelerate](https://huggingface.co/docs/accelerate)。
|
||||
- [示例脚本](https://github.com/huggingface/transformers/tree/main/examples)只是「示例」。它们不一定能直接适配你的具体用例,需要你进行必要的改动。
|
||||
|
||||
|
||||
## 了解更多
|
||||
## 100 个使用 Transformers 的项目
|
||||
|
||||
| 章节 | 描述 |
|
||||
|-|-|
|
||||
| [文档](https://huggingface.co/docs/transformers/) | 完整的 API 文档和教程 |
|
||||
| [任务总结](https://huggingface.co/docs/transformers/task_summary) | 🤗 Transformers 支持的任务 |
|
||||
| [预处理教程](https://huggingface.co/docs/transformers/preprocessing) | 使用 `Tokenizer` 来为模型准备数据 |
|
||||
| [训练和微调](https://huggingface.co/docs/transformers/training) | 在 PyTorch/TensorFlow 的训练循环或 `Trainer` API 中使用 🤗 Transformers 提供的模型 |
|
||||
| [快速上手:微调和用例脚本](https://github.com/huggingface/transformers/tree/main/examples) | 为各种任务提供的用例脚本 |
|
||||
| [模型分享和上传](https://huggingface.co/docs/transformers/model_sharing) | 和社区上传和分享你微调的模型 |
|
||||
| [迁移](https://huggingface.co/docs/transformers/migration) | 从 `pytorch-transformers` 或 `pytorch-pretrained-bert` 迁移到 🤗 Transformers |
|
||||
Transformers 不止是一个使用预训练模型的工具包,它还是围绕 Hugging Face Hub 构建的项目社区。我们希望 Transformers 能助力开发者、研究人员、学生、老师、工程师与任何人构建理想项目。
|
||||
|
||||
为庆祝 Transformers 获得 100,000 颗星,我们制作了 [awesome-transformers](./awesome-transformers.md) 页面,展示了 100 个由社区构建的优秀项目。
|
||||
|
||||
如果你拥有或使用某个项目,认为它应该在列表中出现,欢迎提交 PR 添加它!
|
||||
|
||||
## 示例模型
|
||||
|
||||
你可以直接在它们的 [Hub 模型页](https://huggingface.co/models) 上测试我们的多数模型。
|
||||
|
||||
展开每个模态以查看不同用例中的部分示例模型。
|
||||
|
||||
<details>
|
||||
<summary>音频</summary>
|
||||
|
||||
- 使用 [Whisper](https://huggingface.co/openai/whisper-large-v3-turbo) 进行音频分类
|
||||
- 使用 [Moonshine](https://huggingface.co/UsefulSensors/moonshine) 进行自动语音识别
|
||||
- 使用 [Wav2Vec2](https://huggingface.co/superb/wav2vec2-base-superb-ks) 进行关键词检索
|
||||
- 使用 [Moshi](https://huggingface.co/kyutai/moshiko-pytorch-bf16) 进行语音到语音生成
|
||||
- 使用 [MusicGen](https://huggingface.co/facebook/musicgen-large) 文本到音频生成
|
||||
- 使用 [Bark](https://huggingface.co/suno/bark) 文本到语音生成
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>计算机视觉</summary>
|
||||
|
||||
- 使用 [SAM](https://huggingface.co/facebook/sam-vit-base) 自动生成掩码
|
||||
- 使用 [DepthPro](https://huggingface.co/apple/DepthPro-hf) 进行深度估计
|
||||
- 使用 [DINO v2](https://huggingface.co/facebook/dinov2-base) 进行图像分类
|
||||
- 使用 [SuperPoint](https://huggingface.co/magic-leap-community/superpoint) 进行关键点检测
|
||||
- 使用 [SuperGlue](https://huggingface.co/magic-leap-community/superglue_outdoor) 进行关键点匹配
|
||||
- 使用 [RT-DETRv2](https://huggingface.co/PekingU/rtdetr_v2_r50vd) 进行目标检测
|
||||
- 使用 [VitPose](https://huggingface.co/usyd-community/vitpose-base-simple) 进行姿态估计
|
||||
- 使用 [OneFormer](https://huggingface.co/shi-labs/oneformer_ade20k_swin_large) 进行通用分割
|
||||
- 使用 [VideoMAE](https://huggingface.co/MCG-NJU/videomae-large) 进行视频分类
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>多模态</summary>
|
||||
|
||||
- 使用 [Qwen2-Audio](https://huggingface.co/Qwen/Qwen2-Audio-7B) 实现音频或文本到文本
|
||||
- 使用 [LayoutLMv3](https://huggingface.co/microsoft/layoutlmv3-base) 进行文档问答
|
||||
- 使用 [Qwen-VL](https://huggingface.co/Qwen/Qwen2.5-VL-3B-Instruct) 实现图像或文本到文本
|
||||
- 使用 [BLIP-2](https://huggingface.co/Salesforce/blip2-opt-2.7b) 进行图文描述
|
||||
- 使用 [GOT-OCR2](https://huggingface.co/stepfun-ai/GOT-OCR-2.0-hf) 进行基于 OCR 的文档理解
|
||||
- 使用 [TAPAS](https://huggingface.co/google/tapas-base) 进行表格问答
|
||||
- 使用 [Emu3](https://huggingface.co/BAAI/Emu3-Gen) 进行统一的多模态理解与生成
|
||||
- 使用 [Llava-OneVision](https://huggingface.co/llava-hf/llava-onevision-qwen2-0.5b-ov-hf) 视觉到文本
|
||||
- 使用 [Llava](https://huggingface.co/llava-hf/llava-1.5-7b-hf) 进行视觉问答
|
||||
- 使用 [Kosmos-2](https://huggingface.co/microsoft/kosmos-2-patch14-224) 进行视觉指代表达分割
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>NLP</summary>
|
||||
|
||||
- 使用 [ModernBERT](https://huggingface.co/answerdotai/ModernBERT-base) 进行掩码词填充
|
||||
- 使用 [Gemma](https://huggingface.co/google/gemma-2-2b) 进行命名实体识别(NER)
|
||||
- 使用 [Mixtral](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1) 进行问答
|
||||
- 使用 [BART](https://huggingface.co/facebook/bart-large-cnn) 进行摘要
|
||||
- 使用 [T5](https://huggingface.co/google-t5/t5-base) 进行翻译
|
||||
- 使用 [Llama](https://huggingface.co/meta-llama/Llama-3.2-1B) 进行文本生成
|
||||
- 使用 [Qwen](https://huggingface.co/Qwen/Qwen2.5-0.5B) 进行文本分类
|
||||
|
||||
</details>
|
||||
|
||||
## 引用
|
||||
|
||||
|
||||
@ -14,43 +14,6 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
-->
|
||||
|
||||
<!---
|
||||
A useful guide for English-Traditional Chinese translation of Hugging Face documentation
|
||||
- Add space around English words and numbers when they appear between Chinese characters. E.g., 共 100 多種語言; 使用 transformers 函式庫。
|
||||
- Use square quotes, e.g.,「引用」
|
||||
- Some of terms in the file can be found at National Academy for Educational Research (https://terms.naer.edu.tw/), an official website providing bilingual translations between English and Traditional Chinese.
|
||||
|
||||
Dictionary
|
||||
|
||||
API: API (不翻譯)
|
||||
add: 加入
|
||||
checkpoint: 檢查點
|
||||
code: 程式碼
|
||||
community: 社群
|
||||
confidence: 信賴度
|
||||
dataset: 資料集
|
||||
documentation: 文件
|
||||
example: 基本翻譯為「範例」,或依語意翻為「例子」
|
||||
finetune: 微調
|
||||
Hugging Face: Hugging Face(不翻譯)
|
||||
implementation: 實作
|
||||
inference: 推論
|
||||
library: 函式庫
|
||||
module: 模組
|
||||
NLP/Natural Language Processing: 以 NLP 出現時不翻譯,以 Natural Language Processing 出現時翻譯為自然語言處理
|
||||
online demos: 線上Demo
|
||||
pipeline: pipeline(不翻譯)
|
||||
pretrained/pretrain: 預訓練
|
||||
Python data structures (e.g., list, set, dict): 翻譯為串列,集合,字典,並用括號標註原英文
|
||||
repository: repository(不翻譯)
|
||||
summary: 概覽
|
||||
token-: token-(不翻譯)
|
||||
Trainer: Trainer(不翻譯)
|
||||
transformer: transformer(不翻譯)
|
||||
tutorial: 教學
|
||||
user: 使用者
|
||||
-->
|
||||
|
||||
<p align="center">
|
||||
<picture>
|
||||
<source media="(prefers-color-scheme: dark)" srcset="https://huggingface.co/datasets/huggingface/documentation-images/raw/main/transformers-logo-dark.svg">
|
||||
@ -62,6 +25,7 @@ user: 使用者
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://huggingface.com/models"><img alt="Checkpoints on Hub" src="https://img.shields.io/endpoint?url=https://huggingface.co/api/shields/models&color=brightgreen"></a>
|
||||
<a href="https://circleci.com/gh/huggingface/transformers"><img alt="Build" src="https://img.shields.io/circleci/build/github/huggingface/transformers/main"></a>
|
||||
<a href="https://github.com/huggingface/transformers/blob/main/LICENSE"><img alt="GitHub" src="https://img.shields.io/github/license/huggingface/transformers.svg?color=blue"></a>
|
||||
<a href="https://huggingface.co/docs/transformers/index"><img alt="Documentation" src="https://img.shields.io/website/http/huggingface.co/docs/transformers/index.svg?down_color=red&down_message=offline&up_message=online"></a>
|
||||
@ -72,7 +36,7 @@ user: 使用者
|
||||
|
||||
<h4 align="center">
|
||||
<p>
|
||||
<a href="https://github.com/huggingface/transformers/">English</a> |
|
||||
<a href="https://github.com/huggingface/transformers/blob/main/README.md">English</a> |
|
||||
<a href="https://github.com/huggingface/transformers/blob/main/i18n/README_zh-hans.md">简体中文</a> |
|
||||
<b>繁體中文</b> |
|
||||
<a href="https://github.com/huggingface/transformers/blob/main/i18n/README_ko.md">한국어</a> |
|
||||
@ -80,7 +44,7 @@ user: 使用者
|
||||
<a href="https://github.com/huggingface/transformers/blob/main/i18n/README_ja.md">日本語</a> |
|
||||
<a href="https://github.com/huggingface/transformers/blob/main/i18n/README_hd.md">हिन्दी</a> |
|
||||
<a href="https://github.com/huggingface/transformers/blob/main/i18n/README_ru.md">Русский</a> |
|
||||
<a href="https://github.com/huggingface/transformers/blob/main/i18n/README_pt-br.md">Рortuguês</a> |
|
||||
<a href="https://github.com/huggingface/transformers/blob/main/i18n/README_pt-br.md">Português</a> |
|
||||
<a href="https://github.com/huggingface/transformers/blob/main/i18n/README_te.md">తెలుగు</a> |
|
||||
<a href="https://github.com/huggingface/transformers/blob/main/i18n/README_fr.md">Français</a> |
|
||||
<a href="https://github.com/huggingface/transformers/blob/main/i18n/README_de.md">Deutsch</a> |
|
||||
@ -93,186 +57,261 @@ user: 使用者
|
||||
</h4>
|
||||
|
||||
<h3 align="center">
|
||||
<p>為 Jax、PyTorch 以及 TensorFlow 打造的先進自然語言處理函式庫</p>
|
||||
<p>最先進的預訓練模型,專為推理與訓練而生</p>
|
||||
</h3>
|
||||
|
||||
<h3 align="center">
|
||||
<a href="https://hf.co/course"><img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/course_banner.png"></a>
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/transformers_as_a_model_definition.png"/>
|
||||
</h3>
|
||||
|
||||
🤗 Transformers 提供了數以千計的預訓練模型,支援 100 多種語言的文本分類、資訊擷取、問答、摘要、翻譯、文本生成。它的宗旨是讓最先進的 NLP 技術人人易用。
|
||||
Transformers 是一個為最先進的機器學習模型(涵蓋文字、電腦視覺、音訊、影片及多模態)提供推理和訓練支援的模型定義框架。
|
||||
|
||||
🤗 Transformers 提供了便於快速下載和使用的API,讓你可以將預訓練模型用在給定文本、在你的資料集上微調然後經由 [model hub](https://huggingface.co/models) 與社群共享。同時,每個定義的 Python 模組架構均完全獨立,方便修改和快速研究實驗。
|
||||
它將模型定義集中化,使得該定義在整個生態系中能夠達成共識。`transformers` 是貫穿各個框架的樞紐:如果一個模型定義受到支援,它將與大多數訓練框架(如 Axolotl、Unsloth、DeepSpeed、FSDP、PyTorch-Lightning 等)、推理引擎(如 vLLM、SGLang、TGI 等)以及利用 `transformers` 模型定義的周邊建模函式庫(如 llama.cpp、mlx 等)相容。
|
||||
|
||||
🤗 Transformers 支援三個最熱門的深度學習函式庫: [Jax](https://jax.readthedocs.io/en/latest/), [PyTorch](https://pytorch.org/) 以及 [TensorFlow](https://www.tensorflow.org/) — 並與之完美整合。你可以直接使用其中一個框架訓練你的模型,然後用另一個載入和推論。
|
||||
我們致力於支援最新的頂尖模型,並透過使其模型定義變得簡單、可客製化且高效,來普及它們的應用。
|
||||
|
||||
## 線上Demo
|
||||
在 [Hugging Face Hub](https://huggingface.com/models) 上,有超過 100 萬個 Transformers [模型檢查點](https://huggingface.co/models?library=transformers&sort=trending) 供您使用。
|
||||
|
||||
你可以直接在 [model hub](https://huggingface.co/models) 上測試大多數的模型。我們也提供了 [私有模型託管、模型版本管理以及推論API](https://huggingface.co/pricing)。
|
||||
|
||||
這裡是一些範例:
|
||||
- [用 BERT 做遮蓋填詞](https://huggingface.co/google-bert/bert-base-uncased?text=Paris+is+the+%5BMASK%5D+of+France)
|
||||
- [用 Electra 做專有名詞辨識](https://huggingface.co/dbmdz/electra-large-discriminator-finetuned-conll03-english?text=My+name+is+Sarah+and+I+live+in+London+city)
|
||||
- [用 GPT-2 做文本生成](https://huggingface.co/openai-community/gpt2?text=A+long+time+ago%2C+)
|
||||
- [用 RoBERTa 做自然語言推論](https://huggingface.co/FacebookAI/roberta-large-mnli?text=The+dog+was+lost.+Nobody+lost+any+animal)
|
||||
- [用 BART 做文本摘要](https://huggingface.co/facebook/bart-large-cnn?text=The+tower+is+324+metres+%281%2C063+ft%29+tall%2C+about+the+same+height+as+an+81-storey+building%2C+and+the+tallest+structure+in+Paris.+Its+base+is+square%2C+measuring+125+metres+%28410+ft%29+on+each+side.+During+its+construction%2C+the+Eiffel+Tower+surpassed+the+Washington+Monument+to+become+the+tallest+man-made+structure+in+the+world%2C+a+title+it+held+for+41+years+until+the+Chrysler+Building+in+New+York+City+was+finished+in+1930.+It+was+the+first+structure+to+reach+a+height+of+300+metres.+Due+to+the+addition+of+a+broadcasting+aerial+at+the+top+of+the+tower+in+1957%2C+it+is+now+taller+than+the+Chrysler+Building+by+5.2+metres+%2817+ft%29.+Excluding+transmitters%2C+the+Eiffel+Tower+is+the+second+tallest+free-standing+structure+in+France+after+the+Millau+Viaduct)
|
||||
- [用 DistilBERT 做問答](https://huggingface.co/distilbert/distilbert-base-uncased-distilled-squad?text=Which+name+is+also+used+to+describe+the+Amazon+rainforest+in+English%3F&context=The+Amazon+rainforest+%28Portuguese%3A+Floresta+Amaz%C3%B4nica+or+Amaz%C3%B4nia%3B+Spanish%3A+Selva+Amaz%C3%B3nica%2C+Amazon%C3%ADa+or+usually+Amazonia%3B+French%3A+For%C3%AAt+amazonienne%3B+Dutch%3A+Amazoneregenwoud%29%2C+also+known+in+English+as+Amazonia+or+the+Amazon+Jungle%2C+is+a+moist+broadleaf+forest+that+covers+most+of+the+Amazon+basin+of+South+America.+This+basin+encompasses+7%2C000%2C000+square+kilometres+%282%2C700%2C000+sq+mi%29%2C+of+which+5%2C500%2C000+square+kilometres+%282%2C100%2C000+sq+mi%29+are+covered+by+the+rainforest.+This+region+includes+territory+belonging+to+nine+nations.+The+majority+of+the+forest+is+contained+within+Brazil%2C+with+60%25+of+the+rainforest%2C+followed+by+Peru+with+13%25%2C+Colombia+with+10%25%2C+and+with+minor+amounts+in+Venezuela%2C+Ecuador%2C+Bolivia%2C+Guyana%2C+Suriname+and+French+Guiana.+States+or+departments+in+four+nations+contain+%22Amazonas%22+in+their+names.+The+Amazon+represents+over+half+of+the+planet%27s+remaining+rainforests%2C+and+comprises+the+largest+and+most+biodiverse+tract+of+tropical+rainforest+in+the+world%2C+with+an+estimated+390+billion+individual+trees+divided+into+16%2C000+species)
|
||||
- [用 T5 做翻譯](https://huggingface.co/google-t5/t5-base?text=My+name+is+Wolfgang+and+I+live+in+Berlin)
|
||||
|
||||
**[Write With Transformer](https://transformer.huggingface.co)**,由 Hugging Face 團隊所打造,是一個文本生成的官方 demo。
|
||||
|
||||
## 如果你在尋找由 Hugging Face 團隊所提供的客製化支援服務
|
||||
|
||||
<a target="_blank" href="https://huggingface.co/support">
|
||||
<img alt="HuggingFace Expert Acceleration Program" src="https://huggingface.co/front/thumbnails/support.png" style="max-width: 600px; border: 1px solid #eee; border-radius: 4px; box-shadow: 0 1px 2px 0 rgba(0, 0, 0, 0.05);">
|
||||
</a><br>
|
||||
|
||||
## 快速上手
|
||||
|
||||
我們為快速使用模型提供了 `pipeline` API。 Pipeline 包含了預訓練模型和對應的文本預處理。下面是一個快速使用 pipeline 去判斷正負面情緒的例子:
|
||||
|
||||
```python
|
||||
>>> from transformers import pipeline
|
||||
|
||||
# 使用情緒分析 pipeline
|
||||
>>> classifier = pipeline('sentiment-analysis')
|
||||
>>> classifier('We are very happy to introduce pipeline to the transformers repository.')
|
||||
[{'label': 'POSITIVE', 'score': 0.9996980428695679}]
|
||||
```
|
||||
|
||||
第二行程式碼下載並快取 pipeline 使用的預訓練模型,而第三行程式碼則在給定的文本上進行了評估。這裡的答案“正面” (positive) 具有 99.97% 的信賴度。
|
||||
|
||||
許多的 NLP 任務都有隨選即用的預訓練 `pipeline`。例如,我們可以輕鬆地從給定文本中擷取問題答案:
|
||||
|
||||
``` python
|
||||
>>> from transformers import pipeline
|
||||
|
||||
# 使用問答 pipeline
|
||||
>>> question_answerer = pipeline('question-answering')
|
||||
>>> question_answerer({
|
||||
... 'question': 'What is the name of the repository ?',
|
||||
... 'context': 'Pipeline has been included in the huggingface/transformers repository'
|
||||
... })
|
||||
{'score': 0.30970096588134766, 'start': 34, 'end': 58, 'answer': 'huggingface/transformers'}
|
||||
|
||||
```
|
||||
|
||||
除了提供問題解答,預訓練模型還提供了對應的信賴度分數以及解答在 tokenized 後的文本中開始和結束的位置。你可以從[這個教學](https://huggingface.co/docs/transformers/task_summary)了解更多 `pipeline` API支援的任務。
|
||||
|
||||
要在你的任務中下載和使用任何預訓練模型很簡單,只需三行程式碼。這裡是 PyTorch 版的範例:
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, AutoModel
|
||||
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
|
||||
>>> model = AutoModel.from_pretrained("google-bert/bert-base-uncased")
|
||||
|
||||
>>> inputs = tokenizer("Hello world!", return_tensors="pt")
|
||||
>>> outputs = model(**inputs)
|
||||
```
|
||||
這裡是對應的 TensorFlow 程式碼:
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, TFAutoModel
|
||||
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
|
||||
>>> model = TFAutoModel.from_pretrained("google-bert/bert-base-uncased")
|
||||
|
||||
>>> inputs = tokenizer("Hello world!", return_tensors="tf")
|
||||
>>> outputs = model(**inputs)
|
||||
```
|
||||
|
||||
Tokenizer 為所有的預訓練模型提供了預處理,並可以直接轉換單一字串(比如上面的例子)或串列 (list)。它會輸出一個的字典 (dict) 讓你可以在下游程式碼裡使用或直接藉由 `**` 運算式傳給模型。
|
||||
|
||||
模型本身是一個常規的 [Pytorch `nn.Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) 或 [TensorFlow `tf.keras.Model`](https://www.tensorflow.org/api_docs/python/tf/keras/Model)(取決於你的後端),可依常規方式使用。 [這個教學](https://huggingface.co/transformers/training.html)解釋了如何將這樣的模型整合到一般的 PyTorch 或 TensorFlow 訓練迴圈中,或是如何使用我們的 `Trainer` API 在一個新的資料集上快速進行微調。
|
||||
|
||||
## 為什麼要用 transformers?
|
||||
|
||||
1. 便於使用的先進模型:
|
||||
- NLU 和 NLG 上性能卓越
|
||||
- 對教學和實作友好且低門檻
|
||||
- 高度抽象,使用者只須學習 3 個類別
|
||||
- 對所有模型使用的制式化API
|
||||
|
||||
1. 更低的運算成本,更少的碳排放:
|
||||
- 研究人員可以分享已訓練的模型而非每次從頭開始訓練
|
||||
- 工程師可以減少計算時間以及生產成本
|
||||
- 數十種模型架構、兩千多個預訓練模型、100多種語言支援
|
||||
|
||||
1. 對於模型生命週期的每一個部分都面面俱到:
|
||||
- 訓練先進的模型,只需 3 行程式碼
|
||||
- 模型可以在不同深度學習框架之間任意轉換
|
||||
- 為訓練、評估和生產選擇最適合的框架,並完美銜接
|
||||
|
||||
1. 為你的需求輕鬆客製化專屬模型和範例:
|
||||
- 我們為每種模型架構提供了多個範例來重現原論文結果
|
||||
- 一致的模型內部架構
|
||||
- 模型檔案可單獨使用,便於修改和快速實驗
|
||||
|
||||
## 什麼情況下我不該用 transformers?
|
||||
|
||||
- 本函式庫並不是模組化的神經網絡工具箱。模型文件中的程式碼並未做額外的抽象封裝,以便研究人員快速地翻閱及修改程式碼,而不會深陷複雜的類別包裝之中。
|
||||
- `Trainer` API 並非相容任何模型,它只為本函式庫中的模型最佳化。對於一般的機器學習用途,請使用其他函式庫。
|
||||
- 儘管我們已盡力而為,[examples 目錄](https://github.com/huggingface/transformers/tree/main/examples)中的腳本也僅為範例而已。對於特定問題,它們並不一定隨選即用,可能需要修改幾行程式碼以符合需求。
|
||||
立即探索 [Hub](https://huggingface.com/),尋找合適的模型,並使用 Transformers 幫助您快速上手。
|
||||
|
||||
## 安裝
|
||||
|
||||
### 使用 pip
|
||||
Transformers 支援 Python 3.9+ 和 [PyTorch](https://pytorch.org/get-started/locally/) 2.1+。
|
||||
|
||||
這個 Repository 已在 Python 3.9+、Flax 0.4.1+、PyTorch 2.1+ 和 TensorFlow 2.6+ 下經過測試。
|
||||
使用 [venv](https://docs.python.org/3/library/venv.html) 或基於 Rust 的高速 Python 套件及專案管理器 [uv](https://docs.astral.sh/uv/) 來建立並啟用虛擬環境。
|
||||
|
||||
你可以在[虛擬環境](https://docs.python.org/3/library/venv.html)中安裝 🤗 Transformers。如果你還不熟悉 Python 的虛擬環境,請閱此[使用者指引](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/)。
|
||||
|
||||
首先,用你打算使用的版本的 Python 創建一個虛擬環境並進入。
|
||||
|
||||
然後,你需要安裝 Flax、PyTorch 或 TensorFlow 其中之一。對於該如何在你使用的平台上安裝這些框架,請參閱 [TensorFlow 安裝頁面](https://www.tensorflow.org/install/), [PyTorch 安裝頁面](https://pytorch.org/get-started/locally/#start-locally) 或 [Flax 安裝頁面](https://github.com/google/flax#quick-install)。
|
||||
|
||||
當其中一個後端安裝成功後,🤗 Transformers 可依此安裝:
|
||||
|
||||
```bash
|
||||
pip install transformers
|
||||
```py
|
||||
# venv
|
||||
python -m venv .my-env
|
||||
source .my-env/bin/activate
|
||||
# uv
|
||||
uv venv .my-env
|
||||
source .my-env/bin/activate
|
||||
```
|
||||
|
||||
如果你想要試試範例或者想在正式發布前使用最新開發中的程式碼,你必須[從原始碼安裝](https://huggingface.co/docs/transformers/installation#installing-from-source)。
|
||||
在您的虛擬環境中安裝 Transformers。
|
||||
|
||||
### 使用 conda
|
||||
```py
|
||||
# pip
|
||||
pip install "transformers[torch]"
|
||||
|
||||
🤗 Transformers 可以藉由 conda 依此安裝:
|
||||
|
||||
```shell script
|
||||
conda install conda-forge::transformers
|
||||
# uv
|
||||
uv pip install "transformers[torch]"
|
||||
```
|
||||
|
||||
> **_筆記:_** 從 `huggingface` 頻道安裝 `transformers` 已被淘汰。
|
||||
如果您想使用函式庫的最新變更或有興趣參與貢獻,可以從原始碼安裝 Transformers。然而,*最新*版本可能不穩定。如果您遇到任何錯誤,歡迎隨時提交一個 [issue](https://github.com/huggingface/transformers/issues)。
|
||||
|
||||
要藉由 conda 安裝 Flax、PyTorch 或 TensorFlow 其中之一,請參閱它們各自安裝頁面的說明。
|
||||
```shell
|
||||
git clone https://github.com/huggingface/transformers.git
|
||||
cd transformers
|
||||
|
||||
## 模型架構
|
||||
# pip
|
||||
pip install '.[torch]'
|
||||
|
||||
**🤗 Transformers 支援的[所有的模型檢查點](https://huggingface.co/models)**,由[使用者](https://huggingface.co/users)和[組織](https://huggingface.co/organizations)上傳,均與 huggingface.co [model hub](https://huggingface.co) 完美結合。
|
||||
# uv
|
||||
uv pip install '.[torch]'
|
||||
```
|
||||
|
||||
目前的檢查點數量: 
|
||||
## 快速入門
|
||||
|
||||
🤗 Transformers 目前支援以下的架構: 模型概覽請參閱[這裡](https://huggingface.co/docs/transformers/model_summary).
|
||||
透過 [Pipeline](https://huggingface.co/docs/transformers/pipeline_tutorial) API 快速開始使用 Transformers。`Pipeline` 是一個高階的推理類別,支援文字、音訊、視覺和多模態任務。它負責處理輸入資料的預處理,並回傳適當的輸出。
|
||||
|
||||
要檢查某個模型是否已有 Flax、PyTorch 或 TensorFlow 的實作,或其是否在🤗 Tokenizers 函式庫中有對應的 tokenizer,敬請參閱[此表](https://huggingface.co/docs/transformers/index#supported-frameworks)。
|
||||
實例化一個 pipeline 並指定用於文字生成的模型。該模型會被下載並快取,方便您之後輕鬆複用。最後,傳入一些文字來提示模型。
|
||||
|
||||
這些實作均已於多個資料集測試(請參閱範例腳本)並應與原版實作表現相當。你可以在範例文件的[此節](https://huggingface.co/docs/transformers/examples)中了解實作的細節。
|
||||
```py
|
||||
from transformers import pipeline
|
||||
|
||||
pipeline = pipeline(task="text-generation", model="Qwen/Qwen2.5-1.5B")
|
||||
pipeline("the secret to baking a really good cake is ")
|
||||
[{'generated_text': 'the secret to baking a really good cake is 1) to use the right ingredients and 2) to follow the recipe exactly. the recipe for the cake is as follows: 1 cup of sugar, 1 cup of flour, 1 cup of milk, 1 cup of butter, 1 cup of eggs, 1 cup of chocolate chips. if you want to make 2 cakes, how much sugar do you need? To make 2 cakes, you will need 2 cups of sugar.'}]
|
||||
```
|
||||
|
||||
## 了解更多
|
||||
與模型進行聊天,使用模式是相同的。唯一的區別是您需要建構一個您與系統之間的聊天歷史(作為 `Pipeline` 的輸入)。
|
||||
|
||||
| 章節 | 描述 |
|
||||
|-|-|
|
||||
| [文件](https://huggingface.co/transformers/) | 完整的 API 文件和教學 |
|
||||
| [任務概覽](https://huggingface.co/docs/transformers/task_summary) | 🤗 Transformers 支援的任務 |
|
||||
| [預處理教學](https://huggingface.co/docs/transformers/preprocessing) | 使用 `Tokenizer` 來為模型準備資料 |
|
||||
| [訓練和微調](https://huggingface.co/docs/transformers/training) | 使用 PyTorch/TensorFlow 的內建的訓練方式或於 `Trainer` API 中使用 🤗 Transformers 提供的模型 |
|
||||
| [快速上手:微調和範例腳本](https://github.com/huggingface/transformers/tree/main/examples) | 為各種任務提供的範例腳本 |
|
||||
| [模型分享和上傳](https://huggingface.co/docs/transformers/model_sharing) | 上傳並與社群分享你微調的模型 |
|
||||
| [遷移](https://huggingface.co/docs/transformers/migration) | 從 `pytorch-transformers` 或 `pytorch-pretrained-bert` 遷移到 🤗 Transformers |
|
||||
> [!TIP]
|
||||
> 你也可以直接在命令列中與模型聊天。
|
||||
> ```shell
|
||||
> transformers chat Qwen/Qwen2.5-0.5B-Instruct
|
||||
> ```
|
||||
|
||||
```py
|
||||
import torch
|
||||
from transformers import pipeline
|
||||
|
||||
chat = [
|
||||
{"role": "system", "content": "You are a sassy, wise-cracking robot as imagined by Hollywood circa 1986."},
|
||||
{"role": "user", "content": "Hey, can you tell me any fun things to do in New York?"}
|
||||
]
|
||||
|
||||
pipeline = pipeline(task="text-generation", model="meta-llama/Meta-Llama-3-8B-Instruct", dtype=torch.bfloat16, device_map="auto")
|
||||
response = pipeline(chat, max_new_tokens=512)
|
||||
print(response[0]["generated_text"][-1]["content"])
|
||||
```
|
||||
|
||||
展開下面的範例,查看 `Pipeline` 如何在不同模態和任務上運作。
|
||||
|
||||
<details>
|
||||
<summary>自動語音辨識</summary>
|
||||
|
||||
```py
|
||||
from transformers import pipeline
|
||||
|
||||
pipeline = pipeline(task="automatic-speech-recognition", model="openai/whisper-large-v3")
|
||||
pipeline("https://huggingface.co/datasets/Narsil/asr_dummy/resolve/main/mlk.flac")
|
||||
{'text': ' I have a dream that one day this nation will rise up and live out the true meaning of its creed.'}
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>影像分類</summary>
|
||||
|
||||
<h3 align="center">
|
||||
<a><img src="https://huggingface.co/datasets/Narsil/image_dummy/raw/main/parrots.png"></a>
|
||||
</h3>
|
||||
|
||||
```py
|
||||
from transformers import pipeline
|
||||
|
||||
pipeline = pipeline(task="image-classification", model="facebook/dinov2-small-imagenet1k-1-layer")
|
||||
pipeline("https://huggingface.co/datasets/Narsil/image_dummy/raw/main/parrots.png")
|
||||
[{'label': 'macaw', 'score': 0.997848391532898},
|
||||
{'label': 'sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita',
|
||||
'score': 0.0016551691805943847},
|
||||
{'label': 'lorikeet', 'score': 0.00018523589824326336},
|
||||
{'label': 'African grey, African gray, Psittacus erithacus',
|
||||
'score': 7.85409429227002e-05},
|
||||
{'label': 'quail', 'score': 5.502637941390276e-05}]
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>視覺問答</summary>
|
||||
|
||||
<h3 align="center">
|
||||
<a><img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/idefics-few-shot.jpg"></a>
|
||||
</h3>
|
||||
|
||||
```py
|
||||
from transformers import pipeline
|
||||
|
||||
pipeline = pipeline(task="visual-question-answering", model="Salesforce/blip-vqa-base")
|
||||
pipeline(
|
||||
image="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/idefics-few-shot.jpg",
|
||||
question="What is in the image?",
|
||||
)
|
||||
[{'answer': 'statue of liberty'}]
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
## 為什麼我應該使用 Transformers?
|
||||
|
||||
1. 易於使用的最先進模型:
|
||||
* 在自然語言理解與生成、電腦視覺、音訊、影片和多模態任務上表現卓越。
|
||||
* 為研究人員、工程師與開發者提供了低門檻的入門途徑。
|
||||
* 面向使用者的抽象層級少,只需學習三個核心類別。
|
||||
* 為所有預訓練模型提供了統一的 API 介面。
|
||||
|
||||
2. 更低的運算成本,更小的碳足跡:
|
||||
* 分享訓練好的模型,而不是從零開始訓練。
|
||||
* 減少運算時間和生產成本。
|
||||
* 擁有數十種模型架構和超過100萬個橫跨所有模態的預訓練檢查點。
|
||||
|
||||
3. 為模型的每個生命週期階段選擇合適的框架:
|
||||
* 僅用3行程式碼即可訓練最先進的模型。
|
||||
* 在PyTorch/JAX/TF2.0框架之間輕鬆切換單一模型。
|
||||
* 為訓練、評估和生產選擇最合適的框架。
|
||||
|
||||
4. 輕鬆根據您的需求客製化模型或範例:
|
||||
* 我們為每個架構提供了範例,以重現其原作者發表的結果。
|
||||
* 模型內部結構盡可能保持一致地暴露給使用者。
|
||||
* 模型檔案可以獨立於函式庫使用,便於快速實驗。
|
||||
|
||||
<a target="_blank" href="https://huggingface.co/enterprise">
|
||||
<img alt="Hugging Face Enterprise Hub" src="https://github.com/user-attachments/assets/247fb16d-d251-4583-96c4-d3d76dda4925">
|
||||
</a><br>
|
||||
|
||||
## 為什麼我不應該使用 Transformers?
|
||||
|
||||
- 本函式庫並非一個用於建構神經網路的模組化工具箱。模型檔案中的程式碼為了讓研究人員能快速在模型上迭代,而沒有進行過度的抽象重構,避免了深入額外的抽象層/檔案。
|
||||
- 訓練 API 針對 Transformers 提供的 PyTorch 模型進行了最佳化。對於通用的機器學習迴圈,您應該使用像 [Accelerate](https://huggingface.co/docs/accelerate) 這樣的其他函式庫。
|
||||
- [範例指令稿](https://github.com/huggingface/transformers/tree/main/examples)僅僅是*範例*。它們不一定能在您的特定用例上開箱即用,您可能需要修改程式碼才能使其正常運作。
|
||||
|
||||
## 100個使用 Transformers 的專案
|
||||
|
||||
Transformers 不僅僅是一個使用預訓練模型的工具包,它還是一個圍繞它和 Hugging Face Hub 建構的專案社群。我們希望 Transformers 能夠賦能開發者、研究人員、學生、教授、工程師以及其他任何人,去建構他們夢想中的專案。
|
||||
|
||||
為了慶祝 Transformers 獲得 10 萬顆星標,我們希望透過 [awesome-transformers](./awesome-transformers.md) 頁面來聚焦社群,該頁面列出了100個基於 Transformers 建構的精彩專案。
|
||||
|
||||
如果您擁有或使用一個您認為應該被列入其中的專案,請隨時提交 PR 將其加入!
|
||||
|
||||
## 範例模型
|
||||
|
||||
您可以在我們大多數模型的 [Hub 模型頁面](https://huggingface.co/models) 上直接進行測試。
|
||||
|
||||
展開下面的每個模態,查看一些用於不同用例的範例模型。
|
||||
|
||||
<details>
|
||||
<summary>音訊</summary>
|
||||
|
||||
- Audio classification with [Whisper](https://huggingface.co/openai/whisper-large-v3-turbo)
|
||||
- Automatic speech recognition with [Moonshine](https://huggingface.co/UsefulSensors/moonshine)
|
||||
- Keyword spotting with [Wav2Vec2](https://huggingface.co/superb/wav2vec2-base-superb-ks)
|
||||
- Speech to speech generation with [Moshi](https://huggingface.co/kyutai/moshiko-pytorch-bf16)
|
||||
- Text to audio with [MusicGen](https://huggingface.co/facebook/musicgen-large)
|
||||
- Text to speech with [Bark](https://huggingface.co/suno/bark)
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>電腦視覺</summary>
|
||||
|
||||
- Automatic mask generation with [SAM](https://huggingface.co/facebook/sam-vit-base)
|
||||
- Depth estimation with [DepthPro](https://huggingface.co/apple/DepthPro-hf)
|
||||
- Image classification with [DINO v2](https://huggingface.co/facebook/dinov2-base)
|
||||
- Keypoint detection with [SuperPoint](https://huggingface.co/magic-leap-community/superpoint)
|
||||
- Keypoint matching with [SuperGlue](https://huggingface.co/magic-leap-community/superglue_outdoor)
|
||||
- Object detection with [RT-DETRv2](https://huggingface.co/PekingU/rtdetr_v2_r50vd)
|
||||
- Pose Estimation with [VitPose](https://huggingface.co/usyd-community/vitpose-base-simple)
|
||||
- Universal segmentation with [OneFormer](https://huggingface.co/shi-labs/oneformer_ade20k_swin_large)
|
||||
- Video classification with [VideoMAE](https://huggingface.co/MCG-NJU/videomae-large)
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>多模態</summary>
|
||||
|
||||
- Audio or text to text with [Qwen2-Audio](https://huggingface.co/Qwen/Qwen2-Audio-7B)
|
||||
- Document question answering with [LayoutLMv3](https://huggingface.co/microsoft/layoutlmv3-base)
|
||||
- Image or text to text with [Qwen-VL](https://huggingface.co/Qwen/Qwen2.5-VL-3B-Instruct)
|
||||
- Image captioning [BLIP-2](https://huggingface.co/Salesforce/blip2-opt-2.7b)
|
||||
- OCR-based document understanding with [GOT-OCR2](https://huggingface.co/stepfun-ai/GOT-OCR-2.0-hf)
|
||||
- Table question answering with [TAPAS](https://huggingface.co/google/tapas-base)
|
||||
- Unified multimodal understanding and generation with [Emu3](https://huggingface.co/BAAI/Emu3-Gen)
|
||||
- Vision to text with [Llava-OneVision](https://huggingface.co/llava-hf/llava-onevision-qwen2-0.5b-ov-hf)
|
||||
- Visual question answering with [Llava](https://huggingface.co/llava-hf/llava-1.5-7b-hf)
|
||||
- Visual referring expression segmentation with [Kosmos-2](https://huggingface.co/microsoft/kosmos-2-patch14-224)
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>自然語言處理 (NLP)</summary>
|
||||
|
||||
- Masked word completion with [ModernBERT](https://huggingface.co/answerdotai/ModernBERT-base)
|
||||
- Named entity recognition with [Gemma](https://huggingface.co/google/gemma-2-2b)
|
||||
- Question answering with [Mixtral](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)
|
||||
- Summarization with [BART](https://huggingface.co/facebook/bart-large-cnn)
|
||||
- Translation with [T5](https://huggingface.co/google-t5/t5-base)
|
||||
- Text generation with [Llama](https://huggingface.co/meta-llama/Llama-3.2-1B)
|
||||
- Text classification with [Qwen](https://huggingface.co/Qwen/Qwen2.5-0.5B)
|
||||
|
||||
</details>
|
||||
|
||||
## 引用
|
||||
|
||||
我們已將此函式庫的[論文](https://www.aclweb.org/anthology/2020.emnlp-demos.6/)正式發表。如果你使用了 🤗 Transformers 函式庫,可以引用:
|
||||
現在我們有一篇可供您引用的關於 🤗 Transformers 函式庫的 [論文](https://www.aclweb.org/anthology/2020.emnlp-demos.6/):
|
||||
```bibtex
|
||||
@inproceedings{wolf-etal-2020-transformers,
|
||||
title = "Transformers: State-of-the-Art Natural Language Processing",
|
||||
@ -285,4 +324,4 @@ conda install conda-forge::transformers
|
||||
url = "https://www.aclweb.org/anthology/2020.emnlp-demos.6",
|
||||
pages = "38--45"
|
||||
}
|
||||
```
|
||||
```
|
||||
4
setup.py
4
setup.py
@ -137,8 +137,8 @@ _deps = [
|
||||
"psutil",
|
||||
"pyyaml>=5.1",
|
||||
"pydantic>=2",
|
||||
"pytest>=7.2.0",
|
||||
"pytest-asyncio",
|
||||
"pytest>=7.2.0,<9.0.0",
|
||||
"pytest-asyncio>=1.2.0",
|
||||
"pytest-rerunfailures<16.0",
|
||||
"pytest-timeout",
|
||||
"pytest-xdist",
|
||||
|
||||
@ -723,7 +723,7 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
|
||||
|
||||
if self.mask_replace_prob < 1:
|
||||
warnings.warn(
|
||||
"Random token replacement is not supported with whole word masking.",
|
||||
"Random token replacement is not supported with whole word masking. "
|
||||
"Setting mask_replace_prob to 1.",
|
||||
)
|
||||
self.mask_replace_prob = 1
|
||||
|
||||
@ -82,7 +82,7 @@ class GlueDataset(Dataset):
|
||||
cache_dir: Optional[str] = None,
|
||||
):
|
||||
warnings.warn(
|
||||
"This dataset will be removed from the library soon, preprocessing should be handled with the 🤗 Datasets "
|
||||
"This dataset will be removed from the library soon, preprocessing should be handled with the Hugging Face Datasets "
|
||||
"library. You can have a look at this example script for pointers: "
|
||||
"https://github.com/huggingface/transformers/blob/main/examples/pytorch/text-classification/run_glue.py",
|
||||
FutureWarning,
|
||||
|
||||
@ -21,7 +21,7 @@ if is_sklearn_available():
|
||||
|
||||
|
||||
DEPRECATION_WARNING = (
|
||||
"This metric will be removed from the library soon, metrics should be handled with the 🤗 Evaluate "
|
||||
"This metric will be removed from the library soon, metrics should be handled with the Hugging Face Evaluate "
|
||||
"library. You can have a look at this example script for pointers: "
|
||||
"https://github.com/huggingface/transformers/blob/main/examples/pytorch/text-classification/run_glue.py"
|
||||
)
|
||||
|
||||
@ -28,7 +28,7 @@ from .utils import DataProcessor, InputExample, InputFeatures
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
DEPRECATION_WARNING = (
|
||||
"This {0} will be removed from the library soon, preprocessing should be handled with the 🤗 Datasets "
|
||||
"This {0} will be removed from the library soon, preprocessing should be handled with the Hugging Face Datasets "
|
||||
"library. You can have a look at this example script for pointers: "
|
||||
"https://github.com/huggingface/transformers/blob/main/examples/pytorch/text-classification/run_glue.py"
|
||||
)
|
||||
|
||||
@ -47,8 +47,8 @@ deps = {
|
||||
"psutil": "psutil",
|
||||
"pyyaml": "pyyaml>=5.1",
|
||||
"pydantic": "pydantic>=2",
|
||||
"pytest": "pytest>=7.2.0",
|
||||
"pytest-asyncio": "pytest-asyncio",
|
||||
"pytest": "pytest>=7.2.0,<9.0.0",
|
||||
"pytest-asyncio": "pytest-asyncio>=1.2.0",
|
||||
"pytest-rerunfailures": "pytest-rerunfailures<16.0",
|
||||
"pytest-timeout": "pytest-timeout",
|
||||
"pytest-xdist": "pytest-xdist",
|
||||
|
||||
@ -39,6 +39,7 @@ from .utils import (
|
||||
is_torch_dtype,
|
||||
logging,
|
||||
requires_backends,
|
||||
safe_load_json_file,
|
||||
)
|
||||
from .utils.hub import cached_file
|
||||
|
||||
@ -427,35 +428,42 @@ class FeatureExtractionMixin(PushToHubMixin):
|
||||
feature_extractor_file = os.path.join(pretrained_model_name_or_path, FEATURE_EXTRACTOR_NAME)
|
||||
if os.path.isfile(pretrained_model_name_or_path):
|
||||
resolved_feature_extractor_file = pretrained_model_name_or_path
|
||||
resolved_processor_file = None
|
||||
is_local = True
|
||||
elif is_remote_url(pretrained_model_name_or_path):
|
||||
feature_extractor_file = pretrained_model_name_or_path
|
||||
resolved_processor_file = None
|
||||
resolved_feature_extractor_file = download_url(pretrained_model_name_or_path)
|
||||
else:
|
||||
feature_extractor_file = FEATURE_EXTRACTOR_NAME
|
||||
try:
|
||||
# Load from local folder or from cache or download from model Hub and cache
|
||||
resolved_feature_extractor_files = [
|
||||
resolved_file
|
||||
for filename in [feature_extractor_file, PROCESSOR_NAME]
|
||||
if (
|
||||
resolved_file := cached_file(
|
||||
pretrained_model_name_or_path,
|
||||
filename=filename,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
subfolder=subfolder,
|
||||
token=token,
|
||||
user_agent=user_agent,
|
||||
revision=revision,
|
||||
_raise_exceptions_for_missing_entries=False,
|
||||
)
|
||||
)
|
||||
is not None
|
||||
]
|
||||
resolved_feature_extractor_file = resolved_feature_extractor_files[0]
|
||||
resolved_processor_file = cached_file(
|
||||
pretrained_model_name_or_path,
|
||||
filename=PROCESSOR_NAME,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
user_agent=user_agent,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
_raise_exceptions_for_missing_entries=False,
|
||||
)
|
||||
resolved_feature_extractor_file = cached_file(
|
||||
pretrained_model_name_or_path,
|
||||
filename=feature_extractor_file,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
user_agent=user_agent,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
_raise_exceptions_for_missing_entries=False,
|
||||
)
|
||||
except OSError:
|
||||
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to
|
||||
# the original exception.
|
||||
@ -469,19 +477,24 @@ class FeatureExtractionMixin(PushToHubMixin):
|
||||
f" directory containing a {FEATURE_EXTRACTOR_NAME} file"
|
||||
)
|
||||
|
||||
try:
|
||||
# Load feature_extractor dict
|
||||
with open(resolved_feature_extractor_file, encoding="utf-8") as reader:
|
||||
text = reader.read()
|
||||
feature_extractor_dict = json.loads(text)
|
||||
if "audio_processor" in feature_extractor_dict:
|
||||
feature_extractor_dict = feature_extractor_dict["audio_processor"]
|
||||
else:
|
||||
feature_extractor_dict = feature_extractor_dict.get("feature_extractor", feature_extractor_dict)
|
||||
# Load feature_extractor dict. Priority goes as (nested config if found -> image processor config)
|
||||
# We are downloading both configs because almost all models have a `processor_config.json` but
|
||||
# not all of these are nested. We need to check if it was saved recebtly as nested or if it is legacy style
|
||||
feature_extractor_dict = None
|
||||
if resolved_processor_file is not None:
|
||||
processor_dict = safe_load_json_file(resolved_processor_file)
|
||||
if "feature_extractor" in processor_dict or "audio_processor" in processor_dict:
|
||||
feature_extractor_dict = processor_dict.get("feature_extractor", processor_dict.get("audio_processor"))
|
||||
|
||||
except json.JSONDecodeError:
|
||||
if resolved_feature_extractor_file is not None and feature_extractor_dict is None:
|
||||
feature_extractor_dict = safe_load_json_file(resolved_feature_extractor_file)
|
||||
|
||||
if feature_extractor_dict is None:
|
||||
raise OSError(
|
||||
f"It looks like the config file at '{resolved_feature_extractor_file}' is not a valid JSON file."
|
||||
f"Can't load feature extractor for '{pretrained_model_name_or_path}'. If you were trying to load"
|
||||
" it from 'https://huggingface.co/models', make sure you don't have a local directory with the"
|
||||
f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
|
||||
f" directory containing a {feature_extractor_file} file"
|
||||
)
|
||||
|
||||
if is_local:
|
||||
|
||||
@ -189,7 +189,9 @@ class PagedAttentionCache:
|
||||
num_blocks, max_batch_tokens = memory_handler.infer_num_blocks_and_max_batch_tokens(
|
||||
num_blocks=getattr(generation_config, "num_blocks", None),
|
||||
max_batch_tokens=getattr(generation_config, "max_batch_tokens", None),
|
||||
max_memory_percent=getattr(generation_config, "max_memory", 0.9),
|
||||
max_memory_percent=getattr(
|
||||
generation_config, "max_memory", 0.8
|
||||
), # FIXME: it seems we overcommit memory, was changed from 0.9 which caused OOMs in our benchmarking CI
|
||||
cache_dtype=self.dtype,
|
||||
)
|
||||
|
||||
@ -414,7 +416,7 @@ class PagedAttentionMemoryHandler:
|
||||
self,
|
||||
num_blocks: Optional[int] = None,
|
||||
max_batch_tokens: Optional[int] = None,
|
||||
max_memory_percent: float = 0.9,
|
||||
max_memory_percent: float = 0.8, # FIXME: it seems we overcommit memory, was changed from 0.9 which caused OOMs in our benchmarking CI
|
||||
cache_dtype: torch.dtype = torch.float16,
|
||||
) -> tuple[int, int]:
|
||||
"""Determine optimal number of blocks and maximum number of tokens per batch based on available memory and
|
||||
@ -454,7 +456,7 @@ class PagedAttentionMemoryHandler:
|
||||
|
||||
def compute_num_blocks_and_max_batch_tokens(
|
||||
self,
|
||||
max_memory_percent: float = 0.9,
|
||||
max_memory_percent: float,
|
||||
cache_dtype: torch.dtype = torch.float16,
|
||||
m: float = 0.01,
|
||||
) -> tuple[int, int]:
|
||||
@ -503,7 +505,7 @@ class PagedAttentionMemoryHandler:
|
||||
def compute_max_batch_tokens(
|
||||
self,
|
||||
num_blocks: int,
|
||||
max_memory_percent: float = 0.9,
|
||||
max_memory_percent: float,
|
||||
cache_dtype: torch.dtype = torch.float16,
|
||||
) -> int:
|
||||
"""Calculate maximum batch tokens M given a fixed number of cache blocks. The formula for M is given by:
|
||||
@ -531,7 +533,7 @@ class PagedAttentionMemoryHandler:
|
||||
def compute_num_blocks(
|
||||
self,
|
||||
max_batch_tokens: int,
|
||||
max_memory_percent: float = 0.9,
|
||||
max_memory_percent: float,
|
||||
cache_dtype: torch.dtype = torch.float16,
|
||||
) -> int:
|
||||
"""Calculate number of cache blocks N given a fixed maximum token per token M. The formula for N is given by:
|
||||
|
||||
@ -807,7 +807,7 @@ class ContinuousBatchingManager:
|
||||
"""Check if the background generation thread is running."""
|
||||
return self._generation_thread is not None and self._generation_thread.is_alive()
|
||||
|
||||
def stop(self, block: bool = False, timeout: Optional[float] = None) -> None:
|
||||
def stop(self, block: bool = True, timeout: Optional[float] = None) -> None:
|
||||
"""Signal the background thread to stop.
|
||||
|
||||
Args:
|
||||
@ -818,14 +818,17 @@ class ContinuousBatchingManager:
|
||||
logger.warning("Manager not started.")
|
||||
return
|
||||
|
||||
stop_trigger_time = perf_counter()
|
||||
if not self.stop_event.is_set():
|
||||
self.stop_event.set()
|
||||
logger.info("Stopping continuous batching manager...")
|
||||
|
||||
if block:
|
||||
self.join(timeout)
|
||||
self.join(stop_trigger_time, timeout)
|
||||
|
||||
def join(self, timeout: Optional[float] = None) -> None:
|
||||
self.batch_processor = None
|
||||
|
||||
def join(self, stop_trigger_time: float, timeout: Optional[float] = None) -> None:
|
||||
"""Wait for the background thread to finish.
|
||||
|
||||
Args:
|
||||
@ -834,9 +837,10 @@ class ContinuousBatchingManager:
|
||||
if self._generation_thread is not None:
|
||||
self._generation_thread.join(timeout=timeout)
|
||||
if self._generation_thread.is_alive():
|
||||
logger.warning("Generation thread did not exit after join timeout.")
|
||||
logger.warning(f"Generation thread did not exit after join timeout ({timeout}).")
|
||||
else:
|
||||
logger.info("Continuous Batching Manager stopped.")
|
||||
end = perf_counter()
|
||||
logger.info(f"Continuous Batching Manager stopped after {end - stop_trigger_time:.2f}s.")
|
||||
self._generation_thread = None
|
||||
|
||||
def add_request(
|
||||
@ -877,9 +881,11 @@ class ContinuousBatchingManager:
|
||||
self.input_queue.put(state, block=True, timeout=10) # XXX: pass timeout as fn arg?
|
||||
return request_id
|
||||
|
||||
def add_requests(self, inputs: list[list[int]], max_new_tokens: Optional[int] = None) -> None:
|
||||
def add_requests(
|
||||
self, inputs: list[list[int]], max_new_tokens: Optional[int] = None, streaming: bool = False
|
||||
) -> None:
|
||||
for input_ids in inputs:
|
||||
self.add_request(input_ids, max_new_tokens=max_new_tokens)
|
||||
self.add_request(input_ids, max_new_tokens=max_new_tokens, streaming=streaming)
|
||||
|
||||
def cancel_request(self, request_id: str) -> None:
|
||||
"""Cancel a request by its ID.
|
||||
@ -890,6 +896,7 @@ class ContinuousBatchingManager:
|
||||
if self.batch_processor is not None:
|
||||
self.batch_processor.scheduler.set_request_cancellation(request_id)
|
||||
|
||||
# TODO:handle benchmarking properly when updating / fixing the requeue logic
|
||||
def get_result(
|
||||
self, request_id: Optional[str] = None, timeout: Optional[float] = None
|
||||
) -> Optional[GenerationOutput]:
|
||||
@ -905,6 +912,7 @@ class ContinuousBatchingManager:
|
||||
return None
|
||||
try:
|
||||
result = self.output_queue.get(block=True, timeout=timeout)
|
||||
# NOTE: requeue logic here
|
||||
if request_id is not None and result.request_id != request_id:
|
||||
self.output_queue.put(result)
|
||||
return None
|
||||
@ -1092,6 +1100,7 @@ class ContinuousMixin:
|
||||
num_kv_cuda_graphs=num_kv_cuda_graphs,
|
||||
)
|
||||
|
||||
# TODO: support streaming
|
||||
@traced
|
||||
@torch.inference_mode()
|
||||
def generate_batch(
|
||||
@ -1148,7 +1157,7 @@ class ContinuousMixin:
|
||||
result = manager.get_result(timeout=1)
|
||||
if result:
|
||||
req_id = result.request_id
|
||||
if result.status == RequestStatus.FINISHED:
|
||||
if result.is_finished():
|
||||
results[req_id] = result
|
||||
finished_count += 1
|
||||
pbar.update(1)
|
||||
|
||||
@ -19,6 +19,7 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from ...utils import is_torch_xpu_available
|
||||
from ...utils.logging import logging
|
||||
from ...utils.metrics import traced
|
||||
|
||||
@ -35,6 +36,13 @@ def get_device_and_memory_breakdown() -> tuple[torch.device, int, int, int]:
|
||||
total_memory = torch.cuda.get_device_properties(device).total_memory
|
||||
reserved_memory = torch.cuda.memory_reserved(device)
|
||||
allocated_memory = torch.cuda.memory_allocated(device)
|
||||
elif is_torch_xpu_available():
|
||||
device = torch.device("xpu")
|
||||
torch.xpu.empty_cache()
|
||||
torch.xpu.synchronize()
|
||||
total_memory = torch.xpu.get_device_properties(device).total_memory
|
||||
reserved_memory = torch.xpu.memory_reserved(device)
|
||||
allocated_memory = torch.xpu.memory_allocated(device)
|
||||
elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
|
||||
device = torch.device("mps")
|
||||
# MPS memory reporting (PyTorch 2.0+)
|
||||
@ -83,6 +91,9 @@ class GenerationOutput:
|
||||
status: RequestStatus = RequestStatus.PENDING
|
||||
created_time: float = field(default_factory=time.time)
|
||||
|
||||
def is_finished(self) -> bool:
|
||||
return self.status == RequestStatus.FINISHED
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestState:
|
||||
|
||||
@ -608,7 +608,7 @@ class GenerationMixin(ContinuousMixin):
|
||||
use_cache = kwargs.get("use_cache")
|
||||
if use_cache is None:
|
||||
use_cache = getattr(self.config, "use_cache", False)
|
||||
if past_key_values is None or use_cache:
|
||||
if past_key_values is not None or use_cache:
|
||||
# TODO (joao): handle the case where cache length == input_ids length. The function below results in an
|
||||
# exception because we get empty input_ids after slicing. In essence, we need to roll back the cache 1
|
||||
# token to recompute the logits for the first token to be generated (but not all caches support roll backs)
|
||||
@ -2170,7 +2170,7 @@ class GenerationMixin(ContinuousMixin):
|
||||
return False
|
||||
|
||||
# Base logic
|
||||
valid_hardware = self.device.type == "cuda" or bool(
|
||||
valid_hardware = self.device.type in ["cuda", "xpu"] or bool(
|
||||
generation_config.compile_config is not None and generation_config.compile_config._compile_all_devices
|
||||
)
|
||||
using_compilable_cache = (
|
||||
|
||||
@ -32,6 +32,7 @@ from .utils import (
|
||||
is_offline_mode,
|
||||
is_remote_url,
|
||||
logging,
|
||||
safe_load_json_file,
|
||||
)
|
||||
from .utils.hub import cached_file
|
||||
|
||||
@ -280,35 +281,41 @@ class ImageProcessingMixin(PushToHubMixin):
|
||||
image_processor_file = os.path.join(pretrained_model_name_or_path, image_processor_filename)
|
||||
if os.path.isfile(pretrained_model_name_or_path):
|
||||
resolved_image_processor_file = pretrained_model_name_or_path
|
||||
resolved_processor_file = None
|
||||
is_local = True
|
||||
elif is_remote_url(pretrained_model_name_or_path):
|
||||
image_processor_file = pretrained_model_name_or_path
|
||||
resolved_processor_file = None
|
||||
resolved_image_processor_file = download_url(pretrained_model_name_or_path)
|
||||
else:
|
||||
image_processor_file = image_processor_filename
|
||||
try:
|
||||
# Load from local folder or from cache or download from model Hub and cache
|
||||
resolved_image_processor_files = [
|
||||
resolved_file
|
||||
for filename in [image_processor_file, PROCESSOR_NAME]
|
||||
if (
|
||||
resolved_file := cached_file(
|
||||
pretrained_model_name_or_path,
|
||||
filename=filename,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
user_agent=user_agent,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
_raise_exceptions_for_missing_entries=False,
|
||||
)
|
||||
)
|
||||
is not None
|
||||
]
|
||||
resolved_image_processor_file = resolved_image_processor_files[0]
|
||||
resolved_processor_file = cached_file(
|
||||
pretrained_model_name_or_path,
|
||||
filename=PROCESSOR_NAME,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
user_agent=user_agent,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
_raise_exceptions_for_missing_entries=False,
|
||||
)
|
||||
resolved_image_processor_file = cached_file(
|
||||
pretrained_model_name_or_path,
|
||||
filename=image_processor_file,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
user_agent=user_agent,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
_raise_exceptions_for_missing_entries=False,
|
||||
)
|
||||
except OSError:
|
||||
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to
|
||||
# the original exception.
|
||||
@ -322,16 +329,24 @@ class ImageProcessingMixin(PushToHubMixin):
|
||||
f" directory containing a {image_processor_filename} file"
|
||||
)
|
||||
|
||||
try:
|
||||
# Load image_processor dict
|
||||
with open(resolved_image_processor_file, encoding="utf-8") as reader:
|
||||
text = reader.read()
|
||||
image_processor_dict = json.loads(text)
|
||||
image_processor_dict = image_processor_dict.get("image_processor", image_processor_dict)
|
||||
# Load image_processor dict. Priority goes as (nested config if found -> image processor config)
|
||||
# We are downloading both configs because almost all models have a `processor_config.json` but
|
||||
# not all of these are nested. We need to check if it was saved recebtly as nested or if it is legacy style
|
||||
image_processor_dict = None
|
||||
if resolved_processor_file is not None:
|
||||
processor_dict = safe_load_json_file(resolved_processor_file)
|
||||
if "image_processor" in processor_dict:
|
||||
image_processor_dict = processor_dict["image_processor"]
|
||||
|
||||
except json.JSONDecodeError:
|
||||
if resolved_image_processor_file is not None and image_processor_dict is None:
|
||||
image_processor_dict = safe_load_json_file(resolved_image_processor_file)
|
||||
|
||||
if image_processor_dict is None:
|
||||
raise OSError(
|
||||
f"It looks like the config file at '{resolved_image_processor_file}' is not a valid JSON file."
|
||||
f"Can't load image processor for '{pretrained_model_name_or_path}'. If you were trying to load"
|
||||
" it from 'https://huggingface.co/models', make sure you don't have a local directory with the"
|
||||
f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
|
||||
f" directory containing a {image_processor_filename} file"
|
||||
)
|
||||
|
||||
if is_local:
|
||||
|
||||
@ -821,14 +821,26 @@ def split_to_tiles(images: "torch.Tensor", num_tiles_height: int, num_tiles_widt
|
||||
return image
|
||||
|
||||
|
||||
def _cast_tensor_to_float(x):
|
||||
if x.is_floating_point():
|
||||
return x
|
||||
return x.float()
|
||||
|
||||
|
||||
def _group_images_by_shape(nested_images, *paired_inputs, is_nested: bool = False):
|
||||
"""Helper function to flatten a single level of nested image and batch structures and group by shape."""
|
||||
"""
|
||||
Helper function to flatten a single level of nested image and batch structures and group by shape.
|
||||
Args:
|
||||
nested_images (list):
|
||||
A list of images or a single tensor
|
||||
paired_inputs (Any, *optional*):
|
||||
Zero or more lists that mirror the structure of `nested_images` (flat list, or list of lists when
|
||||
`is_nested=True`). Each element is paired 1:1 with the corresponding image so it can be grouped by the
|
||||
same shape key. These paired values are grouped alongside `nested_images` but are not stacked in the output, so
|
||||
they do not need to be tensors.
|
||||
is_nested (bool, *optional*, defaults to False):
|
||||
Whether the images are nested.
|
||||
Returns:
|
||||
tuple[dict, ...]:
|
||||
- A dictionary with shape as key and list of images with that shape as value
|
||||
- A dictionary with shape as key and list of paired values with that shape as value
|
||||
- A dictionary mapping original indices to (shape, index) tuples
|
||||
- A dictionary mapping original indices to (shape, index) tuples for each paired input
|
||||
"""
|
||||
grouped_images = defaultdict(list)
|
||||
grouped_images_index = {}
|
||||
paired_grouped_values = [defaultdict(list) for _ in paired_inputs]
|
||||
@ -880,27 +892,20 @@ def _reconstruct_nested_structure(indices, processed_images):
|
||||
return result
|
||||
|
||||
|
||||
def _disable_grouping_output_nested(images, *paired_inputs):
|
||||
"""Build the disable_grouping output tuple for a single-level nested structure."""
|
||||
outer_range = range(len(images))
|
||||
inner_ranges = [range(len(images[i])) for i in outer_range]
|
||||
def _iterate_items(items, is_nested: bool):
|
||||
"""
|
||||
Helper function to iterate over items yielding (key, item) pairs.
|
||||
|
||||
# Precompute all (i, j) pairs
|
||||
ij_pairs = [(i, j) for i in outer_range for j in inner_ranges[i]]
|
||||
|
||||
images_dict = {(i, j): images[i][j].unsqueeze(0) for (i, j) in ij_pairs}
|
||||
paired_dicts = [{(i, j): paired_list[i][j].unsqueeze(0) for (i, j) in ij_pairs} for paired_list in paired_inputs]
|
||||
index_map = {(i, j): ((i, j), 0) for (i, j) in ij_pairs}
|
||||
return images_dict, *paired_dicts, index_map
|
||||
|
||||
|
||||
def _disable_grouping_output_flat(images, *paired_inputs):
|
||||
"""Build the disable_grouping output tuple for a flat list structure."""
|
||||
idx_range = range(len(images))
|
||||
images_dict = {i: images[i].unsqueeze(0) for i in idx_range}
|
||||
paired_dicts = [{i: paired_list[i].unsqueeze(0) for i in idx_range} for paired_list in paired_inputs]
|
||||
index_map = {i: (i, 0) for i in idx_range}
|
||||
return images_dict, *paired_dicts, index_map
|
||||
For nested structures, yields ((row_index, col_index), item).
|
||||
For flat structures, yields (index, item).
|
||||
"""
|
||||
if is_nested:
|
||||
for i, row in enumerate(items):
|
||||
for j, item in enumerate(row):
|
||||
yield (i, j), item
|
||||
else:
|
||||
for i, item in enumerate(items):
|
||||
yield i, item
|
||||
|
||||
|
||||
def group_images_by_shape(
|
||||
@ -920,7 +925,7 @@ def group_images_by_shape(
|
||||
Args:
|
||||
images (Union[list["torch.Tensor"], "torch.Tensor"]):
|
||||
A list of images or a single tensor
|
||||
*paired_inputs (Any):
|
||||
paired_inputs (Any, *optional*):
|
||||
Zero or more lists that mirror the structure of `images` (flat list, or list of lists when
|
||||
`is_nested=True`). Each element is paired 1:1 with the corresponding image so it can be grouped by the
|
||||
same shape key. These paired values are grouped alongside `images` but are not stacked in the output, so
|
||||
@ -944,10 +949,14 @@ def group_images_by_shape(
|
||||
disable_grouping = device == "cpu"
|
||||
|
||||
if disable_grouping:
|
||||
if is_nested:
|
||||
return _disable_grouping_output_nested(images, *paired_inputs)
|
||||
else:
|
||||
return _disable_grouping_output_flat(images, *paired_inputs)
|
||||
return (
|
||||
{key: img.unsqueeze(0) for key, img in _iterate_items(images, is_nested)},
|
||||
*[
|
||||
{key: item.unsqueeze(0) for key, item in _iterate_items(paired_list, is_nested)}
|
||||
for paired_list in paired_inputs
|
||||
],
|
||||
{key: (key, 0) for key, _ in _iterate_items(images, is_nested)},
|
||||
)
|
||||
|
||||
# Handle single level nested structure
|
||||
grouped_images, *paired_grouped_values, grouped_images_index = _group_images_by_shape(
|
||||
@ -990,14 +999,3 @@ def reorder_images(
|
||||
]
|
||||
|
||||
return _reconstruct_nested_structure(grouped_images_index, processed_images)
|
||||
|
||||
|
||||
class NumpyToTensor:
|
||||
"""
|
||||
Convert a numpy array to a PyTorch tensor.
|
||||
"""
|
||||
|
||||
def __call__(self, image: np.ndarray):
|
||||
# Same as in PyTorch, we assume incoming numpy images are in HWC format
|
||||
# c.f. https://github.com/pytorch/vision/blob/61d97f41bc209e1407dcfbd685d2ee2da9c1cdad/torchvision/transforms/functional.py#L154
|
||||
return torch.from_numpy(image.transpose(2, 0, 1)).contiguous()
|
||||
|
||||
@ -11,7 +11,6 @@
|
||||
# specific language governing permissions and limitations under the License.
|
||||
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
@ -24,13 +23,7 @@ from ..cache_utils import (
|
||||
StaticCache,
|
||||
)
|
||||
from ..generation.configuration_utils import GenerationConfig
|
||||
from ..masking_utils import (
|
||||
ALL_MASK_ATTENTION_FUNCTIONS,
|
||||
_ignore_causal_mask_sdpa,
|
||||
_is_torch_greater_or_equal_than_2_5,
|
||||
prepare_padding_mask,
|
||||
)
|
||||
from ..modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ..modeling_utils import PreTrainedModel
|
||||
from ..pytorch_utils import (
|
||||
is_torch_greater_or_equal,
|
||||
is_torch_greater_or_equal_than_2_3,
|
||||
@ -229,10 +222,6 @@ class TorchExportableModuleForDecoderOnlyLM(torch.nn.Module):
|
||||
"Using `StaticCache` for export as `layer_types` is not specified or `sliding_window` is `null` in the config."
|
||||
)
|
||||
self.model = TorchExportableModuleWithStaticCache(model, batch_size, max_cache_len, device)
|
||||
# This is the same as sdpa, but mask creation does not use `vmap` which is not exportable
|
||||
ALL_MASK_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", sdpa_mask_without_vmap)
|
||||
ALL_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", ALL_ATTENTION_FUNCTIONS["sdpa"])
|
||||
self.model.model.config._attn_implementation = "sdpa_without_vmap"
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -768,11 +757,6 @@ def convert_and_export_with_cache(
|
||||
|
||||
import torch.export._trace
|
||||
|
||||
# This is the same as sdpa, but mask creation does not use `vmap` which is not exportable
|
||||
ALL_MASK_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", sdpa_mask_without_vmap)
|
||||
ALL_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", ALL_ATTENTION_FUNCTIONS["sdpa"])
|
||||
model.config._attn_implementation = "sdpa_without_vmap"
|
||||
|
||||
with torch.no_grad():
|
||||
# TODO: The default inputs only work for text models. We need to add support for vision/audio models.
|
||||
example_input_ids = (
|
||||
@ -1036,11 +1020,6 @@ def export_with_dynamic_cache(
|
||||
if not is_torch_greater_or_equal_than_2_3:
|
||||
raise ImportError("torch >= 2.3 is required.")
|
||||
|
||||
# This is the same as sdpa, but mask creation does not use `vmap` which is not exportable
|
||||
ALL_MASK_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", sdpa_mask_without_vmap)
|
||||
ALL_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", ALL_ATTENTION_FUNCTIONS["sdpa"])
|
||||
model.config._attn_implementation = "sdpa_without_vmap"
|
||||
|
||||
register_dynamic_cache_export_support()
|
||||
|
||||
with torch.no_grad():
|
||||
@ -1109,92 +1088,3 @@ def _unflatten_dynamic_cache(values, context: torch.utils._pytree.Context):
|
||||
value = value_list[idx] if idx < len(value_list) else None
|
||||
cache.update(key, value, idx)
|
||||
return cache
|
||||
|
||||
|
||||
def sdpa_mask_without_vmap(
|
||||
batch_size: int,
|
||||
cache_position: torch.Tensor,
|
||||
kv_length: int,
|
||||
kv_offset: int = 0,
|
||||
mask_function: Optional[Callable] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
local_size: Optional[int] = None,
|
||||
allow_is_causal_skip: bool = True,
|
||||
allow_torch_fix: bool = True,
|
||||
**kwargs,
|
||||
) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
Create a 4D boolean mask of shape `(batch_size, 1, query_length, kv_length)` where a value of True indicates that
|
||||
the element should take part in the attention computation, and False that it should not.
|
||||
|
||||
This is similar to `masking_utils.sdpa_mask` but does not use `vmap` which is incompatible with export.
|
||||
|
||||
Args:
|
||||
batch_size (`int`):
|
||||
The batch size of the input sequence.
|
||||
cache_position (`torch.Tensor`):
|
||||
A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
|
||||
kv_length (`int`):
|
||||
The size that the key and value states will have during the attention computation.
|
||||
kv_offset (`int`, optional):
|
||||
An optional offset to indicate at which first position the key and values states will refer to.
|
||||
mask_function (`Callable`):
|
||||
The mask factory function describing the mask pattern.
|
||||
attention_mask (`torch.Tensor`, optional):
|
||||
The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length)
|
||||
local_size (`int`, optional):
|
||||
The size of the local attention, if we do not use full attention. This is used only if `allow_is_causal_skip=True`
|
||||
to try to skip mask creation if possible.
|
||||
allow_is_causal_skip (`bool`, optional):
|
||||
Whether to allow to return `None` for the mask under conditions where we can use the `is_causal` argument in
|
||||
`torch.sdpa` instead. Default to `True`.
|
||||
allow_torch_fix (`bool`, optional):
|
||||
Whether to update the mask in case a query is not attending to any tokens, to solve a bug in torch's older
|
||||
versions. We need an arg to skip it when using eager. By default `True`.
|
||||
|
||||
"""
|
||||
|
||||
q_length = cache_position.shape[0]
|
||||
# Potentially pad the 2D mask, and slice it correctly
|
||||
padding_mask = prepare_padding_mask(attention_mask, kv_length, kv_offset)
|
||||
|
||||
# Under specific conditions, we can avoid materializing the mask, instead relying on the `is_causal` argument
|
||||
if allow_is_causal_skip and _ignore_causal_mask_sdpa(padding_mask, q_length, kv_length, local_size):
|
||||
return None
|
||||
|
||||
# Similar to `kv_arange = torch.arange(start=kv_offset, end=kv_offset + kv_length, device=cache_position.device)`
|
||||
# but without data-dependent slicing (i.e. torch.compile friendly)
|
||||
kv_arange = torch.arange(kv_length, device=cache_position.device)
|
||||
kv_arange += kv_offset
|
||||
reshaped_cache_position = cache_position.view(-1, 1)
|
||||
|
||||
# This is a bit hacky to know what pattern we are using, but all mask creation function actually forward
|
||||
# the config through kwargs anyway, so it allows to rely on it
|
||||
# Usually, the `mask_function` is the only entry-point to define the pattern - we could do for loops over it,
|
||||
# but this is more efficient
|
||||
sliding_window = getattr(kwargs["config"], "sliding_window", None)
|
||||
chunk_size = getattr(kwargs["config"], "attention_chunk_size", None)
|
||||
|
||||
if sliding_window is not None and chunk_size is not None:
|
||||
raise ValueError("Cannot use both `sliding_window` and `attention_chunk_size`")
|
||||
|
||||
# Simplest and most efficient way to obtain a causal mask
|
||||
causal_mask = kv_arange <= reshaped_cache_position
|
||||
# If using sliding window, add the sliding mask
|
||||
if sliding_window is not None:
|
||||
sliding_mask_overlay = kv_arange > reshaped_cache_position - sliding_window
|
||||
causal_mask *= sliding_mask_overlay
|
||||
# If using chunk attention, add the chunked mask
|
||||
elif chunk_size is not None:
|
||||
chunked_mask_overlay = kv_arange // chunk_size == reshaped_cache_position // chunk_size
|
||||
causal_mask *= chunked_mask_overlay
|
||||
|
||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, -1, -1, -1)
|
||||
if padding_mask is not None:
|
||||
causal_mask = causal_mask * padding_mask[:, None, None, :]
|
||||
|
||||
# Due to a bug in some older torch version, we need to update the mask in case a query is not attending to any
|
||||
# tokens (due to padding). See details in https://github.com/pytorch/pytorch/issues/110213
|
||||
if not _is_torch_greater_or_equal_than_2_5 and allow_torch_fix:
|
||||
causal_mask |= torch.all(~causal_mask, dim=-1, keepdim=True)
|
||||
return causal_mask
|
||||
|
||||
@ -11,6 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
from functools import partial
|
||||
@ -18,7 +19,7 @@ from types import ModuleType
|
||||
from typing import Optional, Union
|
||||
|
||||
from ..modeling_flash_attention_utils import lazy_import_flash_attention
|
||||
from ..utils import logging
|
||||
from ..utils import ENV_VARS_TRUE_VALUES, logging
|
||||
from ..utils.import_utils import is_kernels_available
|
||||
from .flash_attention import flash_attention_forward
|
||||
|
||||
@ -33,10 +34,22 @@ try:
|
||||
get_kernel,
|
||||
register_kernel_mapping,
|
||||
replace_kernel_forward_from_hub,
|
||||
use_kernel_forward_from_hub,
|
||||
)
|
||||
|
||||
_TRANSFORMERS_USE_HUB_KERNELS = os.environ.get("USE_HUB_KERNELS", "YES").upper()
|
||||
_kernels_available = True
|
||||
_kernels_enabled = _TRANSFORMERS_USE_HUB_KERNELS in ENV_VARS_TRUE_VALUES
|
||||
|
||||
def use_kernel_forward_from_hub(layer_name: str):
|
||||
if _kernels_enabled:
|
||||
from kernels import use_kernel_forward_from_hub as _kernels_use_kernel_forward_from_hub
|
||||
|
||||
return _kernels_use_kernel_forward_from_hub(layer_name)
|
||||
else:
|
||||
logger.warning_once(
|
||||
f"kernels hub usage is disabled through the environment USE_HUB_KERNELS={_TRANSFORMERS_USE_HUB_KERNELS}"
|
||||
)
|
||||
return lambda cls: cls
|
||||
|
||||
_KERNEL_MAPPING: dict[str, dict[Union[Device, str], LayerRepository]] = {
|
||||
"MultiScaleDeformableAttention": {
|
||||
@ -71,6 +84,12 @@ try:
|
||||
layer_name="RMSNorm",
|
||||
)
|
||||
},
|
||||
"npu": {
|
||||
Mode.INFERENCE: LayerRepository(
|
||||
repo_id="kernels-community/liger_kernels",
|
||||
layer_name="LigerRMSNorm",
|
||||
)
|
||||
},
|
||||
},
|
||||
"MLP": {
|
||||
"cuda": LayerRepository(
|
||||
@ -161,6 +180,7 @@ try:
|
||||
|
||||
except ImportError:
|
||||
_kernels_available = False
|
||||
_kernels_enabled = False
|
||||
|
||||
# Stub to make decorators int transformers work when `kernels`
|
||||
# is not installed.
|
||||
|
||||
@ -38,7 +38,7 @@ from transformers.utils.import_utils import _is_package_available
|
||||
|
||||
|
||||
if os.getenv("WANDB_MODE") == "offline":
|
||||
print("⚙️ Running in WANDB offline mode")
|
||||
print("[INFO] Running in WANDB offline mode")
|
||||
|
||||
from .. import PreTrainedModel, TrainingArguments
|
||||
from .. import __version__ as version
|
||||
|
||||
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ..utils import is_accelerate_available, is_torch_available, logging
|
||||
from ..utils import is_accelerate_available, is_torch_available, is_torch_xpu_available, logging
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
@ -114,6 +114,9 @@ def convert_moe_packed_tensors(
|
||||
if not blocks.is_cuda and torch.cuda.is_available():
|
||||
blocks = blocks.cuda()
|
||||
scales = scales.cuda()
|
||||
elif (blocks.device.type != "xpu") and is_torch_xpu_available():
|
||||
blocks = blocks.to("xpu")
|
||||
scales = scales.to("xpu")
|
||||
|
||||
scales = scales.to(torch.int32) - 127 # TODO that's because 128=2**7
|
||||
|
||||
@ -351,6 +354,8 @@ def dequantize(module, param_name, param_value, target_device, dq_param_name, **
|
||||
dequantized = convert_moe_packed_tensors(getattr(module, blocks_attr), getattr(module, scales_attr))
|
||||
if target_device == "cpu" and torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
elif target_device == "cpu" and is_torch_xpu_available():
|
||||
torch.xpu.empty_cache()
|
||||
setattr(module, proj, torch.nn.Parameter(dequantized.to(target_device)))
|
||||
delattr(module, blocks_attr)
|
||||
delattr(module, scales_attr)
|
||||
@ -395,7 +400,7 @@ def load_and_swizzle_mxfp4(module, param_name, param_value, target_device, trito
|
||||
else:
|
||||
blocks = blocks.reshape(local_experts, -1, module.intermediate_size // 2)
|
||||
if getattr(target_device, "type", target_device) == "cpu":
|
||||
target_device = "cuda"
|
||||
target_device = torch.accelerator.current_accelerator().type if hasattr(torch, "accelerator") else "cuda"
|
||||
blocks = blocks.to(target_device).contiguous()
|
||||
scales = scales.to(target_device).contiguous()
|
||||
with on_device(target_device):
|
||||
|
||||
@ -63,9 +63,6 @@ def sdpa_attention_forward(
|
||||
else:
|
||||
sdpa_kwargs = {"enable_gqa": True}
|
||||
|
||||
if attention_mask is not None and attention_mask.ndim == 4:
|
||||
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
|
||||
|
||||
# Instead of relying on the value set in the module directly, we use the is_causal passed in kwargs if it is presented
|
||||
is_causal = is_causal if is_causal is not None else getattr(module, "is_causal", True)
|
||||
|
||||
|
||||
@ -82,8 +82,10 @@ def causal_mask_function(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int)
|
||||
def bidirectional_mask_function(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
|
||||
"""
|
||||
This creates a full bidirectional mask.
|
||||
|
||||
NOTE: It is important to keep an index-based version for non-vmap expansion.
|
||||
"""
|
||||
return q_idx.new_ones((), dtype=torch.bool)
|
||||
return q_idx >= 0
|
||||
|
||||
|
||||
def sliding_window_overlay(sliding_window: int) -> Callable:
|
||||
@ -110,18 +112,6 @@ def chunked_overlay(chunk_size: int, left_padding: torch.Tensor) -> Callable:
|
||||
return inner_mask
|
||||
|
||||
|
||||
def _legacy_chunked_overlay(chunk_size: int) -> Callable:
|
||||
"""
|
||||
Same as the above function, but do not correctly account for left padding tokens.
|
||||
Only kept for compatibility with older torch versions (< 2.6).
|
||||
"""
|
||||
|
||||
def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
|
||||
return kv_idx // chunk_size == q_idx // chunk_size
|
||||
|
||||
return inner_mask
|
||||
|
||||
|
||||
def sliding_window_causal_mask_function(sliding_window: int) -> Callable:
|
||||
"""
|
||||
This return the mask_function function to create a sliding window mask.
|
||||
@ -133,8 +123,6 @@ def chunked_causal_mask_function(chunk_size: int, left_padding: torch.Tensor) ->
|
||||
"""
|
||||
This return the mask_function function to create a chunked attention mask.
|
||||
"""
|
||||
if not _is_torch_greater_or_equal_than_2_6:
|
||||
return and_masks(_legacy_chunked_overlay(chunk_size), causal_mask_function)
|
||||
return and_masks(chunked_overlay(chunk_size, left_padding), causal_mask_function)
|
||||
|
||||
|
||||
@ -175,55 +163,56 @@ def add_offsets_to_mask_function(mask_function: Callable, q_offset: int, kv_offs
|
||||
return inner_mask
|
||||
|
||||
|
||||
def _vmap_for_bhqkv(mask_function: Callable, bh_indices: bool = True) -> Callable:
|
||||
"""
|
||||
Used to vmap our mask_functions over the q_idx and kv_idx dimensions of the inputs. Optionally, vmap over
|
||||
the batch and head indices as well if `bh_indices=True`.
|
||||
Using vmap here allows us to keep the performance of vectorized ops, while having a single set of primitive
|
||||
functions between attention interfaces (i.e. between flex and sdpa/eager, FA2 being a bit different).
|
||||
|
||||
Args:
|
||||
mask_function (`Callable`):
|
||||
The mask_function to vmap.
|
||||
bh_indices (`bool`, optional):
|
||||
Whether to vmap over the batch and head indices as well, or only q and kv indices.
|
||||
|
||||
Returns:
|
||||
Callable: The vmapped function.
|
||||
"""
|
||||
# We vmap the function 2 times, broadcasting the [q_idx, kv_idx] dimensions
|
||||
dimensions = [(None, None, None, 0), (None, None, 0, None)]
|
||||
if bh_indices:
|
||||
# We extend broadcasting over the [batch_idx, head_idx] dimensions
|
||||
dimensions.extend([(None, 0, None, None), (0, None, None, None)])
|
||||
|
||||
for dims in dimensions:
|
||||
mask_function = torch.vmap(mask_function, in_dims=dims, out_dims=0)
|
||||
return mask_function
|
||||
|
||||
|
||||
def prepare_padding_mask(
|
||||
attention_mask: Optional[torch.Tensor], kv_length: int, kv_offset: int, _slice: bool = True
|
||||
attention_mask: Optional[torch.Tensor], kv_length: int, kv_offset: int
|
||||
) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
From the 2D attention mask, prepare the correct padding mask to use by potentially padding it, and slicing
|
||||
according to the `kv_offset` if `_slice` is `True`.
|
||||
From the 2D attention mask, prepare the correct padding mask to use by potentially padding it.
|
||||
"""
|
||||
local_padding_mask = attention_mask
|
||||
if attention_mask is not None:
|
||||
# Pad it if necessary
|
||||
if (padding_length := kv_length + kv_offset - attention_mask.shape[-1]) > 0:
|
||||
local_padding_mask = torch.nn.functional.pad(attention_mask, (0, padding_length))
|
||||
# For flex, we should not slice them, only use an offset
|
||||
if _slice:
|
||||
# Equivalent to: `local_padding_mask = attention_mask[:, kv_offset : kv_offset + kv_length]`,
|
||||
# but without data-dependent slicing (i.e. torch.compile friendly)
|
||||
mask_indices = torch.arange(kv_length, device=local_padding_mask.device)
|
||||
mask_indices += kv_offset
|
||||
local_padding_mask = local_padding_mask[:, mask_indices]
|
||||
return local_padding_mask
|
||||
|
||||
|
||||
def _can_skip_causal_mask_xpu(
|
||||
padding_mask: Optional[torch.Tensor],
|
||||
query_length: int,
|
||||
kv_length: int,
|
||||
local_attention_size: Optional[int],
|
||||
) -> bool:
|
||||
"""
|
||||
XPU-specific logic for determining if we can skip causal mask creation.
|
||||
|
||||
For XPU devices, we have special handling:
|
||||
- Single query tokens (query_length == 1) use the same logic as CUDA
|
||||
- Multi-query tokens can skip if padding_mask is provided and correctly structured
|
||||
The mask must have all True values in the query window and all False after
|
||||
"""
|
||||
|
||||
if is_tracing(padding_mask):
|
||||
return False
|
||||
|
||||
# Check local attention constraint (same as CUDA)
|
||||
if local_attention_size is not None and kv_length >= local_attention_size:
|
||||
return False
|
||||
|
||||
if padding_mask is None:
|
||||
# Without padding mask, can skip if single query token or full causal attention
|
||||
return query_length == 1 or kv_length == query_length
|
||||
|
||||
# XPU allows skipping under additional conditions when padding_mask is provided
|
||||
if query_length == 1:
|
||||
# Single query token: skip only if no padding tokens present
|
||||
return padding_mask.all()
|
||||
|
||||
# XPU-specific: check if query window is all True and rest is all False
|
||||
# This allows XPU to optimize the 1st token in static cache
|
||||
return padding_mask[:, :query_length].all() and not padding_mask[:, query_length:].any()
|
||||
|
||||
|
||||
def _ignore_causal_mask_sdpa(
|
||||
padding_mask: Optional[torch.Tensor],
|
||||
query_length: int,
|
||||
@ -244,6 +233,12 @@ def _ignore_causal_mask_sdpa(
|
||||
mask_indices += kv_offset
|
||||
padding_mask = padding_mask[:, mask_indices]
|
||||
|
||||
if _is_torch_xpu_available:
|
||||
# XPU devices have special handling for mask skipping:
|
||||
# - Single query tokens use the same logic as CUDA
|
||||
# - Multi-query tokens can skip if padding_mask is provided and correctly structured
|
||||
# (all True in query window, all False after)
|
||||
return _can_skip_causal_mask_xpu(padding_mask, query_length, kv_length, local_attention_size)
|
||||
# When using `torch.export` or `torch.onnx.dynamo_export`, we must pass an example input, and `is_causal` behavior is
|
||||
# hard-coded to the forward. If a user exports a model with query_length > 1, the exported model will hard-code `is_causal=True`
|
||||
# which is in general wrong (see https://github.com/pytorch/pytorch/issues/108108). Thus, we only set
|
||||
@ -251,18 +246,11 @@ def _ignore_causal_mask_sdpa(
|
||||
if (
|
||||
not is_tracing(padding_mask)
|
||||
# only cases when lower and upper diags are the same, see https://github.com/pytorch/pytorch/issues/108108
|
||||
and (query_length == 1 or (kv_length == query_length or _is_torch_xpu_available))
|
||||
and (query_length == 1 or kv_length == query_length)
|
||||
# in this case we need to add special patterns to the mask so cannot be skipped otherwise
|
||||
and (local_attention_size is None or kv_length < local_attention_size)
|
||||
# In this case, we need to add padding to the mask, so cannot be skipped otherwise
|
||||
and (
|
||||
padding_mask is None
|
||||
or (
|
||||
padding_mask.all()
|
||||
if not _is_torch_xpu_available or query_length == 1
|
||||
else padding_mask[:, :query_length].all()
|
||||
)
|
||||
)
|
||||
and (padding_mask is None or padding_mask.all())
|
||||
):
|
||||
return True
|
||||
|
||||
@ -282,7 +270,39 @@ def _ignore_bidirectional_mask_sdpa(padding_mask: Optional[torch.Tensor]) -> boo
|
||||
return False
|
||||
|
||||
|
||||
def sdpa_mask_recent_torch(
|
||||
def _vmap_expansion_sdpa(mask_function: Callable) -> Callable:
|
||||
"""
|
||||
Used to vmap our mask_functions over the all 4 dimensions (b_idx, h_idx, q_idx, kv_idx) of the inputs.
|
||||
Using vmap here allows us to keep the performance of vectorized ops, while having a single set of primitive
|
||||
functions between attention interfaces (i.e. between flex and sdpa/eager, FA2 being a bit different).
|
||||
"""
|
||||
# We vmap the function over all 4 dimensions, broadcasting [b_idx, h_idx, q_idx, kv_idx]
|
||||
dimensions = [(None, None, None, 0), (None, None, 0, None), (None, 0, None, None), (0, None, None, None)]
|
||||
for dims in dimensions:
|
||||
mask_function = torch.vmap(mask_function, in_dims=dims, out_dims=0)
|
||||
return mask_function
|
||||
|
||||
|
||||
def _non_vmap_expansion_sdpa(
|
||||
batch_indices: torch.Tensor, head_indices: torch.Tensor, q_indices: torch.Tensor, kv_indices: torch.Tensor
|
||||
):
|
||||
"""
|
||||
Used to broadcast our mask_functions over the all 4 dimensions (b_idx, h_idx, q_idx, kv_idx) of the inputs.
|
||||
Allows the usage of any index-based mask function without relying on vmap.
|
||||
|
||||
NOTE: This is limited to index based functions only and is not guaranteed to work otherwise.
|
||||
|
||||
Reference:
|
||||
- https://github.com/huggingface/optimum-onnx/blob/c123e8f4fab61b54a8e0e31ce74462bcacca576e/optimum/exporters/onnx/model_patcher.py#L362-L365
|
||||
"""
|
||||
batch_indices = batch_indices[:, None, None, None]
|
||||
head_indices = head_indices[None, :, None, None]
|
||||
q_indices = q_indices[None, None, :, None]
|
||||
kv_indices = kv_indices[None, None, None, :]
|
||||
return batch_indices, head_indices, q_indices, kv_indices
|
||||
|
||||
|
||||
def sdpa_mask(
|
||||
batch_size: int,
|
||||
cache_position: torch.Tensor,
|
||||
kv_length: int,
|
||||
@ -292,6 +312,8 @@ def sdpa_mask_recent_torch(
|
||||
local_size: Optional[int] = None,
|
||||
allow_is_causal_skip: bool = True,
|
||||
allow_is_bidirectional_skip: bool = False,
|
||||
allow_torch_fix: bool = True,
|
||||
use_vmap: bool = False,
|
||||
**kwargs,
|
||||
) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
@ -324,6 +346,12 @@ def sdpa_mask_recent_torch(
|
||||
allow_is_bidirectional_skip (`bool`, optional):
|
||||
Whether to allow to return `None` for the mask under conditions where we do not have to add any bias,
|
||||
i.e. full attention without any padding. Default to `False`.
|
||||
allow_torch_fix (`bool`, optional):
|
||||
Whether to update the mask in case a query is not attending to any tokens, to solve a bug in torch's older
|
||||
versions. We need an arg to skip it when using eager. By default `True`.
|
||||
use_vmap (`bool`, optional):
|
||||
Whether to use `vmap` during the mask construction or not. Allows powerful custom patterns that may not be
|
||||
index-based (for the cost of speed performance). By default `False`.
|
||||
|
||||
|
||||
## Creating a simple causal mask:
|
||||
@ -391,97 +419,8 @@ def sdpa_mask_recent_torch(
|
||||
|
||||
"""
|
||||
q_length = cache_position.shape[0]
|
||||
# Potentially pad the 2D mask, and slice it correctly
|
||||
padding_mask = prepare_padding_mask(attention_mask, kv_length, kv_offset, _slice=False)
|
||||
|
||||
# Under specific conditions, we can avoid materializing the mask
|
||||
# 1. Causal masks can rely on the `is_causal` argument
|
||||
# 2. Bidirectional do not need any further processing (no bias)
|
||||
if allow_is_causal_skip and _ignore_causal_mask_sdpa(padding_mask, q_length, kv_length, kv_offset, local_size):
|
||||
return None
|
||||
if allow_is_bidirectional_skip and _ignore_bidirectional_mask_sdpa(padding_mask):
|
||||
return None
|
||||
|
||||
# vmap can incur performance issues as reported in #41566 for bidirectional mask as we only need to expand the
|
||||
# padding mask. Thus, we allow early exit here if we do not detect any modification to the base mask function
|
||||
if mask_function is bidirectional_mask_function:
|
||||
if padding_mask is not None:
|
||||
# used for slicing without data-dependent slicing
|
||||
mask_indices = torch.arange(kv_length, device=cache_position.device) + kv_offset
|
||||
return padding_mask[:, None, None, mask_indices].expand(-1, -1, q_length, -1)
|
||||
else:
|
||||
return torch.ones(batch_size, 1, q_length, kv_length, dtype=torch.bool, device=cache_position.device)
|
||||
|
||||
# Similar to `kv_arange = torch.arange(start=kv_offset, end=kv_offset + kv_length, device=cache_position.device)`
|
||||
# but without data-dependent slicing (i.e. torch.compile friendly)
|
||||
kv_arange = torch.arange(kv_length, device=cache_position.device)
|
||||
kv_arange += kv_offset
|
||||
|
||||
# Potentially add the padding 2D mask
|
||||
if padding_mask is not None:
|
||||
mask_function = and_masks(mask_function, padding_mask_function(padding_mask))
|
||||
|
||||
batch_arange = torch.arange(batch_size, device=cache_position.device)
|
||||
head_arange = torch.arange(1, device=cache_position.device)
|
||||
# This creates the 4D mask easily. Note that we need this context manager as vmap cannot handle slicing a tensor from
|
||||
# scalar tensor (it internally calls `.item()` which vmap does not allow, but this context works around it
|
||||
# We don't need to add an offset to the mask_function either, as we vmap directly the correct indices for k and kv indices
|
||||
with TransformGetItemToIndex():
|
||||
causal_mask = _vmap_for_bhqkv(mask_function)(batch_arange, head_arange, cache_position, kv_arange)
|
||||
|
||||
return causal_mask
|
||||
|
||||
|
||||
def sdpa_mask_older_torch(
|
||||
batch_size: int,
|
||||
cache_position: torch.Tensor,
|
||||
kv_length: int,
|
||||
kv_offset: int = 0,
|
||||
mask_function: Callable = causal_mask_function,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
local_size: Optional[int] = None,
|
||||
allow_is_causal_skip: bool = True,
|
||||
allow_torch_fix: bool = True,
|
||||
allow_is_bidirectional_skip: bool = False,
|
||||
**kwargs,
|
||||
) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
NOTE: This function is only used when torch version is torch<2.5 - see `sdpa_mask_recent_torch` otherwise.
|
||||
|
||||
Create a 4D boolean mask of shape `(batch_size, 1, query_length, kv_length)` where a value of True indicates that
|
||||
the element should take part in the attention computation, and False that it should not.
|
||||
If `allow_torch_fix=True` (the default), rows corresponding to query tokens that do not attend
|
||||
to any other tokens (due to padding) will be fully attended to instead, in order to avoid `nan` propagation (this does
|
||||
not change the final result).
|
||||
|
||||
Args:
|
||||
batch_size (`int`):
|
||||
The batch size of the input sequence.
|
||||
cache_position (`torch.Tensor`):
|
||||
A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
|
||||
kv_length (`int`):
|
||||
The size that the key and value states will have during the attention computation.
|
||||
kv_offset (`int`, optional):
|
||||
An optional offset to indicate at which first position the key and values states will refer to.
|
||||
mask_function (`Callable`):
|
||||
The mask factory function describing the mask pattern.
|
||||
attention_mask (`torch.Tensor`, optional):
|
||||
The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length)
|
||||
local_size (`int`, optional):
|
||||
The size of the local attention, if we do not use full attention. This is used only if `allow_is_causal_skip=True`
|
||||
to try to skip mask creation if possible.
|
||||
allow_is_causal_skip (`bool`, optional):
|
||||
Whether to allow to return `None` for the mask under conditions where we can use the `is_causal` argument in
|
||||
`torch.sdpa` instead. Default to `True`.
|
||||
allow_torch_fix (`bool`, optional):
|
||||
Whether to update the mask in case a query is not attending to any tokens, to solve a bug in torch's older
|
||||
versions. We need an arg to skip it when using eager. By default `True`.
|
||||
allow_is_bidirectional_skip (`bool`, optional):
|
||||
Whether to allow to return `None` for the mask under conditions where we do not have to add any bias,
|
||||
i.e. full attention without any padding. Default to `False`.
|
||||
"""
|
||||
q_length = cache_position.shape[0]
|
||||
# Potentially pad the 2D mask, and slice it correctly
|
||||
# Potentially pad the 2D mask
|
||||
padding_mask = prepare_padding_mask(attention_mask, kv_length, kv_offset)
|
||||
|
||||
# Under specific conditions, we can avoid materializing the mask
|
||||
@ -492,38 +431,45 @@ def sdpa_mask_older_torch(
|
||||
if allow_is_bidirectional_skip and _ignore_bidirectional_mask_sdpa(padding_mask):
|
||||
return None
|
||||
|
||||
# vmap can incur performance issues as reported in #41566 for bidirectional mask as we only need to expand the
|
||||
# padding mask. Thus, we allow early exit here if we do not detect any modification to the base mask function
|
||||
if mask_function is bidirectional_mask_function:
|
||||
if padding_mask is not None:
|
||||
return padding_mask[:, None, None, :].expand(-1, -1, q_length, -1)
|
||||
else:
|
||||
return torch.ones(batch_size, 1, q_length, kv_length, dtype=torch.bool, device=cache_position.device)
|
||||
# Potentially add the padding 2D mask
|
||||
if padding_mask is not None:
|
||||
mask_function = and_masks(mask_function, padding_mask_function(padding_mask))
|
||||
|
||||
batch_arange = torch.arange(batch_size, device=cache_position.device)
|
||||
head_arange = torch.arange(1, device=cache_position.device)
|
||||
# Similar to `kv_arange = torch.arange(start=kv_offset, end=kv_offset + kv_length, device=cache_position.device)`
|
||||
# but without data-dependent slicing (i.e. torch.compile friendly)
|
||||
kv_arange = torch.arange(kv_length, device=cache_position.device)
|
||||
kv_arange += kv_offset
|
||||
kv_arange = torch.arange(kv_length, device=cache_position.device) + kv_offset
|
||||
|
||||
# This creates the 4D mask easily. Note that we do not include vmap over the batch_idx dimension as well,
|
||||
# as vmap cannot handle slicing a tensor from scalar tensor (it internally calls `.item()` which vmap does not allow
|
||||
# However, in more recent version of Pytorch, a trick was introduced to handle it - which is the reason we have
|
||||
# `sdpa_mask_recent_torch`, as it allows more general `mask_function`
|
||||
causal_mask = _vmap_for_bhqkv(mask_function, bh_indices=False)(None, None, cache_position, kv_arange)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, -1, -1, -1)
|
||||
if padding_mask is not None:
|
||||
causal_mask = causal_mask * padding_mask[:, None, None, :]
|
||||
# Actual mask creation
|
||||
# Option 1: Fast non-vmap mask creation (default)
|
||||
if not use_vmap:
|
||||
# Apply mask function element-wise through broadcasting
|
||||
attention_mask = mask_function(*_non_vmap_expansion_sdpa(batch_arange, head_arange, cache_position, kv_arange))
|
||||
# Expand the mask to match batch size and query length if they weren't used in the mask function
|
||||
attention_mask = attention_mask.expand(batch_size, -1, q_length, kv_length)
|
||||
|
||||
# Option 2: Vmap mask creation (torch>=2.6 and custom patterns)
|
||||
elif _is_torch_greater_or_equal_than_2_6:
|
||||
# This creates the 4D mask easily. Note that we need this context manager as vmap cannot handle slicing a tensor from
|
||||
# scalar tensor (it internally calls `.item()` which vmap does not allow, but this context works around it
|
||||
# We don't need to add an offset to the mask_function either, as we vmap directly the correct indices for k and kv indices
|
||||
with TransformGetItemToIndex():
|
||||
attention_mask = _vmap_expansion_sdpa(mask_function)(batch_arange, head_arange, cache_position, kv_arange)
|
||||
|
||||
# Option 3: Error out since it indicates that the user did something custom, which they shouldn't have (torch<2.6)
|
||||
else:
|
||||
raise ValueError(
|
||||
"The vmap functionality for mask creation is only supported from torch>=2.6. "
|
||||
"Please update your torch version or use `use_vmap=False` with index-based masks."
|
||||
)
|
||||
|
||||
# Due to a bug in versions of torch<2.5, we need to update the mask in case a query is not attending to any
|
||||
# tokens (due to padding). See details in https://github.com/pytorch/pytorch/issues/110213
|
||||
if not _is_torch_greater_or_equal_than_2_5 and allow_torch_fix:
|
||||
causal_mask |= torch.all(~causal_mask, dim=-1, keepdim=True)
|
||||
return causal_mask
|
||||
attention_mask = attention_mask | torch.all(~attention_mask, dim=-1, keepdim=True)
|
||||
|
||||
|
||||
# We use the version with newer torch whenever possible, as it is more general and can handle arbitrary mask functions
|
||||
# (especially mask_function indexing a tensor, such as the padding mask function)
|
||||
sdpa_mask = sdpa_mask_recent_torch if _is_torch_greater_or_equal_than_2_6 else sdpa_mask_older_torch
|
||||
return attention_mask
|
||||
|
||||
|
||||
def eager_mask(
|
||||
@ -534,6 +480,7 @@ def eager_mask(
|
||||
mask_function: Callable = causal_mask_function,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
dtype: torch.dtype = torch.float32,
|
||||
use_vmap: bool = False,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
@ -556,10 +503,14 @@ def eager_mask(
|
||||
The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length)
|
||||
dtype (`torch.dtype`, optional):
|
||||
The dtype to use for the mask. By default, `torch.float32`.
|
||||
use_vmap (`bool`, optional):
|
||||
Whether to use `vmap` during the mask construction or not. Allows powerful custom patterns that may not be
|
||||
index-based (for the cost of speed performance). By default `False`.
|
||||
"""
|
||||
# The masks for eager attention are simply boolean mask from sdpa, casted to 0 and -inf
|
||||
_ = kwargs.pop("allow_is_causal_skip", None)
|
||||
_ = kwargs.pop("allow_is_bidirectional_skip", None)
|
||||
_ = kwargs.pop("allow_torch_fix", None)
|
||||
mask = sdpa_mask(
|
||||
batch_size=batch_size,
|
||||
cache_position=cache_position,
|
||||
@ -570,6 +521,7 @@ def eager_mask(
|
||||
allow_is_causal_skip=False,
|
||||
allow_is_bidirectional_skip=False,
|
||||
allow_torch_fix=False,
|
||||
use_vmap=use_vmap,
|
||||
**kwargs,
|
||||
)
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
@ -655,7 +607,7 @@ def flex_attention_mask(
|
||||
if not _is_torch_greater_or_equal_than_2_6 and pad_len > 0:
|
||||
attention_mask = torch.nn.functional.pad(attention_mask, value=0, pad=(0, pad_len))
|
||||
|
||||
padding_mask = prepare_padding_mask(attention_mask, kv_length, kv_offset, _slice=False)
|
||||
padding_mask = prepare_padding_mask(attention_mask, kv_length, kv_offset)
|
||||
mask_function = and_masks(mask_function, padding_mask_function(padding_mask))
|
||||
|
||||
# Add the offsets on top (because flex interface only allows length, not start and end indices)
|
||||
@ -781,9 +733,19 @@ def _preprocess_mask_arguments(
|
||||
# If using a cache, it can give all information about mask sizes based on seen tokens
|
||||
if past_key_values is not None:
|
||||
kv_length, kv_offset = past_key_values.get_mask_sizes(cache_position, layer_idx)
|
||||
# Otherwise, the sizes are simply the input sizes
|
||||
# Otherwise, we infer based on our input
|
||||
else:
|
||||
kv_length, kv_offset = input_embeds.shape[1], 0
|
||||
# 1. Rely on input directly
|
||||
if attention_mask is None:
|
||||
kv_length, kv_offset = input_embeds.shape[1], 0
|
||||
# 2. Rely on the mask instead - needed for special cases like prefix tuning in PEFT
|
||||
#
|
||||
# This is a very unique and special case where an encoder utilizes a cache and expects its length
|
||||
# to be accounted for (usually, they should never use a cache). In general, the mask should always
|
||||
# match with the input sizes nonetheless (i.e. it does not affect others).
|
||||
# Conclusion: "prefix tuning is evil"
|
||||
else:
|
||||
kv_length, kv_offset = attention_mask.shape[-1], 0
|
||||
|
||||
# We check the position_ids for potential packed sequence format (only if the 2D attention mask is explicitly None,
|
||||
# and we don't have past_key_values, i.e. generally a training setup)
|
||||
@ -851,6 +813,11 @@ def create_causal_mask(
|
||||
mask_factory_function = causal_mask_function
|
||||
mask_interface = ALL_MASK_ATTENTION_FUNCTIONS[config._attn_implementation]
|
||||
|
||||
# Defaulting to using non-vmap based mask creations except when detecting
|
||||
# users passing custom mask functions (as we cannot guarantee that they
|
||||
# are properly index-based as required by our implementation).
|
||||
use_vmap = False
|
||||
|
||||
# Do not allow skip if we are compiling (this is to match BC)
|
||||
# TODO: cyril -> probably revisit and remove this, but a lot of tests rely on it
|
||||
if _is_torch_xpu_available:
|
||||
@ -867,14 +834,16 @@ def create_causal_mask(
|
||||
raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6")
|
||||
mask_factory_function = or_masks(mask_factory_function, or_mask_function)
|
||||
allow_is_causal_skip = False
|
||||
use_vmap = True
|
||||
if and_mask_function is not None:
|
||||
if not _is_torch_greater_or_equal_than_2_6:
|
||||
raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6")
|
||||
mask_factory_function = and_masks(mask_factory_function, and_mask_function)
|
||||
allow_is_causal_skip = False
|
||||
use_vmap = True
|
||||
|
||||
# If we detected packing format
|
||||
if packed_sequence_mask is not None and _is_torch_greater_or_equal_than_2_6:
|
||||
if packed_sequence_mask is not None:
|
||||
mask_factory_function = and_masks(mask_factory_function, packed_sequence_mask_function(packed_sequence_mask))
|
||||
allow_is_causal_skip = False
|
||||
|
||||
@ -889,6 +858,7 @@ def create_causal_mask(
|
||||
allow_is_causal_skip=allow_is_causal_skip, # additional kwarg for sdpa
|
||||
dtype=dtype, # Additional kwarg for eager
|
||||
config=config, # Pass the config as well, in case someone wants to easily have their own mask_interface
|
||||
use_vmap=use_vmap, # Short-circuit to non-vmap expansions for the mask
|
||||
)
|
||||
return causal_mask
|
||||
|
||||
@ -942,6 +912,10 @@ def create_bidirectional_mask(
|
||||
|
||||
# Allow skipping the mask creation except we have additional masking operators (and/or masks)
|
||||
allow_is_bidirectional_skip = True
|
||||
# Defaulting to using non-vmap based mask creations except when detecting
|
||||
# users passing custom mask functions (as we cannot guarantee that they
|
||||
# are properly index-based as required by our implementation).
|
||||
use_vmap = False
|
||||
|
||||
# Allow slight deviations from the base mask
|
||||
# Note that it is very important to apply this before any other deviations of the mask (such as packed sequence mask,
|
||||
@ -951,11 +925,13 @@ def create_bidirectional_mask(
|
||||
raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6")
|
||||
mask_factory_function = or_masks(mask_factory_function, or_mask_function)
|
||||
allow_is_bidirectional_skip = False
|
||||
use_vmap = True
|
||||
if and_mask_function is not None:
|
||||
if not _is_torch_greater_or_equal_than_2_6:
|
||||
raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6")
|
||||
mask_factory_function = and_masks(mask_factory_function, and_mask_function)
|
||||
allow_is_bidirectional_skip = False
|
||||
use_vmap = True
|
||||
|
||||
# We now create the mask
|
||||
attention_mask = mask_interface(
|
||||
@ -970,6 +946,7 @@ def create_bidirectional_mask(
|
||||
allow_is_bidirectional_skip=allow_is_bidirectional_skip,
|
||||
dtype=dtype, # Additional kwarg for eager
|
||||
config=config, # Pass the config as well, in case someone wants to easily have their own mask_interface
|
||||
use_vmap=use_vmap, # Short-circuit to non-vmap expansions for the mask
|
||||
)
|
||||
return attention_mask
|
||||
|
||||
@ -1032,6 +1009,10 @@ def create_sliding_window_causal_mask(
|
||||
mask_factory_function = sliding_window_causal_mask_function(sliding_window)
|
||||
mask_interface = ALL_MASK_ATTENTION_FUNCTIONS[config._attn_implementation]
|
||||
|
||||
# Defaulting to using non-vmap based mask creations except when detecting
|
||||
# users passing custom mask functions (as we cannot guarantee that they
|
||||
# are properly index-based as required by our implementation).
|
||||
use_vmap = False
|
||||
# Do not allow skip if we are compiling (this is to match BC)
|
||||
# TODO: cyril -> probably revisit and remove this, but a lot of tests rely on it
|
||||
allow_is_causal_skip = not getattr(past_key_values, "is_compileable", False)
|
||||
@ -1044,14 +1025,16 @@ def create_sliding_window_causal_mask(
|
||||
raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6")
|
||||
mask_factory_function = or_masks(mask_factory_function, or_mask_function)
|
||||
allow_is_causal_skip = False
|
||||
use_vmap = True
|
||||
if and_mask_function is not None:
|
||||
if not _is_torch_greater_or_equal_than_2_6:
|
||||
raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6")
|
||||
mask_factory_function = and_masks(mask_factory_function, and_mask_function)
|
||||
allow_is_causal_skip = False
|
||||
use_vmap = True
|
||||
|
||||
# If we detected packing format
|
||||
if packed_sequence_mask is not None and _is_torch_greater_or_equal_than_2_6:
|
||||
if packed_sequence_mask is not None:
|
||||
mask_factory_function = and_masks(mask_factory_function, packed_sequence_mask_function(packed_sequence_mask))
|
||||
allow_is_causal_skip = False
|
||||
|
||||
@ -1067,6 +1050,7 @@ def create_sliding_window_causal_mask(
|
||||
local_size=sliding_window, # Additional kwarg for sdpa
|
||||
dtype=dtype, # Additional kwarg for eager
|
||||
config=config, # Pass the config as well, in case someone wants to easily have their own mask_interface
|
||||
use_vmap=use_vmap, # Short-circuit to non-vmap expansions for the mask
|
||||
)
|
||||
return causal_mask
|
||||
|
||||
@ -1140,20 +1124,13 @@ def create_chunked_causal_mask(
|
||||
left_padding_tokens = (attention_mask.cumsum(dim=-1) == torch.zeros_like(attention_mask)).sum(dim=-1)
|
||||
else:
|
||||
left_padding_tokens = torch.zeros(batch_size, device=cache_position.device, dtype=int)
|
||||
# Raise a warning for older versions if the problematic left-padding situation arises
|
||||
if (
|
||||
not _is_torch_greater_or_equal_than_2_6
|
||||
and kv_length + kv_offset > chunk_size
|
||||
and (left_padding_tokens > 0).any()
|
||||
):
|
||||
logger.warning_once(
|
||||
"Due to limitations of your current torch version, we cannot correctly account for the left-padding "
|
||||
"when computing the chunked attention pattern. This will lead to a wrong attention mask for the padded "
|
||||
"sequences. Behavior will be undefined. Please upgrade to `torch>=2.6` to solve this issue."
|
||||
)
|
||||
mask_factory_function = chunked_causal_mask_function(chunk_size, left_padding_tokens)
|
||||
mask_interface = ALL_MASK_ATTENTION_FUNCTIONS[config._attn_implementation]
|
||||
|
||||
# Defaulting to using non-vmap based mask creations except when detecting
|
||||
# users passing custom mask functions (as we cannot guarantee that they
|
||||
# are properly index-based as required by our implementation).
|
||||
use_vmap = False
|
||||
# Do not allow skip if we are compiling (this is to match BC)
|
||||
# TODO: cyril -> probably revisit and remove this, but a lot of tests rely on it
|
||||
allow_is_causal_skip = not getattr(past_key_values, "is_compileable", False)
|
||||
@ -1166,14 +1143,16 @@ def create_chunked_causal_mask(
|
||||
raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6")
|
||||
mask_factory_function = or_masks(mask_factory_function, or_mask_function)
|
||||
allow_is_causal_skip = False
|
||||
use_vmap = True
|
||||
if and_mask_function is not None:
|
||||
if not _is_torch_greater_or_equal_than_2_6:
|
||||
raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6")
|
||||
mask_factory_function = and_masks(mask_factory_function, and_mask_function)
|
||||
allow_is_causal_skip = False
|
||||
use_vmap = True
|
||||
|
||||
# If we detected packing format
|
||||
if packed_sequence_mask is not None and _is_torch_greater_or_equal_than_2_6:
|
||||
if packed_sequence_mask is not None:
|
||||
mask_factory_function = and_masks(mask_factory_function, packed_sequence_mask_function(packed_sequence_mask))
|
||||
allow_is_causal_skip = False
|
||||
|
||||
@ -1189,6 +1168,7 @@ def create_chunked_causal_mask(
|
||||
local_size=chunk_size, # Additional kwarg for sdpa
|
||||
dtype=dtype, # Additional kwarg for eager
|
||||
config=config, # Pass the config as well, in case someone wants to easily have their own mask_interface
|
||||
use_vmap=use_vmap, # Short-circuit to non-vmap expansions for the mask
|
||||
)
|
||||
return causal_mask
|
||||
|
||||
|
||||
@ -5113,203 +5113,12 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
be initialized correctly (i.e. weight initialization distribution).
|
||||
Also take care of setting the `_is_hf_initialized` flag for keys that are not missing.
|
||||
"""
|
||||
missing_keys_set = set(missing_keys)
|
||||
|
||||
model_state_dict_keys = set(self.state_dict().keys())
|
||||
|
||||
if missing_keys_set and missing_keys_set >= model_state_dict_keys:
|
||||
if is_deepspeed_zero3_enabled() and not is_quantized:
|
||||
import deepspeed
|
||||
|
||||
params = list(self.state_dict(keep_vars=True).values())
|
||||
with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
|
||||
self.initialize_weights()
|
||||
else:
|
||||
self.initialize_weights()
|
||||
return
|
||||
|
||||
for key in model_state_dict_keys:
|
||||
for key in self.state_dict():
|
||||
# If it's part of the keys that will be loaded, mark it as already initialized
|
||||
if key not in missing_keys_set:
|
||||
if key not in missing_keys:
|
||||
param_or_buffer = self.get_parameter_or_buffer(key)
|
||||
param_or_buffer._is_hf_initialized = True
|
||||
|
||||
handled_missing_keys: set[str] = set()
|
||||
|
||||
if missing_keys_set and not is_quantized:
|
||||
missing_params_by_module: defaultdict[str, set[str]] = defaultdict(set)
|
||||
missing_buffers_by_module: defaultdict[str, set[str]] = defaultdict(set)
|
||||
|
||||
for key in missing_keys_set:
|
||||
if "." not in key:
|
||||
continue
|
||||
module_path, name = key.rsplit(".", 1)
|
||||
if module_path == "":
|
||||
continue
|
||||
try:
|
||||
module = self.get_submodule(module_path)
|
||||
except AttributeError:
|
||||
continue
|
||||
|
||||
parameters = getattr(module, "_parameters", {})
|
||||
if name in parameters and parameters[name] is not None:
|
||||
missing_params_by_module[module_path].add(name)
|
||||
continue
|
||||
|
||||
buffers = getattr(module, "_buffers", {})
|
||||
non_persistent = getattr(module, "_non_persistent_buffers_set", set())
|
||||
if name in buffers and buffers[name] is not None and name not in non_persistent:
|
||||
missing_buffers_by_module[module_path].add(name)
|
||||
|
||||
# Sort by depth (deepest first) so child modules are handled before their parents.
|
||||
module_paths = sorted(
|
||||
set(missing_params_by_module.keys()) | set(missing_buffers_by_module.keys()),
|
||||
key=lambda name: name.count("."),
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
modules_info: list[tuple[str, nn.Module, set[str], set[str]]] = []
|
||||
for module_path in module_paths:
|
||||
try:
|
||||
module = self.get_submodule(module_path)
|
||||
except AttributeError:
|
||||
continue
|
||||
modules_info.append(
|
||||
(
|
||||
module_path,
|
||||
module,
|
||||
missing_params_by_module.get(module_path, set()),
|
||||
missing_buffers_by_module.get(module_path, set()),
|
||||
)
|
||||
)
|
||||
|
||||
if modules_info:
|
||||
|
||||
def _initialize_modules():
|
||||
with torch.no_grad():
|
||||
for module_path, module, module_missing_params, module_missing_buffers in modules_info:
|
||||
immediate_params = dict(module.named_parameters(recurse=False))
|
||||
non_persistent = getattr(module, "_non_persistent_buffers_set", set())
|
||||
persistent_buffers = {
|
||||
name: buffer
|
||||
for name, buffer in module.named_buffers(recurse=False)
|
||||
if name not in non_persistent
|
||||
}
|
||||
|
||||
already_initialized_params = {
|
||||
name
|
||||
for name in module_missing_params
|
||||
if name in immediate_params
|
||||
and getattr(immediate_params[name], "_is_hf_initialized", False)
|
||||
}
|
||||
already_initialized_buffers = {
|
||||
name
|
||||
for name in module_missing_buffers
|
||||
if name in persistent_buffers
|
||||
and getattr(persistent_buffers[name], "_is_hf_initialized", False)
|
||||
}
|
||||
if already_initialized_params or already_initialized_buffers:
|
||||
handled_missing_keys.update(
|
||||
{f"{module_path}.{name}" for name in already_initialized_params}
|
||||
)
|
||||
handled_missing_keys.update(
|
||||
{f"{module_path}.{name}" for name in already_initialized_buffers}
|
||||
)
|
||||
|
||||
missing_params = {
|
||||
name
|
||||
for name in module_missing_params
|
||||
if name in immediate_params
|
||||
and not getattr(immediate_params[name], "_is_hf_initialized", False)
|
||||
}
|
||||
missing_buffers = {
|
||||
name
|
||||
for name in module_missing_buffers
|
||||
if name in persistent_buffers
|
||||
and not getattr(persistent_buffers[name], "_is_hf_initialized", False)
|
||||
}
|
||||
|
||||
if not missing_params and not missing_buffers:
|
||||
continue
|
||||
|
||||
all_param_names = set(immediate_params.keys())
|
||||
all_buffer_names = set(persistent_buffers.keys())
|
||||
# If every immediate parameter and buffer is absent, recreate the whole module.
|
||||
fully_missing = (not all_param_names or missing_params == all_param_names) and (
|
||||
not all_buffer_names or missing_buffers == all_buffer_names
|
||||
)
|
||||
|
||||
if fully_missing:
|
||||
self._initialize_weights(module)
|
||||
else:
|
||||
if is_deepspeed_zero3_enabled():
|
||||
import deepspeed
|
||||
|
||||
preserved_parameters = {}
|
||||
for name, param in immediate_params.items():
|
||||
if name in missing_params:
|
||||
continue
|
||||
with deepspeed.zero.GatheredParameters([param], modifier_rank=None):
|
||||
preserved_parameters[name] = param.detach().clone()
|
||||
|
||||
preserved_buffers = {}
|
||||
for name, buffer in persistent_buffers.items():
|
||||
if name in missing_buffers or buffer is None:
|
||||
continue
|
||||
with deepspeed.zero.GatheredParameters([buffer], modifier_rank=None):
|
||||
preserved_buffers[name] = buffer.detach().clone()
|
||||
else:
|
||||
preserved_parameters = {
|
||||
name: param.detach().clone()
|
||||
for name, param in immediate_params.items()
|
||||
if name not in missing_params
|
||||
}
|
||||
preserved_buffers = {
|
||||
name: buffer.detach().clone()
|
||||
for name, buffer in persistent_buffers.items()
|
||||
if name not in missing_buffers and buffer is not None
|
||||
}
|
||||
|
||||
self._initialize_weights(module)
|
||||
|
||||
for name, tensor in preserved_parameters.items():
|
||||
module._parameters[name].data.copy_(tensor)
|
||||
module._parameters[name]._is_hf_initialized = True
|
||||
|
||||
for name, tensor in preserved_buffers.items():
|
||||
buffer = module._buffers[name]
|
||||
buffer.data.copy_(tensor)
|
||||
buffer._is_hf_initialized = True
|
||||
|
||||
for name in missing_params:
|
||||
param = module._parameters.get(name)
|
||||
if param is not None:
|
||||
param._is_hf_initialized = True
|
||||
handled_missing_keys.add(f"{module_path}.{name}")
|
||||
|
||||
for name in missing_buffers:
|
||||
buffer = module._buffers.get(name)
|
||||
if buffer is not None:
|
||||
buffer._is_hf_initialized = True
|
||||
handled_missing_keys.add(f"{module_path}.{name}")
|
||||
|
||||
if is_deepspeed_zero3_enabled():
|
||||
import deepspeed
|
||||
|
||||
params_to_gather = []
|
||||
for _, module, _, _ in modules_info:
|
||||
params_to_gather.extend(list(module.parameters(recurse=False)))
|
||||
|
||||
if params_to_gather:
|
||||
with deepspeed.zero.GatheredParameters(params_to_gather, modifier_rank=0):
|
||||
_initialize_modules()
|
||||
else:
|
||||
_initialize_modules()
|
||||
else:
|
||||
_initialize_modules()
|
||||
|
||||
missing_keys_set -= handled_missing_keys
|
||||
|
||||
def set_is_initialized_for_modules(module):
|
||||
# A module is already initialized if and only if all its children are also already initialized, and all
|
||||
# its immediate `nn.Parameter` and persistent buffers are also already initialized
|
||||
@ -5332,17 +5141,14 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
# each param)
|
||||
self.apply(set_is_initialized_for_modules)
|
||||
|
||||
not_initialized_parameters = list(
|
||||
{v for v in self.state_dict(keep_vars=True).values() if not getattr(v, "_is_hf_initialized", False)}
|
||||
)
|
||||
|
||||
if not not_initialized_parameters:
|
||||
return
|
||||
|
||||
# This will only initialize submodules that are not marked as initialized by the logic above.
|
||||
# This will only initialize submodules that are not marked as initialized by the line above.
|
||||
if is_deepspeed_zero3_enabled() and not is_quantized:
|
||||
import deepspeed
|
||||
|
||||
# keep_vars=True as we need the original tensors, so that the "_is_hf_initialized" is present on them
|
||||
not_initialized_parameters = list(
|
||||
{v for v in self.state_dict(keep_vars=True).values() if not getattr(v, "_is_hf_initialized", False)}
|
||||
)
|
||||
with deepspeed.zero.GatheredParameters(not_initialized_parameters, modifier_rank=0):
|
||||
self.initialize_weights()
|
||||
else:
|
||||
|
||||
@ -25,6 +25,7 @@ if TYPE_CHECKING:
|
||||
from .arcee import *
|
||||
from .aria import *
|
||||
from .audio_spectrogram_transformer import *
|
||||
from .audioflamingo3 import *
|
||||
from .auto import *
|
||||
from .autoformer import *
|
||||
from .aya_vision import *
|
||||
|
||||
@ -59,9 +59,6 @@ class AlignProcessor(ProcessorMixin):
|
||||
|
||||
"""
|
||||
|
||||
attributes = ["image_processor", "tokenizer"]
|
||||
image_processor_class = "EfficientNetImageProcessor"
|
||||
tokenizer_class = ("BertTokenizer", "BertTokenizerFast")
|
||||
valid_processor_kwargs = AlignProcessorKwargs
|
||||
|
||||
def __init__(self, image_processor, tokenizer):
|
||||
|
||||
@ -35,10 +35,6 @@ class AltCLIPProcessor(ProcessorMixin):
|
||||
The tokenizer is a required input.
|
||||
"""
|
||||
|
||||
attributes = ["image_processor", "tokenizer"]
|
||||
image_processor_class = ("CLIPImageProcessor", "CLIPImageProcessorFast")
|
||||
tokenizer_class = ("XLMRobertaTokenizer", "XLMRobertaTokenizerFast")
|
||||
|
||||
@deprecate_kwarg(old_name="feature_extractor", version="5.0.0", new_name="image_processor")
|
||||
def __init__(self, image_processor=None, tokenizer=None):
|
||||
super().__init__(image_processor, tokenizer)
|
||||
|
||||
@ -906,10 +906,6 @@ class AriaProcessor(ProcessorMixin):
|
||||
A dictionary indicating size conversions for images.
|
||||
"""
|
||||
|
||||
attributes = ["image_processor", "tokenizer"]
|
||||
image_processor_class = "AriaImageProcessor"
|
||||
tokenizer_class = "AutoTokenizer"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_processor=None,
|
||||
|
||||
@ -67,10 +67,6 @@ class AriaProcessor(ProcessorMixin):
|
||||
A dictionary indicating size conversions for images.
|
||||
"""
|
||||
|
||||
attributes = ["image_processor", "tokenizer"]
|
||||
image_processor_class = "AriaImageProcessor"
|
||||
tokenizer_class = "AutoTokenizer"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_processor=None,
|
||||
|
||||
@ -272,7 +272,9 @@ if __name__ == "__main__":
|
||||
"--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub."
|
||||
"--push_to_hub",
|
||||
action="store_true",
|
||||
help="Whether or not to push the converted model to the Hugging Face hub.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
31
src/transformers/models/audioflamingo3/__init__.py
Normal file
31
src/transformers/models/audioflamingo3/__init__.py
Normal file
@ -0,0 +1,31 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 NVIDIA CORPORATION and the HuggingFace Inc. team. All rights
|
||||
# reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import _LazyModule
|
||||
from ...utils.import_utils import define_import_structure
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_audioflamingo3 import *
|
||||
from .modeling_audioflamingo3 import *
|
||||
from .processing_audioflamingo3 import *
|
||||
else:
|
||||
import sys
|
||||
|
||||
_file = globals()["__file__"]
|
||||
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
||||
@ -0,0 +1,210 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 NVIDIA CORPORATION and the HuggingFace Inc. team. All rights
|
||||
# reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...utils import logging
|
||||
from ..auto import CONFIG_MAPPING, AutoConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class AudioFlamingo3EncoderConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of an [`AudioFlamingo3Encoder`]. It is used to instantiate an
|
||||
AudioFlamingo3 audio encoder according to the specified arguments, defining the model architecture. Instantiating a
|
||||
configuration with the defaults will yield a similar configuration to that of the audio encoder of the AudioFlamingo3
|
||||
architecture.
|
||||
|
||||
e.g. [nvidia/audio-flamingo-3-hf](https://huggingface.co/nvidia/audio-flamingo-3-hf)
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
num_mel_bins (`int`, *optional*, defaults to 128):
|
||||
Number of mel features used per input features. Should correspond to the value used in the
|
||||
`AudioFlamingo3Processor` class.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 32):
|
||||
Number of encoder layers.
|
||||
num_attention_heads (`int`, *optional*, defaults to 20):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
intermediate_size (`int`, *optional*, defaults to 5120):
|
||||
Dimensionality of the "intermediate" (often named feed-forward) layer in encoder.
|
||||
layerdrop (`float`, *optional*, defaults to 0.0):
|
||||
The LayerDrop probability for the encoder. See the [LayerDrop paper](https://huggingface.co/papers/1909.11556)
|
||||
for more details.
|
||||
activation_function (`str`, *optional*, defaults to `"gelu"`):
|
||||
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
||||
`"relu"`, `"silu"` and `"gelu_new"` are supported.
|
||||
hidden_size (`int`, *optional*, defaults to 1280):
|
||||
Dimensionality of the layers.
|
||||
dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
activation_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for activations inside the fully connected layer.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
scale_embedding (`bool`, *optional*, defaults to `False`):
|
||||
Scale embeddings by dividing by sqrt(hidden_size).
|
||||
max_source_positions (`int`, *optional*, defaults to 1500):
|
||||
The maximum sequence length of log-mel filter-bank features that this model might ever be used with.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AudioFlamingo3EncoderConfig, AudioFlamingo3Encoder
|
||||
|
||||
>>> # Initializing an AudioFlamingo3EncoderConfig
|
||||
>>> configuration = AudioFlamingo3EncoderConfig()
|
||||
|
||||
>>> # Initializing an AudioFlamingo3Encoder (with random weights)
|
||||
>>> model = AudioFlamingo3Encoder(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "audioflamingo3_encoder"
|
||||
|
||||
attribute_map = {
|
||||
"d_model": "hidden_size",
|
||||
"encoder_layers": "num_hidden_layers",
|
||||
"encoder_attention_heads": "num_attention_heads",
|
||||
"encoder_ffn_dim": "intermediate_size",
|
||||
"encoder_layerdrop": "layerdrop",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_mel_bins=128,
|
||||
num_hidden_layers=32,
|
||||
num_attention_heads=20,
|
||||
intermediate_size=5120,
|
||||
layerdrop=0.0,
|
||||
activation_function="gelu",
|
||||
hidden_size=1280,
|
||||
dropout=0.0,
|
||||
attention_dropout=0.0,
|
||||
activation_dropout=0.0,
|
||||
initializer_range=0.02,
|
||||
scale_embedding=False,
|
||||
max_source_positions=1500,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.num_mel_bins = num_mel_bins
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.intermediate_size = intermediate_size
|
||||
self.dropout = dropout
|
||||
self.attention_dropout = attention_dropout
|
||||
self.activation_dropout = activation_dropout
|
||||
self.activation_function = activation_function
|
||||
self.initializer_range = initializer_range
|
||||
self.layerdrop = layerdrop
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.scale_embedding = scale_embedding
|
||||
self.max_source_positions = max_source_positions
|
||||
|
||||
|
||||
class AudioFlamingo3Config(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of an [`AudioFlamingo3ForConditionalGeneration`]. It is used to instantiate an
|
||||
AudioFlamingo3 model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
||||
with the defaults will yield a similar configuration to that of the AudioFlamingo3.
|
||||
|
||||
e.g. [nvidia/audio-flamingo-3-hf](https://huggingface.co/nvidia/audio-flamingo-3-hf)
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
audio_config (`Union[AudioFlamingo3EncoderConfig, dict]`, *optional*, defaults to `AudioFlamingo3EncoderConfig`):
|
||||
The config object or dictionary of the audio backbone.
|
||||
text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `Qwen2Config`):
|
||||
The config object or dictionary of the text backbone.
|
||||
audio_token_id (`int`, *optional*, defaults to 151669):
|
||||
The audio token index to encode the audio prompt.
|
||||
projector_hidden_act (`str`, *optional*, defaults to `"gelu"`):
|
||||
Activation function used in the projector.
|
||||
projector_bias (`bool`, *optional*, defaults to `True`):
|
||||
Whether to include bias terms in the projector.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AudioFlamingo3ForConditionalGeneration, AudioFlamingo3Config, AudioFlamingo3EncoderConfig, Qwen2Config
|
||||
|
||||
>>> # Initializing an AudioFlamingo3Encoder config
|
||||
>>> audio_config = AudioFlamingo3EncoderConfig()
|
||||
|
||||
>>> # Initializing a Qwen2 config
|
||||
>>> text_config = Qwen2Config()
|
||||
|
||||
>>> # Initializing an AudioFlamingo3 configuration
|
||||
>>> configuration = AudioFlamingo3Config(audio_config, text_config)
|
||||
|
||||
>>> # Initializing a model from the audioflamingo3 style configuration
|
||||
>>> model = AudioFlamingo3ForConditionalGeneration(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "audioflamingo3"
|
||||
sub_configs = {
|
||||
"audio_config": AudioFlamingo3EncoderConfig,
|
||||
"text_config": AutoConfig,
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
audio_config=None,
|
||||
text_config=None,
|
||||
audio_token_id=151669,
|
||||
projector_hidden_act="gelu",
|
||||
projector_bias=True,
|
||||
**kwargs,
|
||||
):
|
||||
self.audio_token_id = audio_token_id
|
||||
|
||||
if isinstance(audio_config, dict):
|
||||
audio_config["model_type"] = audio_config.get("model_type", "audioflamingo3_encoder")
|
||||
audio_config = CONFIG_MAPPING[audio_config["model_type"]](**audio_config)
|
||||
elif audio_config is None:
|
||||
audio_config = CONFIG_MAPPING["audioflamingo3_encoder"]()
|
||||
|
||||
self.audio_config = audio_config
|
||||
|
||||
if isinstance(text_config, dict):
|
||||
text_config["model_type"] = text_config.get("model_type", "qwen2")
|
||||
text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
|
||||
elif text_config is None:
|
||||
text_config = CONFIG_MAPPING["qwen2"]()
|
||||
|
||||
self.text_config = text_config
|
||||
self.projector_hidden_act = projector_hidden_act
|
||||
self.projector_bias = projector_bias
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
__all__ = ["AudioFlamingo3Config", "AudioFlamingo3EncoderConfig"]
|
||||
@ -0,0 +1,286 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 NVIDIA CORPORATION and the HuggingFace Inc. team. All rights
|
||||
# reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Convert AudioFlamingo3 checkpoints into a Hugging Face repository layout."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from safetensors.torch import safe_open
|
||||
|
||||
from transformers import (
|
||||
AudioFlamingo3Config,
|
||||
AudioFlamingo3ForConditionalGeneration,
|
||||
AudioFlamingo3Processor,
|
||||
AutoTokenizer,
|
||||
GenerationConfig,
|
||||
Qwen2Config,
|
||||
WhisperFeatureExtractor,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
|
||||
|
||||
|
||||
def _load_json(p: Path):
|
||||
if not p.is_file():
|
||||
raise FileNotFoundError(f"Missing JSON: {p}")
|
||||
with p.open("r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def write_processor(src_root: Path, dst_root: Path):
|
||||
llm_dir = src_root / "llm"
|
||||
|
||||
# fmt: off
|
||||
tokenizer_chat_template = (
|
||||
"{% if messages[0]['role'] != 'system' %}"
|
||||
"{{ '<|im_start|>system\\nYou are a helpful assistant.<|im_end|>\\n' }}"
|
||||
"{% endif %}"
|
||||
"{% for message in messages if message['content'] is not none %}"
|
||||
"{{ '<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n' }}"
|
||||
"{% endfor %}"
|
||||
"{% if add_generation_prompt %}"
|
||||
"{{ '<|im_start|>assistant\\n' }}"
|
||||
"{% endif %}"
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
# fmt: off
|
||||
processor_chat_template = (
|
||||
"{% if messages[0]['role'] != 'system' %}"
|
||||
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
|
||||
"{% endif %}"
|
||||
"{% for m in messages if m['content'] is not none %}"
|
||||
"<|im_start|>{{ m['role'] }}\n"
|
||||
"{% if m['content'] is string %}"
|
||||
"{{ m['content'] }}"
|
||||
"{% else %}"
|
||||
"{% set audio = namespace(found=False) %}"
|
||||
"{% set text_buf = namespace(v='') %}"
|
||||
"{% for c in m['content'] %}"
|
||||
"{% if c.get('type') == 'audio' or 'audio' in c %}"
|
||||
"{% set audio.found = True %}"
|
||||
"{% elif c.get('type') == 'text' or 'text' in c %}"
|
||||
"{% set text_buf.v = text_buf.v + c['text'] %}"
|
||||
"{% endif %}"
|
||||
"{% endfor %}"
|
||||
"{% if audio.found %}{{ '<sound>' }}{% endif %}{{ text_buf.v }}"
|
||||
"{% endif %}"
|
||||
"<|im_end|>\n"
|
||||
"{% endfor %}"
|
||||
"{% if add_generation_prompt %}"
|
||||
"<|im_start|>assistant\n"
|
||||
"{% endif %}"
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
processor = AudioFlamingo3Processor(
|
||||
feature_extractor=WhisperFeatureExtractor(feature_size=128, return_attention_mask=True),
|
||||
tokenizer=AutoTokenizer.from_pretrained(str(llm_dir), chat_template=tokenizer_chat_template, use_fast=True),
|
||||
chat_template=processor_chat_template,
|
||||
)
|
||||
processor.save_pretrained(str(dst_root))
|
||||
|
||||
logger.info("processor (tokenizer + preprocessor)")
|
||||
return processor
|
||||
|
||||
|
||||
PREFIX_MAP = {
|
||||
"llm": "language_model",
|
||||
"sound_tower": "audio_tower",
|
||||
"sound_mm_projector": "multi_modal_projector",
|
||||
}
|
||||
|
||||
|
||||
def _resolve_component_dir(dirpath: Path):
|
||||
if not dirpath.is_dir():
|
||||
return None
|
||||
idx = dirpath / "model.safetensors.index.json"
|
||||
mono = dirpath / "model.safetensors"
|
||||
if idx.exists():
|
||||
wm = _load_json(idx).get("weight_map") or {}
|
||||
by_shard: dict[str, list[str]] = defaultdict(list)
|
||||
for k, shard in wm.items():
|
||||
by_shard[shard].append(k)
|
||||
return ("sharded", dirpath, {k: sorted(v) for k, v in sorted(by_shard.items())})
|
||||
if mono.exists():
|
||||
return ("file", mono)
|
||||
cands = sorted([x for x in dirpath.iterdir() if x.suffix == ".safetensors"])
|
||||
return ("file", cands[0]) if len(cands) == 1 else None
|
||||
|
||||
|
||||
def merge_and_shard_weights(src_root: Path, dst_root: Path, processor: AudioFlamingo3Processor):
|
||||
state: dict[str, Any] = {}
|
||||
for tag in PREFIX_MAP.keys():
|
||||
comp = _resolve_component_dir(src_root / tag)
|
||||
if not comp:
|
||||
continue
|
||||
|
||||
out_prefix = PREFIX_MAP.get(tag, tag)
|
||||
|
||||
if comp[0] == "file":
|
||||
fp: Path = comp[1]
|
||||
with safe_open(str(fp), framework="pt", device="cpu") as f:
|
||||
for k in f.keys():
|
||||
if k == "__metadata__":
|
||||
continue
|
||||
state[f"{out_prefix}.{k}"] = f.get_tensor(k)
|
||||
else:
|
||||
base: Path = comp[1]
|
||||
shard_map: dict[str, list[str]] = comp[2]
|
||||
for shard, keys in shard_map.items():
|
||||
sp = base / shard
|
||||
with safe_open(str(sp), framework="pt", device="cpu") as f:
|
||||
for k in keys:
|
||||
state[f"{out_prefix}.{k}"] = f.get_tensor(k)
|
||||
|
||||
if not state:
|
||||
raise FileNotFoundError("No tensors found in llm/, sound_tower/, or sound_mm_projector/.")
|
||||
|
||||
tok = processor.tokenizer
|
||||
|
||||
text_config = Qwen2Config(
|
||||
bos_token_id=tok.bos_token_id,
|
||||
eos_token_id=tok.eos_token_id,
|
||||
pad_token_id=tok.pad_token_id,
|
||||
vocab_size=len(tok),
|
||||
hidden_size=3584,
|
||||
intermediate_size=18944,
|
||||
model_max_length=8192,
|
||||
num_attention_heads=28,
|
||||
num_hidden_layers=28,
|
||||
num_key_value_heads=4,
|
||||
rope_theta=1000000.0,
|
||||
use_cache=False,
|
||||
)
|
||||
config = AudioFlamingo3Config(text_config=text_config, audio_token_id=tok.get_vocab()["<sound>"])
|
||||
model = AudioFlamingo3ForConditionalGeneration(config).to(dtype=torch.bfloat16)
|
||||
|
||||
# Update state dict to new key names if necessary
|
||||
projector_key_mapping = {
|
||||
"multi_modal_projector.layers.0.weight": "multi_modal_projector.linear_1.weight",
|
||||
"multi_modal_projector.layers.0.bias": "multi_modal_projector.linear_1.bias",
|
||||
"multi_modal_projector.layers.2.weight": "multi_modal_projector.linear_2.weight",
|
||||
"multi_modal_projector.layers.2.bias": "multi_modal_projector.linear_2.bias",
|
||||
}
|
||||
for old_key, new_key in projector_key_mapping.items():
|
||||
if old_key in state:
|
||||
state[new_key] = state.pop(old_key)
|
||||
|
||||
# Load weights into the instantiated model so we can push via `push_to_hub` later.
|
||||
load_res = model.load_state_dict(state, strict=True)
|
||||
# Enforce a clean load
|
||||
if getattr(load_res, "missing_keys", None) and load_res.missing_keys:
|
||||
mk = load_res.missing_keys
|
||||
raise ValueError(f"Missing keys when loading: {mk[:10]}{' ...' if len(mk) > 10 else ''}")
|
||||
if getattr(load_res, "unexpected_keys", None) and load_res.unexpected_keys:
|
||||
uk = load_res.unexpected_keys
|
||||
raise ValueError(f"Unexpected keys when loading: {uk[:10]}{' ...' if len(uk) > 10 else ''}")
|
||||
|
||||
generation_config = GenerationConfig(
|
||||
bos_token_id=tok.bos_token_id,
|
||||
eos_token_id=tok.eos_token_id,
|
||||
pad_token_id=tok.pad_token_id,
|
||||
max_new_tokens=2048,
|
||||
)
|
||||
model.generation_config = generation_config
|
||||
|
||||
model.save_pretrained(save_directory=str(dst_root))
|
||||
logger.info("model.safetensors index and shards")
|
||||
return model
|
||||
|
||||
|
||||
"""
|
||||
Reproducible Usage
|
||||
==================
|
||||
|
||||
1) Download the original AudioFlamingo-3 weights from NVIDIA (requires Git LFS):
|
||||
|
||||
```
|
||||
git lfs install
|
||||
git clone https://huggingface.co/nvidia/audio-flamingo-3
|
||||
```
|
||||
|
||||
This will create a folder `audio-flamingo-3/` containing the original components:
|
||||
`llm/`, `sound_tower/`, and `sound_mm_projector/`.
|
||||
|
||||
2) Convert to the Hugging Face Transformers format (locally):
|
||||
|
||||
```
|
||||
python src/transformers/models/audioflamingo3/convert_audioflamingo3_to_hf.py \
|
||||
--src_dir audio-flamingo-3 \
|
||||
--dst_dir audio-flamingo-3-hf
|
||||
```
|
||||
|
||||
3) Convert and push directly to the Hub (requires `huggingface-cli login` or `HF_TOKEN`):
|
||||
|
||||
```
|
||||
python src/transformers/models/audioflamingo3/convert_audioflamingo3_to_hf.py \
|
||||
--src_dir audio-flamingo-3 \
|
||||
--dst_dir audio-flamingo-3-hf \
|
||||
--push_to_hub <username-or-org>/audio-flamingo-3
|
||||
```
|
||||
|
||||
This command uploads both the processor (tokenizer + feature extractor) and the converted
|
||||
model (sharded safetensors + configs) to the specified Hub repository.
|
||||
"""
|
||||
|
||||
|
||||
def main() -> None:
|
||||
ap = argparse.ArgumentParser(description="Convert AudioFlamingo3 to Hugging Face format.")
|
||||
ap.add_argument("--src_dir", required=True, help="Source model root directory")
|
||||
ap.add_argument("--dst_dir", required=True, help="Destination directory for converted model")
|
||||
ap.add_argument(
|
||||
"--push_to_hub",
|
||||
default=None,
|
||||
type=str,
|
||||
help=(
|
||||
"Optional repository ID to push the converted assets to the Hugging Face Hub, "
|
||||
"e.g. 'username/audio-flamingo-3'."
|
||||
),
|
||||
)
|
||||
args = ap.parse_args()
|
||||
|
||||
src_root = Path(args.src_dir).resolve()
|
||||
if not src_root.is_dir():
|
||||
raise FileNotFoundError(f"Source directory not found: {src_root}")
|
||||
|
||||
dst_root = Path(args.dst_dir).resolve()
|
||||
if dst_root.exists():
|
||||
raise FileExistsError(f"Destination already exists: {dst_root}")
|
||||
|
||||
processor = write_processor(src_root, dst_root)
|
||||
model = merge_and_shard_weights(src_root, dst_root, processor)
|
||||
|
||||
# Optionally push converted assets using native push_to_hub only
|
||||
if args.push_to_hub:
|
||||
logger.info("Pushing processor to the Hub ...")
|
||||
processor.push_to_hub(args.push_to_hub)
|
||||
logger.info("Pushing model to the Hub ...")
|
||||
model.push_to_hub(args.push_to_hub)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -0,0 +1,628 @@
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# This file was automatically generated from src/transformers/models/audioflamingo3/modular_audioflamingo3.py.
|
||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||
# the file from the modular. If any change should be done, please apply the change to the
|
||||
# modular_audioflamingo3.py file directly. One of our CI enforces this.
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# coding=utf-8
|
||||
# Copyright 2025 NVIDIA CORPORATION and the HuggingFace Inc. team. All rights
|
||||
# reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from collections.abc import Callable
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, EncoderDecoderCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...masking_utils import eager_mask, padding_mask_function
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import BaseModelOutput, CausalLMOutputWithPast
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
|
||||
from ..auto import AutoModel, AutoModelForCausalLM
|
||||
from .configuration_audioflamingo3 import AudioFlamingo3Config, AudioFlamingo3EncoderConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
scaling: Optional[float] = None,
|
||||
dropout: float = 0.0,
|
||||
**kwargs,
|
||||
):
|
||||
if scaling is None:
|
||||
scaling = query.size(-1) ** -0.5
|
||||
|
||||
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
|
||||
if attention_mask is not None and attention_mask.ndim == 4:
|
||||
attn_weights = attn_weights + attention_mask[:, :, :, : key.shape[-2]]
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class AudioFlamingo3Attention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int,
|
||||
num_heads: int,
|
||||
dropout: float = 0.0,
|
||||
is_decoder: bool = False,
|
||||
bias: bool = True,
|
||||
is_causal: bool = False,
|
||||
layer_idx: Optional[int] = None,
|
||||
config: Optional[AudioFlamingo3Config] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.num_heads = num_heads
|
||||
self.dropout = dropout
|
||||
self.head_dim = embed_dim // num_heads
|
||||
self.config = config
|
||||
|
||||
if (self.head_dim * num_heads) != self.embed_dim:
|
||||
raise ValueError(
|
||||
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
|
||||
f" and `num_heads`: {num_heads})."
|
||||
)
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.is_decoder = is_decoder
|
||||
self.is_causal = is_causal
|
||||
|
||||
if layer_idx is None and is_decoder:
|
||||
logger.warning_once(
|
||||
f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and "
|
||||
"will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
|
||||
"when creating this class."
|
||||
)
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
|
||||
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
key_value_states: Optional[torch.Tensor] = None,
|
||||
past_key_values: Optional[Cache] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
cache_position: Optional[torch.Tensor] = None,
|
||||
# TODO: we need a refactor so that the different attention modules can get their specific kwargs
|
||||
# ATM, we have mixed things encoder, decoder, and encoder-decoder attn
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
|
||||
# if key_value_states are provided this layer is used as a cross-attention layer
|
||||
# for the decoder
|
||||
is_cross_attention = key_value_states is not None
|
||||
|
||||
# determine input shapes
|
||||
bsz, tgt_len = hidden_states.shape[:-1]
|
||||
q_input_shape = (bsz, tgt_len, -1, self.head_dim)
|
||||
|
||||
# Scaling is susceptible to floating point arithmetics' inprecisions
|
||||
# which can lead to different results (this is dependent from model
|
||||
# to model, e.g. audioflamingo3 is one such case). We therefore keep the
|
||||
# original order of scaling to follow the original implementation
|
||||
# and enforce no scaling (1.0) in the attention call below.
|
||||
query_states = self.q_proj(hidden_states) * self.scaling
|
||||
query_states = query_states.view(*q_input_shape)
|
||||
query_states = query_states.transpose(1, 2).contiguous()
|
||||
|
||||
# Check is encoder-decoder model is being used. Otherwise we'll get `DynamicCache`
|
||||
if past_key_values is not None and isinstance(past_key_values, EncoderDecoderCache):
|
||||
is_updated = past_key_values.is_updated.get(self.layer_idx)
|
||||
if is_cross_attention:
|
||||
# after the first generated id, we can subsequently re-use all key/value_states from cache
|
||||
past_key_values.is_updated[self.layer_idx] = True
|
||||
past_key_values = past_key_values.cross_attention_cache
|
||||
else:
|
||||
past_key_values = past_key_values.self_attention_cache
|
||||
|
||||
# use key_value_states if cross attention
|
||||
current_states = key_value_states if key_value_states is not None else hidden_states
|
||||
if is_cross_attention and past_key_values and is_updated:
|
||||
# reuse k,v, cross_attentions
|
||||
key_states = past_key_values.layers[self.layer_idx].keys
|
||||
value_states = past_key_values.layers[self.layer_idx].values
|
||||
else:
|
||||
key_states = self.k_proj(current_states).view(bsz, -1, self.num_heads, self.head_dim)
|
||||
value_states = self.v_proj(current_states).view(bsz, -1, self.num_heads, self.head_dim)
|
||||
key_states = key_states.transpose(1, 2).contiguous()
|
||||
value_states = value_states.transpose(1, 2).contiguous()
|
||||
if past_key_values is not None:
|
||||
# save all key/value_states to cache to be re-used for fast auto-regressive generation
|
||||
cache_position = cache_position if not is_cross_attention else None
|
||||
key_states, value_states = past_key_values.update(
|
||||
key_states, value_states, self.layer_idx, {"cache_position": cache_position}
|
||||
)
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
dropout=0.0 if not self.training else self.dropout,
|
||||
scaling=1.0,
|
||||
output_attentions=output_attentions,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class AudioFlamingo3EncoderLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config: AudioFlamingo3Config):
|
||||
super().__init__()
|
||||
self.embed_dim = config.d_model
|
||||
|
||||
self.self_attn = AudioFlamingo3Attention(
|
||||
embed_dim=self.embed_dim,
|
||||
num_heads=config.encoder_attention_heads,
|
||||
dropout=config.attention_dropout,
|
||||
config=config,
|
||||
)
|
||||
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
||||
self.dropout = config.dropout
|
||||
self.activation_fn = ACT2FN[config.activation_function]
|
||||
self.activation_dropout = config.activation_dropout
|
||||
self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
|
||||
self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
|
||||
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
output_attentions: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||||
attention_mask (`torch.FloatTensor`): attention mask of size
|
||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||
returned tensors for more detail.
|
||||
"""
|
||||
residual = hidden_states
|
||||
hidden_states = self.self_attn_layer_norm(hidden_states)
|
||||
hidden_states, attn_weights = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.final_layer_norm(hidden_states)
|
||||
hidden_states = self.activation_fn(self.fc1(hidden_states))
|
||||
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
|
||||
hidden_states = self.fc2(hidden_states)
|
||||
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
if hidden_states.dtype == torch.float16:
|
||||
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
|
||||
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
|
||||
|
||||
return hidden_states, attn_weights
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class AudioFlamingo3PreTrainedModel(PreTrainedModel):
|
||||
config: AudioFlamingo3Config
|
||||
base_model_prefix = "model"
|
||||
input_modalities = ["audio", "text"]
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["AudioFlamingo3Attention"]
|
||||
_skip_keys_device_placement = "past_key_values"
|
||||
_supports_flash_attn = True
|
||||
_supports_sdpa = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
# important: this ported version of AudioFlamingo3 isn't meant for training from scratch - only
|
||||
# inference and fine-tuning - so the proper init weights code has been removed
|
||||
std = (
|
||||
self.config.initializer_range
|
||||
if hasattr(self.config, "initializer_range")
|
||||
else self.config.audio_config.initializer_range
|
||||
)
|
||||
|
||||
if isinstance(module, (nn.Linear, nn.Conv1d)):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
|
||||
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
The audio model from AudioFlamingo3 without any head or projection on top.
|
||||
"""
|
||||
)
|
||||
class AudioFlamingo3Encoder(AudioFlamingo3PreTrainedModel):
|
||||
"""
|
||||
AudioFlamingo3 encoder: Whisper encoder, average pool (time/2), then LayerNorm.
|
||||
"""
|
||||
|
||||
# Ignore copy
|
||||
config: AudioFlamingo3EncoderConfig
|
||||
main_input_name = "input_features"
|
||||
input_modalities = "audio"
|
||||
_no_split_modules = ["AudioFlamingo3EncoderLayer"]
|
||||
|
||||
def __init__(self, config: AudioFlamingo3EncoderConfig):
|
||||
super().__init__(config)
|
||||
self.dropout = config.dropout
|
||||
self.layerdrop = config.encoder_layerdrop
|
||||
|
||||
embed_dim = config.d_model
|
||||
self.num_mel_bins = config.num_mel_bins
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.max_source_positions = config.max_source_positions
|
||||
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
|
||||
|
||||
self.conv1 = nn.Conv1d(self.num_mel_bins, embed_dim, kernel_size=3, padding=1)
|
||||
self.conv2 = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, stride=2, padding=1)
|
||||
|
||||
self.embed_positions = nn.Embedding(self.max_source_positions, embed_dim)
|
||||
self.embed_positions.requires_grad_(False)
|
||||
|
||||
self.layers = nn.ModuleList([AudioFlamingo3EncoderLayer(config) for _ in range(config.encoder_layers)])
|
||||
self.layer_norm = nn.LayerNorm(config.d_model)
|
||||
# Ignore copy
|
||||
self.avg_pooler = nn.AvgPool1d(2, stride=2)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def _freeze_parameters(self):
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
self._requires_grad = False
|
||||
|
||||
def get_input_embeddings(self) -> nn.Module:
|
||||
return self.conv1
|
||||
|
||||
def set_input_embeddings(self, value: nn.Module):
|
||||
self.conv1 = value
|
||||
|
||||
@can_return_tuple
|
||||
def forward(
|
||||
self,
|
||||
input_features: torch.Tensor,
|
||||
input_features_mask: Optional[torch.Tensor] = None,
|
||||
):
|
||||
r"""
|
||||
Args:
|
||||
input_features (`torch.FloatTensor` of shape `(batch_size, feature_size, sequence_length)`):
|
||||
Log-Mel features extracted from raw audio. Use the processor/feature extractor to compute and pad
|
||||
these features from waveform input.
|
||||
input_features_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Mask to avoid performing attention on padding feature indices. Mask values selected in `[0, 1]`:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
"""
|
||||
|
||||
# Prepare attention mask for transformer layers
|
||||
batch_size = input_features.shape[0]
|
||||
seq_len = (input_features.shape[-1] - 1) // 2 + 1 # After conv2 downsampling
|
||||
|
||||
input_features_lengths = input_features_mask.sum(-1)
|
||||
input_features_lengths = (input_features_lengths - 1) // 2 + 1 # conv2 downsampling
|
||||
input_features_mask = torch.arange(seq_len, device=input_features.device) < input_features_lengths[:, None]
|
||||
attention_mask = eager_mask(
|
||||
batch_size=batch_size,
|
||||
cache_position=torch.arange(seq_len, device=input_features.device),
|
||||
kv_length=seq_len,
|
||||
mask_function=padding_mask_function(input_features_mask),
|
||||
dtype=self.conv1.weight.dtype,
|
||||
)
|
||||
|
||||
# Conv front-end
|
||||
inputs_embeds = nn.functional.gelu(self.conv1(input_features))
|
||||
inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))
|
||||
inputs_embeds = inputs_embeds.permute(0, 2, 1)
|
||||
|
||||
# Add positions, dropout
|
||||
hidden_states = inputs_embeds + self.embed_positions.weight
|
||||
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||
|
||||
# Transformer stack
|
||||
for layer in self.layers:
|
||||
drop = self.training and torch.rand([]) < self.layerdrop
|
||||
if not drop:
|
||||
hidden_states = layer(hidden_states, attention_mask)[0]
|
||||
|
||||
# AvgPool (time/2) + LayerNorm
|
||||
hidden_states = hidden_states.permute(0, 2, 1)
|
||||
hidden_states = self.avg_pooler(hidden_states).permute(0, 2, 1)
|
||||
hidden_states = self.layer_norm(hidden_states)
|
||||
|
||||
return BaseModelOutput(
|
||||
last_hidden_state=hidden_states,
|
||||
)
|
||||
|
||||
# Ignore copy
|
||||
def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor):
|
||||
"""
|
||||
Computes the output length of the convolutional layers and the output length of the audio encoder
|
||||
"""
|
||||
input_lengths = (input_lengths - 1) // 2 + 1
|
||||
output_lengths = (input_lengths - 2) // 2 + 1
|
||||
return input_lengths, output_lengths
|
||||
|
||||
|
||||
class AudioFlamingo3MultiModalProjector(nn.Module):
|
||||
"""
|
||||
Audio adaptor (small MLP) that projects AudioFlamingo3Encoder features
|
||||
to the LLM embedding space so they can replace `<sound>` tokens.
|
||||
"""
|
||||
|
||||
def __init__(self, config: AudioFlamingo3Config):
|
||||
super().__init__()
|
||||
self.linear_1 = nn.Linear(
|
||||
config.audio_config.hidden_size, config.text_config.hidden_size, bias=config.projector_bias
|
||||
)
|
||||
self.act = ACT2FN[config.projector_hidden_act]
|
||||
self.linear_2 = nn.Linear(
|
||||
config.text_config.hidden_size, config.text_config.hidden_size, bias=config.projector_bias
|
||||
)
|
||||
|
||||
def forward(self, audio_features):
|
||||
hidden_states = self.linear_1(audio_features)
|
||||
hidden_states = self.act(hidden_states)
|
||||
hidden_states = self.linear_2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
The AudioFlamingo3 model which consists of a fine-tuned Whisper encoder, a multi-modal projector and a Qwen2 language model.
|
||||
"""
|
||||
)
|
||||
class AudioFlamingo3ForConditionalGeneration(AudioFlamingo3PreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = None
|
||||
_tp_plan = None
|
||||
_pp_plan = None
|
||||
_keep_in_fp32_modules_strict = None
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.vocab_size = config.text_config.vocab_size
|
||||
self.audio_tower = AutoModel.from_config(config.audio_config)
|
||||
self.language_model = AutoModelForCausalLM.from_config(config.text_config)
|
||||
self.multi_modal_projector = AudioFlamingo3MultiModalProjector(config)
|
||||
# Similar to Qwen2Audio
|
||||
if self.language_model._tied_weights_keys is not None:
|
||||
self._tied_weights_keys = [f"language_model.{k}" for k in self.language_model._tied_weights_keys]
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.language_model.get_input_embeddings()
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.language_model.set_input_embeddings(value)
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.language_model.get_output_embeddings()
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.language_model.set_output_embeddings(new_embeddings)
|
||||
|
||||
def set_decoder(self, decoder):
|
||||
self.language_model.set_decoder(decoder)
|
||||
|
||||
def get_decoder(self):
|
||||
return self.language_model.get_decoder()
|
||||
|
||||
def get_audio_features(
|
||||
self, input_features: torch.FloatTensor, input_features_mask: torch.Tensor
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
This method is used to get the audio embeddings from input features (a log mel spectrogram), meaning inferring the audio encoder and the multi-modal projector.
|
||||
Args:
|
||||
input_features (`torch.FloatTensor`):
|
||||
Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be
|
||||
obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]` or a
|
||||
`numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into
|
||||
`input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel features, padding
|
||||
and conversion into a tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`]
|
||||
input_features_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`):
|
||||
Mask to avoid performing attention on padded feature indices.
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`:
|
||||
The audio embeddings.
|
||||
"""
|
||||
|
||||
# Encode audio
|
||||
encoder_output = self.audio_tower(input_features, input_features_mask=input_features_mask)
|
||||
audio_embeds = self.multi_modal_projector(encoder_output.last_hidden_state)
|
||||
|
||||
# Mask according to avg pooling (which is after attention blocks)
|
||||
post_lengths = (input_features_mask.sum(-1) - 2) // 2 + 1
|
||||
valid_mask = torch.arange(audio_embeds.shape[1], device=post_lengths.device)[None, :] < post_lengths[:, None]
|
||||
audio_embeds = audio_embeds[valid_mask.to(audio_embeds.device)]
|
||||
return audio_embeds
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
input_features: Optional[torch.FloatTensor] = None,
|
||||
input_features_mask: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Cache] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> CausalLMOutputWithPast:
|
||||
r"""
|
||||
input_features_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`):
|
||||
Mask to avoid performing attention on padding feature indices. Mask values selected in `[0, 1]`:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AudioFlamingo3ForConditionalGeneration, AutoProcessor
|
||||
|
||||
>>> model_id = "nvidia/audio-flamingo-3-hf"
|
||||
>>> processor = AutoProcessor.from_pretrained(model_id)
|
||||
>>> model = AudioFlamingo3ForConditionalGeneration.from_pretrained(model_id, device_map="auto")
|
||||
|
||||
>>> conversations = [
|
||||
>>> [
|
||||
>>> {
|
||||
>>> "role": "user",
|
||||
>>> "content": [
|
||||
>>> {"type": "text", "text": "Transcribe the input speech."},
|
||||
>>> {
|
||||
>>> "type": "audio",
|
||||
>>> "path": "https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/t_837b89f2-26aa-4ee2-bdf6-f73f0dd59b26.wav",
|
||||
>>> },
|
||||
>>> ],
|
||||
>>> }
|
||||
>>> ],
|
||||
>>> [
|
||||
>>> {
|
||||
>>> "role": "user",
|
||||
>>> "content": [
|
||||
>>> {
|
||||
>>> "type": "text",
|
||||
>>> "text": "This track feels really peaceful and introspective. What elements make it feel so calming and meditative?",
|
||||
>>> },
|
||||
>>> {"type": "audio", "path": "https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/FPSbCAANfbJLVSwD.mp3"},
|
||||
>>> ],
|
||||
>>> }
|
||||
>>> ],
|
||||
>>> ]
|
||||
|
||||
>>> inputs = processor.apply_chat_template(
|
||||
>>> conversations,
|
||||
>>> tokenize=True,
|
||||
>>> add_generation_prompt=True,
|
||||
>>> return_dict=True,
|
||||
>>> ).to(model.device)
|
||||
|
||||
>>> outputs = model.generate(**inputs, max_new_tokens=500)
|
||||
|
||||
>>> decoded_outputs = processor.batch_decode(
|
||||
>>> outputs[:, inputs["input_ids"].shape[1]:], skip_special_tokens=True
|
||||
>>> )
|
||||
>>> print(decoded_outputs)
|
||||
["The spoken content of the audio is...", "The track's calming and meditative feel can be attributed to..."]
|
||||
```"""
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
if input_features is not None and input_ids is not None:
|
||||
audio_embeds = self.get_audio_features(input_features, input_features_mask)
|
||||
|
||||
# replace text-audio token placeholders with audio embeddings
|
||||
audio_token_mask = (input_ids == self.config.audio_token_id).unsqueeze(-1)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(
|
||||
audio_token_mask.to(inputs_embeds.device), audio_embeds.to(inputs_embeds.device)
|
||||
)
|
||||
|
||||
outputs: CausalLMOutputWithPast = self.language_model(
|
||||
inputs_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
labels=labels,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
logits_to_keep=logits_to_keep,
|
||||
**kwargs,
|
||||
)
|
||||
return outputs
|
||||
|
||||
def prepare_inputs_for_generation(self, *args, **kwargs):
|
||||
# Overwritten -- we should not pass input_features when we are in cached decoding stage
|
||||
|
||||
input_features = kwargs.pop("input_features", None)
|
||||
input_features_mask = kwargs.pop("input_features_mask", None)
|
||||
cache_position = kwargs.get("cache_position")
|
||||
|
||||
model_inputs = super().prepare_inputs_for_generation(*args, **kwargs)
|
||||
|
||||
if cache_position is not None and cache_position[0] == 0:
|
||||
# input_features should only be passed when we are not in cached decoding stage
|
||||
if input_features is not None:
|
||||
model_inputs["input_features"] = input_features
|
||||
if input_features_mask is not None:
|
||||
model_inputs["input_features_mask"] = input_features_mask
|
||||
|
||||
return model_inputs
|
||||
|
||||
|
||||
__all__ = ["AudioFlamingo3ForConditionalGeneration", "AudioFlamingo3PreTrainedModel", "AudioFlamingo3Encoder"]
|
||||
307
src/transformers/models/audioflamingo3/modular_audioflamingo3.py
Normal file
307
src/transformers/models/audioflamingo3/modular_audioflamingo3.py
Normal file
@ -0,0 +1,307 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 NVIDIA CORPORATION and the HuggingFace Inc. team. All rights
|
||||
# reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache
|
||||
from ...masking_utils import eager_mask, padding_mask_function
|
||||
from ...modeling_outputs import BaseModelOutput, CausalLMOutputWithPast
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
|
||||
from ..qwen2_audio.modeling_qwen2_audio import (
|
||||
Qwen2AudioEncoder,
|
||||
Qwen2AudioPreTrainedModel,
|
||||
)
|
||||
from ..voxtral.modeling_voxtral import VoxtralForConditionalGeneration, VoxtralMultiModalProjector
|
||||
from ..whisper.modeling_whisper import WhisperEncoderLayer
|
||||
from .configuration_audioflamingo3 import AudioFlamingo3Config
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class AudioFlamingo3EncoderLayer(WhisperEncoderLayer):
|
||||
pass
|
||||
|
||||
|
||||
class AudioFlamingo3PreTrainedModel(Qwen2AudioPreTrainedModel):
|
||||
pass
|
||||
|
||||
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
The audio model from AudioFlamingo3 without any head or projection on top.
|
||||
"""
|
||||
)
|
||||
class AudioFlamingo3Encoder(Qwen2AudioEncoder):
|
||||
"""
|
||||
AudioFlamingo3 encoder: Whisper encoder, average pool (time/2), then LayerNorm.
|
||||
"""
|
||||
|
||||
@can_return_tuple
|
||||
def forward(
|
||||
self,
|
||||
input_features: torch.Tensor,
|
||||
input_features_mask: Optional[torch.Tensor] = None,
|
||||
):
|
||||
r"""
|
||||
Args:
|
||||
input_features (`torch.FloatTensor` of shape `(batch_size, feature_size, sequence_length)`):
|
||||
Log-Mel features extracted from raw audio. Use the processor/feature extractor to compute and pad
|
||||
these features from waveform input.
|
||||
input_features_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Mask to avoid performing attention on padding feature indices. Mask values selected in `[0, 1]`:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
"""
|
||||
|
||||
# Prepare attention mask for transformer layers
|
||||
batch_size = input_features.shape[0]
|
||||
seq_len = (input_features.shape[-1] - 1) // 2 + 1 # After conv2 downsampling
|
||||
|
||||
input_features_lengths = input_features_mask.sum(-1)
|
||||
input_features_lengths = (input_features_lengths - 1) // 2 + 1 # conv2 downsampling
|
||||
input_features_mask = torch.arange(seq_len, device=input_features.device) < input_features_lengths[:, None]
|
||||
attention_mask = eager_mask(
|
||||
batch_size=batch_size,
|
||||
cache_position=torch.arange(seq_len, device=input_features.device),
|
||||
kv_length=seq_len,
|
||||
mask_function=padding_mask_function(input_features_mask),
|
||||
dtype=self.conv1.weight.dtype,
|
||||
)
|
||||
|
||||
# Conv front-end
|
||||
inputs_embeds = nn.functional.gelu(self.conv1(input_features))
|
||||
inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))
|
||||
inputs_embeds = inputs_embeds.permute(0, 2, 1)
|
||||
|
||||
# Add positions, dropout
|
||||
hidden_states = inputs_embeds + self.embed_positions.weight
|
||||
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||
|
||||
# Transformer stack
|
||||
for layer in self.layers:
|
||||
drop = self.training and torch.rand([]) < self.layerdrop
|
||||
if not drop:
|
||||
hidden_states = layer(hidden_states, attention_mask)[0]
|
||||
|
||||
# AvgPool (time/2) + LayerNorm
|
||||
hidden_states = hidden_states.permute(0, 2, 1)
|
||||
hidden_states = self.avg_pooler(hidden_states).permute(0, 2, 1)
|
||||
hidden_states = self.layer_norm(hidden_states)
|
||||
|
||||
return BaseModelOutput(
|
||||
last_hidden_state=hidden_states,
|
||||
)
|
||||
|
||||
|
||||
class AudioFlamingo3MultiModalProjector(VoxtralMultiModalProjector):
|
||||
"""
|
||||
Audio adaptor (small MLP) that projects AudioFlamingo3Encoder features
|
||||
to the LLM embedding space so they can replace `<sound>` tokens.
|
||||
"""
|
||||
|
||||
def __init__(self, config: AudioFlamingo3Config):
|
||||
super().__init__()
|
||||
self.linear_1 = nn.Linear(
|
||||
config.audio_config.hidden_size, config.text_config.hidden_size, bias=config.projector_bias
|
||||
)
|
||||
self.act = ACT2FN[config.projector_hidden_act]
|
||||
self.linear_2 = nn.Linear(
|
||||
config.text_config.hidden_size, config.text_config.hidden_size, bias=config.projector_bias
|
||||
)
|
||||
|
||||
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
The AudioFlamingo3 model which consists of a fine-tuned Whisper encoder, a multi-modal projector and a Qwen2 language model.
|
||||
"""
|
||||
)
|
||||
class AudioFlamingo3ForConditionalGeneration(VoxtralForConditionalGeneration):
|
||||
_tied_weights_keys = None
|
||||
_tp_plan = None
|
||||
_pp_plan = None
|
||||
_keep_in_fp32_modules_strict = None
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
# Similar to Qwen2Audio
|
||||
if self.language_model._tied_weights_keys is not None:
|
||||
self._tied_weights_keys = [f"language_model.{k}" for k in self.language_model._tied_weights_keys]
|
||||
|
||||
def get_audio_features(
|
||||
self, input_features: torch.FloatTensor, input_features_mask: torch.Tensor
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
This method is used to get the audio embeddings from input features (a log mel spectrogram), meaning inferring the audio encoder and the multi-modal projector.
|
||||
Args:
|
||||
input_features (`torch.FloatTensor`):
|
||||
Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be
|
||||
obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]` or a
|
||||
`numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into
|
||||
`input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel features, padding
|
||||
and conversion into a tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`]
|
||||
input_features_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`):
|
||||
Mask to avoid performing attention on padded feature indices.
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`:
|
||||
The audio embeddings.
|
||||
"""
|
||||
|
||||
# Encode audio
|
||||
encoder_output = self.audio_tower(input_features, input_features_mask=input_features_mask)
|
||||
audio_embeds = self.multi_modal_projector(encoder_output.last_hidden_state)
|
||||
|
||||
# Mask according to avg pooling (which is after attention blocks)
|
||||
post_lengths = (input_features_mask.sum(-1) - 2) // 2 + 1
|
||||
valid_mask = torch.arange(audio_embeds.shape[1], device=post_lengths.device)[None, :] < post_lengths[:, None]
|
||||
audio_embeds = audio_embeds[valid_mask.to(audio_embeds.device)]
|
||||
return audio_embeds
|
||||
|
||||
def get_audio_embeds(self):
|
||||
raise NotImplementedError("This method is not supported for AudioFlamingo3ForConditionalGeneration.")
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
input_features: Optional[torch.FloatTensor] = None,
|
||||
input_features_mask: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Cache] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> CausalLMOutputWithPast:
|
||||
r"""
|
||||
input_features_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`):
|
||||
Mask to avoid performing attention on padding feature indices. Mask values selected in `[0, 1]`:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AudioFlamingo3ForConditionalGeneration, AutoProcessor
|
||||
|
||||
>>> model_id = "nvidia/audio-flamingo-3-hf"
|
||||
>>> processor = AutoProcessor.from_pretrained(model_id)
|
||||
>>> model = AudioFlamingo3ForConditionalGeneration.from_pretrained(model_id, device_map="auto")
|
||||
|
||||
>>> conversations = [
|
||||
>>> [
|
||||
>>> {
|
||||
>>> "role": "user",
|
||||
>>> "content": [
|
||||
>>> {"type": "text", "text": "Transcribe the input speech."},
|
||||
>>> {
|
||||
>>> "type": "audio",
|
||||
>>> "path": "https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/t_837b89f2-26aa-4ee2-bdf6-f73f0dd59b26.wav",
|
||||
>>> },
|
||||
>>> ],
|
||||
>>> }
|
||||
>>> ],
|
||||
>>> [
|
||||
>>> {
|
||||
>>> "role": "user",
|
||||
>>> "content": [
|
||||
>>> {
|
||||
>>> "type": "text",
|
||||
>>> "text": "This track feels really peaceful and introspective. What elements make it feel so calming and meditative?",
|
||||
>>> },
|
||||
>>> {"type": "audio", "path": "https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/FPSbCAANfbJLVSwD.mp3"},
|
||||
>>> ],
|
||||
>>> }
|
||||
>>> ],
|
||||
>>> ]
|
||||
|
||||
>>> inputs = processor.apply_chat_template(
|
||||
>>> conversations,
|
||||
>>> tokenize=True,
|
||||
>>> add_generation_prompt=True,
|
||||
>>> return_dict=True,
|
||||
>>> ).to(model.device)
|
||||
|
||||
>>> outputs = model.generate(**inputs, max_new_tokens=500)
|
||||
|
||||
>>> decoded_outputs = processor.batch_decode(
|
||||
>>> outputs[:, inputs["input_ids"].shape[1]:], skip_special_tokens=True
|
||||
>>> )
|
||||
>>> print(decoded_outputs)
|
||||
["The spoken content of the audio is...", "The track's calming and meditative feel can be attributed to..."]
|
||||
```"""
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
if input_features is not None and input_ids is not None:
|
||||
audio_embeds = self.get_audio_features(input_features, input_features_mask)
|
||||
|
||||
# replace text-audio token placeholders with audio embeddings
|
||||
audio_token_mask = (input_ids == self.config.audio_token_id).unsqueeze(-1)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(
|
||||
audio_token_mask.to(inputs_embeds.device), audio_embeds.to(inputs_embeds.device)
|
||||
)
|
||||
|
||||
outputs: CausalLMOutputWithPast = self.language_model(
|
||||
inputs_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
labels=labels,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
logits_to_keep=logits_to_keep,
|
||||
**kwargs,
|
||||
)
|
||||
return outputs
|
||||
|
||||
def prepare_inputs_for_generation(self, *args, **kwargs):
|
||||
# Overwritten -- we should not pass input_features when we are in cached decoding stage
|
||||
|
||||
input_features = kwargs.pop("input_features", None)
|
||||
input_features_mask = kwargs.pop("input_features_mask", None)
|
||||
cache_position = kwargs.get("cache_position")
|
||||
|
||||
model_inputs = super().prepare_inputs_for_generation(*args, **kwargs)
|
||||
|
||||
if cache_position is not None and cache_position[0] == 0:
|
||||
# input_features should only be passed when we are not in cached decoding stage
|
||||
if input_features is not None:
|
||||
model_inputs["input_features"] = input_features
|
||||
if input_features_mask is not None:
|
||||
model_inputs["input_features_mask"] = input_features_mask
|
||||
|
||||
return model_inputs
|
||||
|
||||
|
||||
__all__ = ["AudioFlamingo3ForConditionalGeneration", "AudioFlamingo3PreTrainedModel", "AudioFlamingo3Encoder"]
|
||||
@ -0,0 +1,318 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 NVIDIA CORPORATION and the HuggingFace Inc. team. All rights
|
||||
# reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import re
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ...audio_utils import AudioInput, make_list_of_audio
|
||||
from ...feature_extraction_utils import BatchFeature
|
||||
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
|
||||
from ...tokenization_utils_base import TextInput
|
||||
from ...utils import is_torch_available, logging
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
MAX_AUDIO_LEN = 10 * 60 # 10 minutes
|
||||
DEFAULT_TRANSCRIPTION_PROMPT = "Transcribe the input speech."
|
||||
|
||||
|
||||
class AudioFlamingo3ProcessorKwargs(ProcessingKwargs, total=False):
|
||||
_defaults = {
|
||||
"text_kwargs": {
|
||||
"padding": True,
|
||||
},
|
||||
"audio_kwargs": {
|
||||
"sampling_rate": 16000,
|
||||
"chunk_length": 30.0,
|
||||
"return_attention_mask": True,
|
||||
"padding": "max_length",
|
||||
},
|
||||
"common_kwargs": {
|
||||
"return_tensors": "pt",
|
||||
"padding_side": "left",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class AudioFlamingo3Processor(ProcessorMixin):
|
||||
r"""
|
||||
Constructs an AudioFlamingo3 processor which wraps an AudioFlamingo3 feature extractor and an AudioFlamingo3
|
||||
tokenizer into a single processor.
|
||||
|
||||
[`AudioFlamingo3Processor`] offers all the functionalities of [`WhisperFeatureExtractor`] and
|
||||
[`Qwen2TokenizerFast`]. See the [`~AudioFlamingo3Processor.__call__`] for more information.
|
||||
|
||||
Args:
|
||||
feature_extractor ([`WhisperFeatureExtractor`]):
|
||||
The feature extractor is a required input.
|
||||
tokenizer ([`Qwen2TokenizerFast`]):
|
||||
The tokenizer is a required input.
|
||||
chat_template (`Optional[str]`, *optional*):
|
||||
The Jinja template to use for formatting the conversation. If not provided, the tokenizer's default chat
|
||||
template will be used.
|
||||
audio_token (`Optional[str]`, *optional*, defaults to `"<sound>"`):
|
||||
Special token used to represent audio inputs in the chat template.
|
||||
"""
|
||||
|
||||
attributes = ["feature_extractor", "tokenizer"]
|
||||
feature_extractor_class = "WhisperFeatureExtractor"
|
||||
tokenizer_class = "Qwen2TokenizerFast"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
feature_extractor,
|
||||
tokenizer,
|
||||
chat_template=None,
|
||||
audio_token="<sound>",
|
||||
):
|
||||
self.audio_token = audio_token
|
||||
self.audio_token_id = tokenizer.convert_tokens_to_ids(audio_token)
|
||||
super().__init__(feature_extractor, tokenizer, chat_template=chat_template)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
text: Union[TextInput, list[TextInput]],
|
||||
audio: Optional[AudioInput] = None,
|
||||
output_labels: Optional[bool] = False,
|
||||
**kwargs: Unpack[AudioFlamingo3ProcessorKwargs],
|
||||
) -> BatchFeature:
|
||||
r"""
|
||||
Main method to prepare one or several text sequence(s) and audio waveform(s) for the model. This
|
||||
method expands `<sound>` placeholders in the text based on the post-pool frame counts of the
|
||||
audio windows, then tokenizes the provided strings as-is, and extracts log-mel features
|
||||
with [`WhisperFeatureExtractor`]. If `audio` is `None`, no audio processing is performed and
|
||||
the text is tokenized as-is (LM-only behavior).
|
||||
|
||||
Args:
|
||||
text (`str` or `list[str]`):
|
||||
Input sequence or batch of sequences.
|
||||
audio (`np.ndarray` or `list[np.ndarray]`):
|
||||
Input audio or batch of audios as NumPy arrays. If provided, there must be as many `text` inputs as
|
||||
`audio` inputs.
|
||||
output_labels (bool, *optional*, default=False):
|
||||
Whether to return labels for training.
|
||||
|
||||
Returns:
|
||||
[`BatchFeature`]: A dictionary with tokenized text (`input_ids`, `attention_mask`) and
|
||||
audio features (`input_features`, `input_features_mask`).
|
||||
"""
|
||||
|
||||
# Merge defaults with user kwargs
|
||||
call_kwargs = self._merge_kwargs(
|
||||
AudioFlamingo3ProcessorKwargs,
|
||||
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
text_kwargs = call_kwargs["text_kwargs"]
|
||||
audio_kwargs = call_kwargs["audio_kwargs"]
|
||||
return_tensors = text_kwargs.get("return_tensors")
|
||||
if return_tensors != "pt":
|
||||
raise ValueError(f"{self.__class__.__name__} only supports `return_tensors='pt'`.")
|
||||
|
||||
if isinstance(text, str):
|
||||
text = [text]
|
||||
elif not (isinstance(text, (list, tuple)) and all(isinstance(t, str) for t in text)):
|
||||
raise ValueError("Invalid input text. Please provide a string, or a list of strings")
|
||||
|
||||
audio_inputs = {}
|
||||
if audio is not None:
|
||||
audio = make_list_of_audio(audio)
|
||||
if len(text) != len(audio):
|
||||
raise ValueError(f"Got {len(text)} text but {len(audio)} audios; they must match 1:1.")
|
||||
|
||||
# Determine number of chunks per sample, and flatten
|
||||
window_size = int(audio_kwargs["sampling_rate"] * audio_kwargs["chunk_length"])
|
||||
max_windows = int(MAX_AUDIO_LEN // audio_kwargs["chunk_length"])
|
||||
|
||||
per_sample_windows: list[int] = []
|
||||
flat_chunks: list[np.ndarray] = []
|
||||
|
||||
for audio_el in audio:
|
||||
n_samples = int(audio_el.shape[0])
|
||||
n_win = max(1, (n_samples + window_size - 1) // window_size)
|
||||
if n_win > max_windows:
|
||||
logger.warning(
|
||||
f"Audio duration ({n_samples / audio_kwargs['sampling_rate']:.1f}s) exceeds {MAX_AUDIO_LEN}s; truncating to first {MAX_AUDIO_LEN}s."
|
||||
)
|
||||
n_win = max_windows
|
||||
per_sample_windows.append(n_win)
|
||||
|
||||
time_cap = min(n_samples, n_win * window_size)
|
||||
for i in range(n_win):
|
||||
start = i * window_size
|
||||
end = min((i + 1) * window_size, time_cap)
|
||||
flat_chunks.append(audio_el[start:end])
|
||||
|
||||
# Feature extraction
|
||||
audio_inputs = self.feature_extractor(flat_chunks, **audio_kwargs)
|
||||
padding_mask = audio_inputs.pop("attention_mask")
|
||||
audio_inputs["input_features_mask"] = padding_mask
|
||||
|
||||
# Compute sequence lengths token counting
|
||||
audio_lenghts = torch.stack([s.sum() for s in torch.split(padding_mask.sum(-1), per_sample_windows)])
|
||||
conv_output_lengths = (audio_lenghts - 1) // 2 + 1 # After conv2 downsampling
|
||||
audio_tokens_lengths = (conv_output_lengths - 2) // 2 + 1 # After avg pooling
|
||||
|
||||
# expand audio tokens in text
|
||||
for i, audio_length in enumerate(audio_tokens_lengths):
|
||||
expanded = re.sub(re.escape(self.audio_token), self.audio_token * audio_length, text[i])
|
||||
text[i] = expanded
|
||||
|
||||
# Tokenize
|
||||
text_inputs = self.tokenizer(text, **text_kwargs)
|
||||
|
||||
data = {**text_inputs, **audio_inputs}
|
||||
if output_labels:
|
||||
labels = data["input_ids"].clone()
|
||||
labels[labels == self.audio_token_id] = -100
|
||||
labels[labels == self.tokenizer.pad_token_id] = -100
|
||||
data["labels"] = labels
|
||||
|
||||
return BatchFeature(data=data, tensor_type=return_tensors)
|
||||
|
||||
@property
|
||||
def model_input_names(self) -> list[str]:
|
||||
tok_names = self.tokenizer.model_input_names
|
||||
fea_names = self.feature_extractor.model_input_names
|
||||
return list(dict.fromkeys(tok_names + fea_names + ["input_features_mask"]))
|
||||
|
||||
def apply_transcription_request(
|
||||
self,
|
||||
audio: Union[str, list[str], AudioInput],
|
||||
prompt: Optional[Union[str, list[str]]] = None,
|
||||
**kwargs: Unpack[AudioFlamingo3ProcessorKwargs],
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
Prepare inputs for automatic speech recognition without manually writing the default transcription prompt.
|
||||
|
||||
Args:
|
||||
audio (`str`, `list[str]`, `np.ndarray`, `torch.Tensor`, `list[np.ndarray]`, `list[torch.Tensor]`):
|
||||
Audio to transcribe. Strings are interpreted as local paths or URLs and will be loaded automatically by
|
||||
the chat template loader; NumPy arrays and PyTorch tensors are forwarded directly.
|
||||
prompt (`str` or `list[str]`, *optional*):
|
||||
Custom prompt(s) to include in the user turn. A list must be the same length as the batch. When `None`,
|
||||
each sample uses `"Transcribe the input speech."`.
|
||||
**kwargs:
|
||||
Additional keyword arguments forwarded to [`~AudioFlamingo3Processor.apply_chat_template`] (for example
|
||||
`text_kwargs`, `audio_kwargs`, ...).
|
||||
|
||||
Returns:
|
||||
[`BatchFeature`]: Processor outputs ready to be passed to [`AudioFlamingo3ForConditionalGeneration.generate`].
|
||||
|
||||
"""
|
||||
|
||||
if isinstance(audio, str):
|
||||
audio_items: list[Union[str, np.ndarray]] = [audio]
|
||||
elif isinstance(audio, (list, tuple)) and audio and all(isinstance(el, str) for el in audio):
|
||||
audio_items = list(audio)
|
||||
else:
|
||||
audio_items = list(make_list_of_audio(audio))
|
||||
if is_torch_available():
|
||||
audio_items = [el.detach().cpu().numpy() if isinstance(el, torch.Tensor) else el for el in audio_items]
|
||||
|
||||
batch_size = len(audio_items)
|
||||
if batch_size == 0:
|
||||
raise ValueError("`audio` must contain at least one sample.")
|
||||
|
||||
if prompt is None:
|
||||
prompts = [DEFAULT_TRANSCRIPTION_PROMPT] * batch_size
|
||||
elif isinstance(prompt, str):
|
||||
prompts = [prompt] * batch_size
|
||||
elif isinstance(prompt, (list, tuple)):
|
||||
if len(prompt) != batch_size:
|
||||
raise ValueError(
|
||||
f"Received {len(prompt)} prompt(s) for {batch_size} audio sample(s); counts must match."
|
||||
)
|
||||
prompts = []
|
||||
for item in prompt:
|
||||
if item is None:
|
||||
prompts.append(DEFAULT_TRANSCRIPTION_PROMPT)
|
||||
elif isinstance(item, str):
|
||||
prompts.append(item)
|
||||
else:
|
||||
raise TypeError("Each prompt must be a string or `None`.")
|
||||
else:
|
||||
raise TypeError("`prompt` must be a string, a sequence of strings, or `None`.")
|
||||
|
||||
conversations = [
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": prompt_text},
|
||||
{"type": "audio", "path": audio_item}
|
||||
if isinstance(audio_item, str)
|
||||
else {"type": "audio", "audio": audio_item},
|
||||
],
|
||||
}
|
||||
]
|
||||
for prompt_text, audio_item in zip(prompts, audio_items)
|
||||
]
|
||||
|
||||
return self.apply_chat_template(
|
||||
conversations,
|
||||
tokenize=True,
|
||||
add_generation_prompt=True,
|
||||
return_dict=True,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def batch_decode(self, *args, strip_prefix=False, **kwargs):
|
||||
"""
|
||||
Forward arguments to [`~PreTrainedTokenizer.batch_decode`] and optionally remove the assistant framing the model
|
||||
was trained to produce.
|
||||
|
||||
AF3 transcription requests respond with sentences such as `"The spoken content of the audio is \"...\"."`.
|
||||
Setting `strip_prefix=True` trims the fixed prefix for just the transcription text.
|
||||
"""
|
||||
decoded = self.tokenizer.batch_decode(*args, **kwargs)
|
||||
if strip_prefix:
|
||||
decoded = [self._strip_assistant_prefix_and_quotes(text) for text in decoded]
|
||||
return decoded
|
||||
|
||||
def _strip_assistant_prefix_and_quotes(self, text: str) -> str:
|
||||
"""
|
||||
Remove the assistant prefix and surrounding quotes from a decoded transcription string.
|
||||
"""
|
||||
|
||||
stripped = text.strip()
|
||||
|
||||
for prefix in (
|
||||
"The spoken content of the audio is",
|
||||
"The transcription of the audio is",
|
||||
):
|
||||
if stripped.startswith(prefix):
|
||||
stripped = stripped[len(prefix) :].strip()
|
||||
break
|
||||
|
||||
if stripped.endswith("."):
|
||||
stripped = stripped[:-1].strip()
|
||||
|
||||
if len(stripped) >= 2 and stripped[0] == stripped[-1] and stripped[0] in {"'", '"'}:
|
||||
stripped = stripped[1:-1].strip()
|
||||
|
||||
return stripped
|
||||
|
||||
|
||||
__all__ = ["AudioFlamingo3Processor"]
|
||||
@ -45,6 +45,8 @@ CONFIG_MAPPING_NAMES = OrderedDict[str, str](
|
||||
("aria", "AriaConfig"),
|
||||
("aria_text", "AriaTextConfig"),
|
||||
("audio-spectrogram-transformer", "ASTConfig"),
|
||||
("audioflamingo3", "AudioFlamingo3Config"),
|
||||
("audioflamingo3_encoder", "AudioFlamingo3EncoderConfig"),
|
||||
("autoformer", "AutoformerConfig"),
|
||||
("aya_vision", "AyaVisionConfig"),
|
||||
("bamba", "BambaConfig"),
|
||||
@ -223,6 +225,7 @@ CONFIG_MAPPING_NAMES = OrderedDict[str, str](
|
||||
("layoutlm", "LayoutLMConfig"),
|
||||
("layoutlmv2", "LayoutLMv2Config"),
|
||||
("layoutlmv3", "LayoutLMv3Config"),
|
||||
("layoutxlm", "LayoutLMv2Config"),
|
||||
("led", "LEDConfig"),
|
||||
("levit", "LevitConfig"),
|
||||
("lfm2", "Lfm2Config"),
|
||||
@ -476,6 +479,8 @@ MODEL_NAMES_MAPPING = OrderedDict[str, str](
|
||||
("aria", "Aria"),
|
||||
("aria_text", "AriaText"),
|
||||
("audio-spectrogram-transformer", "Audio Spectrogram Transformer"),
|
||||
("audioflamingo3", "AudioFlamingo3"),
|
||||
("audioflamingo3_encoder", "AudioFlamingo3Encoder"),
|
||||
("autoformer", "Autoformer"),
|
||||
("aya_vision", "AyaVision"),
|
||||
("bamba", "Bamba"),
|
||||
@ -959,6 +964,7 @@ DEPRECATED_MODELS = [
|
||||
|
||||
SPECIAL_MODEL_TYPE_TO_MODULE_NAME = OrderedDict[str, str](
|
||||
[
|
||||
("audioflamingo3_encoder", "audioflamingo3"),
|
||||
("openai-gpt", "openai"),
|
||||
("data2vec-audio", "data2vec"),
|
||||
("data2vec-text", "data2vec"),
|
||||
|
||||
@ -15,7 +15,6 @@
|
||||
"""AutoFeatureExtractor class."""
|
||||
|
||||
import importlib
|
||||
import json
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
from typing import Optional, Union
|
||||
@ -24,7 +23,7 @@ from typing import Optional, Union
|
||||
from ...configuration_utils import PreTrainedConfig
|
||||
from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code
|
||||
from ...feature_extraction_utils import FeatureExtractionMixin
|
||||
from ...utils import CONFIG_NAME, FEATURE_EXTRACTOR_NAME, cached_file, logging
|
||||
from ...utils import CONFIG_NAME, FEATURE_EXTRACTOR_NAME, PROCESSOR_NAME, cached_file, logging, safe_load_json_file
|
||||
from .auto_factory import _LazyAutoMapping
|
||||
from .configuration_auto import (
|
||||
CONFIG_MAPPING_NAMES,
|
||||
@ -41,6 +40,7 @@ FEATURE_EXTRACTOR_MAPPING_NAMES = OrderedDict(
|
||||
("audio-spectrogram-transformer", "ASTFeatureExtractor"),
|
||||
("clap", "ClapFeatureExtractor"),
|
||||
("clvp", "ClvpFeatureExtractor"),
|
||||
("csm", "EncodecFeatureExtractor"),
|
||||
("dac", "DacFeatureExtractor"),
|
||||
("data2vec-audio", "Wav2Vec2FeatureExtractor"),
|
||||
("dia", "DiaFeatureExtractor"),
|
||||
@ -49,14 +49,20 @@ FEATURE_EXTRACTOR_MAPPING_NAMES = OrderedDict(
|
||||
("granite_speech", "GraniteSpeechFeatureExtractor"),
|
||||
("hubert", "Wav2Vec2FeatureExtractor"),
|
||||
("kyutai_speech_to_text", "KyutaiSpeechToTextFeatureExtractor"),
|
||||
("markuplm", "MarkupLMFeatureExtractor"),
|
||||
("mctct", "MCTCTFeatureExtractor"),
|
||||
("mimi", "EncodecFeatureExtractor"),
|
||||
("moonshine", "Wav2Vec2FeatureExtractor"),
|
||||
("moshi", "EncodecFeatureExtractor"),
|
||||
("musicgen", "EncodecFeatureExtractor"),
|
||||
("musicgen_melody", "MusicgenMelodyFeatureExtractor"),
|
||||
("parakeet_ctc", "ParakeetFeatureExtractor"),
|
||||
("parakeet_encoder", "ParakeetFeatureExtractor"),
|
||||
("phi4_multimodal", "Phi4MultimodalFeatureExtractor"),
|
||||
("pop2piano", "Pop2PianoFeatureExtractor"),
|
||||
("qwen2_5_omni", "WhisperFeatureExtractor"),
|
||||
("qwen2_audio", "WhisperFeatureExtractor"),
|
||||
("qwen3_omni_moe", "WhisperFeatureExtractor"),
|
||||
("seamless_m4t", "SeamlessM4TFeatureExtractor"),
|
||||
("seamless_m4t_v2", "SeamlessM4TFeatureExtractor"),
|
||||
("sew", "Wav2Vec2FeatureExtractor"),
|
||||
@ -66,6 +72,7 @@ FEATURE_EXTRACTOR_MAPPING_NAMES = OrderedDict(
|
||||
("unispeech", "Wav2Vec2FeatureExtractor"),
|
||||
("unispeech-sat", "Wav2Vec2FeatureExtractor"),
|
||||
("univnet", "UnivNetFeatureExtractor"),
|
||||
("voxtral", "WhisperFeatureExtractor"),
|
||||
("wav2vec2", "Wav2Vec2FeatureExtractor"),
|
||||
("wav2vec2-bert", "Wav2Vec2FeatureExtractor"),
|
||||
("wav2vec2-conformer", "Wav2Vec2FeatureExtractor"),
|
||||
@ -167,9 +174,10 @@ def get_feature_extractor_config(
|
||||
feature_extractor.save_pretrained("feature-extractor-test")
|
||||
feature_extractor_config = get_feature_extractor_config("feature-extractor-test")
|
||||
```"""
|
||||
resolved_config_file = cached_file(
|
||||
# Load with a priority given to the nested processor config, if available in repo
|
||||
resolved_processor_file = cached_file(
|
||||
pretrained_model_name_or_path,
|
||||
FEATURE_EXTRACTOR_NAME,
|
||||
filename=PROCESSOR_NAME,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
@ -178,16 +186,37 @@ def get_feature_extractor_config(
|
||||
local_files_only=local_files_only,
|
||||
_raise_exceptions_for_gated_repo=False,
|
||||
_raise_exceptions_for_missing_entries=False,
|
||||
_raise_exceptions_for_connection_errors=False,
|
||||
)
|
||||
if resolved_config_file is None:
|
||||
logger.info(
|
||||
"Could not locate the feature extractor configuration file, will try to use the model config instead."
|
||||
)
|
||||
resolved_feature_extractor_file = cached_file(
|
||||
pretrained_model_name_or_path,
|
||||
filename=FEATURE_EXTRACTOR_NAME,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
revision=revision,
|
||||
local_files_only=local_files_only,
|
||||
_raise_exceptions_for_gated_repo=False,
|
||||
_raise_exceptions_for_missing_entries=False,
|
||||
)
|
||||
|
||||
# An empty list if none of the possible files is found in the repo
|
||||
if not resolved_feature_extractor_file and not resolved_processor_file:
|
||||
logger.info("Could not locate the feature extractor configuration file.")
|
||||
return {}
|
||||
|
||||
with open(resolved_config_file, encoding="utf-8") as reader:
|
||||
return json.load(reader)
|
||||
# Load feature_extractor dict. Priority goes as (nested config if found -> feature extractor config)
|
||||
# We are downloading both configs because almost all models have a `processor_config.json` but
|
||||
# not all of these are nested. We need to check if it was saved recently as nested or if it is legacy style
|
||||
feature_extractor_dict = {}
|
||||
if resolved_processor_file is not None:
|
||||
processor_dict = safe_load_json_file(resolved_processor_file)
|
||||
if "feature_extractor" in processor_dict:
|
||||
feature_extractor_dict = processor_dict["feature_extractor"]
|
||||
|
||||
if resolved_feature_extractor_file is not None and feature_extractor_dict is None:
|
||||
feature_extractor_dict = safe_load_json_file(resolved_feature_extractor_file)
|
||||
return feature_extractor_dict
|
||||
|
||||
|
||||
class AutoFeatureExtractor:
|
||||
|
||||
@ -15,7 +15,6 @@
|
||||
"""AutoImageProcessor class."""
|
||||
|
||||
import importlib
|
||||
import json
|
||||
import os
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
@ -29,12 +28,14 @@ from ...image_processing_utils_fast import BaseImageProcessorFast
|
||||
from ...utils import (
|
||||
CONFIG_NAME,
|
||||
IMAGE_PROCESSOR_NAME,
|
||||
PROCESSOR_NAME,
|
||||
cached_file,
|
||||
is_timm_config_dict,
|
||||
is_timm_local_checkpoint,
|
||||
is_torchvision_available,
|
||||
is_vision_available,
|
||||
logging,
|
||||
safe_load_json_file,
|
||||
)
|
||||
from ...utils.import_utils import requires
|
||||
from .auto_factory import _LazyAutoMapping
|
||||
@ -62,7 +63,9 @@ else:
|
||||
("aimv2", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
|
||||
("aimv2_vision_model", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
|
||||
("align", ("EfficientNetImageProcessor", "EfficientNetImageProcessorFast")),
|
||||
("altclip", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
|
||||
("aria", ("AriaImageProcessor", None)),
|
||||
("aya_vision", ("GotOcr2ImageProcessor", "GotOcr2ImageProcessorFast")),
|
||||
("beit", ("BeitImageProcessor", "BeitImageProcessorFast")),
|
||||
("bit", ("BitImageProcessor", "BitImageProcessorFast")),
|
||||
("blip", ("BlipImageProcessor", "BlipImageProcessorFast")),
|
||||
@ -73,6 +76,8 @@ else:
|
||||
("clip", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
|
||||
("clipseg", ("ViTImageProcessor", "ViTImageProcessorFast")),
|
||||
("cohere2_vision", (None, "Cohere2VisionImageProcessorFast")),
|
||||
("colpali", ("SiglipImageProcessor", "SiglipImageProcessorFast")),
|
||||
("colqwen2", ("Qwen2VLImageProcessor", "Qwen2VLImageProcessorFast")),
|
||||
("conditional_detr", ("ConditionalDetrImageProcessor", "ConditionalDetrImageProcessorFast")),
|
||||
("convnext", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")),
|
||||
("convnextv2", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")),
|
||||
@ -95,8 +100,10 @@ else:
|
||||
("efficientformer", ("EfficientFormerImageProcessor", None)),
|
||||
("efficientloftr", ("EfficientLoFTRImageProcessor", "EfficientLoFTRImageProcessorFast")),
|
||||
("efficientnet", ("EfficientNetImageProcessor", "EfficientNetImageProcessorFast")),
|
||||
("emu3", ("Emu3ImageProcessor", None)),
|
||||
("eomt", ("EomtImageProcessor", "EomtImageProcessorFast")),
|
||||
("flava", ("FlavaImageProcessor", "FlavaImageProcessorFast")),
|
||||
("florence2", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
|
||||
("focalnet", ("BitImageProcessor", "BitImageProcessorFast")),
|
||||
("fuyu", ("FuyuImageProcessor", "FuyuImageProcessorFast")),
|
||||
("gemma3", ("Gemma3ImageProcessor", "Gemma3ImageProcessorFast")),
|
||||
@ -114,11 +121,13 @@ else:
|
||||
("ijepa", ("ViTImageProcessor", "ViTImageProcessorFast")),
|
||||
("imagegpt", ("ImageGPTImageProcessor", "ImageGPTImageProcessorFast")),
|
||||
("instructblip", ("BlipImageProcessor", "BlipImageProcessorFast")),
|
||||
("internvl", ("GotOcr2ImageProcessor", "GotOcr2ImageProcessorFast")),
|
||||
("janus", ("JanusImageProcessor", "JanusImageProcessorFast")),
|
||||
("kosmos-2", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
|
||||
("kosmos-2.5", ("Kosmos2_5ImageProcessor", "Kosmos2_5ImageProcessorFast")),
|
||||
("layoutlmv2", ("LayoutLMv2ImageProcessor", "LayoutLMv2ImageProcessorFast")),
|
||||
("layoutlmv3", ("LayoutLMv3ImageProcessor", "LayoutLMv3ImageProcessorFast")),
|
||||
("layoutxlm", ("LayoutLMv2ImageProcessor", "LayoutLMv2ImageProcessor")),
|
||||
("levit", ("LevitImageProcessor", "LevitImageProcessorFast")),
|
||||
("lfm2_vl", (None, "Lfm2VlImageProcessorFast")),
|
||||
("lightglue", ("LightGlueImageProcessor", "LightGlueImageProcessorFast")),
|
||||
@ -141,6 +150,7 @@ else:
|
||||
("mobilevitv2", ("MobileViTImageProcessor", "MobileViTImageProcessorFast")),
|
||||
("nat", ("ViTImageProcessor", "ViTImageProcessorFast")),
|
||||
("nougat", ("NougatImageProcessor", "NougatImageProcessorFast")),
|
||||
("omdet-turbo", ("DetrImageProcessor", "DetrImageProcessorFast")),
|
||||
("oneformer", ("OneFormerImageProcessor", "OneFormerImageProcessorFast")),
|
||||
("ovis2", ("Ovis2ImageProcessor", "Ovis2ImageProcessorFast")),
|
||||
("owlv2", ("Owlv2ImageProcessor", "Owlv2ImageProcessorFast")),
|
||||
@ -155,14 +165,17 @@ else:
|
||||
("prompt_depth_anything", ("PromptDepthAnythingImageProcessor", "PromptDepthAnythingImageProcessorFast")),
|
||||
("pvt", ("PvtImageProcessor", "PvtImageProcessorFast")),
|
||||
("pvt_v2", ("PvtImageProcessor", "PvtImageProcessorFast")),
|
||||
("qwen2_5_omni", ("Qwen2VLImageProcessor", "Qwen2VLImageProcessorFast")),
|
||||
("qwen2_5_vl", ("Qwen2VLImageProcessor", "Qwen2VLImageProcessorFast")),
|
||||
("qwen2_vl", ("Qwen2VLImageProcessor", "Qwen2VLImageProcessorFast")),
|
||||
("qwen3_omni_moe", ("Qwen2VLImageProcessor", "Qwen2VLImageProcessorFast")),
|
||||
("qwen3_vl", ("Qwen2VLImageProcessor", "Qwen2VLImageProcessorFast")),
|
||||
("regnet", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")),
|
||||
("resnet", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")),
|
||||
("rt_detr", ("RTDetrImageProcessor", "RTDetrImageProcessorFast")),
|
||||
("sam", ("SamImageProcessor", "SamImageProcessorFast")),
|
||||
("sam2", (None, "Sam2ImageProcessorFast")),
|
||||
("sam2_video", (None, "Sam2ImageProcessorFast")),
|
||||
("sam_hq", ("SamImageProcessor", "SamImageProcessorFast")),
|
||||
("segformer", ("SegformerImageProcessor", "SegformerImageProcessorFast")),
|
||||
("seggpt", ("SegGptImageProcessor", None)),
|
||||
@ -180,12 +193,14 @@ else:
|
||||
("textnet", ("TextNetImageProcessor", "TextNetImageProcessorFast")),
|
||||
("timesformer", ("VideoMAEImageProcessor", None)),
|
||||
("timm_wrapper", ("TimmWrapperImageProcessor", None)),
|
||||
("trocr", ("ViTImageProcessor", "ViTImageProcessorFast")),
|
||||
("tvlt", ("TvltImageProcessor", None)),
|
||||
("tvp", ("TvpImageProcessor", "TvpImageProcessorFast")),
|
||||
("udop", ("LayoutLMv3ImageProcessor", "LayoutLMv3ImageProcessorFast")),
|
||||
("upernet", ("SegformerImageProcessor", "SegformerImageProcessorFast")),
|
||||
("van", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")),
|
||||
("video_llama_3", ("VideoLlama3ImageProcessor", "VideoLlama3ImageProcessorFast")),
|
||||
("video_llava", ("VideoLlavaImageProcessor", None)),
|
||||
("videomae", ("VideoMAEImageProcessor", None)),
|
||||
("vilt", ("ViltImageProcessor", "ViltImageProcessorFast")),
|
||||
("vipllava", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
|
||||
@ -305,9 +320,10 @@ def get_image_processor_config(
|
||||
image_processor.save_pretrained("image-processor-test")
|
||||
image_processor_config = get_image_processor_config("image-processor-test")
|
||||
```"""
|
||||
resolved_config_file = cached_file(
|
||||
# Load with a priority given to the nested processor config, if available in repo
|
||||
resolved_processor_file = cached_file(
|
||||
pretrained_model_name_or_path,
|
||||
IMAGE_PROCESSOR_NAME,
|
||||
filename=PROCESSOR_NAME,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
@ -316,16 +332,38 @@ def get_image_processor_config(
|
||||
local_files_only=local_files_only,
|
||||
_raise_exceptions_for_gated_repo=False,
|
||||
_raise_exceptions_for_missing_entries=False,
|
||||
_raise_exceptions_for_connection_errors=False,
|
||||
)
|
||||
if resolved_config_file is None:
|
||||
logger.info(
|
||||
"Could not locate the image processor configuration file, will try to use the model config instead."
|
||||
)
|
||||
resolved_image_processor_file = cached_file(
|
||||
pretrained_model_name_or_path,
|
||||
filename=IMAGE_PROCESSOR_NAME,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
revision=revision,
|
||||
local_files_only=local_files_only,
|
||||
_raise_exceptions_for_gated_repo=False,
|
||||
_raise_exceptions_for_missing_entries=False,
|
||||
)
|
||||
|
||||
# An empty list if none of the possible files is found in the repo
|
||||
if not resolved_image_processor_file and not resolved_processor_file:
|
||||
logger.info("Could not locate the image processor configuration file.")
|
||||
return {}
|
||||
|
||||
with open(resolved_config_file, encoding="utf-8") as reader:
|
||||
return json.load(reader)
|
||||
# Load image_processor dict. Priority goes as (nested config if found -> image processor config)
|
||||
# We are downloading both configs because almost all models have a `processor_config.json` but
|
||||
# not all of these are nested. We need to check if it was saved recently as nested or if it is legacy style
|
||||
image_processor_dict = {}
|
||||
if resolved_processor_file is not None:
|
||||
processor_dict = safe_load_json_file(resolved_processor_file)
|
||||
if "image_processor" in processor_dict:
|
||||
image_processor_dict = processor_dict["image_processor"]
|
||||
|
||||
if resolved_image_processor_file is not None and image_processor_dict is None:
|
||||
image_processor_dict = safe_load_json_file(resolved_image_processor_file)
|
||||
|
||||
return image_processor_dict
|
||||
|
||||
|
||||
def _warning_fast_image_processor_available(fast_class):
|
||||
@ -524,10 +562,9 @@ class AutoImageProcessor:
|
||||
)
|
||||
use_fast = False
|
||||
if use_fast:
|
||||
for image_processors in IMAGE_PROCESSOR_MAPPING_NAMES.values():
|
||||
if image_processor_type in image_processors:
|
||||
break
|
||||
else:
|
||||
# Check if the fast image processor class exists
|
||||
image_processor_class_fast = get_image_processor_class_from_name(image_processor_type)
|
||||
if image_processor_class_fast is None:
|
||||
image_processor_type = image_processor_type[:-4]
|
||||
use_fast = False
|
||||
logger.warning_once(
|
||||
|
||||
@ -53,6 +53,8 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
||||
("aria", "AriaModel"),
|
||||
("aria_text", "AriaTextModel"),
|
||||
("audio-spectrogram-transformer", "ASTModel"),
|
||||
("audioflamingo3", "AudioFlamingo3ForConditionalGeneration"),
|
||||
("audioflamingo3_encoder", "AudioFlamingo3Encoder"),
|
||||
("autoformer", "AutoformerModel"),
|
||||
("aya_vision", "AyaVisionModel"),
|
||||
("bamba", "BambaModel"),
|
||||
@ -445,6 +447,7 @@ MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for pre-training mapping
|
||||
("albert", "AlbertForPreTraining"),
|
||||
("audioflamingo3", "AudioFlamingo3ForConditionalGeneration"),
|
||||
("bart", "BartForConditionalGeneration"),
|
||||
("bert", "BertForPreTraining"),
|
||||
("big_bird", "BigBirdForPreTraining"),
|
||||
@ -1159,6 +1162,7 @@ MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES = OrderedDict(
|
||||
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for Seq2Seq Causal LM mapping
|
||||
("audioflamingo3", "AudioFlamingo3ForConditionalGeneration"),
|
||||
("bart", "BartForConditionalGeneration"),
|
||||
("bigbird_pegasus", "BigBirdPegasusForConditionalGeneration"),
|
||||
("blenderbot", "BlenderbotForConditionalGeneration"),
|
||||
@ -1700,6 +1704,7 @@ MODEL_FOR_BACKBONE_MAPPING_NAMES = OrderedDict(
|
||||
("dinov2", "Dinov2Backbone"),
|
||||
("dinov2_with_registers", "Dinov2WithRegistersBackbone"),
|
||||
("dinov3_convnext", "DINOv3ConvNextBackbone"),
|
||||
("dinov3_vit", "DINOv3ViTBackbone"),
|
||||
("focalnet", "FocalNetBackbone"),
|
||||
("hgnet_v2", "HGNetV2Backbone"),
|
||||
("hiera", "HieraBackbone"),
|
||||
|
||||
@ -48,6 +48,7 @@ PROCESSOR_MAPPING_NAMES = OrderedDict(
|
||||
("align", "AlignProcessor"),
|
||||
("altclip", "AltCLIPProcessor"),
|
||||
("aria", "AriaProcessor"),
|
||||
("audioflamingo3", "AudioFlamingo3Processor"),
|
||||
("aya_vision", "AyaVisionProcessor"),
|
||||
("bark", "BarkProcessor"),
|
||||
("blip", "BlipProcessor"),
|
||||
@ -107,6 +108,7 @@ PROCESSOR_MAPPING_NAMES = OrderedDict(
|
||||
("mllama", "MllamaProcessor"),
|
||||
("mm-grounding-dino", "GroundingDinoProcessor"),
|
||||
("moonshine", "Wav2Vec2Processor"),
|
||||
("omdet-turbo", "OmDetTurboProcessor"),
|
||||
("oneformer", "OneFormerProcessor"),
|
||||
("ovis2", "Ovis2Processor"),
|
||||
("owlv2", "Owlv2Processor"),
|
||||
|
||||
@ -72,6 +72,7 @@ TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]](
|
||||
),
|
||||
),
|
||||
("align", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("altclip", ("XLMRobertaTokenizer", "XLMRobertaTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("arcee", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("aria", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("aya_vision", (None, "CohereTokenizerFast" if is_tokenizers_available() else None)),
|
||||
@ -156,6 +157,7 @@ TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]](
|
||||
("codegen", ("CodeGenTokenizer", "CodeGenTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("cohere", (None, "CohereTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("cohere2", (None, "CohereTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("cohere2_vision", (None, "CohereTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("colpali", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("colqwen2", ("Qwen2Tokenizer", "Qwen2TokenizerFast" if is_tokenizers_available() else None)),
|
||||
("convbert", ("ConvBertTokenizer", "ConvBertTokenizerFast" if is_tokenizers_available() else None)),
|
||||
@ -224,6 +226,7 @@ TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]](
|
||||
),
|
||||
),
|
||||
("distilbert", ("DistilBertTokenizer", "DistilBertTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("donut", ("XLMRobertaTokenizer", "XLMRobertaTokenizerFast" if is_tokenizers_available() else None)),
|
||||
(
|
||||
"dpr",
|
||||
(
|
||||
@ -238,6 +241,7 @@ TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]](
|
||||
("ernie4_5_moe", (None, "LlamaTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("ernie_m", ("ErnieMTokenizer" if is_sentencepiece_available() else None, None)),
|
||||
("esm", ("EsmTokenizer", None)),
|
||||
("evolla", (None, "LlamaTokenizerFast" if is_tokenizers_available() else None)),
|
||||
(
|
||||
"exaone4",
|
||||
(
|
||||
@ -252,10 +256,13 @@ TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]](
|
||||
("FastSpeech2ConformerTokenizer" if is_g2p_en_available() else None, None),
|
||||
),
|
||||
("flaubert", ("FlaubertTokenizer", None)),
|
||||
("flava", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("flex_olmo", (None, "GPT2TokenizerFast" if is_tokenizers_available() else None)),
|
||||
("florence2", ("BartTokenizer", "BartTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("fnet", ("FNetTokenizer", "FNetTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("fsmt", ("FSMTTokenizer", None)),
|
||||
("funnel", ("FunnelTokenizer", "FunnelTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("fuyu", (None, "LlamaTokenizerFast" if is_tokenizers_available() else None)),
|
||||
(
|
||||
"gemma",
|
||||
(
|
||||
@ -304,6 +311,7 @@ TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]](
|
||||
("glm4_moe", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("glm4v", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("glm4v_moe", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("got_ocr2", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("gpt-sw3", ("GPTSw3Tokenizer" if is_sentencepiece_available() else None, None)),
|
||||
("gpt2", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
|
||||
("gpt_bigcode", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
|
||||
@ -314,6 +322,7 @@ TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]](
|
||||
("gptj", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
|
||||
("gptsan-japanese", ("GPTSanJapaneseTokenizer", None)),
|
||||
("granite", ("GPT2Tokenizer", None)),
|
||||
("granite_speech", ("GPT2Tokenizer", None)),
|
||||
("granitemoe", ("GPT2Tokenizer", None)),
|
||||
("granitemoehybrid", ("GPT2Tokenizer", None)),
|
||||
("granitemoeshared", ("GPT2Tokenizer", None)),
|
||||
@ -353,11 +362,14 @@ TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]](
|
||||
),
|
||||
),
|
||||
("kosmos-2.5", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("kyutai_speech_to_text", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("layoutlm", ("LayoutLMTokenizer", "LayoutLMTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("layoutlmv2", ("LayoutLMv2Tokenizer", "LayoutLMv2TokenizerFast" if is_tokenizers_available() else None)),
|
||||
("layoutlmv3", ("LayoutLMv3Tokenizer", "LayoutLMv3TokenizerFast" if is_tokenizers_available() else None)),
|
||||
("layoutxlm", ("LayoutXLMTokenizer", "LayoutXLMTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("led", ("LEDTokenizer", "LEDTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("lfm2", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("lfm2_vl", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("lilt", ("LayoutLMv3Tokenizer", "LayoutLMv3TokenizerFast" if is_tokenizers_available() else None)),
|
||||
(
|
||||
"llama",
|
||||
@ -398,6 +410,7 @@ TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]](
|
||||
("mamba", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("mamba2", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("marian", ("MarianTokenizer" if is_sentencepiece_available() else None, None)),
|
||||
("markuplm", ("MarkupLMTokenizer", "MarkupLMTokenizerFast" if is_tokenizers_available() else None)),
|
||||
(
|
||||
"mbart",
|
||||
(
|
||||
@ -484,6 +497,7 @@ TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]](
|
||||
"NllbTokenizerFast" if is_tokenizers_available() else None,
|
||||
),
|
||||
),
|
||||
("nougat", (None, "NougatTokenizerFast" if is_tokenizers_available() else None)),
|
||||
(
|
||||
"nystromformer",
|
||||
(
|
||||
@ -505,6 +519,7 @@ TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]](
|
||||
("OpenAIGPTTokenizer", "OpenAIGPTTokenizerFast" if is_tokenizers_available() else None),
|
||||
),
|
||||
("opt", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
|
||||
("ovis2", (None, "Qwen2TokenizerFast" if is_tokenizers_available() else None)),
|
||||
("owlv2", ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("owlvit", ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("paligemma", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
|
||||
@ -530,6 +545,7 @@ TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]](
|
||||
None,
|
||||
),
|
||||
),
|
||||
("perception_lm", (None, "LlamaTokenizerFast" if is_tokenizers_available() else None)),
|
||||
(
|
||||
"persimmon",
|
||||
(
|
||||
@ -539,6 +555,7 @@ TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]](
|
||||
),
|
||||
("phi", ("CodeGenTokenizer", "CodeGenTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("phi3", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("phi4_multimodal", (None, "GPT2TokenizerFast" if is_tokenizers_available() else None)),
|
||||
("phimoe", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("phobert", ("PhobertTokenizer", None)),
|
||||
("pix2struct", ("T5Tokenizer", "T5TokenizerFast" if is_tokenizers_available() else None)),
|
||||
@ -552,6 +569,7 @@ TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]](
|
||||
),
|
||||
),
|
||||
("plbart", ("PLBartTokenizer" if is_sentencepiece_available() else None, None)),
|
||||
("pop2piano", ("Pop2PianoTokenizer", None)),
|
||||
("prophetnet", ("ProphetNetTokenizer", None)),
|
||||
("qdqbert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
|
||||
(
|
||||
@ -658,6 +676,7 @@ TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]](
|
||||
),
|
||||
),
|
||||
("smollm3", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("smolvlm", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("speech_to_text", ("Speech2TextTokenizer" if is_sentencepiece_available() else None, None)),
|
||||
("speech_to_text_2", ("Speech2Text2Tokenizer", None)),
|
||||
("speecht5", ("SpeechT5Tokenizer" if is_sentencepiece_available() else None, None)),
|
||||
@ -692,6 +711,7 @@ TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]](
|
||||
("tapas", ("TapasTokenizer", None)),
|
||||
("tapex", ("TapexTokenizer", None)),
|
||||
("transfo-xl", ("TransfoXLTokenizer", None)),
|
||||
("trocr", ("XLMRobertaTokenizer", "XLMRobertaTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("tvp", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
|
||||
(
|
||||
"udop",
|
||||
@ -707,9 +727,14 @@ TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]](
|
||||
"T5TokenizerFast" if is_tokenizers_available() else None,
|
||||
),
|
||||
),
|
||||
("video_llama_3", ("Qwen2Tokenizer", "Qwen2TokenizerFast" if is_tokenizers_available() else None)),
|
||||
("video_llava", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("vilt", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("vipllava", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
|
||||
(
|
||||
"vision_text_dual_encoder",
|
||||
("PreTrainedTokenizer", "PreTrainedTokenizerFast" if is_tokenizers_available() else None),
|
||||
),
|
||||
("visual_bert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("vits", ("VitsTokenizer", None)),
|
||||
(
|
||||
@ -725,6 +750,7 @@ TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]](
|
||||
("wav2vec2-bert", ("Wav2Vec2CTCTokenizer", None)),
|
||||
("wav2vec2-conformer", ("Wav2Vec2CTCTokenizer", None)),
|
||||
("wav2vec2_phoneme", ("Wav2Vec2PhonemeCTCTokenizer", None)),
|
||||
("wav2vec2_with_lm", ("Wav2Vec2CTCTokenizer", None)),
|
||||
("whisper", ("WhisperTokenizer", "WhisperTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("xclip", ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None)),
|
||||
(
|
||||
@ -1160,7 +1186,7 @@ class AutoTokenizer:
|
||||
The configuration corresponding to the model to register.
|
||||
slow_tokenizer_class ([`PretrainedTokenizer`], *optional*):
|
||||
The slow tokenizer to register.
|
||||
fast_tokenizer_class ([`PretrainedTokenizerFast`], *optional*):
|
||||
fast_tokenizer_class ([`PreTrainedTokenizerFast`], *optional*):
|
||||
The fast tokenizer to register.
|
||||
"""
|
||||
if slow_tokenizer_class is None and fast_tokenizer_class is None:
|
||||
|
||||
@ -15,7 +15,6 @@
|
||||
"""AutoVideoProcessor class."""
|
||||
|
||||
import importlib
|
||||
import json
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
@ -23,7 +22,16 @@ from typing import TYPE_CHECKING, Optional, Union
|
||||
# Build the list of all video processors
|
||||
from ...configuration_utils import PreTrainedConfig
|
||||
from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code
|
||||
from ...utils import CONFIG_NAME, VIDEO_PROCESSOR_NAME, cached_file, is_torchvision_available, logging
|
||||
from ...utils import (
|
||||
CONFIG_NAME,
|
||||
IMAGE_PROCESSOR_NAME,
|
||||
PROCESSOR_NAME,
|
||||
VIDEO_PROCESSOR_NAME,
|
||||
cached_file,
|
||||
is_torchvision_available,
|
||||
logging,
|
||||
safe_load_json_file,
|
||||
)
|
||||
from ...utils.import_utils import requires
|
||||
from ...video_processing_utils import BaseVideoProcessor
|
||||
from .auto_factory import _LazyAutoMapping
|
||||
@ -60,6 +68,7 @@ else:
|
||||
("qwen3_vl_moe", "Qwen3VLVideoProcessor"),
|
||||
("sam2_video", "Sam2VideoVideoProcessor"),
|
||||
("smolvlm", "SmolVLMVideoProcessor"),
|
||||
("video_llama_3", "VideoLlama3VideoProcessor"),
|
||||
("video_llava", "VideoLlavaVideoProcessor"),
|
||||
("videomae", "VideoMAEVideoProcessor"),
|
||||
("vjepa2", "VJEPA2VideoProcessor"),
|
||||
@ -167,24 +176,59 @@ def get_video_processor_config(
|
||||
video_processor.save_pretrained("video-processor-test")
|
||||
video_processor = get_video_processor_config("video-processor-test")
|
||||
```"""
|
||||
resolved_config_file = cached_file(
|
||||
# Load with a priority given to the nested processor config, if available in repo
|
||||
resolved_processor_file = cached_file(
|
||||
pretrained_model_name_or_path,
|
||||
VIDEO_PROCESSOR_NAME,
|
||||
filename=PROCESSOR_NAME,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
revision=revision,
|
||||
local_files_only=local_files_only,
|
||||
_raise_exceptions_for_gated_repo=False,
|
||||
_raise_exceptions_for_missing_entries=False,
|
||||
)
|
||||
if resolved_config_file is None:
|
||||
logger.info(
|
||||
"Could not locate the video processor configuration file, will try to use the model config instead."
|
||||
resolved_video_processor_files = [
|
||||
resolved_file
|
||||
for filename in [VIDEO_PROCESSOR_NAME, IMAGE_PROCESSOR_NAME]
|
||||
if (
|
||||
resolved_file := cached_file(
|
||||
pretrained_model_name_or_path,
|
||||
filename=filename,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
revision=revision,
|
||||
local_files_only=local_files_only,
|
||||
_raise_exceptions_for_gated_repo=False,
|
||||
_raise_exceptions_for_missing_entries=False,
|
||||
_raise_exceptions_for_connection_errors=False,
|
||||
)
|
||||
)
|
||||
is not None
|
||||
]
|
||||
resolved_video_processor_file = resolved_video_processor_files[0] if resolved_video_processor_files else None
|
||||
|
||||
# An empty list if none of the possible files is found in the repo
|
||||
if not resolved_video_processor_file and not resolved_processor_file:
|
||||
logger.info("Could not locate the video processor configuration file.")
|
||||
return {}
|
||||
|
||||
with open(resolved_config_file, encoding="utf-8") as reader:
|
||||
return json.load(reader)
|
||||
# Load video_processor dict. Priority goes as (nested config if found -> video processor config -> image processor config)
|
||||
# We are downloading both configs because almost all models have a `processor_config.json` but
|
||||
# not all of these are nested. We need to check if it was saved recebtly as nested or if it is legacy style
|
||||
video_processor_dict = {}
|
||||
if resolved_processor_file is not None:
|
||||
processor_dict = safe_load_json_file(resolved_processor_file)
|
||||
if "video_processor" in processor_dict:
|
||||
video_processor_dict = processor_dict["video_processor"]
|
||||
|
||||
if resolved_video_processor_file is not None and video_processor_dict is None:
|
||||
video_processor_dict = safe_load_json_file(resolved_video_processor_file)
|
||||
|
||||
return video_processor_dict
|
||||
|
||||
|
||||
@requires(backends=("vision", "torchvision"))
|
||||
@ -291,7 +335,7 @@ class AutoVideoProcessor:
|
||||
|
||||
# Some models have different image processors, e.g. InternVL uses GotOCRImageProcessor
|
||||
# We cannot use GotOCRVideoProcessor when falling back for BC and should try to infer from config later on
|
||||
if video_processor_class_inferred in VIDEO_PROCESSOR_MAPPING_NAMES.values():
|
||||
if video_processor_class_from_name(video_processor_class_inferred) is not None:
|
||||
video_processor_class = video_processor_class_inferred
|
||||
if "AutoImageProcessor" in config_dict.get("auto_map", {}):
|
||||
image_processor_auto_map = config_dict["auto_map"]["AutoImageProcessor"]
|
||||
|
||||
@ -70,10 +70,6 @@ class AyaVisionProcessor(ProcessorMixin):
|
||||
in a chat into a tokenizable string.
|
||||
"""
|
||||
|
||||
attributes = ["image_processor", "tokenizer"]
|
||||
image_processor_class = "AutoImageProcessor"
|
||||
tokenizer_class = "AutoTokenizer"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_processor=None,
|
||||
|
||||
@ -49,9 +49,6 @@ class BarkProcessor(ProcessorMixin):
|
||||
|
||||
"""
|
||||
|
||||
tokenizer_class = "AutoTokenizer"
|
||||
attributes = ["tokenizer"]
|
||||
|
||||
preset_shape = {
|
||||
"semantic_prompt": 1, # 1D array of shape (X,)
|
||||
"coarse_prompt": 2, # 2D array of shape (2,X)
|
||||
|
||||
@ -53,10 +53,6 @@ class BlipProcessor(ProcessorMixin):
|
||||
An instance of ['BertTokenizerFast`]. The tokenizer is a required input.
|
||||
"""
|
||||
|
||||
attributes = ["image_processor", "tokenizer"]
|
||||
image_processor_class = ("BlipImageProcessor", "BlipImageProcessorFast")
|
||||
tokenizer_class = ("BertTokenizer", "BertTokenizerFast")
|
||||
|
||||
def __init__(self, image_processor, tokenizer, **kwargs):
|
||||
tokenizer.return_token_type_ids = False
|
||||
super().__init__(image_processor, tokenizer)
|
||||
|
||||
@ -60,10 +60,6 @@ class Blip2Processor(ProcessorMixin):
|
||||
Number of tokens used by the Qformer as queries, should be same as in model's config.
|
||||
"""
|
||||
|
||||
attributes = ["image_processor", "tokenizer"]
|
||||
image_processor_class = ("BlipImageProcessor", "BlipImageProcessorFast")
|
||||
tokenizer_class = "AutoTokenizer"
|
||||
|
||||
def __init__(self, image_processor, tokenizer, num_query_tokens=None, **kwargs):
|
||||
tokenizer.return_token_type_ids = False
|
||||
if not hasattr(tokenizer, "image_token"):
|
||||
|
||||
@ -54,9 +54,6 @@ class BridgeTowerProcessor(ProcessorMixin):
|
||||
An instance of ['RobertaTokenizerFast`]. The tokenizer is a required input.
|
||||
"""
|
||||
|
||||
attributes = ["image_processor", "tokenizer"]
|
||||
image_processor_class = "BridgeTowerImageProcessor"
|
||||
tokenizer_class = ("RobertaTokenizer", "RobertaTokenizerFast")
|
||||
valid_processor_kwargs = BridgeTowerProcessorKwargs
|
||||
|
||||
def __init__(self, image_processor, tokenizer):
|
||||
|
||||
@ -138,7 +138,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--push_to_hub",
|
||||
action="store_true",
|
||||
help="Whether or not to push the converted model and processor to the 🤗 hub.",
|
||||
help="Whether or not to push the converted model and processor to the Hugging Face hub.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -46,8 +46,6 @@ class BrosProcessor(ProcessorMixin):
|
||||
An instance of ['BertTokenizerFast`]. The tokenizer is a required input.
|
||||
"""
|
||||
|
||||
attributes = ["tokenizer"]
|
||||
tokenizer_class = ("BertTokenizer", "BertTokenizerFast")
|
||||
valid_processor_kwargs = BrosProcessorKwargs
|
||||
|
||||
def __init__(self, tokenizer=None, **kwargs):
|
||||
|
||||
@ -69,10 +69,6 @@ class ChameleonProcessor(ProcessorMixin):
|
||||
The special token used to indicate image in the text.
|
||||
"""
|
||||
|
||||
attributes = ["image_processor", "tokenizer"]
|
||||
tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast")
|
||||
image_processor_class = "ChameleonImageProcessor"
|
||||
|
||||
def __init__(self, image_processor, tokenizer, image_seq_length: int = 1024, image_token: str = "<image>"):
|
||||
self.image_seq_length = image_seq_length
|
||||
self.image_token = tokenizer.image_token if hasattr(tokenizer, "image_token") else image_token
|
||||
|
||||
@ -34,10 +34,6 @@ class ChineseCLIPProcessor(ProcessorMixin):
|
||||
The tokenizer is a required input.
|
||||
"""
|
||||
|
||||
attributes = ["image_processor", "tokenizer"]
|
||||
image_processor_class = ("ChineseCLIPImageProcessor", "ChineseCLIPImageProcessorFast")
|
||||
tokenizer_class = ("BertTokenizer", "BertTokenizerFast")
|
||||
|
||||
def __init__(self, image_processor=None, tokenizer=None, **kwargs):
|
||||
super().__init__(image_processor, tokenizer)
|
||||
|
||||
|
||||
@ -42,9 +42,6 @@ class ClapProcessor(ProcessorMixin):
|
||||
The tokenizer is a required input.
|
||||
"""
|
||||
|
||||
feature_extractor_class = "ClapFeatureExtractor"
|
||||
tokenizer_class = ("RobertaTokenizer", "RobertaTokenizerFast")
|
||||
|
||||
def __init__(self, feature_extractor, tokenizer):
|
||||
super().__init__(feature_extractor, tokenizer)
|
||||
|
||||
|
||||
@ -33,10 +33,6 @@ class CLIPProcessor(ProcessorMixin):
|
||||
The tokenizer is a required input.
|
||||
"""
|
||||
|
||||
attributes = ["image_processor", "tokenizer"]
|
||||
image_processor_class = ("CLIPImageProcessor", "CLIPImageProcessorFast")
|
||||
tokenizer_class = "AutoTokenizer"
|
||||
|
||||
def __init__(self, image_processor=None, tokenizer=None, **kwargs):
|
||||
super().__init__(image_processor, tokenizer)
|
||||
|
||||
|
||||
@ -257,7 +257,9 @@ if __name__ == "__main__":
|
||||
"--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub."
|
||||
"--push_to_hub",
|
||||
action="store_true",
|
||||
help="Whether or not to push the converted model to the Hugging Face hub.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -34,10 +34,6 @@ class CLIPSegProcessor(ProcessorMixin):
|
||||
The tokenizer is a required input.
|
||||
"""
|
||||
|
||||
attributes = ["image_processor", "tokenizer"]
|
||||
image_processor_class = ("ViTImageProcessor", "ViTImageProcessorFast")
|
||||
tokenizer_class = ("CLIPTokenizer", "CLIPTokenizerFast")
|
||||
|
||||
def __init__(self, image_processor=None, tokenizer=None, **kwargs):
|
||||
super().__init__(image_processor, tokenizer)
|
||||
|
||||
|
||||
@ -38,9 +38,6 @@ class ClvpProcessor(ProcessorMixin):
|
||||
An instance of [`ClvpTokenizer`]. The tokenizer is a required input.
|
||||
"""
|
||||
|
||||
feature_extractor_class = "ClvpFeatureExtractor"
|
||||
tokenizer_class = "ClvpTokenizer"
|
||||
|
||||
def __init__(self, feature_extractor, tokenizer):
|
||||
super().__init__(feature_extractor, tokenizer)
|
||||
|
||||
|
||||
@ -129,7 +129,7 @@ class Cohere2VisionCausalLMOutputWithPast(ModelOutput):
|
||||
@auto_docstring
|
||||
class Cohere2VisionPreTrainedModel(PreTrainedModel):
|
||||
config: Cohere2VisionConfig
|
||||
base_model_prefix = ""
|
||||
base_model_prefix = "model"
|
||||
input_modalities = ["image", "text"]
|
||||
supports_gradient_checkpointing = True
|
||||
_skip_keys_device_placement = "past_key_values"
|
||||
|
||||
@ -26,6 +26,7 @@ from transformers.models.aya_vision.modeling_aya_vision import (
|
||||
AyaVisionForConditionalGeneration,
|
||||
AyaVisionModel,
|
||||
AyaVisionModelOutputWithPast,
|
||||
AyaVisionPreTrainedModel,
|
||||
)
|
||||
from transformers.models.got_ocr2.image_processing_got_ocr2_fast import GotOcr2ImageProcessorFast
|
||||
|
||||
@ -89,6 +90,10 @@ class Cohere2VisionCausalLMOutputWithPast(AyaVisionCausalLMOutputWithPast):
|
||||
pass
|
||||
|
||||
|
||||
class Cohere2VisionPreTrainedModel(AyaVisionPreTrainedModel):
|
||||
base_model_prefix = "model"
|
||||
|
||||
|
||||
class Cohere2VisionModel(AyaVisionModel):
|
||||
_checkpoint_conversion_mapping = {}
|
||||
|
||||
@ -340,7 +345,7 @@ class Cohere2VisionImageProcessorFast(GotOcr2ImageProcessorFast):
|
||||
|
||||
__all__ = [
|
||||
"Cohere2VisionForConditionalGeneration",
|
||||
"Cohere2VisionPreTrainedModel", # noqa: F822
|
||||
"Cohere2VisionPreTrainedModel",
|
||||
"Cohere2VisionModel",
|
||||
"Cohere2VisionImageProcessorFast",
|
||||
]
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user