mirror of
https://github.com/huggingface/transformers.git
synced 2025-11-18 16:54:55 +08:00
Compare commits
290 Commits
lightweigh
...
fix-torcha
| Author | SHA1 | Date | |
|---|---|---|---|
| 055aef8690 | |||
| 0103627278 | |||
| 7d8df526e6 | |||
| 00b00448e7 | |||
| 3651460288 | |||
| f5a7c33dce | |||
| 3c8c7572e6 | |||
| 2f0a6aed58 | |||
| f93f35709c | |||
| 5881d8eb91 | |||
| ea5822db85 | |||
| 9fa1b7a2c4 | |||
| e033947a5c | |||
| 7b457fd04c | |||
| 09bcd2ee11 | |||
| 86a4e51647 | |||
| 5be67b96fc | |||
| 8755a4beef | |||
| db4fe31ddf | |||
| 94a53d4c66 | |||
| de74aebbc7 | |||
| 7b7c990364 | |||
| c137ea3323 | |||
| 0412832432 | |||
| bbf5b000e2 | |||
| f7d0183d2b | |||
| 2a00e493c2 | |||
| 3ffc59ef92 | |||
| d176b48973 | |||
| 76d66be5e5 | |||
| a052513335 | |||
| 3e69622256 | |||
| 44943fb87d | |||
| a0029f207b | |||
| 5c9d56cb07 | |||
| f8f0973415 | |||
| e4df75269a | |||
| e235eeddb7 | |||
| d841a04b3e | |||
| 443573aeb8 | |||
| 85ab08590a | |||
| 75d3afcb48 | |||
| 72eff97c4d | |||
| 9788014a93 | |||
| 9c0db728bd | |||
| 386e259b85 | |||
| e16da231ca | |||
| 18b02eea94 | |||
| 0e7d2d052d | |||
| dde5500d80 | |||
| 074a449f6b | |||
| 9fde9f7893 | |||
| 9a76a6eee3 | |||
| 32226787a9 | |||
| 8ff4ad56a5 | |||
| 78d46227f8 | |||
| 2fa058fe8a | |||
| f692f4bdcb | |||
| acbeeae720 | |||
| 399388d1fe | |||
| bdbc01a6a4 | |||
| d22363560f | |||
| c48e1edb49 | |||
| 0c2b667d13 | |||
| 1dabb4c334 | |||
| 0f022b59d9 | |||
| 0e51decd6d | |||
| e341529210 | |||
| f72f96d400 | |||
| cc0819540b | |||
| 84dd6eb26e | |||
| 82f94b8ae0 | |||
| 3fea865810 | |||
| e4cadfb1c2 | |||
| 6cb3794080 | |||
| 2526cc5d91 | |||
| 710b1fffcf | |||
| 07574dddd4 | |||
| a228fd0ad2 | |||
| ef8b6c3548 | |||
| b57d7897c4 | |||
| d9e7fe65c8 | |||
| 92c0229af4 | |||
| 58389a1ff0 | |||
| acc5b2452a | |||
| 5146dec408 | |||
| d91701f7ee | |||
| 8baa3fe987 | |||
| 2733ff69c4 | |||
| b8927d67ef | |||
| 8c16de161f | |||
| 57988f25a2 | |||
| c43495a51a | |||
| 912562c08a | |||
| 2ff765e9ed | |||
| e7165da04d | |||
| ead2ac3776 | |||
| 74a0e9c71b | |||
| e2aefee7fc | |||
| bd36211210 | |||
| 45271710d0 | |||
| 42fd4c4325 | |||
| db02b9d716 | |||
| 9601b82ce7 | |||
| ff108789ca | |||
| 5c54332e3b | |||
| 8936cc408f | |||
| 50714d8ca7 | |||
| c921cedee7 | |||
| dcad7030b2 | |||
| 1652c9c52f | |||
| a581fd75e7 | |||
| 89846e7d81 | |||
| 20d1b340c4 | |||
| 5e71bd4ae7 | |||
| 5794d27d1c | |||
| 0b95826c97 | |||
| 32b9273893 | |||
| 0fb23403e4 | |||
| 5d7507b16d | |||
| 8fd255c7f0 | |||
| ba3de5add4 | |||
| ba1a8b64c0 | |||
| 76b6a92d74 | |||
| f85f2397ec | |||
| 675b2bca69 | |||
| dc5a22c2af | |||
| 4f212de424 | |||
| e088408964 | |||
| d7c81717ae | |||
| da7dc100ac | |||
| 4894a25774 | |||
| 93862177d8 | |||
| 8f7b1d02bb | |||
| 8b924a3b12 | |||
| a170f290a8 | |||
| 00b95ee009 | |||
| 1c87945a3c | |||
| 02386ce7c6 | |||
| 8a8beff73e | |||
| 77ccbb17fd | |||
| ab6ee8aed4 | |||
| a8fb5540c9 | |||
| 23e3ed7489 | |||
| ce8c1c1978 | |||
| d923061e63 | |||
| 2ff85326fc | |||
| 80517f5322 | |||
| 7d78aa1b37 | |||
| 22fcdaf9c6 | |||
| f2938df853 | |||
| d1e84db344 | |||
| 4d7970991c | |||
| 6c88206d3b | |||
| 82a35bcc89 | |||
| 3baf4b7f6b | |||
| 9b6a7a445b | |||
| 85973fc9ad | |||
| c515eb6d91 | |||
| 4d34cedff5 | |||
| 9cb0432c2d | |||
| 0da6e92757 | |||
| 9022bc293e | |||
| b148577e3c | |||
| 7eda8aa764 | |||
| 606452d69e | |||
| a79de84819 | |||
| 20b6142aa7 | |||
| 29aa0515a0 | |||
| 52d85e0fb4 | |||
| e59b1fffab | |||
| 6b398e149f | |||
| 0ebb1b6219 | |||
| 7061956922 | |||
| 29e017d50a | |||
| 19f94d0f40 | |||
| e465bc0ae0 | |||
| 913171a9d8 | |||
| 07e265d10d | |||
| 3e4d8ea958 | |||
| 1d4411aa17 | |||
| e848ab6165 | |||
| 573af7594c | |||
| 2d84aba1da | |||
| 9f5ec4ac90 | |||
| ef5123b8ad | |||
| 6d0aa66327 | |||
| 7f196f9313 | |||
| b225885f58 | |||
| 904283dd1c | |||
| d34482c6a0 | |||
| e0fd1e42e3 | |||
| f4775fcac4 | |||
| 00846a2ef4 | |||
| 5d4d27e6e2 | |||
| b320474eae | |||
| 630934707d | |||
| c3c534fe67 | |||
| edf96f8451 | |||
| 00e36042a8 | |||
| 48c85c78da | |||
| 912dd2f7ba | |||
| 9bed48862c | |||
| 50a85efdcd | |||
| d9bb0e340e | |||
| fe9b047899 | |||
| 9f615bcc1c | |||
| a01ad8d63e | |||
| 8cf96946e7 | |||
| 28a1d22526 | |||
| 6c9fda4e0e | |||
| 4443658942 | |||
| 0402e564ce | |||
| 134959c142 | |||
| 17f25f9f3b | |||
| 3cde7b0606 | |||
| 22145750da | |||
| c53755fce7 | |||
| f1312dc91c | |||
| edeacc3867 | |||
| de09779953 | |||
| e1eb5a4adb | |||
| aa0ebbec82 | |||
| ac1af43293 | |||
| 653933c293 | |||
| a92cb1fe61 | |||
| ec49d7339d | |||
| 965b006613 | |||
| 9735c6e011 | |||
| a8998de322 | |||
| c3f5437233 | |||
| a5859af437 | |||
| 8e74adc4d0 | |||
| 62ccfd9b7f | |||
| 7efb487d31 | |||
| 0519e21dd3 | |||
| 6f6deb0f88 | |||
| 466df965f3 | |||
| 2fe87ce1dd | |||
| c6bb839d21 | |||
| fe220cf182 | |||
| 667133317e | |||
| 7b64815cc5 | |||
| b01dd4fd98 | |||
| 58fc7b5799 | |||
| c9417f9872 | |||
| b82c4f256f | |||
| a693417568 | |||
| b6027426f2 | |||
| b4ef14c23b | |||
| fb3422794d | |||
| fbea44e9e2 | |||
| d7d922acd5 | |||
| 30f41f2c9d | |||
| 9c9669360c | |||
| b20b69373e | |||
| afdb59ddf4 | |||
| 4b2058be0e | |||
| b2e97bf570 | |||
| f74d41f18f | |||
| 36a4b5d5ac | |||
| bde538dc0f | |||
| 7728fda7c7 | |||
| d36e62c12d | |||
| b8586194ce | |||
| 0e56676260 | |||
| d1c47d0e02 | |||
| e40427fc57 | |||
| 1aae8d97f2 | |||
| 0569ee8693 | |||
| e0da883e85 | |||
| f62bc7e0dd | |||
| bfb804756d | |||
| a08b927826 | |||
| 8ca058d64c | |||
| 01f8a7e419 | |||
| e956317273 | |||
| 213a64d4ae | |||
| f8d1f98dc1 | |||
| 9c07ead1fc | |||
| 86e48e242b | |||
| 8a3e3d43bb | |||
| 46b7632fbc | |||
| 0ff608d466 | |||
| 15ec137a1d | |||
| 993c2fbe74 | |||
| 7bb32d5f7f | |||
| 22734c5047 | |||
| 941738e5f3 | |||
| d76ebe4195 |
@ -20,4 +20,4 @@ jobs:
|
||||
contents: read
|
||||
with:
|
||||
workflow_name: ${{ inputs.workflow_name }}
|
||||
run_count: ${{ fromJSON(inputs.run_count) }}
|
||||
run_count: ${{ fromJSON(inputs.run_count) }}
|
||||
|
||||
3
.github/workflows/get-pr-info.yml
vendored
3
.github/workflows/get-pr-info.yml
vendored
@ -87,6 +87,9 @@ jobs:
|
||||
PR_FILES: ${{ steps.pr_info.outputs.files }}
|
||||
if: ${{ inputs.pr_number != '' }}
|
||||
steps:
|
||||
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
|
||||
with:
|
||||
config: ${{ vars.PERMISSIONS_CONFIG }}
|
||||
- name: Extract PR details
|
||||
id: pr_info
|
||||
uses: actions/github-script@v6
|
||||
|
||||
3
.github/workflows/get-pr-number.yml
vendored
3
.github/workflows/get-pr-number.yml
vendored
@ -13,6 +13,9 @@ jobs:
|
||||
outputs:
|
||||
PR_NUMBER: ${{ steps.set_pr_number.outputs.PR_NUMBER }}
|
||||
steps:
|
||||
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
|
||||
with:
|
||||
config: ${{ vars.PERMISSIONS_CONFIG }}
|
||||
- name: Get PR number
|
||||
shell: bash
|
||||
env:
|
||||
|
||||
@ -13,6 +13,9 @@ jobs:
|
||||
name: Notify new model
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
|
||||
with:
|
||||
config: ${{ vars.PERMISSIONS_CONFIG }}
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
12
.github/workflows/pr_build_doc_with_comment.yml
vendored
12
.github/workflows/pr_build_doc_with_comment.yml
vendored
@ -35,6 +35,9 @@ jobs:
|
||||
PR_MERGE_COMMIT_DATE: ${{ needs.get-pr-info.outputs.PR_MERGE_COMMIT_DATE }}
|
||||
PR_MERGE_COMMIT_TIMESTAMP: ${{ needs.get-pr-info.outputs.PR_MERGE_COMMIT_TIMESTAMP }}
|
||||
steps:
|
||||
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
|
||||
with:
|
||||
config: ${{ vars.PERMISSIONS_CONFIG }}
|
||||
- run: |
|
||||
COMMENT_TIMESTAMP=$(date -d "${COMMENT_DATE}" +"%s")
|
||||
echo "COMMENT_DATE: $COMMENT_DATE"
|
||||
@ -54,6 +57,9 @@ jobs:
|
||||
statuses: write
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
|
||||
with:
|
||||
config: ${{ vars.PERMISSIONS_CONFIG }}
|
||||
- name: Create Run
|
||||
id: create_run
|
||||
env:
|
||||
@ -77,6 +83,9 @@ jobs:
|
||||
pull-requests: write
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
|
||||
with:
|
||||
config: ${{ vars.PERMISSIONS_CONFIG }}
|
||||
- name: Reply to the comment
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
@ -112,6 +121,9 @@ jobs:
|
||||
GITHUB_RUN_URL: https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}
|
||||
STATUS_OK: ${{ contains(fromJSON('["skipped", "success"]'), needs.create_run.result) }}
|
||||
steps:
|
||||
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
|
||||
with:
|
||||
config: ${{ vars.PERMISSIONS_CONFIG }}
|
||||
- name: Get `build-doc` job status
|
||||
run: |
|
||||
echo "${{ needs.build-doc.result }}"
|
||||
|
||||
8
.github/workflows/pr_slow_ci_suggestion.yml
vendored
8
.github/workflows/pr_slow_ci_suggestion.yml
vendored
@ -23,6 +23,10 @@ jobs:
|
||||
outputs:
|
||||
jobs: ${{ steps.get_jobs.outputs.jobs_to_run }}
|
||||
steps:
|
||||
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
|
||||
with:
|
||||
config: ${{ vars.PERMISSIONS_CONFIG }}
|
||||
|
||||
# This checkout to the main branch
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
@ -89,6 +93,10 @@ jobs:
|
||||
pull-requests: write
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
|
||||
with:
|
||||
config: ${{ vars.PERMISSIONS_CONFIG }}
|
||||
|
||||
- name: Check and update comment if needed
|
||||
uses: actions/github-script@v7
|
||||
env:
|
||||
|
||||
4
.github/workflows/push-important-models.yml
vendored
4
.github/workflows/push-important-models.yml
vendored
@ -11,6 +11,10 @@ jobs:
|
||||
outputs:
|
||||
matrix: ${{ steps.set-matrix.outputs.matrix }}
|
||||
steps:
|
||||
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
|
||||
with:
|
||||
config: ${{ vars.PERMISSIONS_CONFIG }}
|
||||
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
|
||||
4
.github/workflows/release-conda.yml
vendored
4
.github/workflows/release-conda.yml
vendored
@ -18,6 +18,10 @@ jobs:
|
||||
shell: bash -l {0}
|
||||
|
||||
steps:
|
||||
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
|
||||
with:
|
||||
config: ${{ vars.PERMISSIONS_CONFIG }}
|
||||
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
|
||||
24
.github/workflows/self-comment-ci.yml
vendored
24
.github/workflows/self-comment-ci.yml
vendored
@ -46,6 +46,10 @@ jobs:
|
||||
PR_HEAD_SHA: ${{ needs.get-pr-info.outputs.PR_HEAD_SHA }}
|
||||
PR_MERGE_SHA: ${{ needs.get-pr-info.outputs.PR_MERGE_COMMIT_SHA }}
|
||||
steps:
|
||||
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
|
||||
with:
|
||||
config: ${{ vars.PERMISSIONS_CONFIG }}
|
||||
|
||||
- name: Verify `merge_commit` timestamp is older than the issue comment timestamp
|
||||
env:
|
||||
COMMENT_DATE: ${{ github.event.comment.created_at }}
|
||||
@ -67,6 +71,10 @@ jobs:
|
||||
models: ${{ steps.models_to_run.outputs.models }}
|
||||
quantizations: ${{ steps.models_to_run.outputs.quantizations }}
|
||||
steps:
|
||||
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
|
||||
with:
|
||||
config: ${{ vars.PERMISSIONS_CONFIG }}
|
||||
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: "0"
|
||||
@ -109,6 +117,10 @@ jobs:
|
||||
pull-requests: write
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
|
||||
with:
|
||||
config: ${{ vars.PERMISSIONS_CONFIG }}
|
||||
|
||||
- name: Reply to the comment
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
@ -131,6 +143,10 @@ jobs:
|
||||
pull-requests: write
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
|
||||
with:
|
||||
config: ${{ vars.PERMISSIONS_CONFIG }}
|
||||
|
||||
- name: Reply to the comment
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
@ -152,6 +168,10 @@ jobs:
|
||||
statuses: write
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
|
||||
with:
|
||||
config: ${{ vars.PERMISSIONS_CONFIG }}
|
||||
|
||||
- name: Create Run
|
||||
id: create_run
|
||||
env:
|
||||
@ -210,6 +230,10 @@ jobs:
|
||||
if: ${{ always() && needs.create_run.result == 'success' }}
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
|
||||
with:
|
||||
config: ${{ vars.PERMISSIONS_CONFIG }}
|
||||
|
||||
- name: Show reports from jobs
|
||||
env:
|
||||
MODEL_REPORT: ${{ needs.model-ci.outputs.report }}
|
||||
|
||||
4
.github/workflows/self-nightly-caller.yml
vendored
4
.github/workflows/self-nightly-caller.yml
vendored
@ -30,6 +30,10 @@ jobs:
|
||||
name: Setup
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
|
||||
with:
|
||||
config: ${{ vars.PERMISSIONS_CONFIG }}
|
||||
|
||||
- name: Setup
|
||||
run: |
|
||||
mkdir "setup_values"
|
||||
|
||||
@ -14,6 +14,9 @@ jobs:
|
||||
outputs:
|
||||
run_number: ${{ steps.get_number.outputs.run_number }}
|
||||
steps:
|
||||
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
|
||||
with:
|
||||
config: ${{ vars.PERMISSIONS_CONFIG }}
|
||||
- name: Get number
|
||||
id: get_number
|
||||
run: |
|
||||
|
||||
@ -10,5 +10,9 @@ jobs:
|
||||
runs-on: ubuntu-22.04
|
||||
if: ${{ always() }}
|
||||
steps:
|
||||
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
|
||||
with:
|
||||
config: ${{ vars.PERMISSIONS_CONFIG }}
|
||||
|
||||
- name: Trigger scheduled AMD CI via workflow_run
|
||||
run: echo "Trigger scheduled AMD CI via workflow_run"
|
||||
|
||||
3
.github/workflows/self-scheduled-caller.yml
vendored
3
.github/workflows/self-scheduled-caller.yml
vendored
@ -32,6 +32,9 @@ jobs:
|
||||
name: Setup
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
|
||||
with:
|
||||
config: ${{ vars.PERMISSIONS_CONFIG }}
|
||||
- name: Setup
|
||||
env:
|
||||
prev_workflow_run_id: ${{ inputs.prev_workflow_run_id || env.prev_workflow_run_id }}
|
||||
|
||||
@ -32,6 +32,10 @@ jobs:
|
||||
name: Setup
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
|
||||
with:
|
||||
config: ${{ vars.PERMISSIONS_CONFIG }}
|
||||
|
||||
- name: Setup
|
||||
run: |
|
||||
mkdir "setup_values"
|
||||
|
||||
16
.github/workflows/self-scheduled-intel-gaudi.yml
vendored
16
.github/workflows/self-scheduled-intel-gaudi.yml
vendored
@ -38,6 +38,10 @@ jobs:
|
||||
folder_slices: ${{ steps.set-matrix.outputs.folder_slices }}
|
||||
quantization_matrix: ${{ steps.set-matrix.outputs.quantization_matrix }}
|
||||
steps:
|
||||
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
|
||||
with:
|
||||
config: ${{ vars.PERMISSIONS_CONFIG }}
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
@ -122,6 +126,10 @@ jobs:
|
||||
--cap-add=sys_nice
|
||||
--shm-size=64G
|
||||
steps:
|
||||
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
|
||||
with:
|
||||
config: ${{ vars.PERMISSIONS_CONFIG }}
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
@ -191,6 +199,10 @@ jobs:
|
||||
--cap-add=sys_nice
|
||||
--shm-size=64G
|
||||
steps:
|
||||
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
|
||||
with:
|
||||
config: ${{ vars.PERMISSIONS_CONFIG }}
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
@ -263,6 +275,10 @@ jobs:
|
||||
--cap-add=sys_nice
|
||||
--shm-size=64G
|
||||
steps:
|
||||
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
|
||||
with:
|
||||
config: ${{ vars.PERMISSIONS_CONFIG }}
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
|
||||
21
.github/workflows/self-scheduled.yml
vendored
21
.github/workflows/self-scheduled.yml
vendored
@ -78,6 +78,9 @@ jobs:
|
||||
slice_ids: ${{ steps.set-matrix.outputs.slice_ids }}
|
||||
quantization_matrix: ${{ steps.set-matrix-quantization.outputs.quantization_matrix }}
|
||||
steps:
|
||||
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
|
||||
with:
|
||||
config: ${{ vars.PERMISSIONS_CONFIG }}
|
||||
- name: Update clone
|
||||
working-directory: /transformers
|
||||
env:
|
||||
@ -184,6 +187,9 @@ jobs:
|
||||
image: huggingface/transformers-all-latest-gpu
|
||||
options: --gpus all --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/
|
||||
steps:
|
||||
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
|
||||
with:
|
||||
config: ${{ vars.PERMISSIONS_CONFIG }}
|
||||
- name: Update clone
|
||||
working-directory: /transformers
|
||||
env:
|
||||
@ -256,6 +262,9 @@ jobs:
|
||||
image: huggingface/transformers-all-latest-gpu
|
||||
options: --gpus all --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/
|
||||
steps:
|
||||
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
|
||||
with:
|
||||
config: ${{ vars.PERMISSIONS_CONFIG }}
|
||||
- name: Update clone
|
||||
working-directory: /transformers
|
||||
env:
|
||||
@ -329,6 +338,9 @@ jobs:
|
||||
image: ${{ inputs.docker }}
|
||||
options: --gpus all --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/
|
||||
steps:
|
||||
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
|
||||
with:
|
||||
config: ${{ vars.PERMISSIONS_CONFIG }}
|
||||
- name: Update clone
|
||||
working-directory: ${{ inputs.working-directory-prefix }}/transformers
|
||||
env:
|
||||
@ -434,6 +446,9 @@ jobs:
|
||||
image: huggingface/transformers-quantization-latest-gpu
|
||||
options: --gpus all --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/
|
||||
steps:
|
||||
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
|
||||
with:
|
||||
config: ${{ vars.PERMISSIONS_CONFIG }}
|
||||
- name: Echo folder ${{ matrix.folders }}
|
||||
shell: bash
|
||||
env:
|
||||
@ -518,6 +533,9 @@ jobs:
|
||||
image: ${{ inputs.docker }}
|
||||
options: --gpus all --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/
|
||||
steps:
|
||||
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
|
||||
with:
|
||||
config: ${{ vars.PERMISSIONS_CONFIG }}
|
||||
- name: Update clone
|
||||
working-directory: /transformers
|
||||
env:
|
||||
@ -588,6 +606,9 @@ jobs:
|
||||
steps:
|
||||
# Checkout in order to run `utils/extract_warnings.py`. Avoid **explicit** checkout (i.e. don't specify `ref`) for
|
||||
# security reason.
|
||||
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
|
||||
with:
|
||||
config: ${{ vars.PERMISSIONS_CONFIG }}
|
||||
- name: Checkout transformers
|
||||
uses: actions/checkout@v4
|
||||
|
||||
|
||||
4
.github/workflows/slack-report.yml
vendored
4
.github/workflows/slack-report.yml
vendored
@ -38,6 +38,10 @@ jobs:
|
||||
runs-on: ubuntu-22.04
|
||||
if: always() && !cancelled()
|
||||
steps:
|
||||
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
|
||||
with:
|
||||
config: ${{ vars.PERMISSIONS_CONFIG }}
|
||||
|
||||
- name: Preliminary job status
|
||||
shell: bash
|
||||
# For the meaning of these environment variables, see the job `Setup`
|
||||
|
||||
8
.github/workflows/ssh-runner.yml
vendored
8
.github/workflows/ssh-runner.yml
vendored
@ -30,6 +30,10 @@ jobs:
|
||||
outputs:
|
||||
RUNNER: ${{ steps.set_runner.outputs.RUNNER }}
|
||||
steps:
|
||||
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
|
||||
with:
|
||||
config: ${{ vars.PERMISSIONS_CONFIG }}
|
||||
|
||||
- name: Get runner to use
|
||||
shell: bash
|
||||
env:
|
||||
@ -58,6 +62,10 @@ jobs:
|
||||
container:
|
||||
image: ${{ github.event.inputs.docker_image }}
|
||||
steps:
|
||||
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
|
||||
with:
|
||||
config: ${{ vars.PERMISSIONS_CONFIG }}
|
||||
|
||||
- name: Update clone
|
||||
working-directory: /transformers
|
||||
env:
|
||||
|
||||
27
.github/workflows/stale.yml
vendored
27
.github/workflows/stale.yml
vendored
@ -14,16 +14,21 @@ jobs:
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
|
||||
with:
|
||||
config: ${{ vars.PERMISSIONS_CONFIG }}
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: 3.8
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install requirements
|
||||
run: |
|
||||
pip install PyGithub
|
||||
- name: Close stale issues
|
||||
run: |
|
||||
python scripts/stale.py
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: 3.8
|
||||
|
||||
- name: Install requirements
|
||||
run: |
|
||||
pip install PyGithub
|
||||
- name: Close stale issues
|
||||
run: |
|
||||
python scripts/stale.py
|
||||
|
||||
4
.github/workflows/trufflehog.yml
vendored
4
.github/workflows/trufflehog.yml
vendored
@ -10,6 +10,10 @@ jobs:
|
||||
trufflehog:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
|
||||
with:
|
||||
config: ${{ vars.PERMISSIONS_CONFIG }}
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
|
||||
4
.github/workflows/update_metdata.yml
vendored
4
.github/workflows/update_metdata.yml
vendored
@ -14,6 +14,10 @@ jobs:
|
||||
shell: bash -l {0}
|
||||
|
||||
steps:
|
||||
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
|
||||
with:
|
||||
config: ${{ vars.PERMISSIONS_CONFIG }}
|
||||
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Setup environment
|
||||
|
||||
1
Makefile
1
Makefile
@ -45,6 +45,7 @@ repo-consistency:
|
||||
python utils/check_modular_conversion.py
|
||||
python utils/check_dummies.py
|
||||
python utils/check_repo.py
|
||||
python utils/check_init_weights_data.py
|
||||
python utils/check_inits.py
|
||||
python utils/check_pipeline_typing.py
|
||||
python utils/check_config_docstrings.py
|
||||
|
||||
@ -508,16 +508,16 @@ BERT `_init_weights` Methode:
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.zero_()
|
||||
module.weight.fill_(1.0)
|
||||
```
|
||||
|
||||
Sie können weitere benutzerdefinierte Schemata verwenden, wenn Sie eine spezielle Initialisierung für einige Module benötigen. Zum Beispiel in
|
||||
@ -533,9 +533,9 @@ def _init_weights(self, module):
|
||||
module.project_hid._is_hf_initialized = True
|
||||
module.project_q._is_hf_initialized = True
|
||||
elif isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
```
|
||||
|
||||
Das Flag `_is_hf_initialized` wird intern verwendet, um sicherzustellen, dass wir ein Submodul nur einmal initialisieren. Wenn Sie es auf
|
||||
|
||||
@ -314,16 +314,16 @@ Random initialization occurs in the `_init_weights` method of `BrandNewLlamaPreT
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.zero_()
|
||||
module.weight.fill_(1.0)
|
||||
```
|
||||
|
||||
The initialization scheme can look different if you need to adapt it to your model. For example, [`Wav2Vec2ForPreTraining`] initializes [nn.Linear](https://pytorch.org/docs/stable/generated/torch.nn.Linear.html) in its last two linear layers.
|
||||
@ -339,9 +339,9 @@ def _init_weights(self, module):
|
||||
module.project_hid._is_hf_initialized = True
|
||||
module.project_q._is_hf_initialized = True
|
||||
elif isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
```
|
||||
|
||||
### Convert checkpoints to Transformers
|
||||
|
||||
@ -159,7 +159,7 @@ conversation3 = [
|
||||
|
||||
conversations = [conversation1, conversation2, conversation3]
|
||||
inputs = processor.apply_chat_template(
|
||||
conversations,
|
||||
conversation,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
|
||||
@ -149,7 +149,7 @@ The example below packs `up_proj` and `gate_proj` into a single `gate_up_proj` m
|
||||
```python
|
||||
class Llama4TextExperts(nn.Module):
|
||||
...
|
||||
self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, 2 * self.expert_dim))
|
||||
self.gate_up_proj = nn.Parameter(torch.zeros(self.num_experts, self.hidden_size, 2 * self.expert_dim))
|
||||
```
|
||||
|
||||
Batch matrix multiplication can be used in the `forward` pass to compute the output of the `gate_up_proj` module.
|
||||
|
||||
@ -406,16 +406,16 @@ model = BrandNewBertModel(BrandNewBertConfig())
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.zero_()
|
||||
module.weight.fill_(1.0)
|
||||
```
|
||||
|
||||
特定のモジュールに特別な初期化が必要な場合、カスタムスキームをさらに持つことができます。たとえば、
|
||||
@ -431,9 +431,9 @@ def _init_weights(self, module):
|
||||
module.project_hid._is_hf_initialized = True
|
||||
module.project_q._is_hf_initialized = True
|
||||
elif isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
```
|
||||
|
||||
`_is_hf_initialized`フラグは、サブモジュールを一度だけ初期化することを確実にするために内部で使用されます。
|
||||
|
||||
@ -348,16 +348,16 @@ model = BrandNewBertModel(BrandNewBertConfig())
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.zero_()
|
||||
module.weight.fill_(1.0)
|
||||
```
|
||||
|
||||
몇 가지 모듈에 대해 특별한 초기화가 필요한 경우 사용자 정의 방식을 사용할 수도 있습니다. 예를 들어, `Wav2Vec2ForPreTraining`에서 마지막 두 개의 선형 레이어는 일반적인 PyTorch `nn.Linear`의 초기화를 가져야 하지만, 다른 모든 레이어는 위와 같은 초기화를 사용해야 합니다. 이는 다음과 같이 코드화됩니다:
|
||||
@ -371,9 +371,9 @@ def _init_weights(self, module):
|
||||
module.project_hid._is_hf_initialized = True
|
||||
module.project_q._is_hf_initialized = True
|
||||
elif isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
```
|
||||
|
||||
`_is_hf_initialized` 플래그는 서브모듈을 한 번만 초기화하도록 내부적으로 사용됩니다. `module.project_q` 및 `module.project_hid`에 대해 `True`로 설정함으로써, 우리가 수행한 사용자 정의 초기화가 이후에 덮어쓰이지 않도록 합니다. 즉, `_init_weights` 함수가 이들에게 적용되지 않습니다.
|
||||
|
||||
@ -152,7 +152,7 @@ class ParallelInterface(MutableMapping):
|
||||
```python
|
||||
class Llama4TextExperts(nn.Module):
|
||||
...
|
||||
self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, 2 * self.expert_dim))
|
||||
self.gate_up_proj = nn.Parameter(torch.zeros(self.num_experts, self.hidden_size, 2 * self.expert_dim))
|
||||
```
|
||||
|
||||
배치 행렬 곱셈을 `forward` 패스에서 사용하여 `gate_up_proj` 모듈의 출력을 계산할 수 있습니다.
|
||||
|
||||
@ -502,16 +502,10 @@ class DummyBertLMPredictionHead(nn.Module):
|
||||
|
||||
# The output weights are the same as the input embeddings, but there is
|
||||
# an output-only bias for each token.
|
||||
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=True)
|
||||
|
||||
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
||||
|
||||
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
||||
self.decoder.bias = self.bias
|
||||
|
||||
def _tie_weights(self):
|
||||
self.decoder.bias = self.bias
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.transform(hidden_states)
|
||||
hidden_states = self.decoder(hidden_states)
|
||||
@ -536,18 +530,18 @@ class DummyBertPreTrainedModel(PreTrainedModel):
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.zero_()
|
||||
module.weight.fill_(1.0)
|
||||
elif isinstance(module, DummyBertLMPredictionHead):
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
|
||||
|
||||
@auto_docstring(
|
||||
|
||||
@ -265,7 +265,7 @@ class MyNewModel2PreTrainedModel(PreTrainedModel):
|
||||
|
||||
# We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight)
|
||||
if "RMSNorm" in module.__class__.__name__:
|
||||
module.weight.data.zero_()
|
||||
module.weight.zero_()
|
||||
|
||||
|
||||
class MyNewModel2ForSequenceClassification(GenericForSequenceClassification, MyNewModel2PreTrainedModel):
|
||||
|
||||
@ -104,9 +104,9 @@ class NewTaskModelPreTrainedModel(PreTrainedModel):
|
||||
std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
|
||||
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
module.weight.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
|
||||
|
||||
def token_type_ids_mask_function(
|
||||
@ -428,7 +428,7 @@ class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin):
|
||||
"^multi_modal_projector": "model.multi_modal_projector",
|
||||
"^language_model.lm_head": "lm_head",
|
||||
}
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"}
|
||||
main_input_name: ClassVar[str] = "doc_input_ids" # transformers-related
|
||||
|
||||
def __init__(self, config):
|
||||
@ -440,7 +440,15 @@ class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin):
|
||||
self.custom_text_proj = nn.Linear(self.config.text_config.hidden_size, self.embedding_dim)
|
||||
|
||||
if self.language_model._tied_weights_keys is not None:
|
||||
self._tied_weights_keys = [f"model.language_model.{k}" for k in self.language_model._tied_weights_keys]
|
||||
prefix = "model.language_model."
|
||||
prefixed_mapping = {
|
||||
f"{prefix}{target}": f"{prefix}{source}"
|
||||
for target, source in self.language_model._tied_weights_keys.items()
|
||||
}
|
||||
if isinstance(self._tied_weights_keys, dict):
|
||||
self._tied_weights_keys.update(prefixed_mapping)
|
||||
else:
|
||||
self._tied_weights_keys = prefixed_mapping
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
|
||||
@ -505,16 +505,10 @@ class RobertaLMPredictionHead(nn.Module):
|
||||
|
||||
# The output weights are the same as the input embeddings, but there is
|
||||
# an output-only bias for each token.
|
||||
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=True)
|
||||
|
||||
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
||||
|
||||
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
||||
self.decoder.bias = self.bias
|
||||
|
||||
def _tie_weights(self):
|
||||
self.decoder.bias = self.bias
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.transform(hidden_states)
|
||||
hidden_states = self.decoder(hidden_states)
|
||||
@ -539,18 +533,18 @@ class RobertaPreTrainedModel(PreTrainedModel):
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.zero_()
|
||||
module.weight.fill_(1.0)
|
||||
elif isinstance(module, RobertaLMPredictionHead):
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
|
||||
|
||||
@auto_docstring(
|
||||
|
||||
@ -846,11 +846,11 @@ class TestDetrPreTrainedModel(PreTrainedModel):
|
||||
nn.init.xavier_uniform_(module.output_proj.weight.data)
|
||||
nn.init.constant_(module.output_proj.bias.data, 0.0)
|
||||
elif isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
module.weight.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
module.weight.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
if hasattr(module, "reference_points") and not self.config.two_stage:
|
||||
|
||||
@ -19,7 +19,15 @@ class NewTaskModelForNewTask(PaliGemmaForConditionalGeneration):
|
||||
self.custom_text_proj = nn.Linear(self.config.text_config.hidden_size, self.embedding_dim)
|
||||
|
||||
if self.language_model._tied_weights_keys is not None:
|
||||
self._tied_weights_keys = [f"model.language_model.{k}" for k in self.language_model._tied_weights_keys]
|
||||
prefix = "model.language_model."
|
||||
prefixed_mapping = {
|
||||
f"{prefix}{target}": f"{prefix}{source}"
|
||||
for target, source in self.language_model._tied_weights_keys.items()
|
||||
}
|
||||
if isinstance(self._tied_weights_keys, dict):
|
||||
self._tied_weights_keys.update(prefixed_mapping)
|
||||
else:
|
||||
self._tied_weights_keys = prefixed_mapping
|
||||
|
||||
self.post_init()
|
||||
|
||||
|
||||
@ -18,7 +18,7 @@ import platform
|
||||
import re
|
||||
import string
|
||||
import time
|
||||
from collections.abc import AsyncIterator, Callable
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Annotated, Optional
|
||||
|
||||
import click
|
||||
@ -98,39 +98,55 @@ If you're a new user, check this basic flag guide: https://huggingface.co/docs/t
|
||||
|
||||
|
||||
class RichInterface:
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
user_id: str,
|
||||
token_processors: Optional[list[Callable[[str], str]]] = None,
|
||||
sequence_processor: Optional[list[Callable[[str], str]]] = None,
|
||||
):
|
||||
def __init__(self, model_id: str, user_id: str):
|
||||
self._console = Console()
|
||||
self.model_id = model_id
|
||||
self.user_id = user_id
|
||||
|
||||
token_processors = token_processors or []
|
||||
sequence_processor = sequence_processor or []
|
||||
|
||||
self.token_processors = [self._special_token_processor, *token_processors]
|
||||
self.sequence_processor = [self._code_blocks_sequence_processor, *sequence_processor]
|
||||
|
||||
async def stream_output(self, stream: AsyncIterator[ChatCompletionStreamOutput]) -> tuple[str, int]:
|
||||
self._console.print(f"[bold blue]<{self.model_id}>:")
|
||||
with Live(console=self._console, refresh_per_second=4) as live:
|
||||
text = ""
|
||||
async for token in await stream:
|
||||
outputs = token.choices[0].delta.content
|
||||
|
||||
if not outputs:
|
||||
continue
|
||||
|
||||
text += outputs
|
||||
sequence = self._process_sequence(text)
|
||||
# Escapes single words encased in <>, e.g. <think> -> \<think\>, for proper rendering in Markdown.
|
||||
# It only escapes single words that may have `_`, optionally following a `/` (e.g. </think>)
|
||||
outputs = re.sub(r"<(/*)(\w*)>", r"\<\1\2\>", outputs)
|
||||
|
||||
markdown = Markdown(sequence, code_theme="github-dark")
|
||||
text += outputs
|
||||
# Render the accumulated text as Markdown
|
||||
# NOTE: this is a workaround for the rendering "unstandard markdown"
|
||||
# in rich. The chatbots output treat "\n" as a new line for
|
||||
# better compatibility with real-world text. However, rendering
|
||||
# in markdown would break the format. It is because standard markdown
|
||||
# treat a single "\n" in normal text as a space.
|
||||
# Our workaround is adding two spaces at the end of each line.
|
||||
# This is not a perfect solution, as it would
|
||||
# introduce trailing spaces (only) in code block, but it works well
|
||||
# especially for console output, because in general the console does not
|
||||
# care about trailing spaces.
|
||||
|
||||
lines = []
|
||||
for line in text.splitlines():
|
||||
lines.append(line)
|
||||
if line.startswith("```"):
|
||||
# Code block marker - do not add trailing spaces, as it would
|
||||
# break the syntax highlighting
|
||||
lines.append("\n")
|
||||
else:
|
||||
lines.append(" \n")
|
||||
|
||||
markdown = Markdown("".join(lines).strip(), code_theme="github-dark")
|
||||
|
||||
# Update the Live console output
|
||||
live.update(markdown, refresh=True)
|
||||
|
||||
self._console.print()
|
||||
|
||||
return text
|
||||
|
||||
def input(self) -> str:
|
||||
@ -164,46 +180,6 @@ class RichInterface:
|
||||
self._console.print(f"[bold blue]{config}")
|
||||
self._console.print()
|
||||
|
||||
def _special_token_processor(self, token):
|
||||
# Escapes single words encased in <>, e.g. <think> -> \<think\>, for proper rendering in Markdown.
|
||||
# It only escapes single words that may have `_`, optionally following a `/` (e.g. </think>)
|
||||
return re.sub(r"<(/*)(\w*)>", r"\<\1\2\>", token)
|
||||
|
||||
def _code_blocks_sequence_processor(self, sequence):
|
||||
# Render the accumulated text as Markdown
|
||||
# NOTE: this is a workaround for the rendering "unstandard markdown"
|
||||
# in rich. The chatbots output treat "\n" as a new line for
|
||||
# better compatibility with real-world text. However, rendering
|
||||
# in markdown would break the format. It is because standard markdown
|
||||
# treat a single "\n" in normal text as a space.
|
||||
# Our workaround is adding two spaces at the end of each line.
|
||||
# This is not a perfect solution, as it would
|
||||
# introduce trailing spaces (only) in code block, but it works well
|
||||
# especially for console output, because in general the console does not
|
||||
# care about trailing spaces.
|
||||
|
||||
lines = []
|
||||
for line in sequence.splitlines():
|
||||
lines.append(line)
|
||||
if line.startswith("```"):
|
||||
# Code block marker - do not add trailing spaces, as it would
|
||||
# break the syntax highlighting
|
||||
lines.append("\n")
|
||||
else:
|
||||
lines.append(" \n")
|
||||
|
||||
return "".join(lines).strip()
|
||||
|
||||
def _process_token(self, token):
|
||||
for token_processor in self.token_processors:
|
||||
token = token_processor(token)
|
||||
return token
|
||||
|
||||
def _process_sequence(self, sequence):
|
||||
for sequence_processor in self.sequence_processor:
|
||||
sequence = sequence_processor(sequence)
|
||||
return sequence
|
||||
|
||||
|
||||
class ChatCommand(typer.core.TyperCommand):
|
||||
"""Custom Click command to override missing parameter error message.
|
||||
@ -264,8 +240,6 @@ class Chat:
|
||||
help="Path to a local generation config file or to a HuggingFace repo containing a `generation_config.json` file. Other generation settings passed as CLI arguments will be applied on top of this generation config."
|
||||
),
|
||||
] = None,
|
||||
token_processors: Optional[list[Callable[[str], str]]] = None,
|
||||
sequence_processors: Optional[list[Callable[[str], str]]] = None,
|
||||
) -> None:
|
||||
"""Chat with a model from the command line."""
|
||||
self.base_url = base_url
|
||||
@ -273,9 +247,6 @@ class Chat:
|
||||
self.system_prompt = system_prompt
|
||||
self.save_folder = save_folder
|
||||
|
||||
self.token_processors = token_processors or []
|
||||
self.sequence_processors = sequence_processors or []
|
||||
|
||||
# Generation settings
|
||||
config = load_generation_config(generation_config)
|
||||
config.update(do_sample=True, max_new_tokens=256) # some default values
|
||||
@ -381,12 +352,7 @@ class Chat:
|
||||
# Main logic
|
||||
|
||||
async def _inner_run(self):
|
||||
interface = RichInterface(
|
||||
model_id=self.model_id,
|
||||
user_id=self.user,
|
||||
sequence_processor=self.sequence_processors,
|
||||
token_processors=self.token_processors,
|
||||
)
|
||||
interface = RichInterface(model_id=self.model_id, user_id=self.user)
|
||||
interface.clear()
|
||||
chat = new_chat_history(self.system_prompt)
|
||||
|
||||
|
||||
@ -14,7 +14,9 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import copy
|
||||
import datetime
|
||||
import enum
|
||||
import functools
|
||||
import gc
|
||||
import io
|
||||
import json
|
||||
@ -30,7 +32,8 @@ from threading import Thread
|
||||
from typing import Annotated, Optional, TypedDict, Union
|
||||
|
||||
import typer
|
||||
from huggingface_hub import scan_cache_dir
|
||||
from huggingface_hub import model_info
|
||||
from huggingface_hub.constants import HF_HUB_OFFLINE
|
||||
from openai.types.chat.chat_completion import Choice
|
||||
from tokenizers.decoders import DecodeStream
|
||||
|
||||
@ -717,45 +720,51 @@ class Serve:
|
||||
"""
|
||||
return f"data: {chunk.model_dump_json(exclude_none=True)}\n\n"
|
||||
|
||||
@staticmethod
|
||||
def get_gen_models(cache_dir: Optional[str] = None) -> list[dict[str, any]]:
|
||||
@functools.cache
|
||||
def get_gen_models(self) -> list[dict[str, any]]:
|
||||
"""
|
||||
List generative models in the cache.
|
||||
This is by no means a limit to which models may be instantiated with `transformers serve`: any chat-based
|
||||
model working with generate can work.
|
||||
|
||||
This is a limited list of models to ensure we have a discoverable /v1/models endpoint for third-party
|
||||
integrations.
|
||||
"""
|
||||
generative_models = []
|
||||
models = [
|
||||
"Menlo/Jan-nano",
|
||||
"Menlo/Jan-nano-128k",
|
||||
"Qwen/Qwen2.5-0.5B-Instruct",
|
||||
"Qwen/Qwen2.5-3B-Instruct",
|
||||
"Qwen/Qwen2.5-7B-Instruct",
|
||||
"Qwen/Qwen2.5-14B-Instruct",
|
||||
"meta-llama/Llama-3.1-8B-Instruct",
|
||||
"meta-llama/Llama-3.2-1B-Instruct",
|
||||
"meta-llama/Llama-3.3-70B-Instruct",
|
||||
"HuggingFaceTB/SmolVLM-Instruct",
|
||||
"ibm-granite/granite-vision-3.2-2b",
|
||||
"Qwen/Qwen2.5-VL-7B-Instruct",
|
||||
]
|
||||
|
||||
for repo in scan_cache_dir(cache_dir).repos:
|
||||
if repo.repo_type != "model":
|
||||
continue
|
||||
|
||||
refs = repo.refs
|
||||
for ref, revision_info in refs.items():
|
||||
files = revision_info.files
|
||||
config_path = next((f.file_path for f in files if f.file_name == "config.json"), None)
|
||||
|
||||
if not config_path:
|
||||
continue
|
||||
|
||||
config = json.loads(config_path.open().read())
|
||||
|
||||
if "architectures" not in config:
|
||||
continue
|
||||
|
||||
architectures = config["architectures"]
|
||||
if any(arch for arch in architectures if arch in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values()):
|
||||
print(repo.repo_id, ref)
|
||||
author = repo.repo_id.split("/") if "/" in repo.repo_id else ""
|
||||
repo_id = repo.repo_id + (f"@{ref}" if ref != "main" else "")
|
||||
generative_models.append(
|
||||
{
|
||||
"owned_by": author,
|
||||
"id": repo_id,
|
||||
"object": "model",
|
||||
"created": repo.last_modified,
|
||||
}
|
||||
)
|
||||
|
||||
return generative_models
|
||||
if HF_HUB_OFFLINE:
|
||||
return [
|
||||
{
|
||||
"id": model,
|
||||
"object": "model",
|
||||
"created": datetime.datetime.now().timestamp(),
|
||||
"owned_by": model.split("/")[0],
|
||||
}
|
||||
for model in models
|
||||
]
|
||||
else:
|
||||
model_infos = [model_info(model) for model in models]
|
||||
return [
|
||||
{
|
||||
"id": model.id,
|
||||
"object": "model",
|
||||
"created": model.created_at.timestamp(),
|
||||
"owned_by": model.author,
|
||||
}
|
||||
for model in model_infos
|
||||
]
|
||||
|
||||
def continuous_batching_chat_completion(self, req: dict, request_id: str) -> StreamingResponse | JSONResponse:
|
||||
"""
|
||||
|
||||
@ -876,7 +876,7 @@ class PreTrainedConfig(PushToHubMixin):
|
||||
if hasattr(self, "quantization_config"):
|
||||
serializable_config_dict["quantization_config"] = (
|
||||
self.quantization_config.to_dict()
|
||||
if not isinstance(self.quantization_config, dict)
|
||||
if not isinstance(self.quantization_config, dict) and self.quantization_config is not None
|
||||
else self.quantization_config
|
||||
)
|
||||
self.dict_dtype_to_str(serializable_config_dict)
|
||||
@ -910,7 +910,7 @@ class PreTrainedConfig(PushToHubMixin):
|
||||
if hasattr(self, "quantization_config"):
|
||||
output["quantization_config"] = (
|
||||
self.quantization_config.to_dict()
|
||||
if not isinstance(self.quantization_config, dict)
|
||||
if not isinstance(self.quantization_config, dict) and self.quantization_config is not None
|
||||
else self.quantization_config
|
||||
)
|
||||
self.dict_dtype_to_str(output)
|
||||
|
||||
141
src/transformers/conversion_mapping.py
Normal file
141
src/transformers/conversion_mapping.py
Normal file
@ -0,0 +1,141 @@
|
||||
# coding=utf-8
|
||||
# Copyright (C) 2025 the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .core_model_loading import Concatenate, MergeModulelist, WeightConverter
|
||||
from .utils import is_torch_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
def _build_checkpoint_conversion_mapping():
|
||||
mapping = {
|
||||
"mixtral": [
|
||||
WeightConverter(
|
||||
source_keys=[
|
||||
"block_sparse_moe.experts.*.w1.weight",
|
||||
"block_sparse_moe.experts.*.w3.weight",
|
||||
], # you give me a list of 2 keys, I collect a list of a list of tensors
|
||||
target_keys="mlp.experts.gate_up_proj", # target key gets the list of two tensors
|
||||
operations=[
|
||||
MergeModulelist(
|
||||
dim=0
|
||||
), # each process has two lists of tensors, we cat each list. -> we end up with 2 tensors
|
||||
Concatenate(dim=1), # each process has 2 tensors, gate and up, we concat them into gate_up
|
||||
], # we want the loading to add this shard operation here. Though we can't shard after concats and merge, needs to be first
|
||||
),
|
||||
WeightConverter(
|
||||
source_keys=[
|
||||
"block_sparse_moe.experts.*.w2.weight",
|
||||
],
|
||||
target_keys="mlp.experts.down_proj", # target key gets the list of two tensors
|
||||
operations=[
|
||||
MergeModulelist(
|
||||
dim=0
|
||||
), # each process has two lists of tensors, we cat each list. -> we end up with 2 tensors
|
||||
], # we want the loading to add this shard operation here. Though we can't shard after concats and merge, needs to be first
|
||||
),
|
||||
# WeightConverter(
|
||||
# ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"],
|
||||
# "self_attn.qkv_proj",
|
||||
# operations=[Concatenate(dim=0)], # more like stack?
|
||||
# ),
|
||||
WeightConverter("*.block_sparse_moe.", "*.mlp."),
|
||||
],
|
||||
"qwen2_moe": [
|
||||
WeightConverter(
|
||||
source_keys=[
|
||||
"mlp.experts.*.gate_proj.weight",
|
||||
"mlp.experts.*.up_proj.weight",
|
||||
],
|
||||
target_keys="mlp.experts.gate_up_proj",
|
||||
operations=[MergeModulelist(dim=0), Concatenate(dim=1)],
|
||||
),
|
||||
WeightConverter(
|
||||
source_keys=["mlp.experts.*.down_proj.weight"],
|
||||
target_keys="mlp.experts.down_proj",
|
||||
operations=[MergeModulelist(dim=0)],
|
||||
),
|
||||
],
|
||||
"legacy": [
|
||||
WeightConverter(
|
||||
source_keys="LayerNorm.gamma",
|
||||
target_keys="LayerNorm.weight",
|
||||
),
|
||||
WeightConverter(
|
||||
source_keys="LayerNorm.beta",
|
||||
target_keys="LayerNorm.bias",
|
||||
),
|
||||
],
|
||||
}
|
||||
if hasattr(torch.nn.utils.parametrizations, "weight_norm"):
|
||||
mapping["legacy"] += [
|
||||
WeightConverter(
|
||||
source_keys="weight_g",
|
||||
target_keys="parametrizations.weight.original0",
|
||||
),
|
||||
WeightConverter(
|
||||
source_keys="weight_v",
|
||||
target_keys="parametrizations.weight.original1",
|
||||
),
|
||||
]
|
||||
else:
|
||||
mapping["legacy"] += [
|
||||
WeightConverter(
|
||||
source_keys="parametrizations.weight.original0",
|
||||
target_keys="weight_g",
|
||||
),
|
||||
WeightConverter(
|
||||
source_keys="parametrizations.weight.original1",
|
||||
target_keys="weight_v",
|
||||
),
|
||||
]
|
||||
|
||||
mapping["phimoe"] = mapping["mixtral"].copy()
|
||||
mapping["deepseek_v2"] = mapping["qwen2_moe"].copy()
|
||||
mapping["deepseek_v3"] = mapping["qwen2_moe"].copy()
|
||||
mapping["dot1"] = mapping["qwen2_moe"].copy()
|
||||
mapping["ernie_4_5_moe"] = mapping["qwen2_moe"].copy()
|
||||
mapping["glm4_moe"] = mapping["qwen2_moe"].copy()
|
||||
mapping["glm4v_moe"] = mapping["qwen2_moe"].copy()
|
||||
mapping["jamba"] = mapping["qwen2_moe"].copy()
|
||||
mapping["lfm2_moe"] = mapping["mixtral"].copy()
|
||||
mapping["long_cat_flash"] = mapping["qwen2_moe"].copy()
|
||||
mapping["qwen3_moe"] = mapping["qwen2_moe"].copy()
|
||||
mapping["qwen3_omni_moe"] = mapping["qwen2_moe"].copy()
|
||||
mapping["qwen3_next"] = mapping["qwen2_moe"].copy()
|
||||
mapping["qwen3_vl_moe"] = mapping["qwen2_moe"].copy()
|
||||
mapping["hunyuan_v1_moe"] = mapping["qwen2_moe"].copy()
|
||||
mapping["minimax"] = mapping["mixtral"].copy()
|
||||
|
||||
return mapping
|
||||
|
||||
|
||||
_checkpoint_conversion_mapping_cache = None
|
||||
|
||||
|
||||
def get_checkpoint_conversion_mapping():
|
||||
global _checkpoint_conversion_mapping_cache
|
||||
if _checkpoint_conversion_mapping_cache is None:
|
||||
_checkpoint_conversion_mapping_cache = _build_checkpoint_conversion_mapping()
|
||||
globals()["_checkpoint_conversion_mapping"] = _checkpoint_conversion_mapping_cache
|
||||
return _checkpoint_conversion_mapping_cache
|
||||
|
||||
|
||||
def __getattr__(name):
|
||||
if name == "_checkpoint_conversion_mapping":
|
||||
return get_checkpoint_conversion_mapping()
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
761
src/transformers/core_model_loading.py
Normal file
761
src/transformers/core_model_loading.py
Normal file
@ -0,0 +1,761 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Core helpers for loading model checkpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import itertools
|
||||
import os
|
||||
import re
|
||||
from abc import abstractmethod
|
||||
from collections import defaultdict
|
||||
from collections.abc import MutableMapping, MutableSet, Sequence
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from functools import partial
|
||||
from types import MethodType
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from .integrations.tensor_parallel import ALL_PARALLEL_STYLES, TensorParallelLayer, DTensor, Replicate
|
||||
from .quantizers import HfQuantizer
|
||||
from .utils import is_torch_greater_or_equal, logging
|
||||
from .utils.quantization_config import QuantizationMethod
|
||||
|
||||
|
||||
_torch_distributed_available = torch.distributed.is_available()
|
||||
_is_dtensor_available = _torch_distributed_available and is_torch_greater_or_equal("2.5")
|
||||
if _is_dtensor_available:
|
||||
from torch.distributed.tensor import DTensor
|
||||
|
||||
|
||||
import itertools
|
||||
import os
|
||||
import re
|
||||
from abc import abstractmethod
|
||||
from collections import defaultdict
|
||||
from collections.abc import MutableMapping, MutableSet, Sequence
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from functools import partial
|
||||
from types import MethodType
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from .integrations.tensor_parallel import ALL_PARALLEL_STYLES, TensorParallelLayer
|
||||
from .quantizers import HfQuantizer
|
||||
from .utils import is_torch_greater_or_equal, logging
|
||||
from .utils.quantization_config import QuantizationMethod
|
||||
|
||||
|
||||
_torch_distributed_available = torch.distributed.is_available()
|
||||
_is_dtensor_available = _torch_distributed_available and is_torch_greater_or_equal("2.5")
|
||||
if _is_dtensor_available:
|
||||
from torch.distributed.tensor import DTensor
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
str_to_torch_dtype = {
|
||||
"BOOL": torch.bool,
|
||||
"U8": torch.uint8,
|
||||
"I8": torch.int8,
|
||||
"I16": torch.int16,
|
||||
"F16": torch.float16,
|
||||
"BF16": torch.bfloat16,
|
||||
"I32": torch.int32,
|
||||
"F32": torch.float32,
|
||||
"F64": torch.float64,
|
||||
"I64": torch.int64,
|
||||
"F8_E4M3": torch.float8_e4m3fn,
|
||||
"F8_E5M2": torch.float8_e5m2,
|
||||
}
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
str_to_torch_dtype = {
|
||||
"BOOL": torch.bool,
|
||||
"U8": torch.uint8,
|
||||
"I8": torch.int8,
|
||||
"I16": torch.int16,
|
||||
"F16": torch.float16,
|
||||
"BF16": torch.bfloat16,
|
||||
"I32": torch.int32,
|
||||
"F32": torch.float32,
|
||||
"F64": torch.float64,
|
||||
"I64": torch.int64,
|
||||
"F8_E4M3": torch.float8_e4m3fn,
|
||||
"F8_E5M2": torch.float8_e5m2,
|
||||
}
|
||||
|
||||
|
||||
def _glob_to_regex_src(glob: str, *, digits_only: bool = True) -> str:
|
||||
"""
|
||||
Convert a glob with '*' into a regex *source* string. We don't use `glob.translate`
|
||||
'*' matches (\\d+) if digits_only else (.+). Inner groups are non-capturing.
|
||||
"""
|
||||
star = r"(\d+)" if digits_only else r"(.+)"
|
||||
return re.escape(glob).replace(r"\*", star)
|
||||
|
||||
|
||||
def build_glob_alt(
|
||||
globs: list[str],
|
||||
) -> tuple[re.Pattern, dict[str, str]]:
|
||||
r"""
|
||||
Build one compiled regex alternation with a named group per glob. This allows to run a single
|
||||
re.match and get the correct group name to finally get which pattern matched.
|
||||
Returns (compiled_regex, name->glob map).
|
||||
|
||||
Example:
|
||||
|
||||
```py
|
||||
>>> reg, map_ = build_glob_alt(["mlp.*.w1", "mlp.*.w2"])
|
||||
>>> print(reg)
|
||||
(re.compile(r'(?P<g0>.*mlp\.(\d+)\.w1)|(?P<g1>.*mlp\.(\d+)\.w2)', re.UNICODE),
|
||||
>>> print(map_)
|
||||
{'g0': 'mlp.*.w1', 'g1': 'mlp.*.w2'})
|
||||
>>> match_ = reg.match("model.layers.0.mlp.0.w1.weight")
|
||||
>>> print(match_.lastgroup)
|
||||
'g0'
|
||||
>>> print(map_[match_.lastgroup])
|
||||
mlp.*.w1
|
||||
```
|
||||
"""
|
||||
name_map: dict[str, str] = {}
|
||||
parts: list[str] = []
|
||||
prefix_src = r".*"
|
||||
|
||||
for i, g in enumerate(globs):
|
||||
name = f"g{i}"
|
||||
name_map[name] = g
|
||||
pat_src = _glob_to_regex_src(g)
|
||||
parts.append(f"(?P<{name}>{prefix_src}{pat_src})")
|
||||
|
||||
alt_src = "|".join(parts)
|
||||
return re.compile(alt_src), name_map
|
||||
|
||||
|
||||
def match_glob(key: str, alt: re.Pattern, name_map: dict[str, str]) -> Optional[str]:
|
||||
"""
|
||||
Match the key against the alternation; return the original glob string that matched.
|
||||
"""
|
||||
m = alt.match(key)
|
||||
if not m:
|
||||
return None
|
||||
return name_map.get(m.lastgroup)
|
||||
|
||||
|
||||
class ConversionOps:
|
||||
"""Base class for weight conversion operations."""
|
||||
|
||||
# The inverse operation class, will be used when saving the checkpoint
|
||||
reverse_op: type[ConversionOps]
|
||||
|
||||
@abstractmethod
|
||||
def convert(
|
||||
self, value: Union[dict[str, torch.Tensor], Sequence[torch.Tensor], torch.Tensor], *args, **kwargs
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class Chunk(ConversionOps):
|
||||
"""Split a tensor along ``dim`` into equally sized chunks or using explicit ``sizes``."""
|
||||
|
||||
reverse_op: type[ConversionOps]
|
||||
|
||||
def __init__(self, dim: int = 0, chunks: Optional[int] = None, sizes: Optional[Sequence[int]] = None):
|
||||
if chunks is None and sizes is None:
|
||||
raise ValueError("`chunks` or `sizes` must be provided for Chunk operations.")
|
||||
if chunks is not None and chunks <= 0:
|
||||
raise ValueError("`chunks` must be a strictly positive integer.")
|
||||
self.dim = dim
|
||||
self.chunks = chunks
|
||||
self.sizes = list(sizes) if sizes is not None else None
|
||||
self.reverse_op = Concatenate
|
||||
|
||||
def convert(self, value: torch.Tensor, *args, **kwargs) -> list[torch.Tensor]:
|
||||
if not isinstance(value, torch.Tensor):
|
||||
raise TypeError("Chunk expects a torch.Tensor as input.")
|
||||
if self.sizes is not None:
|
||||
return list(torch.split(value, self.sizes, dim=self.dim))
|
||||
return list(torch.chunk(value, self.chunks, dim=self.dim))
|
||||
|
||||
|
||||
class Concatenate(ConversionOps):
|
||||
"""Concatenate tensors along `dim` using a reusable buffer."""
|
||||
|
||||
reverse_op: type[ConversionOps]
|
||||
|
||||
def __init__(self, dim: int = 0):
|
||||
self.dim = dim
|
||||
self.reverse_op = Chunk
|
||||
|
||||
@torch.no_grad
|
||||
def convert(self, value: Sequence[torch.Tensor], *args, **kwargs) -> torch.Tensor:
|
||||
if isinstance(value[0], list):
|
||||
value = [v[0] for v in value]
|
||||
tensors = value
|
||||
if not tensors:
|
||||
raise ValueError("Fuse requires at least one tensor to concatenate.")
|
||||
|
||||
return torch.cat(tuple(tensors), dim=self.dim)
|
||||
|
||||
|
||||
class MergeModulelist(Concatenate):
|
||||
"""
|
||||
Merge a list of tensors into a single tensor along the first dimension.
|
||||
We explicitly define this because for EP or TP you want to make sure you know what you are doing!
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, dim: int = 0):
|
||||
super().__init__(dim=dim)
|
||||
self.reverse_op = SplitModulelist
|
||||
|
||||
@torch.no_grad
|
||||
def convert(self, value: Sequence[torch.Tensor], *args, **kwargs) -> list[torch.Tensor]:
|
||||
merged = []
|
||||
for group in value:
|
||||
if not isinstance(group, Sequence) or len(group) == 0:
|
||||
raise ValueError("MergeModulelist requires non-empty sub-sequences.")
|
||||
group = [k for k in group if k.ndim]
|
||||
merged.append(torch.stack(group, dim=self.dim))
|
||||
return merged
|
||||
|
||||
|
||||
class SplitModulelist(ConversionOps):
|
||||
"""Inverse of :class:`MergeModulelist` using explicit split sizes per group."""
|
||||
|
||||
def __init__(self, sizes: Sequence[Sequence[int]], dim: int = 0):
|
||||
if not isinstance(sizes, Sequence) or not all(isinstance(sub, Sequence) and sub for sub in sizes):
|
||||
raise ValueError("`sizes` must be a sequence of non-empty sequences of integers.")
|
||||
self.sizes = [list(sub) for sub in sizes]
|
||||
self.dim = dim
|
||||
self.reverse_op = MergeModulelist
|
||||
|
||||
@torch.no_grad
|
||||
def convert(self, value: Sequence[torch.Tensor], *, context: dict[str, Any]) -> list[list[torch.Tensor]]:
|
||||
if not isinstance(value, Sequence):
|
||||
raise TypeError("SplitModulelist expects a sequence of tensors.")
|
||||
if len(value) != len(self.sizes):
|
||||
raise ValueError("Number of tensors does not match the provided split specifications.")
|
||||
|
||||
result: list[list[torch.Tensor]] = []
|
||||
for tensor, split_sizes in zip(value, self.sizes):
|
||||
if not isinstance(tensor, torch.Tensor):
|
||||
raise TypeError("SplitModulelist can only split torch.Tensor instances.")
|
||||
splits = torch.split(tensor, split_sizes, dim=self.dim)
|
||||
result.append(list(splits))
|
||||
return result
|
||||
|
||||
|
||||
class PermuteForRope(ConversionOps):
|
||||
"""
|
||||
Applies the permutation required to convert complex RoPE weights to the split sin/cos format.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def _apply(self, tensor: torch.Tensor) -> torch.Tensor:
|
||||
dim1, dim2 = tensor.shape
|
||||
n_heads = self.config.getattr("num_attention_heads", 1)
|
||||
|
||||
tensor = tensor.view(n_heads, dim1 // n_heads // 2, 2, dim2)
|
||||
tensor = tensor.transpose(1, 2).reshape(dim1, dim2)
|
||||
return tensor
|
||||
|
||||
@torch.no_grad
|
||||
def convert(
|
||||
self, value: Union[dict[str, torch.Tensor], Sequence[torch.Tensor], torch.Tensor], config
|
||||
) -> Union[dict[str, torch.Tensor], list[torch.Tensor], torch.Tensor]:
|
||||
self.config = config
|
||||
out = [[self._apply(x) for x in inner] if isinstance(inner, list) else self._apply(inner) for inner in value]
|
||||
return out
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class WeightConverter:
|
||||
r"""
|
||||
A weight convert that acts on a pattern of source keys.
|
||||
The keys need to be collected based on the target keys.
|
||||
|
||||
With wild card, glob patterns are matched, so you have to be detailed with what to match. If you match:
|
||||
`model.layers.*.experts.*` -> it will act on all of them
|
||||
{"model.layers.*.experts.*": []}
|
||||
but
|
||||
`experts.*.mlp` will be layer specific.
|
||||
{"model.layers.1.experts.*": [], }
|
||||
- source_keys: str | list[str] (wildcards '*' match digits)
|
||||
- target_keys: str | list[str] | None
|
||||
- distributed_operation / operations / quantization_operations are ALWAYS lists.
|
||||
"""
|
||||
|
||||
source_keys: Union[str, list[str]]
|
||||
target_keys: Optional[Union[str, list[str]]] = None
|
||||
operations: list[ConversionOps] = field(default_factory=list, repr=False)
|
||||
|
||||
distributed_operation: Optional[TensorParallelLayer] = None
|
||||
quantization_operation: Optional[ConversionOps] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if not isinstance(self.source_keys, list):
|
||||
self.source_keys = [self.source_keys]
|
||||
targets_were_none = False
|
||||
if not isinstance(self.target_keys, list):
|
||||
if self.target_keys is None:
|
||||
self.target_keys = list(self.source_keys)
|
||||
targets_were_none = True
|
||||
else:
|
||||
self.target_keys = [self.target_keys]
|
||||
|
||||
if not targets_were_none and bool(len(self.source_keys) - 1) + bool(len(self.target_keys) - 1) >= 2:
|
||||
raise ValueError(
|
||||
f"source keys={self.source_keys}, target_keys={self.target_keys} but you can only have one to many, one to one or many to one."
|
||||
)
|
||||
|
||||
for pattern in self.source_keys:
|
||||
if any(ch in pattern for ch in set("^$+?{}[]|()")):
|
||||
raise AssertionError(f"'{pattern}' is not glob")
|
||||
for pattern in self.target_keys:
|
||||
if any(ch in pattern for ch in set("^$+?{}[]|()")):
|
||||
raise AssertionError(f"'{pattern}' is not glob")
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ConversionEntry:
|
||||
weight_converter: WeightConverter
|
||||
collected_tensors: dict = field(default_factory=lambda: defaultdict(dict))
|
||||
|
||||
|
||||
GLOBAL_WORKERS = min(16, (os.cpu_count() or 8) * 2) # NVMe: 8-16; HDD/NFS: 2-4
|
||||
|
||||
# Factory function to create LoadedParameter subclasses dynamically
|
||||
def get_loaded_parameter_class(base_cls):
|
||||
"""
|
||||
base_cls: an nn.Parameter subclass (or nn.Parameter) or a Tensor
|
||||
Returns a new class that combines the base_cls with LoadedParameterMixin
|
||||
|
||||
"""
|
||||
class LoadedParam(base_cls):
|
||||
_inplace_methods = [
|
||||
'add_', 'mul_', 'clamp_', 'zero_', 'fill_', 'normal_', 'uniform_',
|
||||
'copy_', 'erfinv_', 'log_', "__getitem__", "neg_", "exp_", "sub_"
|
||||
]
|
||||
def __new__(cls, from_existing, **kwargs):
|
||||
if isinstance(from_existing, torch.nn.Parameter):
|
||||
inst = super().__new__(cls, from_existing.data, from_existing.requires_grad, **from_existing.__dict__)
|
||||
else:
|
||||
inst = super().__new__(cls, from_existing)
|
||||
inst._original_type = from_existing
|
||||
# Explicitly override all in-place methods per instance
|
||||
for method_name in inst._inplace_methods:
|
||||
setattr(inst, method_name, MethodType(inst._skip, inst))
|
||||
|
||||
return inst
|
||||
|
||||
def _skip(self, *args, **kwargs):
|
||||
"""Helper to skip in-place operations."""
|
||||
return self
|
||||
|
||||
def __repr__(self):
|
||||
return f"LoadedParameter(data={self.data})"
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
return super().data
|
||||
|
||||
@data.setter
|
||||
def data(self, new):
|
||||
pass
|
||||
|
||||
def __lt__(self, other): return torch.Tensor.__lt__(self, other)
|
||||
def __le__(self, other): return torch.Tensor.__le__(self, other)
|
||||
def __gt__(self, other): return torch.Tensor.__gt__(self, other)
|
||||
def __ge__(self, other): return torch.Tensor.__ge__(self, other)
|
||||
def __eq__(self, other): return torch.Tensor.__eq__(self, other)
|
||||
def __ne__(self, other): return torch.Tensor.__ne__(self, other)
|
||||
def __iadd__(self, *args, **kwargs): return self
|
||||
def __isub__(self, *args, **kwargs): return self
|
||||
def __imul__(self, *args, **kwargs): return self
|
||||
def __imatmul__(self, *args, **kwargs): return self
|
||||
def __itruediv__(self, *args, **kwargs): return self
|
||||
def __ifloordiv__(self, *args, **kwargs): return self
|
||||
def __imod__(self, *args, **kwargs): return self
|
||||
def __ipow__(self, *args, **kwargs): return self
|
||||
def __iand__(self, *args, **kwargs): return self
|
||||
def __ior__(self, *args, **kwargs): return self
|
||||
def __ixor__(self, *args, **kwargs): return self
|
||||
def __ilshift__(self, *args, **kwargs): return self
|
||||
def __irshift__(self, *args, **kwargs): return self
|
||||
|
||||
return LoadedParam
|
||||
|
||||
def _materialize_copy(tensor, dtype=None):
|
||||
tensor = tensor[...]
|
||||
if dtype is not None:
|
||||
tensor = tensor.to(dtype)
|
||||
return tensor
|
||||
|
||||
|
||||
def spawn_materialize(thread_pool, tensor, dtype=None) -> Future:
|
||||
def _job():
|
||||
return _materialize_copy(tensor, dtype)
|
||||
|
||||
return thread_pool.submit(_job)
|
||||
|
||||
|
||||
def spawn_tp_materialize(thread_pool, tensor, sharding_method, tensor_idx, dtype=None) -> Future:
|
||||
def _job():
|
||||
return sharding_method.shard_tensor(tensor, param_casting_dtype=dtype, tensor_idx=tensor_idx)[0]
|
||||
|
||||
return thread_pool.submit(_job)
|
||||
|
||||
|
||||
def dot_natural_key(s: str):
|
||||
parts = s.split(".")
|
||||
for i, p in enumerate(parts):
|
||||
# whole-segment digits -> int; otherwise leave as str
|
||||
if p.isdigit():
|
||||
parts[i] = int(p)
|
||||
return parts
|
||||
|
||||
|
||||
@contextmanager
|
||||
def log_to_misc(
|
||||
layer_name: str,
|
||||
misc: MutableMapping[str, str],
|
||||
extras: Any = None,
|
||||
op: Union[list[ConversionOps], ConversionOps, None] = None,
|
||||
):
|
||||
# A simple helper to handle errors with contextual messages.
|
||||
try:
|
||||
yield
|
||||
except Exception as e:
|
||||
|
||||
def _format_op_name(curr_op: Union[list[ConversionOps], ConversionOps, None]) -> Optional[str]:
|
||||
if curr_op is None:
|
||||
return None
|
||||
if isinstance(curr_op, (list, tuple, set)):
|
||||
names = [o.__class__.__name__ for o in curr_op if o is not None]
|
||||
if not names:
|
||||
return None
|
||||
return ", ".join(names)
|
||||
return curr_op.__class__.__name__
|
||||
|
||||
op_name = _format_op_name(op)
|
||||
if isinstance(extras, tuple) and len(extras) == 2:
|
||||
values, target_keys = extras
|
||||
descriptor = f"{op_name} " if op_name else ""
|
||||
misc[layer_name] = (
|
||||
f"{e}\nError: {descriptor}on tensors destined for {target_keys}. Ckpt contains: {len(values[0])}"
|
||||
)
|
||||
elif isinstance(extras, str):
|
||||
suffix = f" via {op_name}" if op_name else ""
|
||||
misc[layer_name] = f"{e}\nError{suffix} when processing parameter {extras}"
|
||||
elif extras is None and op_name:
|
||||
misc[layer_name] = f"{op_name}: {e}"
|
||||
else:
|
||||
misc[layer_name] = f"{extras} |Error: {e}"
|
||||
raise SkipLayer()
|
||||
|
||||
|
||||
def set_param_for_module(
|
||||
model: torch.nn.Module,
|
||||
layer_name: str,
|
||||
param_value: torch.Tensor,
|
||||
meta_model_state_dict: MutableMapping[str, Any],
|
||||
empty_param: torch.Tensor,
|
||||
mismatch_keys: MutableSet[tuple[str, torch.Size, torch.Size]],
|
||||
missing_keys: MutableSet[str],
|
||||
misc: MutableMapping[str, Any],
|
||||
distributed_operation: Optional[TensorParallelLayer],
|
||||
hf_quantizer,
|
||||
):
|
||||
with log_to_misc(layer_name, misc, layer_name):
|
||||
module_path, _, param_name = layer_name.rpartition(".")
|
||||
module_obj = model.get_submodule(module_path) if module_path else model
|
||||
if isinstance(param_value, list):
|
||||
param_value = param_value[0]
|
||||
elif not isinstance(param_value, torch.nn.Parameter):
|
||||
param_value = param_value[...]
|
||||
ref = meta_model_state_dict.get(layer_name, empty_param)
|
||||
|
||||
|
||||
use_dtensor = hasattr(distributed_operation, "use_dtensor") and distributed_operation.use_dtensor
|
||||
if not isinstance(param_value, torch.nn.Parameter):
|
||||
if distributed_operation is not None:
|
||||
param_value = DTensor.from_local(
|
||||
param_value,
|
||||
distributed_operation.device_mesh,
|
||||
getattr(distributed_operation, "shard", Replicate()),
|
||||
run_check=False,
|
||||
shape=ref.size(),
|
||||
stride=ref.stride(),
|
||||
)
|
||||
if not use_dtensor:
|
||||
# we convert to local
|
||||
param_value = param_value.to_local()
|
||||
|
||||
if param_name not in module_obj._buffers:
|
||||
param_value = torch.nn.Parameter(param_value, requires_grad=param_value.is_floating_point())
|
||||
|
||||
# to skip any inplace method that modifies the param data
|
||||
param_value = get_loaded_parameter_class(param_value.__class__)(from_existing=param_value)
|
||||
|
||||
# skip mismatch for hf_quantizer for now
|
||||
if ref is not None and ref.shape != param_value.shape and hf_quantizer is None:
|
||||
mismatch_keys.add((layer_name, param_value.shape, ref.shape))
|
||||
setattr(module_obj._parameters[param_name], "_is_hf_initialized", False) # Needs to be initialized
|
||||
missing_keys.discard(layer_name)
|
||||
else:
|
||||
missing_keys.discard(layer_name)
|
||||
param_value._is_hf_initialized = True # super important otherwise _init_weight re-initi if bias is missing
|
||||
setattr(module_obj, param_name, param_value)
|
||||
|
||||
|
||||
class SkipLayer(Exception):
|
||||
"""Control-flow sentinel: abort processing of the current layer only."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def convert_and_load_state_dict_in_model(
|
||||
model,
|
||||
state_dict,
|
||||
weight_mapping,
|
||||
tp_plan,
|
||||
hf_quantizer,
|
||||
dtype=None,
|
||||
device_map=None,
|
||||
dtype_plan=None,
|
||||
device_mesh=None,
|
||||
loading_task_model_from_base_state_dict: bool = False,
|
||||
loading_base_model_from_task_state_dict: bool = False,
|
||||
):
|
||||
"""
|
||||
Convert a state dict according to a weight mapping (one WeightConverter per glob pattern),
|
||||
collecting tensors per *layer instance* (the concrete indices captured from '*').
|
||||
"""
|
||||
|
||||
prefix = model.base_model_prefix
|
||||
tp_plan = tp_plan or {} # {glob_pattern: plan_obj_or_key}
|
||||
device_map = device_map or {} # {exact_target_key: device}
|
||||
dtype_plan = dtype_plan or {} # {glob_pattern: dtype}
|
||||
weight_mapping = weight_mapping or {} # {glob_pattern: WeightConverter}
|
||||
meta_model_state_dict = model.state_dict()
|
||||
missing_keys = set(meta_model_state_dict.keys())
|
||||
|
||||
misc = {}
|
||||
mismatch_keys = set()
|
||||
unexpected_keys = set()
|
||||
# Global thread_pool
|
||||
thread_pool = ThreadPoolExecutor(max_workers=GLOBAL_WORKERS)
|
||||
|
||||
_patterns = list(itertools.chain.from_iterable([k.source_keys for k in weight_mapping]))
|
||||
source_to_target = {sk: k for k in weight_mapping for sk in k.source_keys}
|
||||
weight_pattern_alt, weight_pattern_by_group_name = build_glob_alt(_patterns)
|
||||
tp_plan_alt, tp_plan_by_group_name = build_glob_alt(list(tp_plan.keys()))
|
||||
dtype_policy_alt, dtype_policy_by_group_name = build_glob_alt(list(dtype_plan.keys()))
|
||||
|
||||
state_dict = sorted(state_dict.items(), key=lambda kv: dot_natural_key(kv[0]))
|
||||
# 1. Create the conversion entries
|
||||
by_conversion_pattern: dict[str, ConversionEntry] = {}
|
||||
for original_key, tensor in state_dict:
|
||||
matched_pattern = match_glob(original_key, weight_pattern_alt, weight_pattern_by_group_name)
|
||||
if matched_pattern is not None:
|
||||
converter = source_to_target[matched_pattern] # TODO make sure its the ref
|
||||
sub_with_extractor = partial(re.sub, _glob_to_regex_src(matched_pattern), string=original_key)
|
||||
entry_key = "|".join(converter.target_keys)
|
||||
target_key = "|".join(map(sub_with_extractor, [k.replace("*", "\\1") for k in converter.target_keys]))
|
||||
entry: ConversionEntry = by_conversion_pattern.setdefault(entry_key, ConversionEntry(converter))
|
||||
converter_key = sub_with_extractor(matched_pattern)
|
||||
else:
|
||||
converter = WeightConverter(original_key)
|
||||
converter_key = entry_key = target_key = original_key
|
||||
entry = by_conversion_pattern.setdefault(converter_key, ConversionEntry(converter))
|
||||
|
||||
_dtype = dtype
|
||||
new_target_key = [] # test_load_with_mismatched_shapes for AutoModel.from_pretrained(AutoForCausal, vocab=10)
|
||||
for t in target_key.split("|"):
|
||||
if t.startswith(prefix) and meta_model_state_dict.get(t.replace(f"{prefix}.", "")) is not None:
|
||||
t = t.replace(f"{prefix}.", "")
|
||||
elif meta_model_state_dict.get(f"{prefix}.{t}") is not None:
|
||||
t = f"{prefix}.{t}"
|
||||
new_target_key.append(t)
|
||||
empty_param = meta_model_state_dict.get(t)
|
||||
# If it does not exist, it's unexpected
|
||||
if empty_param is None:
|
||||
if hf_quantizer is not None and hf_quantizer.is_valid_unexpected_keys(t):
|
||||
pass
|
||||
else:
|
||||
unexpected_keys.add(t)
|
||||
continue
|
||||
|
||||
if hf_quantizer is not None and hf_quantizer.param_needs_quantization(model, t):
|
||||
converter.quantization_operation = hf_quantizer.get_quantize_ops()
|
||||
# TODO: to clean later. We need to use the empty_param from the checkpoint to decide if we upcast the param to a specific dtype
|
||||
k_dtype = tensor.get_dtype()
|
||||
dtype = str_to_torch_dtype[k_dtype]
|
||||
empty_param_checkpoint = torch.empty(size=tensor.get_shape(), dtype=dtype, device="meta")
|
||||
_, _dtype = _infer_parameter_dtype(model, t, empty_param_checkpoint, hf_quantizer)
|
||||
else:
|
||||
_dtype = dtype
|
||||
matched_dtype_pattern = match_glob(t, dtype_policy_alt, dtype_policy_by_group_name)
|
||||
if matched_dtype_pattern is not None:
|
||||
_dtype = dtype_plan[matched_dtype_pattern]
|
||||
elif empty_param.dtype != _dtype:
|
||||
_dtype = empty_param.dtype
|
||||
|
||||
first_target_key = new_target_key[0]
|
||||
target_key = "|".join(new_target_key)
|
||||
|
||||
future = None
|
||||
if device_mesh:
|
||||
if matched_tp_pattern := match_glob(first_target_key, tp_plan_alt, tp_plan_by_group_name):
|
||||
empty_param = meta_model_state_dict.get(first_target_key)
|
||||
if getattr(converter, "distributed_operation", {}) is None:
|
||||
tp_layer = ALL_PARALLEL_STYLES[model.tp_plan[matched_tp_pattern]].__class__
|
||||
converter.distributed_operation = tp_layer(
|
||||
device_mesh=device_mesh, rank=device_map[""].index, empty_param=empty_param.clone()
|
||||
)
|
||||
# VERY IMPORTANT: this tells us wether we collected stuffs or not.
|
||||
shard_index = len(entry.collected_tensors[target_key].get(converter_key, []))
|
||||
future = spawn_tp_materialize(
|
||||
thread_pool,
|
||||
tensor,
|
||||
_dtype,
|
||||
converter.distributed_operation,
|
||||
shard_index,
|
||||
)
|
||||
|
||||
if future is None: # If not TP, async materialize the tensors. TODO handle disk offload?
|
||||
future = spawn_materialize(thread_pool, tensor, _dtype)
|
||||
entry.collected_tensors[target_key].setdefault(converter_key, []).append(future)
|
||||
|
||||
# 2. Actually convert the ckpt
|
||||
inverse_converters = {}
|
||||
keys = list(by_conversion_pattern.keys())
|
||||
|
||||
with logging.tqdm(total=len(keys), desc="Loading weights") as pbar:
|
||||
for key in keys[::-1]: # revert to process simple keys first
|
||||
group = by_conversion_pattern.pop(key)
|
||||
converter = group.weight_converter
|
||||
operations = converter.operations if isinstance(converter.operations, list) else [converter.operations]
|
||||
for layer_name, tensors_for_this_layer in group.collected_tensors.items():
|
||||
concrete_target_keys = layer_name.split("|")
|
||||
try:
|
||||
if bool(set(concrete_target_keys) - unexpected_keys):
|
||||
with log_to_misc(layer_name, misc):
|
||||
values = [[k.result() for k in inner] for inner in tensors_for_this_layer.values()]
|
||||
|
||||
for op in operations:
|
||||
with log_to_misc(layer_name, misc, (values, concrete_target_keys), operations):
|
||||
values = op.convert(values, model.config)
|
||||
|
||||
values = [values] if not isinstance(values, list) else values
|
||||
with log_to_misc(layer_name, misc, (values, concrete_target_keys), operations):
|
||||
realized_value = {
|
||||
k: t for k, t in zip(concrete_target_keys, values) if k not in unexpected_keys
|
||||
}
|
||||
|
||||
for k in list(realized_value.keys()).copy():
|
||||
if op := converter.quantization_operation:
|
||||
with log_to_misc(layer_name, misc, op=op):
|
||||
realized_value.update(
|
||||
op.convert({k: realized_value.pop(k)}, model=model, missing_keys=missing_keys)
|
||||
)
|
||||
|
||||
for k, output_value in realized_value.items():
|
||||
for src in converter.source_keys: # what should happen to k when we meet k at saving
|
||||
inverse_converters[k] = {src: converter}
|
||||
set_param_for_module(
|
||||
model,
|
||||
k,
|
||||
output_value,
|
||||
meta_model_state_dict,
|
||||
empty_param,
|
||||
mismatch_keys,
|
||||
missing_keys,
|
||||
misc,
|
||||
converter.distributed_operation,
|
||||
hf_quantizer
|
||||
)
|
||||
except Exception as e :
|
||||
raise e
|
||||
del group
|
||||
|
||||
# Update progress bar
|
||||
pbar.update()
|
||||
pbar.refresh()
|
||||
|
||||
model.inverse_converters = inverse_converters
|
||||
thread_pool.shutdown(wait=False)
|
||||
return missing_keys, unexpected_keys, mismatch_keys, misc
|
||||
|
||||
|
||||
# TODO this is not done yet!
|
||||
def revert_weight_conversion(model, state_dict):
|
||||
mapping = getattr(model, "", {}) # IDK why but setting this will fail all llava.
|
||||
reverse_key_mapping = [(v, k) for k, v in mapping.items()]
|
||||
original_state_dict = {}
|
||||
for key, value in state_dict.items():
|
||||
for pattern, inverse_converter in reverse_key_mapping:
|
||||
# TODO FIXME you name it
|
||||
replacement = inverse_converter.lstrip("^") # strip off un-needed chars and patterns
|
||||
replacement = re.sub(r"\(.*\)", "", replacement)
|
||||
key, n_replace = re.subn(pattern, replacement, key)
|
||||
# Early exit of the loop
|
||||
if n_replace > 0:
|
||||
break
|
||||
original_state_dict[key] = value
|
||||
state_dict = original_state_dict
|
||||
return state_dict
|
||||
|
||||
def _infer_parameter_dtype(
|
||||
model: torch.nn.Module,
|
||||
param_name: str,
|
||||
empty_param: torch.Tensor,
|
||||
hf_quantizer: Optional[HfQuantizer] = None,
|
||||
) -> tuple[bool, Optional[torch.dtype]]:
|
||||
try:
|
||||
old_param = model.get_parameter_or_buffer(param_name)
|
||||
except Exception as e:
|
||||
if hf_quantizer is not None and hf_quantizer.quantization_config.quant_method in {
|
||||
QuantizationMethod.HQQ,
|
||||
QuantizationMethod.QUARK,
|
||||
QuantizationMethod.MXFP4,
|
||||
QuantizationMethod.BITS_AND_BYTES,
|
||||
}:
|
||||
return True, None
|
||||
else:
|
||||
raise e
|
||||
is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn")
|
||||
# We convert floating dtypes to the `dtype` passed except for float8_e4m3fn type. We also want to keep the buffers/params
|
||||
# in int/uint/bool and not cast them.
|
||||
casting_dtype = None
|
||||
is_param_float8_e4m3fn = is_torch_e4m3fn_available and empty_param.dtype == torch.float8_e4m3fn
|
||||
if empty_param.dtype.is_floating_point and not is_param_float8_e4m3fn:
|
||||
# dtype that was instantiated in the meta model -- note that this respects subconfigs dtypes
|
||||
if hf_quantizer is not None and hf_quantizer.param_needs_quantization(model, param_name):
|
||||
casting_dtype = model.config._pre_quantization_dtype
|
||||
else:
|
||||
casting_dtype = old_param.dtype
|
||||
return old_param is not None and old_param.is_contiguous(), casting_dtype
|
||||
@ -19,7 +19,6 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from ...utils import is_torch_xpu_available
|
||||
from ...utils.logging import logging
|
||||
from ...utils.metrics import traced
|
||||
|
||||
@ -36,13 +35,6 @@ def get_device_and_memory_breakdown() -> tuple[torch.device, int, int, int]:
|
||||
total_memory = torch.cuda.get_device_properties(device).total_memory
|
||||
reserved_memory = torch.cuda.memory_reserved(device)
|
||||
allocated_memory = torch.cuda.memory_allocated(device)
|
||||
elif is_torch_xpu_available():
|
||||
device = torch.device("xpu")
|
||||
torch.xpu.empty_cache()
|
||||
torch.xpu.synchronize()
|
||||
total_memory = torch.xpu.get_device_properties(device).total_memory
|
||||
reserved_memory = torch.xpu.memory_reserved(device)
|
||||
allocated_memory = torch.xpu.memory_allocated(device)
|
||||
elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
|
||||
device = torch.device("mps")
|
||||
# MPS memory reporting (PyTorch 2.0+)
|
||||
|
||||
@ -1635,7 +1635,12 @@ class GenerationMixin(ContinuousMixin):
|
||||
|
||||
# TransformersKwargs are model-agnostic attention and generation arguments such as 'output_attentions'
|
||||
for key, value in model_kwargs.items():
|
||||
if value is not None and key not in model_args and key not in TransformersKwargs.__optional_keys__:
|
||||
if (
|
||||
value is not None
|
||||
and key not in model_args
|
||||
and key not in TransformersKwargs.__optional_keys__
|
||||
and key != "debug_io"
|
||||
):
|
||||
unused_model_args.append(key)
|
||||
|
||||
if unused_model_args:
|
||||
|
||||
@ -383,10 +383,11 @@ class BayesianDetectorModel(PreTrainedModel):
|
||||
)
|
||||
self.prior = torch.nn.Parameter(torch.tensor([self.base_rate]))
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights."""
|
||||
if isinstance(module, nn.Parameter):
|
||||
module.weight.data.normal_(mean=0.0, std=0.02)
|
||||
module.weight.normal_(mean=0.0, std=0.02)
|
||||
|
||||
def _compute_posterior(
|
||||
self,
|
||||
|
||||
@ -821,26 +821,14 @@ def split_to_tiles(images: "torch.Tensor", num_tiles_height: int, num_tiles_widt
|
||||
return image
|
||||
|
||||
|
||||
def _cast_tensor_to_float(x):
|
||||
if x.is_floating_point():
|
||||
return x
|
||||
return x.float()
|
||||
|
||||
|
||||
def _group_images_by_shape(nested_images, *paired_inputs, is_nested: bool = False):
|
||||
"""
|
||||
Helper function to flatten a single level of nested image and batch structures and group by shape.
|
||||
Args:
|
||||
nested_images (list):
|
||||
A list of images or a single tensor
|
||||
paired_inputs (Any, *optional*):
|
||||
Zero or more lists that mirror the structure of `nested_images` (flat list, or list of lists when
|
||||
`is_nested=True`). Each element is paired 1:1 with the corresponding image so it can be grouped by the
|
||||
same shape key. These paired values are grouped alongside `nested_images` but are not stacked in the output, so
|
||||
they do not need to be tensors.
|
||||
is_nested (bool, *optional*, defaults to False):
|
||||
Whether the images are nested.
|
||||
Returns:
|
||||
tuple[dict, ...]:
|
||||
- A dictionary with shape as key and list of images with that shape as value
|
||||
- A dictionary with shape as key and list of paired values with that shape as value
|
||||
- A dictionary mapping original indices to (shape, index) tuples
|
||||
- A dictionary mapping original indices to (shape, index) tuples for each paired input
|
||||
"""
|
||||
"""Helper function to flatten a single level of nested image and batch structures and group by shape."""
|
||||
grouped_images = defaultdict(list)
|
||||
grouped_images_index = {}
|
||||
paired_grouped_values = [defaultdict(list) for _ in paired_inputs]
|
||||
@ -892,20 +880,27 @@ def _reconstruct_nested_structure(indices, processed_images):
|
||||
return result
|
||||
|
||||
|
||||
def _iterate_items(items, is_nested: bool):
|
||||
"""
|
||||
Helper function to iterate over items yielding (key, item) pairs.
|
||||
def _disable_grouping_output_nested(images, *paired_inputs):
|
||||
"""Build the disable_grouping output tuple for a single-level nested structure."""
|
||||
outer_range = range(len(images))
|
||||
inner_ranges = [range(len(images[i])) for i in outer_range]
|
||||
|
||||
For nested structures, yields ((row_index, col_index), item).
|
||||
For flat structures, yields (index, item).
|
||||
"""
|
||||
if is_nested:
|
||||
for i, row in enumerate(items):
|
||||
for j, item in enumerate(row):
|
||||
yield (i, j), item
|
||||
else:
|
||||
for i, item in enumerate(items):
|
||||
yield i, item
|
||||
# Precompute all (i, j) pairs
|
||||
ij_pairs = [(i, j) for i in outer_range for j in inner_ranges[i]]
|
||||
|
||||
images_dict = {(i, j): images[i][j].unsqueeze(0) for (i, j) in ij_pairs}
|
||||
paired_dicts = [{(i, j): paired_list[i][j].unsqueeze(0) for (i, j) in ij_pairs} for paired_list in paired_inputs]
|
||||
index_map = {(i, j): ((i, j), 0) for (i, j) in ij_pairs}
|
||||
return images_dict, *paired_dicts, index_map
|
||||
|
||||
|
||||
def _disable_grouping_output_flat(images, *paired_inputs):
|
||||
"""Build the disable_grouping output tuple for a flat list structure."""
|
||||
idx_range = range(len(images))
|
||||
images_dict = {i: images[i].unsqueeze(0) for i in idx_range}
|
||||
paired_dicts = [{i: paired_list[i].unsqueeze(0) for i in idx_range} for paired_list in paired_inputs]
|
||||
index_map = {i: (i, 0) for i in idx_range}
|
||||
return images_dict, *paired_dicts, index_map
|
||||
|
||||
|
||||
def group_images_by_shape(
|
||||
@ -925,7 +920,7 @@ def group_images_by_shape(
|
||||
Args:
|
||||
images (Union[list["torch.Tensor"], "torch.Tensor"]):
|
||||
A list of images or a single tensor
|
||||
paired_inputs (Any, *optional*):
|
||||
*paired_inputs (Any):
|
||||
Zero or more lists that mirror the structure of `images` (flat list, or list of lists when
|
||||
`is_nested=True`). Each element is paired 1:1 with the corresponding image so it can be grouped by the
|
||||
same shape key. These paired values are grouped alongside `images` but are not stacked in the output, so
|
||||
@ -949,14 +944,10 @@ def group_images_by_shape(
|
||||
disable_grouping = device == "cpu"
|
||||
|
||||
if disable_grouping:
|
||||
return (
|
||||
{key: img.unsqueeze(0) for key, img in _iterate_items(images, is_nested)},
|
||||
*[
|
||||
{key: item.unsqueeze(0) for key, item in _iterate_items(paired_list, is_nested)}
|
||||
for paired_list in paired_inputs
|
||||
],
|
||||
{key: (key, 0) for key, _ in _iterate_items(images, is_nested)},
|
||||
)
|
||||
if is_nested:
|
||||
return _disable_grouping_output_nested(images, *paired_inputs)
|
||||
else:
|
||||
return _disable_grouping_output_flat(images, *paired_inputs)
|
||||
|
||||
# Handle single level nested structure
|
||||
grouped_images, *paired_grouped_values, grouped_images_index = _group_images_by_shape(
|
||||
@ -999,3 +990,14 @@ def reorder_images(
|
||||
]
|
||||
|
||||
return _reconstruct_nested_structure(grouped_images_index, processed_images)
|
||||
|
||||
|
||||
class NumpyToTensor:
|
||||
"""
|
||||
Convert a numpy array to a PyTorch tensor.
|
||||
"""
|
||||
|
||||
def __call__(self, image: np.ndarray):
|
||||
# Same as in PyTorch, we assume incoming numpy images are in HWC format
|
||||
# c.f. https://github.com/pytorch/vision/blob/61d97f41bc209e1407dcfbd685d2ee2da9c1cdad/torchvision/transforms/functional.py#L154
|
||||
return torch.from_numpy(image.transpose(2, 0, 1)).contiguous()
|
||||
|
||||
@ -36,6 +36,7 @@ _import_structure = {
|
||||
"get_keys_to_not_convert",
|
||||
"replace_with_bnb_linear",
|
||||
"validate_bnb_backend_availability",
|
||||
"Bnb4bitQuantize",
|
||||
],
|
||||
"deepspeed": [
|
||||
"HfDeepSpeedConfig",
|
||||
@ -177,6 +178,7 @@ if TYPE_CHECKING:
|
||||
unpack_weights,
|
||||
)
|
||||
from .bitsandbytes import (
|
||||
Bnb4bitQuantize,
|
||||
dequantize_and_replace,
|
||||
get_keys_to_not_convert,
|
||||
replace_with_bnb_linear,
|
||||
|
||||
@ -435,6 +435,7 @@ def _get_device_map(
|
||||
if max_memory is not None and device_name in max_memory:
|
||||
inferred_max_memory[device_name] = min(inferred_max_memory[device_name], max_memory[device_name])
|
||||
|
||||
model.tie_weights()
|
||||
device_map = infer_auto_device_map(
|
||||
model,
|
||||
max_memory=inferred_max_memory,
|
||||
@ -512,10 +513,8 @@ def accelerate_disk_offload(
|
||||
checkpoint_files,
|
||||
device_map,
|
||||
checkpoint_keys,
|
||||
key_renaming_mapping,
|
||||
sharded_metadata,
|
||||
dtype,
|
||||
reverse_key_renaming_mapping,
|
||||
):
|
||||
disk_only_shard_files = []
|
||||
if disk_offload_folder is not None:
|
||||
@ -534,19 +533,13 @@ def accelerate_disk_offload(
|
||||
weight_map = dict.fromkeys(checkpoint_keys, checkpoint_files[0])
|
||||
else:
|
||||
folder = os.path.sep.join(checkpoint_files[0].split(os.path.sep)[:-1])
|
||||
# Fix the weight map keys according to the key mapping
|
||||
weight_map = {
|
||||
key_renaming_mapping[k]: v
|
||||
for k, v in sharded_metadata["weight_map"].items()
|
||||
if k in key_renaming_mapping
|
||||
}
|
||||
weight_map = {k: os.path.join(folder, v) for k, v in weight_map.items()}
|
||||
# Find potential checkpoints containing only offloaded weights
|
||||
disk_only_shard_files = get_disk_only_shard_files(device_map, weight_map)
|
||||
disk_offload_index = {
|
||||
name: {
|
||||
"safetensors_file": file,
|
||||
"weight_name": reverse_key_renaming_mapping[name],
|
||||
"weight_name": name,
|
||||
"dtype": str_dtype,
|
||||
}
|
||||
for name, file in weight_map.items()
|
||||
|
||||
@ -1,7 +1,9 @@
|
||||
import inspect
|
||||
from copy import deepcopy
|
||||
from collections import defaultdict
|
||||
from inspect import signature
|
||||
from typing import Optional
|
||||
|
||||
from ..quantizers.quantizers_utils import get_module_from_name
|
||||
from ..utils import (
|
||||
get_available_devices,
|
||||
is_accelerate_available,
|
||||
@ -24,10 +26,52 @@ if is_accelerate_available():
|
||||
import accelerate
|
||||
from accelerate import init_empty_weights
|
||||
from accelerate.hooks import add_hook_to_module, remove_hook_from_module
|
||||
from accelerate.utils import find_tied_parameters
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
from ..core_model_loading import ConversionOps
|
||||
|
||||
|
||||
class Bnb4bitQuantize(ConversionOps):
|
||||
def __init__(self, hf_quantizer):
|
||||
self.hf_quantizer = hf_quantizer
|
||||
|
||||
def convert(self, input_dict: torch.Tensor, model: Optional[torch.nn.Module] = None, **kwargs) -> dict[str, torch.Tensor]:
|
||||
target_key, value = tuple(input_dict.items())[0]
|
||||
value = value[0] if isinstance(value, list) else value
|
||||
|
||||
full_name = target_key
|
||||
# update param name to get the weights instead of the quantized stats
|
||||
target_key = self.hf_quantizer.get_param_name(target_key)
|
||||
module, _ = get_module_from_name(model, target_key)
|
||||
|
||||
if not self.hf_quantizer.pre_quantized:
|
||||
# Support models using `Conv1D` in place of `nn.Linear` (e.g. openai-community/gpt2) by transposing the weight matrix prior to quantization.
|
||||
# Since weights are saved in the correct "orientation", we skip transposing when loading.
|
||||
if issubclass(module.source_cls, Conv1D):
|
||||
value = value.T
|
||||
old_value = model.get_parameter_or_buffer(target_key)
|
||||
new_value = bnb.nn.Params4bit(value, requires_grad=False, **old_value.__dict__).to(value.device)
|
||||
return {target_key : new_value}
|
||||
else:
|
||||
module_name = target_key.rsplit(".", 1)[0]
|
||||
# Save the states for later quantization when they are all gathered
|
||||
if not hasattr(self.hf_quantizer, "param_quant_stats"):
|
||||
self.hf_quantizer.param_quant_stats = defaultdict(dict)
|
||||
self.hf_quantizer.param_quant_stats[module_name].update({full_name: value})
|
||||
# We are ready for quantization in this case (note, the +1 is for the weight itself)
|
||||
if len(self.hf_quantizer.param_quant_stats[module_name]) == len(self.hf_quantizer.bnb_keys) + 1:
|
||||
weight = self.hf_quantizer.param_quant_stats[module_name].pop(f"{module_name}.weight")
|
||||
new_value = bnb.nn.Params4bit.from_prequantized(
|
||||
data=weight,
|
||||
quantized_stats=self.hf_quantizer.param_quant_stats[module_name],
|
||||
requires_grad=False,
|
||||
device=value.device,
|
||||
module=module
|
||||
)
|
||||
del self.hf_quantizer.param_quant_stats[module_name]
|
||||
return {target_key : new_value}
|
||||
return {}
|
||||
|
||||
def _replace_with_bnb_linear(
|
||||
model,
|
||||
@ -151,52 +195,6 @@ def replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name
|
||||
return model
|
||||
|
||||
|
||||
def get_keys_to_not_convert(model):
|
||||
r"""
|
||||
An utility function to get the key of the module to keep in full precision if any For example for CausalLM modules
|
||||
we may want to keep the lm_head in full precision for numerical stability reasons. For other architectures, we want
|
||||
to keep the tied weights of the model. The function will return a list of the keys of the modules to not convert in
|
||||
int8.
|
||||
|
||||
Parameters:
|
||||
model (`torch.nn.Module`):
|
||||
Input model
|
||||
"""
|
||||
# Create a copy of the model and tie the weights, then
|
||||
# check if it contains tied weights
|
||||
tied_model = deepcopy(model) # this has 0 cost since it is done inside `init_empty_weights` context manager`
|
||||
tied_model.tie_weights()
|
||||
|
||||
tied_params = find_tied_parameters(tied_model)
|
||||
tied_keys = sum(tied_params, [])
|
||||
has_tied_params = len(tied_keys) > 0
|
||||
|
||||
# If there is not tied weights, we want to keep the lm_head(output_embedding) in full precision
|
||||
if not has_tied_params:
|
||||
output_emb = model.get_output_embeddings()
|
||||
if output_emb is not None:
|
||||
list_last_module = [name for name, module in model.named_modules() if id(module) == id(output_emb)]
|
||||
return list_last_module
|
||||
|
||||
# otherwise, no tied weights, no output embedding defined, simply keep the last module in full precision
|
||||
list_modules = list(model.named_parameters())
|
||||
list_last_module = [list_modules[-1][0]]
|
||||
# add last module together with tied weights
|
||||
intersection = set(list_last_module) - set(tied_keys)
|
||||
list_untouched = list(set(tied_keys)) + list(intersection)
|
||||
|
||||
# remove ".weight" from the keys
|
||||
names_to_remove = [".weight", ".bias"]
|
||||
filtered_module_names = []
|
||||
for name in list_untouched:
|
||||
for name_to_remove in names_to_remove:
|
||||
if name_to_remove in name:
|
||||
name = name.replace(name_to_remove, "")
|
||||
filtered_module_names.append(name)
|
||||
|
||||
return filtered_module_names
|
||||
|
||||
|
||||
# Copied from PEFT: https://github.com/huggingface/peft/blob/47b3712898539569c02ec5b3ed4a6c36811331a1/src/peft/utils/integrations.py#L41
|
||||
def dequantize_bnb_weight(weight: "torch.nn.Parameter", dtype: "torch.dtype", state=None):
|
||||
"""
|
||||
|
||||
@ -11,6 +11,7 @@
|
||||
# specific language governing permissions and limitations under the License.
|
||||
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
@ -23,7 +24,13 @@ from ..cache_utils import (
|
||||
StaticCache,
|
||||
)
|
||||
from ..generation.configuration_utils import GenerationConfig
|
||||
from ..modeling_utils import PreTrainedModel
|
||||
from ..masking_utils import (
|
||||
ALL_MASK_ATTENTION_FUNCTIONS,
|
||||
_ignore_causal_mask_sdpa,
|
||||
_is_torch_greater_or_equal_than_2_5,
|
||||
prepare_padding_mask,
|
||||
)
|
||||
from ..modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ..pytorch_utils import (
|
||||
is_torch_greater_or_equal,
|
||||
is_torch_greater_or_equal_than_2_3,
|
||||
@ -222,6 +229,10 @@ class TorchExportableModuleForDecoderOnlyLM(torch.nn.Module):
|
||||
"Using `StaticCache` for export as `layer_types` is not specified or `sliding_window` is `null` in the config."
|
||||
)
|
||||
self.model = TorchExportableModuleWithStaticCache(model, batch_size, max_cache_len, device)
|
||||
# This is the same as sdpa, but mask creation does not use `vmap` which is not exportable
|
||||
ALL_MASK_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", sdpa_mask_without_vmap)
|
||||
ALL_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", ALL_ATTENTION_FUNCTIONS["sdpa"])
|
||||
self.model.model.config._attn_implementation = "sdpa_without_vmap"
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -757,6 +768,11 @@ def convert_and_export_with_cache(
|
||||
|
||||
import torch.export._trace
|
||||
|
||||
# This is the same as sdpa, but mask creation does not use `vmap` which is not exportable
|
||||
ALL_MASK_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", sdpa_mask_without_vmap)
|
||||
ALL_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", ALL_ATTENTION_FUNCTIONS["sdpa"])
|
||||
model.config._attn_implementation = "sdpa_without_vmap"
|
||||
|
||||
with torch.no_grad():
|
||||
# TODO: The default inputs only work for text models. We need to add support for vision/audio models.
|
||||
example_input_ids = (
|
||||
@ -1020,6 +1036,11 @@ def export_with_dynamic_cache(
|
||||
if not is_torch_greater_or_equal_than_2_3:
|
||||
raise ImportError("torch >= 2.3 is required.")
|
||||
|
||||
# This is the same as sdpa, but mask creation does not use `vmap` which is not exportable
|
||||
ALL_MASK_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", sdpa_mask_without_vmap)
|
||||
ALL_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", ALL_ATTENTION_FUNCTIONS["sdpa"])
|
||||
model.config._attn_implementation = "sdpa_without_vmap"
|
||||
|
||||
register_dynamic_cache_export_support()
|
||||
|
||||
with torch.no_grad():
|
||||
@ -1088,3 +1109,92 @@ def _unflatten_dynamic_cache(values, context: torch.utils._pytree.Context):
|
||||
value = value_list[idx] if idx < len(value_list) else None
|
||||
cache.update(key, value, idx)
|
||||
return cache
|
||||
|
||||
|
||||
def sdpa_mask_without_vmap(
|
||||
batch_size: int,
|
||||
cache_position: torch.Tensor,
|
||||
kv_length: int,
|
||||
kv_offset: int = 0,
|
||||
mask_function: Optional[Callable] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
local_size: Optional[int] = None,
|
||||
allow_is_causal_skip: bool = True,
|
||||
allow_torch_fix: bool = True,
|
||||
**kwargs,
|
||||
) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
Create a 4D boolean mask of shape `(batch_size, 1, query_length, kv_length)` where a value of True indicates that
|
||||
the element should take part in the attention computation, and False that it should not.
|
||||
|
||||
This is similar to `masking_utils.sdpa_mask` but does not use `vmap` which is incompatible with export.
|
||||
|
||||
Args:
|
||||
batch_size (`int`):
|
||||
The batch size of the input sequence.
|
||||
cache_position (`torch.Tensor`):
|
||||
A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
|
||||
kv_length (`int`):
|
||||
The size that the key and value states will have during the attention computation.
|
||||
kv_offset (`int`, optional):
|
||||
An optional offset to indicate at which first position the key and values states will refer to.
|
||||
mask_function (`Callable`):
|
||||
The mask factory function describing the mask pattern.
|
||||
attention_mask (`torch.Tensor`, optional):
|
||||
The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length)
|
||||
local_size (`int`, optional):
|
||||
The size of the local attention, if we do not use full attention. This is used only if `allow_is_causal_skip=True`
|
||||
to try to skip mask creation if possible.
|
||||
allow_is_causal_skip (`bool`, optional):
|
||||
Whether to allow to return `None` for the mask under conditions where we can use the `is_causal` argument in
|
||||
`torch.sdpa` instead. Default to `True`.
|
||||
allow_torch_fix (`bool`, optional):
|
||||
Whether to update the mask in case a query is not attending to any tokens, to solve a bug in torch's older
|
||||
versions. We need an arg to skip it when using eager. By default `True`.
|
||||
|
||||
"""
|
||||
|
||||
q_length = cache_position.shape[0]
|
||||
# Potentially pad the 2D mask, and slice it correctly
|
||||
padding_mask = prepare_padding_mask(attention_mask, kv_length, kv_offset)
|
||||
|
||||
# Under specific conditions, we can avoid materializing the mask, instead relying on the `is_causal` argument
|
||||
if allow_is_causal_skip and _ignore_causal_mask_sdpa(padding_mask, q_length, kv_length, local_size):
|
||||
return None
|
||||
|
||||
# Similar to `kv_arange = torch.arange(start=kv_offset, end=kv_offset + kv_length, device=cache_position.device)`
|
||||
# but without data-dependent slicing (i.e. torch.compile friendly)
|
||||
kv_arange = torch.arange(kv_length, device=cache_position.device)
|
||||
kv_arange += kv_offset
|
||||
reshaped_cache_position = cache_position.view(-1, 1)
|
||||
|
||||
# This is a bit hacky to know what pattern we are using, but all mask creation function actually forward
|
||||
# the config through kwargs anyway, so it allows to rely on it
|
||||
# Usually, the `mask_function` is the only entry-point to define the pattern - we could do for loops over it,
|
||||
# but this is more efficient
|
||||
sliding_window = getattr(kwargs["config"], "sliding_window", None)
|
||||
chunk_size = getattr(kwargs["config"], "attention_chunk_size", None)
|
||||
|
||||
if sliding_window is not None and chunk_size is not None:
|
||||
raise ValueError("Cannot use both `sliding_window` and `attention_chunk_size`")
|
||||
|
||||
# Simplest and most efficient way to obtain a causal mask
|
||||
causal_mask = kv_arange <= reshaped_cache_position
|
||||
# If using sliding window, add the sliding mask
|
||||
if sliding_window is not None:
|
||||
sliding_mask_overlay = kv_arange > reshaped_cache_position - sliding_window
|
||||
causal_mask *= sliding_mask_overlay
|
||||
# If using chunk attention, add the chunked mask
|
||||
elif chunk_size is not None:
|
||||
chunked_mask_overlay = kv_arange // chunk_size == reshaped_cache_position // chunk_size
|
||||
causal_mask *= chunked_mask_overlay
|
||||
|
||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, -1, -1, -1)
|
||||
if padding_mask is not None:
|
||||
causal_mask = causal_mask * padding_mask[:, None, None, :]
|
||||
|
||||
# Due to a bug in some older torch version, we need to update the mask in case a query is not attending to any
|
||||
# tokens (due to padding). See details in https://github.com/pytorch/pytorch/issues/110213
|
||||
if not _is_torch_greater_or_equal_than_2_5 and allow_torch_fix:
|
||||
causal_mask |= torch.all(~causal_mask, dim=-1, keepdim=True)
|
||||
return causal_mask
|
||||
|
||||
@ -13,8 +13,11 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional
|
||||
import re
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from ..core_model_loading import ConversionOps
|
||||
from ..utils import is_accelerate_available, is_torch_accelerator_available, is_torch_available, logging
|
||||
|
||||
|
||||
@ -30,6 +33,18 @@ if is_accelerate_available():
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
try:
|
||||
_FP8_DTYPE = torch.float8_e4m3fn
|
||||
_FP8_MIN = torch.finfo(_FP8_DTYPE).min
|
||||
_FP8_MAX = torch.finfo(_FP8_DTYPE).max
|
||||
_FP8_IS_INT = False
|
||||
except AttributeError:
|
||||
_FP8_DTYPE = torch.int8
|
||||
_FP8_MIN, _FP8_MAX = -127, 127
|
||||
_FP8_IS_INT = True
|
||||
logger.warning_once(
|
||||
"torch.float8_e4m3fn not available; falling back to int8 emulation for Fp8Quantize operations."
|
||||
)
|
||||
|
||||
|
||||
# Copied from https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/inference/kernel.py
|
||||
@ -332,6 +347,12 @@ class FP8Linear(nn.Linear):
|
||||
if self.weight.element_size() > 1:
|
||||
return F.linear(input, self.weight, self.bias)
|
||||
else:
|
||||
if isinstance(self.weight, torch.distributed.tensor.DTensor):
|
||||
weight = self.weight._local_tensor.contiguous()
|
||||
scale_inv = self.weight_scale_inv._local_tensor.contiguous()
|
||||
else:
|
||||
weight = self.weight.contiguous()
|
||||
scale_inv = self.weight_scale_inv.contiguous()
|
||||
# Context manager used to switch among the available accelerators
|
||||
device_type = torch.accelerator.current_accelerator().type if is_torch_accelerator_available() else "cuda"
|
||||
torch_accelerator_module = getattr(torch, device_type, torch.cuda)
|
||||
@ -339,9 +360,9 @@ class FP8Linear(nn.Linear):
|
||||
qinput, scale = act_quant(input, self.block_size[1])
|
||||
output = w8a8_block_fp8_matmul_triton(
|
||||
qinput,
|
||||
self.weight,
|
||||
weight,
|
||||
scale,
|
||||
self.weight_scale_inv,
|
||||
scale_inv,
|
||||
self.block_size,
|
||||
output_dtype=input.dtype,
|
||||
)
|
||||
@ -350,9 +371,124 @@ class FP8Linear(nn.Linear):
|
||||
torch_accelerator_module.synchronize()
|
||||
if self.bias is not None:
|
||||
output = output + self.bias
|
||||
output = torch.nan_to_num(output, nan=0.0)
|
||||
return output.to(dtype=input.dtype)
|
||||
|
||||
|
||||
def _ceil_div(a, b):
|
||||
return (a + b - 1) // b
|
||||
|
||||
|
||||
class FP8Expert(nn.Module):
|
||||
dtype = torch.float8_e4m3fn
|
||||
|
||||
def __init__(self, config, block_size, device):
|
||||
super().__init__()
|
||||
|
||||
from ..activations import ACT2FN
|
||||
|
||||
self.block_size = block_size
|
||||
self.num_experts = config.num_local_experts
|
||||
self.hidden_dim = config.hidden_size
|
||||
self.intermediate_dim = config.intermediate_size
|
||||
|
||||
Wg_out, Wg_in = 2 * self.intermediate_dim, self.hidden_dim
|
||||
Wd_out, Wd_in = self.hidden_dim, self.intermediate_dim
|
||||
|
||||
self.gate_up_proj = nn.Parameter(
|
||||
torch.zeros(self.num_experts, Wg_out, Wg_in, dtype=FP8Expert.dtype, device=device)
|
||||
)
|
||||
self.down_proj = nn.Parameter(
|
||||
torch.zeros(self.num_experts, Wd_out, Wd_in, dtype=FP8Expert.dtype, device=device)
|
||||
)
|
||||
|
||||
# Create inverse scale tiles only when using 1-byte types (fp8)
|
||||
if self.gate_up_proj.element_size() == 1:
|
||||
bo, bi = self.block_size
|
||||
|
||||
# gate_up tiles: ceil(Wg_out/bo) x ceil(Wg_in/bi)
|
||||
gu_scale_o = _ceil_div(Wg_out, bo)
|
||||
gu_scale_i = _ceil_div(Wg_in, bi)
|
||||
self.gate_up_proj_scales_inv = nn.Parameter(
|
||||
torch.zeros(self.num_experts, gu_scale_o, gu_scale_i, dtype=torch.float32, device=device)
|
||||
)
|
||||
|
||||
# down tiles: ceil(Wd_out/bo) x ceil(Wd_in/bi)
|
||||
dp_scale_o = _ceil_div(Wd_out, bo)
|
||||
dp_scale_i = _ceil_div(Wd_in, bi)
|
||||
self.down_proj_scales_inv = nn.Parameter(
|
||||
torch.zeros(self.num_experts, dp_scale_o, dp_scale_i, dtype=torch.float32, device=device)
|
||||
)
|
||||
else:
|
||||
# Match FP8Linear behavior when not using 1-byte weights
|
||||
self.register_parameter("gate_up_proj_scale_inv", None)
|
||||
self.register_parameter("down_proj_scale_inv", None)
|
||||
|
||||
# (Optional) bias per projection — many MoEs omit bias; keep None to match your FP8Linear default
|
||||
self.register_parameter("gate_up_bias", None)
|
||||
self.register_parameter("down_bias", None)
|
||||
|
||||
# Activation used in the MLP (same as your config / ACT2FN)
|
||||
# Keep a handle here; actual usage happens in forward of your MoE block
|
||||
self.act_fn = ACT2FN[config.hidden_act]
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
top_k_index: torch.Tensor,
|
||||
top_k_weights: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
final_hidden_states = torch.zeros_like(hidden_states)
|
||||
num_experts = top_k_weights.shape[1]
|
||||
with torch.no_grad():
|
||||
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1)
|
||||
expert_mask = expert_mask.permute(2, 1, 0)
|
||||
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
|
||||
|
||||
for expert_idx in expert_hit:
|
||||
expert_idx = expert_idx[0]
|
||||
if expert_idx == num_experts:
|
||||
continue
|
||||
_, token_idx = torch.where(expert_mask[expert_idx])
|
||||
current_state = hidden_states.index_select(0, token_idx)
|
||||
gate, up = self.linear(
|
||||
current_state, self.gate_up_proj[expert_idx], self.gate_up_proj_scales_inv[expert_idx]
|
||||
).chunk(2, dim=-1)
|
||||
current_hidden_states = self.act_fn(gate) * up
|
||||
current_hidden_states = self.linear(
|
||||
current_hidden_states, self.down_proj[expert_idx], self.down_proj_scales_inv[expert_idx]
|
||||
)
|
||||
|
||||
routing_weights = top_k_weights[token_idx, expert_idx].unsqueeze(-1)
|
||||
current_hidden_states = current_hidden_states * routing_weights.to(current_hidden_states.dtype)
|
||||
final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))
|
||||
|
||||
return final_hidden_states
|
||||
|
||||
def linear(self, input: torch.Tensor, weight: torch.Tensor, weight_scale_inv: torch.Tensor) -> torch.Tensor:
|
||||
if weight.element_size() > 1:
|
||||
return F.linear(input, weight, None)
|
||||
else:
|
||||
# Context manager used to switch among the available accelerators
|
||||
device_type = torch.accelerator.current_accelerator().type if is_torch_accelerator_available() else "cuda"
|
||||
torch_accelerator_module = getattr(torch, device_type, torch.cuda)
|
||||
with torch_accelerator_module.device(input.device):
|
||||
qinput, scale = act_quant(input, self.block_size[1])
|
||||
output = w8a8_block_fp8_matmul_triton(
|
||||
qinput,
|
||||
weight,
|
||||
scale,
|
||||
weight_scale_inv,
|
||||
self.block_size,
|
||||
output_dtype=input.dtype,
|
||||
)
|
||||
# Blocks the CPU until all accelerator operations on the specified device are complete. It is used to ensure that the results of the
|
||||
# preceding operations are ready before proceeding
|
||||
torch_accelerator_module.synchronize()
|
||||
return output.to(dtype=input.dtype)
|
||||
|
||||
|
||||
# TODO: we do need this.... but not recursive...
|
||||
def _replace_with_fp8_linear(
|
||||
model,
|
||||
tp_plan=None,
|
||||
@ -361,40 +497,48 @@ def _replace_with_fp8_linear(
|
||||
quantization_config=None,
|
||||
has_been_replaced=False,
|
||||
):
|
||||
"""Replace Linear layers with FP8Linear."""
|
||||
if current_key_name is None:
|
||||
current_key_name = []
|
||||
iterator = list(model.named_parameters()).copy()
|
||||
for name, empty_tensor in iterator:
|
||||
current_key_name = name
|
||||
name = name.rsplit(".", 1)[0] if "." in name else name
|
||||
module = model.get_submodule(name)
|
||||
|
||||
for name, module in model.named_children():
|
||||
current_key_name.append(name)
|
||||
|
||||
if isinstance(module, nn.Linear) and name not in (modules_to_not_convert or []):
|
||||
current_key_name_str = ".".join(current_key_name)
|
||||
if not any(key in current_key_name_str for key in (modules_to_not_convert or [])):
|
||||
with init_empty_weights():
|
||||
model._modules[name] = FP8Linear(
|
||||
in_features=module.in_features,
|
||||
out_features=module.out_features,
|
||||
bias=module.bias is not None,
|
||||
device=module.weight.device,
|
||||
dtype=module.weight.dtype,
|
||||
activation_scheme=quantization_config.activation_scheme,
|
||||
block_size=quantization_config.weight_block_size,
|
||||
current_key_name_str = re.sub(r"\d+", "*", current_key_name)
|
||||
if not any(key in current_key_name_str for key in (modules_to_not_convert or [])):
|
||||
with init_empty_weights():
|
||||
if (
|
||||
"gate_up_proj" in current_key_name
|
||||
or "down_proj" in current_key_name
|
||||
and "experts" in current_key_name
|
||||
): # Experts!
|
||||
in_features = empty_tensor.size(-2)
|
||||
out_features = empty_tensor.size(-1)
|
||||
model.set_submodule(
|
||||
name,
|
||||
FP8Expert(
|
||||
config=model.config,
|
||||
block_size=quantization_config.weight_block_size,
|
||||
device=empty_tensor.device,
|
||||
),
|
||||
)
|
||||
has_been_replaced = True
|
||||
# when changing a layer the TP PLAN for that layer should be updated. TODO
|
||||
|
||||
if len(list(module.children())) > 0:
|
||||
_, has_been_replaced = _replace_with_fp8_linear(
|
||||
module,
|
||||
tp_plan,
|
||||
modules_to_not_convert,
|
||||
current_key_name,
|
||||
quantization_config,
|
||||
has_been_replaced=has_been_replaced,
|
||||
)
|
||||
|
||||
current_key_name.pop(-1)
|
||||
elif isinstance(module, nn.Linear):
|
||||
in_features = module.in_features
|
||||
out_features = module.out_features
|
||||
model.set_submodule(
|
||||
name,
|
||||
FP8Linear(
|
||||
in_features=in_features,
|
||||
out_features=out_features,
|
||||
bias=module.bias is not None,
|
||||
device=module.weight.device,
|
||||
dtype=module.weight.dtype,
|
||||
activation_scheme=quantization_config.activation_scheme,
|
||||
block_size=quantization_config.weight_block_size,
|
||||
),
|
||||
)
|
||||
has_been_replaced = True
|
||||
# when changing a layer the TP PLAN for that layer should be updated. TODO
|
||||
|
||||
return model, has_been_replaced
|
||||
|
||||
@ -405,7 +549,7 @@ def replace_with_fp8_linear(
|
||||
quantization_config=None,
|
||||
):
|
||||
"""Helper function to replace model layers with FP8 versions."""
|
||||
modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert
|
||||
modules_to_not_convert += ["lm_head"]
|
||||
|
||||
if quantization_config.modules_to_not_convert is not None:
|
||||
modules_to_not_convert.extend(quantization_config.modules_to_not_convert)
|
||||
@ -424,3 +568,124 @@ def replace_with_fp8_linear(
|
||||
)
|
||||
|
||||
return model
|
||||
class Fp8Quantize(ConversionOps):
|
||||
"""
|
||||
A quantization operation that creates two tensors, weight and scale out of a weight.
|
||||
"""
|
||||
|
||||
reverse_op: type[ConversionOps]
|
||||
|
||||
def __init__(self, hf_quantizer):
|
||||
self.hf_quantizer = hf_quantizer
|
||||
self.reverse_op = Fp8Dequantize
|
||||
|
||||
def convert(self, input_dict: torch.Tensor, **kwargs) -> dict[str, torch.Tensor]:
|
||||
# Unpack single key/value (value may be wrapped in a list)
|
||||
target_keys, value = tuple(input_dict.items())[0]
|
||||
value = value[0] if isinstance(value, list) else value
|
||||
|
||||
# Resolve block size (support dict-like or attr-like quant_config)
|
||||
block_size = None
|
||||
if self.hf_quantizer.quantization_config is not None:
|
||||
if isinstance(self.hf_quantizer.quantization_config, dict):
|
||||
block_size = self.hf_quantizer.quantization_config.get("weight_block_size")
|
||||
else:
|
||||
block_size = getattr(self.hf_quantizer.quantization_config, "weight_block_size", None)
|
||||
if block_size is None:
|
||||
block_size = (value.shape[-2], value.shape[-1])
|
||||
|
||||
block_m, block_n = block_size
|
||||
rows, cols = value.shape[-2], value.shape[-1]
|
||||
|
||||
# Enforce exact tiling like your original
|
||||
if rows % block_m != 0 or cols % block_n != 0:
|
||||
raise ValueError(
|
||||
f"Matrix dimensions ({rows}, {cols}) must be divisible by block sizes ({block_m}, {block_n}). for {target_keys}"
|
||||
)
|
||||
|
||||
# Leading dims can be empty (2D) or include num_experts/... (3D+)
|
||||
leading_shape = value.shape[:-2]
|
||||
rows_tiles = rows // block_m
|
||||
cols_tiles = cols // block_n
|
||||
|
||||
original_shape = value.shape
|
||||
value_fp32 = value.to(torch.float32)
|
||||
|
||||
# Reshape to (..., rows_tiles, block_m, cols_tiles, block_n)
|
||||
reshaped = value_fp32.reshape(*leading_shape, rows_tiles, block_m, cols_tiles, block_n)
|
||||
|
||||
# Per-tile max-abs over the block dims
|
||||
# dims: block_m is at -3, block_n is at -1 after the reshape
|
||||
max_abs = reshaped.abs().amax(dim=(-3, -1))
|
||||
safe_max_abs = torch.where(max_abs > 0, max_abs, torch.ones_like(max_abs))
|
||||
|
||||
# Tile scale (we store inverse scale like your Linear: weight_scale_inv)
|
||||
scales = _FP8_MAX / safe_max_abs
|
||||
scales = torch.where(max_abs > 0, scales, torch.ones_like(scales)) # keep zeros stable
|
||||
|
||||
# Broadcast scales back over the block dims and quantize
|
||||
# max_abs/scales shape: (..., rows_tiles, cols_tiles)
|
||||
scales_broadcast = scales.unsqueeze(-1).unsqueeze(-3) # -> (..., rows_tiles, 1, cols_tiles, 1)
|
||||
scaled = reshaped * scales_broadcast
|
||||
|
||||
if _FP8_IS_INT:
|
||||
quantized = torch.clamp(scaled.round(), min=_FP8_MIN, max=_FP8_MAX).to(_FP8_DTYPE)
|
||||
else:
|
||||
quantized = torch.clamp(scaled, min=_FP8_MIN, max=_FP8_MAX).to(_FP8_DTYPE)
|
||||
|
||||
quantized = quantized.reshape(original_shape)
|
||||
|
||||
inv_scales = (1.0 / scales).to(torch.float32) # shape: (*leading, rows_tiles, cols_tiles)
|
||||
if target_keys.endswith("weight"):
|
||||
scale_key = target_keys.rsplit(".", 1)[0] + ".weight_scale_inv"
|
||||
else:
|
||||
scale_key = target_keys + "_scales_inv"
|
||||
|
||||
# Return both quantized weights and per-tile inverse scales (keeps leading dims, e.g., num_experts)
|
||||
return {
|
||||
target_keys: quantized,
|
||||
scale_key: inv_scales,
|
||||
}
|
||||
|
||||
class Fp8Dequantize(ConversionOps):
|
||||
"""Inverse operation of :class:`Fp8Quantize`. Takes a pair (weight, scale) and reconstructs the fp32 tensor."""
|
||||
|
||||
def __init__(self, block_size: Optional[tuple[int, int]] = None):
|
||||
self.block_size = block_size
|
||||
self.reverse_op = Fp8Quantize
|
||||
|
||||
def convert(
|
||||
self,
|
||||
value: Union[Sequence[torch.Tensor], dict[str, torch.Tensor]],
|
||||
*,
|
||||
context: dict[str, Any],
|
||||
) -> torch.Tensor:
|
||||
if isinstance(value, dict):
|
||||
tensors = list(value.values())
|
||||
else:
|
||||
tensors = list(value) if isinstance(value, Sequence) else [value]
|
||||
if len(tensors) != 2:
|
||||
raise ValueError("Fp8Dequantize expects exactly two tensors: quantized weights and scales.")
|
||||
quantized, scales = tensors
|
||||
if not isinstance(quantized, torch.Tensor) or not isinstance(scales, torch.Tensor):
|
||||
raise TypeError("Fp8Dequantize expects tensors as inputs.")
|
||||
|
||||
quantized_fp32 = quantized.to(torch.float32)
|
||||
rows, cols = quantized_fp32.shape[-2:]
|
||||
block_size = self.block_size
|
||||
if block_size is None:
|
||||
quant_config = context.get("quantization_config")
|
||||
block_size = getattr(quant_config, "weight_block_size", None)
|
||||
if block_size is None:
|
||||
block_size = (rows, cols)
|
||||
block_m, block_n = block_size
|
||||
if rows % block_m != 0 or cols % block_n != 0:
|
||||
raise ValueError(
|
||||
f"Matrix dimensions ({rows}, {cols}) must be divisible by block sizes ({block_m}, {block_n})."
|
||||
)
|
||||
|
||||
reshaped = quantized_fp32.reshape(-1, rows // block_m, block_m, cols // block_n, block_n)
|
||||
expanded_scales = scales.to(torch.float32).reshape(-1, rows // block_m, cols // block_n)
|
||||
expanded_scales = expanded_scales.unsqueeze(-1).unsqueeze(2)
|
||||
dequantized = reshaped * expanded_scales
|
||||
return dequantized.reshape(quantized_fp32.shape)
|
||||
|
||||
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ..utils import is_accelerate_available, is_torch_available, is_torch_xpu_available, logging
|
||||
from ..utils import is_accelerate_available, is_torch_available, logging
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
@ -114,9 +114,6 @@ def convert_moe_packed_tensors(
|
||||
if not blocks.is_cuda and torch.cuda.is_available():
|
||||
blocks = blocks.cuda()
|
||||
scales = scales.cuda()
|
||||
elif (blocks.device.type != "xpu") and is_torch_xpu_available():
|
||||
blocks = blocks.to("xpu")
|
||||
scales = scales.to("xpu")
|
||||
|
||||
scales = scales.to(torch.int32) - 127 # TODO that's because 128=2**7
|
||||
|
||||
@ -354,8 +351,6 @@ def dequantize(module, param_name, param_value, target_device, dq_param_name, **
|
||||
dequantized = convert_moe_packed_tensors(getattr(module, blocks_attr), getattr(module, scales_attr))
|
||||
if target_device == "cpu" and torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
elif target_device == "cpu" and is_torch_xpu_available():
|
||||
torch.xpu.empty_cache()
|
||||
setattr(module, proj, torch.nn.Parameter(dequantized.to(target_device)))
|
||||
delattr(module, blocks_attr)
|
||||
delattr(module, scales_attr)
|
||||
@ -400,7 +395,7 @@ def load_and_swizzle_mxfp4(module, param_name, param_value, target_device, trito
|
||||
else:
|
||||
blocks = blocks.reshape(local_experts, -1, module.intermediate_size // 2)
|
||||
if getattr(target_device, "type", target_device) == "cpu":
|
||||
target_device = torch.accelerator.current_accelerator().type if hasattr(torch, "accelerator") else "cuda"
|
||||
target_device = "cuda"
|
||||
blocks = blocks.to(target_device).contiguous()
|
||||
scales = scales.to(target_device).contiguous()
|
||||
with on_device(target_device):
|
||||
|
||||
@ -236,7 +236,7 @@ class PeftAdapterMixin:
|
||||
**adapter_kwargs,
|
||||
)
|
||||
peft_config.inference_mode = not is_trainable
|
||||
|
||||
# TODO: WE NEED TOO APPLY OUR DYNAMIC WEIGHT CONVERSION AT SOME POINT HERE!
|
||||
# Create and add fresh new adapters into the model.
|
||||
inject_adapter_in_model(peft_config, self, adapter_name, **peft_load_kwargs)
|
||||
|
||||
|
||||
@ -18,6 +18,7 @@ import operator
|
||||
import os
|
||||
import re
|
||||
from functools import partial, reduce
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@ -306,7 +307,7 @@ def repack_weights(
|
||||
return final_ordered_tensor
|
||||
|
||||
|
||||
def get_tensor_shard(param, empty_param, device_mesh, rank, dim):
|
||||
def get_tensor_shard(param, empty_param, device_mesh, rank, dim, tensor_idx: Optional[int] = None):
|
||||
"""
|
||||
Generalized tensor sharding across a multi-dimensional device mesh.
|
||||
Extract only the fraction of the parameter owned by the given `rank` when the parameter would have gone sharding at provided `dim`.
|
||||
@ -358,32 +359,57 @@ def get_tensor_shard(param, empty_param, device_mesh, rank, dim):
|
||||
rank (int): Global rank of the current process/device.
|
||||
dim (int): Dimension along which to shard the tensor.
|
||||
"""
|
||||
param_dim = empty_param.dim()
|
||||
|
||||
if dim < 0:
|
||||
dim = param_dim + dim
|
||||
if dim >= param_dim:
|
||||
raise ValueError(f"dim {dim} is out of bounds for tensor of dimension {param_dim}")
|
||||
|
||||
param_dim = empty_param.ndim
|
||||
# Flatten the mesh to get the total number of devices
|
||||
mesh_shape = device_mesh.shape
|
||||
world_size = reduce(operator.mul, mesh_shape)
|
||||
if dim < 0:
|
||||
dim = param_dim + dim
|
||||
if empty_param.dim() == 3 and dim == 1 and len(param.get_shape()) == 2:
|
||||
dim = 0
|
||||
elif empty_param.dim() == 3 and dim == 2 and len(param.get_shape()) == 2:
|
||||
dim = 0
|
||||
|
||||
shard_size = math.ceil(empty_param.size(dim) / world_size)
|
||||
start = rank * shard_size
|
||||
end = min(start + shard_size, empty_param.size(dim))
|
||||
|
||||
if dim >= param_dim:
|
||||
raise ValueError(f"dim {dim} is out of bounds for tensor of dimension {param_dim}")
|
||||
|
||||
if rank >= world_size:
|
||||
raise ValueError(f"Rank {rank} is out of bounds for mesh size {world_size}")
|
||||
|
||||
shard_size = math.ceil(empty_param.shape[dim] / world_size)
|
||||
start = rank * shard_size
|
||||
# we have the full tensor not 1 part of it.
|
||||
# in that case, we just assume that the weight was properly saved
|
||||
# and thus because we TP if the layer is colwise it should not use this. Layer should be packed_colwise
|
||||
# to inform that it needs to read form a packed tensor. It will also take care of the module list thingy.
|
||||
# here we take care of potential chunking / layer split / layer chunking.
|
||||
# The only "hard" case is? if we collect q,k,v -> merge it into qkv. In that case
|
||||
# actually we still shard dim=0 does not change
|
||||
# so only case is if the dim of the empty param is 3 and the shard dim is 0 -> we put the
|
||||
# tensor on a certain device (with the input tensor_index)
|
||||
dimensions = param.get_shape()
|
||||
|
||||
# Construct slicing index dynamically
|
||||
end = min(start + shard_size, empty_param.shape[dim])
|
||||
slice_indices = [slice(None)] * param_dim
|
||||
if start < empty_param.shape[dim]:
|
||||
if empty_param.dim() == 3 and dim == 0 and len(param.get_shape()) == 2:
|
||||
# special case we don't "shard" just send this entire tensor to the correct rank.
|
||||
if start <= tensor_idx < end:
|
||||
# this tensor does need to be materialized on this device:
|
||||
return param[:]
|
||||
else:
|
||||
return torch.empty([], dtype=torch.int64, device=rank)
|
||||
|
||||
slice_indices = [slice(None)] * len(param.get_shape())
|
||||
|
||||
if start < param.get_shape()[dim]:
|
||||
slice_indices[dim] = slice(start, end)
|
||||
return param[tuple(slice_indices)]
|
||||
dimensions = list(param.shape)
|
||||
param = param[tuple(slice_indices)]
|
||||
if isinstance(param, list): # TODO handle the modulelist case!
|
||||
param = [p[:] for p in param]
|
||||
return param
|
||||
|
||||
dimensions[dim] = 0
|
||||
return torch.empty(tuple(dimensions), dtype=torch.int64)
|
||||
return torch.empty(tuple(dimensions), dtype=torch.int64) # empty allocates memory....
|
||||
|
||||
|
||||
def distribute_module(
|
||||
@ -410,6 +436,19 @@ class TensorParallelLayer:
|
||||
"""
|
||||
|
||||
use_dtensor = True
|
||||
device_mesh = None
|
||||
rank = None
|
||||
|
||||
# Used to compare the shape of the original tensor
|
||||
empty_param = None
|
||||
|
||||
# Used to init the corresponding DTensor
|
||||
shard = None
|
||||
|
||||
def __init__(self, device_mesh=None, rank=None, empty_param=None):
|
||||
self.rank = rank
|
||||
self.device_mesh = device_mesh
|
||||
self.empty_param = empty_param
|
||||
|
||||
@staticmethod
|
||||
def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh): ...
|
||||
@ -439,12 +478,12 @@ class GatherParallel(TensorParallelLayer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
input_layouts: Placement | None = None,
|
||||
output_layouts: Placement | None = None,
|
||||
use_local_output: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
super().__init__(**kwargs)
|
||||
self.input_layouts = (input_layouts or Replicate(),)
|
||||
self.output_layouts = output_layouts
|
||||
self.desired_input_layouts = (Replicate(),)
|
||||
@ -465,6 +504,21 @@ class GatherParallel(TensorParallelLayer):
|
||||
dist.all_reduce(outputs[0], op=dist.ReduceOp.SUM, async_op=False)
|
||||
return outputs
|
||||
|
||||
def shard_tensor(
|
||||
self,
|
||||
param,
|
||||
param_type=None,
|
||||
param_casting_dtype=None,
|
||||
to_contiguous=None,
|
||||
rank=None,
|
||||
device_mesh=None,
|
||||
tensor_idx=None,
|
||||
):
|
||||
shard = [Replicate()]
|
||||
parameter = param[...].to(param_casting_dtype)
|
||||
self.shard = shard
|
||||
return parameter, shard
|
||||
|
||||
def prepare_module_tp(self, module: nn.Module, device_mesh) -> nn.Module:
|
||||
distribute_module(
|
||||
module,
|
||||
@ -493,6 +547,23 @@ class IsolatedParallel(TensorParallelLayer):
|
||||
# TODO: figure out dynamo support for instance method and switch this to instance method
|
||||
return outputs
|
||||
|
||||
def shard_tensor(
|
||||
self,
|
||||
param,
|
||||
param_type=None,
|
||||
param_casting_dtype=None,
|
||||
to_contiguous=None,
|
||||
rank=None,
|
||||
device_mesh=None,
|
||||
tensor_idx=None,
|
||||
):
|
||||
mesh = device_mesh or self.device_mesh
|
||||
parameter = param[...].to(param_casting_dtype)
|
||||
if mesh is not None:
|
||||
parameter = parameter / mesh.size()
|
||||
self.shard = None
|
||||
return parameter, None
|
||||
|
||||
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
|
||||
param = param[...].to(param_casting_dtype)
|
||||
if to_contiguous:
|
||||
@ -515,8 +586,8 @@ class ReplicateParallel(TensorParallelLayer):
|
||||
This class is used to replicate computation in a TP layer (used in SP regions when we don't use sequence parallelism for example)
|
||||
"""
|
||||
|
||||
def __init__(self, *, use_dtensor=True, use_local_output=True):
|
||||
super().__init__()
|
||||
def __init__(self, use_dtensor=True, use_local_output=True, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.input_layouts = (Replicate(),)
|
||||
self.output_layouts = (Replicate(),)
|
||||
self.desired_input_layouts = (Replicate(),)
|
||||
@ -537,12 +608,33 @@ class ReplicateParallel(TensorParallelLayer):
|
||||
def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
|
||||
return outputs.to_local() if use_local_output and isinstance(outputs, DTensor) else outputs
|
||||
|
||||
def shard_tensor(
|
||||
self,
|
||||
param,
|
||||
param_type=None,
|
||||
param_casting_dtype=None,
|
||||
to_contiguous=None,
|
||||
rank=None,
|
||||
device_mesh=None,
|
||||
tensor_idx=None,
|
||||
):
|
||||
parameter = param[...].to(param_casting_dtype)
|
||||
shard = [Replicate()]
|
||||
self.shard = shard
|
||||
return parameter, shard
|
||||
|
||||
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
|
||||
param = param[...].to(param_casting_dtype)
|
||||
if to_contiguous:
|
||||
param = param.contiguous()
|
||||
param = DTensor.from_local(param, device_mesh, [Replicate()], run_check=False)
|
||||
return param
|
||||
parameter, shard = self.shard_tensor(
|
||||
param,
|
||||
param_type=param_type,
|
||||
param_casting_dtype=param_casting_dtype,
|
||||
to_contiguous=to_contiguous,
|
||||
rank=rank,
|
||||
device_mesh=device_mesh,
|
||||
)
|
||||
if self.use_dtensor:
|
||||
parameter = DTensor.from_local(parameter, device_mesh, shard, run_check=False)
|
||||
return parameter
|
||||
|
||||
|
||||
class ColwiseParallel(TensorParallelLayer):
|
||||
@ -552,13 +644,13 @@ class ColwiseParallel(TensorParallelLayer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
input_layouts: Placement | None = None,
|
||||
output_layouts: Placement | None = None,
|
||||
use_local_output: bool = True,
|
||||
use_dtensor=True,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
super().__init__(**kwargs)
|
||||
self.input_layouts = (input_layouts or Replicate(),)
|
||||
self.output_layouts = (output_layouts or Shard(-1),)
|
||||
self.desired_input_layouts = (Replicate(),)
|
||||
@ -578,18 +670,34 @@ class ColwiseParallel(TensorParallelLayer):
|
||||
input_tensor = input_tensor.redistribute(placements=desired_input_layouts, async_op=False)
|
||||
return input_tensor
|
||||
|
||||
def shard_tensor(
|
||||
self,
|
||||
param,
|
||||
param_type=None,
|
||||
param_casting_dtype=None,
|
||||
to_contiguous=None,
|
||||
rank=None,
|
||||
device_mesh=None,
|
||||
tensor_idx=None,
|
||||
):
|
||||
device_mesh = self.device_mesh
|
||||
empty_param = self.empty_param
|
||||
rank = self.rank
|
||||
if param_type == "bias":
|
||||
parameter = get_tensor_shard(param, empty_param, device_mesh, rank, -1, tensor_idx)
|
||||
shard = [Shard(-1)]
|
||||
else:
|
||||
shard = [Shard(-2)]
|
||||
parameter = get_tensor_shard(param, empty_param, device_mesh, rank, -2, tensor_idx)
|
||||
parameter = parameter.to(param_casting_dtype)
|
||||
self.shard = shard
|
||||
return parameter, shard
|
||||
|
||||
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
|
||||
# colwise shard weight/bias to Shard(0), weight be Shard(-2) (0 if you have 1 dim only)
|
||||
# means Colwise as Linear is input * weight^T + bias, where
|
||||
# weight would become Shard(1)
|
||||
if param_type == "bias":
|
||||
parameter = get_tensor_shard(param, empty_param, device_mesh, rank, -1)
|
||||
shard = [Shard(-1)]
|
||||
else:
|
||||
shard = [Shard(-2)]
|
||||
parameter = get_tensor_shard(param, empty_param, device_mesh, rank, -2)
|
||||
|
||||
parameter = parameter.to(param_casting_dtype)
|
||||
parameter, shard = self.shard_tensor(param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh)
|
||||
if to_contiguous:
|
||||
parameter = parameter.contiguous()
|
||||
if self.use_dtensor:
|
||||
@ -608,6 +716,26 @@ class ColwiseParallel(TensorParallelLayer):
|
||||
|
||||
|
||||
class PackedColwiseParallel(ColwiseParallel):
|
||||
def shard_tensor(
|
||||
self,
|
||||
param,
|
||||
param_type=None,
|
||||
param_casting_dtype=None,
|
||||
to_contiguous=None,
|
||||
rank=None,
|
||||
device_mesh=None,
|
||||
tensor_idx=None,
|
||||
):
|
||||
device_mesh = device_mesh or self.device_mesh
|
||||
empty_param = self.empty_param
|
||||
rank = rank if rank is not None else self.rank
|
||||
return get_packed_weights(param, empty_param, device_mesh, rank, -2).to(param_casting_dtype), [Shard(-2)]
|
||||
|
||||
def create_nn_parameter(
|
||||
self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh
|
||||
):
|
||||
return nn.Parameter(param, requires_grad=param.is_floating_point())
|
||||
|
||||
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
|
||||
# colwise shard weight/bias to Shard(0), weight be Shard(-2) (0 if you have 1 dim only)
|
||||
# means Colwise as Linear is input * weight^T + bias, where
|
||||
@ -642,18 +770,41 @@ class RowwiseParallel(TensorParallelLayer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
input_layouts: Placement | None = None,
|
||||
output_layouts: Placement | None = None,
|
||||
use_local_output: bool = True,
|
||||
use_dtensor=True,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
super().__init__(**kwargs)
|
||||
self.input_layouts = (input_layouts or Shard(-1),)
|
||||
self.output_layouts = (output_layouts or Replicate(),)
|
||||
self.use_local_output = use_local_output
|
||||
self.use_dtensor = use_dtensor
|
||||
|
||||
def shard_tensor(
|
||||
self,
|
||||
param,
|
||||
param_type=None,
|
||||
param_casting_dtype=None,
|
||||
to_contiguous=None,
|
||||
rank=None,
|
||||
device_mesh=None,
|
||||
tensor_idx=None,
|
||||
):
|
||||
device_mesh = device_mesh or self.device_mesh
|
||||
empty_param = self.empty_param
|
||||
rank = rank if rank is not None else self.rank
|
||||
if param_type == "bias":
|
||||
shard = [Replicate()]
|
||||
parameter = param[...]
|
||||
else:
|
||||
parameter = get_tensor_shard(param, empty_param, device_mesh, rank, -1, tensor_idx=tensor_idx)
|
||||
shard = [Shard(-1)]
|
||||
parameter = parameter.to(param_casting_dtype)
|
||||
self.shard = shard
|
||||
return parameter, shard
|
||||
|
||||
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
|
||||
# Rowwise shard weight to Shard(1), bias to Replicate(), weight be Shard(1)
|
||||
# means Rowwise as nn.Linear is input * weight^T + bias, where
|
||||
@ -725,6 +876,21 @@ class RowwiseParallel(TensorParallelLayer):
|
||||
|
||||
|
||||
class PackedRowwiseParallel(RowwiseParallel):
|
||||
def shard_tensor(
|
||||
self,
|
||||
param,
|
||||
param_type=None,
|
||||
param_casting_dtype=None,
|
||||
to_contiguous=None,
|
||||
rank=None,
|
||||
device_mesh=None,
|
||||
tensor_idx=None,
|
||||
):
|
||||
device_mesh = device_mesh or self.device_mesh
|
||||
empty_param = self.empty_param
|
||||
rank = rank if rank is not None else self.rank
|
||||
return get_packed_weights(param, empty_param, device_mesh, rank, -1), [Shard(-1)]
|
||||
|
||||
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
|
||||
# colwise shard weight/bias to Shard(0), weight be Shard(-2) (0 if you have 1 dim only)
|
||||
# means Colwise as Linear is input * weight^T + bias, where
|
||||
@ -783,8 +949,8 @@ class SequenceParallel(TensorParallelLayer):
|
||||
to ensure that they are replicated.
|
||||
"""
|
||||
|
||||
def __init__(self, *, sequence_dim: int = 1, use_local_output: bool = False, use_dtensor=False):
|
||||
super().__init__()
|
||||
def __init__(self, sequence_dim: int = 1, use_local_output: bool = False, use_dtensor=False, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.input_layouts = (Replicate(),)
|
||||
self.desired_input_layouts = (Shard(1),)
|
||||
self.output_layouts = (Replicate(),)
|
||||
@ -793,6 +959,21 @@ class SequenceParallel(TensorParallelLayer):
|
||||
self.sequence_sharding = (Shard(sequence_dim),)
|
||||
self.use_local_output = use_local_output
|
||||
|
||||
def shard_tensor(
|
||||
self,
|
||||
param,
|
||||
param_type=None,
|
||||
param_casting_dtype=None,
|
||||
to_contiguous=None,
|
||||
rank=None,
|
||||
device_mesh=None,
|
||||
tensor_idx=None,
|
||||
):
|
||||
parameter = param[...].to(param_casting_dtype)
|
||||
shard = [Replicate()]
|
||||
self.shard = shard
|
||||
return parameter, shard
|
||||
|
||||
@staticmethod
|
||||
def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh):
|
||||
input_tensor = inputs[0]
|
||||
@ -827,10 +1008,34 @@ class GroupedGemmParallel(TensorParallelLayer):
|
||||
Applies Expert Parallelism to MoE experts by loading the correct experts on each device.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.use_dtensor = False
|
||||
|
||||
def shard_tensor(
|
||||
self,
|
||||
param,
|
||||
param_type=None,
|
||||
param_casting_dtype=None,
|
||||
to_contiguous=None,
|
||||
rank=None,
|
||||
device_mesh=None,
|
||||
tensor_idx=None,
|
||||
):
|
||||
empty_param = self.empty_param
|
||||
ep_rank = self.rank
|
||||
device_mesh = self.device_mesh
|
||||
|
||||
global_num_experts = empty_param.shape[0]
|
||||
if global_num_experts % device_mesh.size() != 0:
|
||||
raise ValueError(
|
||||
f"Global number of experts must be divisible by number of devices: {global_num_experts} % {device_mesh.size()} != 0"
|
||||
)
|
||||
local_num_experts = global_num_experts // device_mesh.size()
|
||||
parameter = param[ep_rank * local_num_experts : (ep_rank + 1) * local_num_experts].to(param_casting_dtype)
|
||||
self.shard = None
|
||||
return parameter, None
|
||||
|
||||
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
|
||||
ep_rank = rank
|
||||
global_num_experts = empty_param.shape[0]
|
||||
@ -851,8 +1056,8 @@ class RouterParallel(TensorParallelLayer):
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
self.use_dtensor = False
|
||||
|
||||
@staticmethod
|
||||
@ -917,6 +1122,20 @@ class RouterParallel(TensorParallelLayer):
|
||||
) # masking class for one hot
|
||||
return router_scores, router_indices
|
||||
|
||||
def shard_tensor(
|
||||
self,
|
||||
param,
|
||||
param_type=None,
|
||||
param_casting_dtype=None,
|
||||
to_contiguous=None,
|
||||
rank=None,
|
||||
device_mesh=None,
|
||||
tensor_idx=None,
|
||||
):
|
||||
parameter = param[...].to(param_casting_dtype)
|
||||
self.shard = None
|
||||
return parameter, None
|
||||
|
||||
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
|
||||
# TODO: i'd like for this to be the default
|
||||
param = param[...].to(param_casting_dtype)
|
||||
@ -1059,6 +1278,9 @@ def shard_and_distribute_module(
|
||||
if current_shard_plan is not None:
|
||||
try:
|
||||
tp_layer = ALL_PARALLEL_STYLES[current_shard_plan]
|
||||
tp_layer.empty_param = empty_param
|
||||
tp_layer.device_mesh = device_mesh
|
||||
tp_layer.rank = rank
|
||||
param = tp_layer.partition_tensor(
|
||||
param, empty_param, param_type, param_casting_dtype, is_contiguous, rank, device_mesh
|
||||
)
|
||||
|
||||
127
src/transformers/integrations/torchao.py
Normal file
127
src/transformers/integrations/torchao.py
Normal file
@ -0,0 +1,127 @@
|
||||
import importlib.metadata
|
||||
import re
|
||||
import types
|
||||
from collections import defaultdict
|
||||
from typing import Optional, Any
|
||||
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
from transformers.utils.import_utils import is_torchao_available
|
||||
from transformers.utils import logging
|
||||
|
||||
from ..core_model_loading import ConversionOps
|
||||
from ..quantizers.quantizers_utils import get_module_from_name
|
||||
|
||||
|
||||
if is_torchao_available():
|
||||
TORCHAO_VERSION = version.parse(importlib.metadata.version("torchao"))
|
||||
if version.parse(importlib.metadata.version("torchao")) >= version.parse("0.14.0"):
|
||||
from torchao.prototype.safetensors.safetensors_support import (
|
||||
unflatten_tensor_state_dict,
|
||||
)
|
||||
from torchao.prototype.safetensors.safetensors_utils import is_metadata_torchao
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class TorchAoQuantize(ConversionOps):
|
||||
def __init__(self, hf_quantizer):
|
||||
self.hf_quantizer = hf_quantizer
|
||||
|
||||
def convert(
|
||||
self, input_dict: dict[str, torch.Tensor], model: Optional[torch.nn.Module] = None, missing_keys=None, **kwargs
|
||||
) -> dict[str, torch.Tensor]:
|
||||
target_key, value = tuple(input_dict.items())[0]
|
||||
value = value[0] if isinstance(value, list) else value
|
||||
|
||||
full_name = target_key
|
||||
# update param name to get the weights instead of the quantized stats
|
||||
target_key = self.hf_quantizer.get_param_name(target_key)
|
||||
module, _ = get_module_from_name(model, target_key)
|
||||
|
||||
from torchao.quantization import quantize_
|
||||
|
||||
# Those are the pre quantized weights
|
||||
if ":" in target_key:
|
||||
target_key = target_key.rsplit(":", 1)[0]
|
||||
module, tensor_name = get_module_from_name(model, target_key)
|
||||
|
||||
if self.hf_quantizer.pre_quantized:
|
||||
# If it's a bias, no need to do anything special (except removing the ":_data" part of the key, but was
|
||||
# already done) - if it's unsafe-serialized (i.e. not safetensors), not need for anything either
|
||||
is_unsafe_serialization = ":" not in full_name
|
||||
if tensor_name == "bias" or is_unsafe_serialization:
|
||||
return {target_key: value}
|
||||
# Sanity check for the new serialization format
|
||||
elif not (TORCHAO_VERSION >= version.parse("0.14.0") and is_metadata_torchao(self.hf_quantizer.metadata)):
|
||||
raise ValueError("To use `safetensors` serialization, you should have `torchao>=0.14.0` installed")
|
||||
|
||||
# Save the states for later quantization when they are all gathered
|
||||
if not hasattr(self.hf_quantizer, "ao_params"):
|
||||
self.hf_quantizer.ao_params = defaultdict(dict)
|
||||
self.hf_quantizer.ao_params[target_key].update({full_name: value})
|
||||
|
||||
# We are ready for quantization in this case (we retrieved all the needed keys)
|
||||
if len(self.hf_quantizer.ao_params[target_key]) == len(self.hf_quantizer.weight_ao_keys):
|
||||
new_param = unflatten_tensor_state_dict(
|
||||
self.hf_quantizer.ao_params[target_key], self.hf_quantizer.metadata
|
||||
)[target_key]
|
||||
del self.hf_quantizer.ao_params[target_key]
|
||||
return {target_key: new_param}
|
||||
|
||||
# Add repr to the module
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
module.extra_repr = types.MethodType(self.hf_quantizer._linear_extra_repr, module)
|
||||
return {}
|
||||
else:
|
||||
module._parameters[tensor_name] = torch.nn.Parameter(value, requires_grad=value.requires_grad).to(
|
||||
value.device
|
||||
)
|
||||
# if we are quantizing tied parameters, to avoid tying the quantized weights
|
||||
# the correct order to do it is
|
||||
# 1. load the weight to model
|
||||
# 2. run tie_weights to populate the weights
|
||||
# 3. quantize
|
||||
mm: Any = model
|
||||
input_embed = mm.get_input_embeddings() if hasattr(mm, "get_input_embeddings") else None
|
||||
if self.hf_quantizer.quantization_config.untie_embedding_weights and id(module) == id(input_embed):
|
||||
if hasattr(mm, "tie_weights"):
|
||||
mm.tie_weights()
|
||||
if hasattr(mm, "config") and hasattr(mm.config, "get_text_config"):
|
||||
setattr(mm.config.get_text_config(decoder=True), "tie_word_embeddings", False)
|
||||
|
||||
# handle ModuleFqnToConfig, introduced in torchao 0.12.0+
|
||||
if self.hf_quantizer.quantization_config._get_ao_version() >= version.Version("0.12.0"):
|
||||
from torchao.quantization import ModuleFqnToConfig
|
||||
|
||||
config = self.hf_quantizer.quantization_config.get_apply_tensor_subclass()
|
||||
if isinstance(config, ModuleFqnToConfig):
|
||||
module_fqn, _ = target_key.rsplit(".", 1)
|
||||
c = None
|
||||
if module_fqn in config.module_fqn_to_config:
|
||||
assert not module_fqn.startswith("re:"), (
|
||||
"module fqn should not start with`re:`, which is used for specifying regex"
|
||||
)
|
||||
c = config.module_fqn_to_config[module_fqn]
|
||||
else:
|
||||
for maybe_module_fqn_pattern in config.module_fqn_to_config:
|
||||
if not maybe_module_fqn_pattern.startswith("re:"):
|
||||
continue
|
||||
elif re.fullmatch(maybe_module_fqn_pattern[3:], module_fqn):
|
||||
# we'll apply the config for first fully matched pattern
|
||||
c = config.module_fqn_to_config[maybe_module_fqn_pattern]
|
||||
break
|
||||
else:
|
||||
c = config.module_fqn_to_config.get("_default", None)
|
||||
|
||||
if c is not None:
|
||||
# filter_fn: not filtering out any modules
|
||||
quantize_(module, c, filter_fn=lambda x, fqn: True)
|
||||
module._is_hf_initialized = True
|
||||
missing_keys.discard(target_key)
|
||||
return {}
|
||||
quantize_(module, self.hf_quantizer.quantization_config.get_apply_tensor_subclass())
|
||||
module._is_hf_initialized = True
|
||||
missing_keys.discard(target_key)
|
||||
return {}
|
||||
@ -82,10 +82,8 @@ def causal_mask_function(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int)
|
||||
def bidirectional_mask_function(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
|
||||
"""
|
||||
This creates a full bidirectional mask.
|
||||
|
||||
NOTE: It is important to keep an index-based version for non-vmap expansion.
|
||||
"""
|
||||
return q_idx >= 0
|
||||
return q_idx.new_ones((), dtype=torch.bool)
|
||||
|
||||
|
||||
def sliding_window_overlay(sliding_window: int) -> Callable:
|
||||
@ -112,6 +110,18 @@ def chunked_overlay(chunk_size: int, left_padding: torch.Tensor) -> Callable:
|
||||
return inner_mask
|
||||
|
||||
|
||||
def _legacy_chunked_overlay(chunk_size: int) -> Callable:
|
||||
"""
|
||||
Same as the above function, but do not correctly account for left padding tokens.
|
||||
Only kept for compatibility with older torch versions (< 2.6).
|
||||
"""
|
||||
|
||||
def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
|
||||
return kv_idx // chunk_size == q_idx // chunk_size
|
||||
|
||||
return inner_mask
|
||||
|
||||
|
||||
def sliding_window_causal_mask_function(sliding_window: int) -> Callable:
|
||||
"""
|
||||
This return the mask_function function to create a sliding window mask.
|
||||
@ -123,6 +133,8 @@ def chunked_causal_mask_function(chunk_size: int, left_padding: torch.Tensor) ->
|
||||
"""
|
||||
This return the mask_function function to create a chunked attention mask.
|
||||
"""
|
||||
if not _is_torch_greater_or_equal_than_2_6:
|
||||
return and_masks(_legacy_chunked_overlay(chunk_size), causal_mask_function)
|
||||
return and_masks(chunked_overlay(chunk_size, left_padding), causal_mask_function)
|
||||
|
||||
|
||||
@ -163,17 +175,52 @@ def add_offsets_to_mask_function(mask_function: Callable, q_offset: int, kv_offs
|
||||
return inner_mask
|
||||
|
||||
|
||||
def _vmap_for_bhqkv(mask_function: Callable, bh_indices: bool = True) -> Callable:
|
||||
"""
|
||||
Used to vmap our mask_functions over the q_idx and kv_idx dimensions of the inputs. Optionally, vmap over
|
||||
the batch and head indices as well if `bh_indices=True`.
|
||||
Using vmap here allows us to keep the performance of vectorized ops, while having a single set of primitive
|
||||
functions between attention interfaces (i.e. between flex and sdpa/eager, FA2 being a bit different).
|
||||
|
||||
Args:
|
||||
mask_function (`Callable`):
|
||||
The mask_function to vmap.
|
||||
bh_indices (`bool`, optional):
|
||||
Whether to vmap over the batch and head indices as well, or only q and kv indices.
|
||||
|
||||
Returns:
|
||||
Callable: The vmapped function.
|
||||
"""
|
||||
# We vmap the function 2 times, broadcasting the [q_idx, kv_idx] dimensions
|
||||
dimensions = [(None, None, None, 0), (None, None, 0, None)]
|
||||
if bh_indices:
|
||||
# We extend broadcasting over the [batch_idx, head_idx] dimensions
|
||||
dimensions.extend([(None, 0, None, None), (0, None, None, None)])
|
||||
|
||||
for dims in dimensions:
|
||||
mask_function = torch.vmap(mask_function, in_dims=dims, out_dims=0)
|
||||
return mask_function
|
||||
|
||||
|
||||
def prepare_padding_mask(
|
||||
attention_mask: Optional[torch.Tensor], kv_length: int, kv_offset: int
|
||||
attention_mask: Optional[torch.Tensor], kv_length: int, kv_offset: int, _slice: bool = True
|
||||
) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
From the 2D attention mask, prepare the correct padding mask to use by potentially padding it.
|
||||
From the 2D attention mask, prepare the correct padding mask to use by potentially padding it, and slicing
|
||||
according to the `kv_offset` if `_slice` is `True`.
|
||||
"""
|
||||
local_padding_mask = attention_mask
|
||||
if attention_mask is not None:
|
||||
# Pad it if necessary
|
||||
if (padding_length := kv_length + kv_offset - attention_mask.shape[-1]) > 0:
|
||||
local_padding_mask = torch.nn.functional.pad(attention_mask, (0, padding_length))
|
||||
# For flex, we should not slice them, only use an offset
|
||||
if _slice:
|
||||
# Equivalent to: `local_padding_mask = attention_mask[:, kv_offset : kv_offset + kv_length]`,
|
||||
# but without data-dependent slicing (i.e. torch.compile friendly)
|
||||
mask_indices = torch.arange(kv_length, device=local_padding_mask.device)
|
||||
mask_indices += kv_offset
|
||||
local_padding_mask = local_padding_mask[:, mask_indices]
|
||||
return local_padding_mask
|
||||
|
||||
|
||||
@ -235,39 +282,7 @@ def _ignore_bidirectional_mask_sdpa(padding_mask: Optional[torch.Tensor]) -> boo
|
||||
return False
|
||||
|
||||
|
||||
def _vmap_expansion_sdpa(mask_function: Callable) -> Callable:
|
||||
"""
|
||||
Used to vmap our mask_functions over the all 4 dimensions (b_idx, h_idx, q_idx, kv_idx) of the inputs.
|
||||
Using vmap here allows us to keep the performance of vectorized ops, while having a single set of primitive
|
||||
functions between attention interfaces (i.e. between flex and sdpa/eager, FA2 being a bit different).
|
||||
"""
|
||||
# We vmap the function over all 4 dimensions, broadcasting [b_idx, h_idx, q_idx, kv_idx]
|
||||
dimensions = [(None, None, None, 0), (None, None, 0, None), (None, 0, None, None), (0, None, None, None)]
|
||||
for dims in dimensions:
|
||||
mask_function = torch.vmap(mask_function, in_dims=dims, out_dims=0)
|
||||
return mask_function
|
||||
|
||||
|
||||
def _non_vmap_expansion_sdpa(
|
||||
batch_indices: torch.Tensor, head_indices: torch.Tensor, q_indices: torch.Tensor, kv_indices: torch.Tensor
|
||||
):
|
||||
"""
|
||||
Used to broadcast our mask_functions over the all 4 dimensions (b_idx, h_idx, q_idx, kv_idx) of the inputs.
|
||||
Allows the usage of any index-based mask function without relying on vmap.
|
||||
|
||||
NOTE: This is limited to index based functions only and is not guaranteed to work otherwise.
|
||||
|
||||
Reference:
|
||||
- https://github.com/huggingface/optimum-onnx/blob/c123e8f4fab61b54a8e0e31ce74462bcacca576e/optimum/exporters/onnx/model_patcher.py#L362-L365
|
||||
"""
|
||||
batch_indices = batch_indices[:, None, None, None]
|
||||
head_indices = head_indices[None, :, None, None]
|
||||
q_indices = q_indices[None, None, :, None]
|
||||
kv_indices = kv_indices[None, None, None, :]
|
||||
return batch_indices, head_indices, q_indices, kv_indices
|
||||
|
||||
|
||||
def sdpa_mask(
|
||||
def sdpa_mask_recent_torch(
|
||||
batch_size: int,
|
||||
cache_position: torch.Tensor,
|
||||
kv_length: int,
|
||||
@ -277,8 +292,6 @@ def sdpa_mask(
|
||||
local_size: Optional[int] = None,
|
||||
allow_is_causal_skip: bool = True,
|
||||
allow_is_bidirectional_skip: bool = False,
|
||||
allow_torch_fix: bool = True,
|
||||
use_vmap: bool = False,
|
||||
**kwargs,
|
||||
) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
@ -311,12 +324,6 @@ def sdpa_mask(
|
||||
allow_is_bidirectional_skip (`bool`, optional):
|
||||
Whether to allow to return `None` for the mask under conditions where we do not have to add any bias,
|
||||
i.e. full attention without any padding. Default to `False`.
|
||||
allow_torch_fix (`bool`, optional):
|
||||
Whether to update the mask in case a query is not attending to any tokens, to solve a bug in torch's older
|
||||
versions. We need an arg to skip it when using eager. By default `True`.
|
||||
use_vmap (`bool`, optional):
|
||||
Whether to use `vmap` during the mask construction or not. Allows powerful custom patterns that may not be
|
||||
index-based (for the cost of speed performance). By default `False`.
|
||||
|
||||
|
||||
## Creating a simple causal mask:
|
||||
@ -384,8 +391,97 @@ def sdpa_mask(
|
||||
|
||||
"""
|
||||
q_length = cache_position.shape[0]
|
||||
# Potentially pad the 2D mask, and slice it correctly
|
||||
padding_mask = prepare_padding_mask(attention_mask, kv_length, kv_offset, _slice=False)
|
||||
|
||||
# Potentially pad the 2D mask
|
||||
# Under specific conditions, we can avoid materializing the mask
|
||||
# 1. Causal masks can rely on the `is_causal` argument
|
||||
# 2. Bidirectional do not need any further processing (no bias)
|
||||
if allow_is_causal_skip and _ignore_causal_mask_sdpa(padding_mask, q_length, kv_length, kv_offset, local_size):
|
||||
return None
|
||||
if allow_is_bidirectional_skip and _ignore_bidirectional_mask_sdpa(padding_mask):
|
||||
return None
|
||||
|
||||
# vmap can incur performance issues as reported in #41566 for bidirectional mask as we only need to expand the
|
||||
# padding mask. Thus, we allow early exit here if we do not detect any modification to the base mask function
|
||||
if mask_function is bidirectional_mask_function:
|
||||
if padding_mask is not None:
|
||||
# used for slicing without data-dependent slicing
|
||||
mask_indices = torch.arange(kv_length, device=cache_position.device) + kv_offset
|
||||
return padding_mask[:, None, None, mask_indices].expand(-1, -1, q_length, -1)
|
||||
else:
|
||||
return torch.ones(batch_size, 1, q_length, kv_length, dtype=torch.bool, device=cache_position.device)
|
||||
|
||||
# Similar to `kv_arange = torch.arange(start=kv_offset, end=kv_offset + kv_length, device=cache_position.device)`
|
||||
# but without data-dependent slicing (i.e. torch.compile friendly)
|
||||
kv_arange = torch.arange(kv_length, device=cache_position.device)
|
||||
kv_arange += kv_offset
|
||||
|
||||
# Potentially add the padding 2D mask
|
||||
if padding_mask is not None:
|
||||
mask_function = and_masks(mask_function, padding_mask_function(padding_mask))
|
||||
|
||||
batch_arange = torch.arange(batch_size, device=cache_position.device)
|
||||
head_arange = torch.arange(1, device=cache_position.device)
|
||||
# This creates the 4D mask easily. Note that we need this context manager as vmap cannot handle slicing a tensor from
|
||||
# scalar tensor (it internally calls `.item()` which vmap does not allow, but this context works around it
|
||||
# We don't need to add an offset to the mask_function either, as we vmap directly the correct indices for k and kv indices
|
||||
with TransformGetItemToIndex():
|
||||
causal_mask = _vmap_for_bhqkv(mask_function)(batch_arange, head_arange, cache_position, kv_arange)
|
||||
|
||||
return causal_mask
|
||||
|
||||
|
||||
def sdpa_mask_older_torch(
|
||||
batch_size: int,
|
||||
cache_position: torch.Tensor,
|
||||
kv_length: int,
|
||||
kv_offset: int = 0,
|
||||
mask_function: Callable = causal_mask_function,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
local_size: Optional[int] = None,
|
||||
allow_is_causal_skip: bool = True,
|
||||
allow_torch_fix: bool = True,
|
||||
allow_is_bidirectional_skip: bool = False,
|
||||
**kwargs,
|
||||
) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
NOTE: This function is only used when torch version is torch<2.5 - see `sdpa_mask_recent_torch` otherwise.
|
||||
|
||||
Create a 4D boolean mask of shape `(batch_size, 1, query_length, kv_length)` where a value of True indicates that
|
||||
the element should take part in the attention computation, and False that it should not.
|
||||
If `allow_torch_fix=True` (the default), rows corresponding to query tokens that do not attend
|
||||
to any other tokens (due to padding) will be fully attended to instead, in order to avoid `nan` propagation (this does
|
||||
not change the final result).
|
||||
|
||||
Args:
|
||||
batch_size (`int`):
|
||||
The batch size of the input sequence.
|
||||
cache_position (`torch.Tensor`):
|
||||
A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
|
||||
kv_length (`int`):
|
||||
The size that the key and value states will have during the attention computation.
|
||||
kv_offset (`int`, optional):
|
||||
An optional offset to indicate at which first position the key and values states will refer to.
|
||||
mask_function (`Callable`):
|
||||
The mask factory function describing the mask pattern.
|
||||
attention_mask (`torch.Tensor`, optional):
|
||||
The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length)
|
||||
local_size (`int`, optional):
|
||||
The size of the local attention, if we do not use full attention. This is used only if `allow_is_causal_skip=True`
|
||||
to try to skip mask creation if possible.
|
||||
allow_is_causal_skip (`bool`, optional):
|
||||
Whether to allow to return `None` for the mask under conditions where we can use the `is_causal` argument in
|
||||
`torch.sdpa` instead. Default to `True`.
|
||||
allow_torch_fix (`bool`, optional):
|
||||
Whether to update the mask in case a query is not attending to any tokens, to solve a bug in torch's older
|
||||
versions. We need an arg to skip it when using eager. By default `True`.
|
||||
allow_is_bidirectional_skip (`bool`, optional):
|
||||
Whether to allow to return `None` for the mask under conditions where we do not have to add any bias,
|
||||
i.e. full attention without any padding. Default to `False`.
|
||||
"""
|
||||
q_length = cache_position.shape[0]
|
||||
# Potentially pad the 2D mask, and slice it correctly
|
||||
padding_mask = prepare_padding_mask(attention_mask, kv_length, kv_offset)
|
||||
|
||||
# Under specific conditions, we can avoid materializing the mask
|
||||
@ -396,45 +492,38 @@ def sdpa_mask(
|
||||
if allow_is_bidirectional_skip and _ignore_bidirectional_mask_sdpa(padding_mask):
|
||||
return None
|
||||
|
||||
# Potentially add the padding 2D mask
|
||||
if padding_mask is not None:
|
||||
mask_function = and_masks(mask_function, padding_mask_function(padding_mask))
|
||||
# vmap can incur performance issues as reported in #41566 for bidirectional mask as we only need to expand the
|
||||
# padding mask. Thus, we allow early exit here if we do not detect any modification to the base mask function
|
||||
if mask_function is bidirectional_mask_function:
|
||||
if padding_mask is not None:
|
||||
return padding_mask[:, None, None, :].expand(-1, -1, q_length, -1)
|
||||
else:
|
||||
return torch.ones(batch_size, 1, q_length, kv_length, dtype=torch.bool, device=cache_position.device)
|
||||
|
||||
batch_arange = torch.arange(batch_size, device=cache_position.device)
|
||||
head_arange = torch.arange(1, device=cache_position.device)
|
||||
# Similar to `kv_arange = torch.arange(start=kv_offset, end=kv_offset + kv_length, device=cache_position.device)`
|
||||
# but without data-dependent slicing (i.e. torch.compile friendly)
|
||||
kv_arange = torch.arange(kv_length, device=cache_position.device) + kv_offset
|
||||
kv_arange = torch.arange(kv_length, device=cache_position.device)
|
||||
kv_arange += kv_offset
|
||||
|
||||
# Actual mask creation
|
||||
# Option 1: Fast non-vmap mask creation (default)
|
||||
if not use_vmap:
|
||||
# Apply mask function element-wise through broadcasting
|
||||
attention_mask = mask_function(*_non_vmap_expansion_sdpa(batch_arange, head_arange, cache_position, kv_arange))
|
||||
# Expand the mask to match batch size and query length if they weren't used in the mask function
|
||||
attention_mask = attention_mask.expand(batch_size, -1, q_length, kv_length)
|
||||
|
||||
# Option 2: Vmap mask creation (torch>=2.6 and custom patterns)
|
||||
elif _is_torch_greater_or_equal_than_2_6:
|
||||
# This creates the 4D mask easily. Note that we need this context manager as vmap cannot handle slicing a tensor from
|
||||
# scalar tensor (it internally calls `.item()` which vmap does not allow, but this context works around it
|
||||
# We don't need to add an offset to the mask_function either, as we vmap directly the correct indices for k and kv indices
|
||||
with TransformGetItemToIndex():
|
||||
attention_mask = _vmap_expansion_sdpa(mask_function)(batch_arange, head_arange, cache_position, kv_arange)
|
||||
|
||||
# Option 3: Error out since it indicates that the user did something custom, which they shouldn't have (torch<2.6)
|
||||
else:
|
||||
raise ValueError(
|
||||
"The vmap functionality for mask creation is only supported from torch>=2.6. "
|
||||
"Please update your torch version or use `use_vmap=False` with index-based masks."
|
||||
)
|
||||
# This creates the 4D mask easily. Note that we do not include vmap over the batch_idx dimension as well,
|
||||
# as vmap cannot handle slicing a tensor from scalar tensor (it internally calls `.item()` which vmap does not allow
|
||||
# However, in more recent version of Pytorch, a trick was introduced to handle it - which is the reason we have
|
||||
# `sdpa_mask_recent_torch`, as it allows more general `mask_function`
|
||||
causal_mask = _vmap_for_bhqkv(mask_function, bh_indices=False)(None, None, cache_position, kv_arange)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, -1, -1, -1)
|
||||
if padding_mask is not None:
|
||||
causal_mask = causal_mask * padding_mask[:, None, None, :]
|
||||
|
||||
# Due to a bug in versions of torch<2.5, we need to update the mask in case a query is not attending to any
|
||||
# tokens (due to padding). See details in https://github.com/pytorch/pytorch/issues/110213
|
||||
if not _is_torch_greater_or_equal_than_2_5 and allow_torch_fix:
|
||||
attention_mask = attention_mask | torch.all(~attention_mask, dim=-1, keepdim=True)
|
||||
causal_mask |= torch.all(~causal_mask, dim=-1, keepdim=True)
|
||||
return causal_mask
|
||||
|
||||
return attention_mask
|
||||
|
||||
# We use the version with newer torch whenever possible, as it is more general and can handle arbitrary mask functions
|
||||
# (especially mask_function indexing a tensor, such as the padding mask function)
|
||||
sdpa_mask = sdpa_mask_recent_torch if _is_torch_greater_or_equal_than_2_6 else sdpa_mask_older_torch
|
||||
|
||||
|
||||
def eager_mask(
|
||||
@ -445,7 +534,6 @@ def eager_mask(
|
||||
mask_function: Callable = causal_mask_function,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
dtype: torch.dtype = torch.float32,
|
||||
use_vmap: bool = False,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
@ -468,14 +556,10 @@ def eager_mask(
|
||||
The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length)
|
||||
dtype (`torch.dtype`, optional):
|
||||
The dtype to use for the mask. By default, `torch.float32`.
|
||||
use_vmap (`bool`, optional):
|
||||
Whether to use `vmap` during the mask construction or not. Allows powerful custom patterns that may not be
|
||||
index-based (for the cost of speed performance). By default `False`.
|
||||
"""
|
||||
# The masks for eager attention are simply boolean mask from sdpa, casted to 0 and -inf
|
||||
_ = kwargs.pop("allow_is_causal_skip", None)
|
||||
_ = kwargs.pop("allow_is_bidirectional_skip", None)
|
||||
_ = kwargs.pop("allow_torch_fix", None)
|
||||
mask = sdpa_mask(
|
||||
batch_size=batch_size,
|
||||
cache_position=cache_position,
|
||||
@ -486,7 +570,6 @@ def eager_mask(
|
||||
allow_is_causal_skip=False,
|
||||
allow_is_bidirectional_skip=False,
|
||||
allow_torch_fix=False,
|
||||
use_vmap=use_vmap,
|
||||
**kwargs,
|
||||
)
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
@ -572,7 +655,7 @@ def flex_attention_mask(
|
||||
if not _is_torch_greater_or_equal_than_2_6 and pad_len > 0:
|
||||
attention_mask = torch.nn.functional.pad(attention_mask, value=0, pad=(0, pad_len))
|
||||
|
||||
padding_mask = prepare_padding_mask(attention_mask, kv_length, kv_offset)
|
||||
padding_mask = prepare_padding_mask(attention_mask, kv_length, kv_offset, _slice=False)
|
||||
mask_function = and_masks(mask_function, padding_mask_function(padding_mask))
|
||||
|
||||
# Add the offsets on top (because flex interface only allows length, not start and end indices)
|
||||
@ -768,11 +851,6 @@ def create_causal_mask(
|
||||
mask_factory_function = causal_mask_function
|
||||
mask_interface = ALL_MASK_ATTENTION_FUNCTIONS[config._attn_implementation]
|
||||
|
||||
# Defaulting to using non-vmap based mask creations except when detecting
|
||||
# users passing custom mask functions (as we cannot guarantee that they
|
||||
# are properly index-based as required by our implementation).
|
||||
use_vmap = False
|
||||
|
||||
# Do not allow skip if we are compiling (this is to match BC)
|
||||
# TODO: cyril -> probably revisit and remove this, but a lot of tests rely on it
|
||||
if _is_torch_xpu_available:
|
||||
@ -789,16 +867,14 @@ def create_causal_mask(
|
||||
raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6")
|
||||
mask_factory_function = or_masks(mask_factory_function, or_mask_function)
|
||||
allow_is_causal_skip = False
|
||||
use_vmap = True
|
||||
if and_mask_function is not None:
|
||||
if not _is_torch_greater_or_equal_than_2_6:
|
||||
raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6")
|
||||
mask_factory_function = and_masks(mask_factory_function, and_mask_function)
|
||||
allow_is_causal_skip = False
|
||||
use_vmap = True
|
||||
|
||||
# If we detected packing format
|
||||
if packed_sequence_mask is not None:
|
||||
if packed_sequence_mask is not None and _is_torch_greater_or_equal_than_2_6:
|
||||
mask_factory_function = and_masks(mask_factory_function, packed_sequence_mask_function(packed_sequence_mask))
|
||||
allow_is_causal_skip = False
|
||||
|
||||
@ -813,7 +889,6 @@ def create_causal_mask(
|
||||
allow_is_causal_skip=allow_is_causal_skip, # additional kwarg for sdpa
|
||||
dtype=dtype, # Additional kwarg for eager
|
||||
config=config, # Pass the config as well, in case someone wants to easily have their own mask_interface
|
||||
use_vmap=use_vmap, # Short-circuit to non-vmap expansions for the mask
|
||||
)
|
||||
return causal_mask
|
||||
|
||||
@ -867,10 +942,6 @@ def create_bidirectional_mask(
|
||||
|
||||
# Allow skipping the mask creation except we have additional masking operators (and/or masks)
|
||||
allow_is_bidirectional_skip = True
|
||||
# Defaulting to using non-vmap based mask creations except when detecting
|
||||
# users passing custom mask functions (as we cannot guarantee that they
|
||||
# are properly index-based as required by our implementation).
|
||||
use_vmap = False
|
||||
|
||||
# Allow slight deviations from the base mask
|
||||
# Note that it is very important to apply this before any other deviations of the mask (such as packed sequence mask,
|
||||
@ -880,13 +951,11 @@ def create_bidirectional_mask(
|
||||
raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6")
|
||||
mask_factory_function = or_masks(mask_factory_function, or_mask_function)
|
||||
allow_is_bidirectional_skip = False
|
||||
use_vmap = True
|
||||
if and_mask_function is not None:
|
||||
if not _is_torch_greater_or_equal_than_2_6:
|
||||
raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6")
|
||||
mask_factory_function = and_masks(mask_factory_function, and_mask_function)
|
||||
allow_is_bidirectional_skip = False
|
||||
use_vmap = True
|
||||
|
||||
# We now create the mask
|
||||
attention_mask = mask_interface(
|
||||
@ -901,7 +970,6 @@ def create_bidirectional_mask(
|
||||
allow_is_bidirectional_skip=allow_is_bidirectional_skip,
|
||||
dtype=dtype, # Additional kwarg for eager
|
||||
config=config, # Pass the config as well, in case someone wants to easily have their own mask_interface
|
||||
use_vmap=use_vmap, # Short-circuit to non-vmap expansions for the mask
|
||||
)
|
||||
return attention_mask
|
||||
|
||||
@ -964,10 +1032,6 @@ def create_sliding_window_causal_mask(
|
||||
mask_factory_function = sliding_window_causal_mask_function(sliding_window)
|
||||
mask_interface = ALL_MASK_ATTENTION_FUNCTIONS[config._attn_implementation]
|
||||
|
||||
# Defaulting to using non-vmap based mask creations except when detecting
|
||||
# users passing custom mask functions (as we cannot guarantee that they
|
||||
# are properly index-based as required by our implementation).
|
||||
use_vmap = False
|
||||
# Do not allow skip if we are compiling (this is to match BC)
|
||||
# TODO: cyril -> probably revisit and remove this, but a lot of tests rely on it
|
||||
allow_is_causal_skip = not getattr(past_key_values, "is_compileable", False)
|
||||
@ -980,16 +1044,14 @@ def create_sliding_window_causal_mask(
|
||||
raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6")
|
||||
mask_factory_function = or_masks(mask_factory_function, or_mask_function)
|
||||
allow_is_causal_skip = False
|
||||
use_vmap = True
|
||||
if and_mask_function is not None:
|
||||
if not _is_torch_greater_or_equal_than_2_6:
|
||||
raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6")
|
||||
mask_factory_function = and_masks(mask_factory_function, and_mask_function)
|
||||
allow_is_causal_skip = False
|
||||
use_vmap = True
|
||||
|
||||
# If we detected packing format
|
||||
if packed_sequence_mask is not None:
|
||||
if packed_sequence_mask is not None and _is_torch_greater_or_equal_than_2_6:
|
||||
mask_factory_function = and_masks(mask_factory_function, packed_sequence_mask_function(packed_sequence_mask))
|
||||
allow_is_causal_skip = False
|
||||
|
||||
@ -1005,7 +1067,6 @@ def create_sliding_window_causal_mask(
|
||||
local_size=sliding_window, # Additional kwarg for sdpa
|
||||
dtype=dtype, # Additional kwarg for eager
|
||||
config=config, # Pass the config as well, in case someone wants to easily have their own mask_interface
|
||||
use_vmap=use_vmap, # Short-circuit to non-vmap expansions for the mask
|
||||
)
|
||||
return causal_mask
|
||||
|
||||
@ -1079,13 +1140,20 @@ def create_chunked_causal_mask(
|
||||
left_padding_tokens = (attention_mask.cumsum(dim=-1) == torch.zeros_like(attention_mask)).sum(dim=-1)
|
||||
else:
|
||||
left_padding_tokens = torch.zeros(batch_size, device=cache_position.device, dtype=int)
|
||||
# Raise a warning for older versions if the problematic left-padding situation arises
|
||||
if (
|
||||
not _is_torch_greater_or_equal_than_2_6
|
||||
and kv_length + kv_offset > chunk_size
|
||||
and (left_padding_tokens > 0).any()
|
||||
):
|
||||
logger.warning_once(
|
||||
"Due to limitations of your current torch version, we cannot correctly account for the left-padding "
|
||||
"when computing the chunked attention pattern. This will lead to a wrong attention mask for the padded "
|
||||
"sequences. Behavior will be undefined. Please upgrade to `torch>=2.6` to solve this issue."
|
||||
)
|
||||
mask_factory_function = chunked_causal_mask_function(chunk_size, left_padding_tokens)
|
||||
mask_interface = ALL_MASK_ATTENTION_FUNCTIONS[config._attn_implementation]
|
||||
|
||||
# Defaulting to using non-vmap based mask creations except when detecting
|
||||
# users passing custom mask functions (as we cannot guarantee that they
|
||||
# are properly index-based as required by our implementation).
|
||||
use_vmap = False
|
||||
# Do not allow skip if we are compiling (this is to match BC)
|
||||
# TODO: cyril -> probably revisit and remove this, but a lot of tests rely on it
|
||||
allow_is_causal_skip = not getattr(past_key_values, "is_compileable", False)
|
||||
@ -1098,16 +1166,14 @@ def create_chunked_causal_mask(
|
||||
raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6")
|
||||
mask_factory_function = or_masks(mask_factory_function, or_mask_function)
|
||||
allow_is_causal_skip = False
|
||||
use_vmap = True
|
||||
if and_mask_function is not None:
|
||||
if not _is_torch_greater_or_equal_than_2_6:
|
||||
raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6")
|
||||
mask_factory_function = and_masks(mask_factory_function, and_mask_function)
|
||||
allow_is_causal_skip = False
|
||||
use_vmap = True
|
||||
|
||||
# If we detected packing format
|
||||
if packed_sequence_mask is not None:
|
||||
if packed_sequence_mask is not None and _is_torch_greater_or_equal_than_2_6:
|
||||
mask_factory_function = and_masks(mask_factory_function, packed_sequence_mask_function(packed_sequence_mask))
|
||||
allow_is_causal_skip = False
|
||||
|
||||
@ -1123,7 +1189,6 @@ def create_chunked_causal_mask(
|
||||
local_size=chunk_size, # Additional kwarg for sdpa
|
||||
dtype=dtype, # Additional kwarg for eager
|
||||
config=config, # Pass the config as well, in case someone wants to easily have their own mask_interface
|
||||
use_vmap=use_vmap, # Short-circuit to non-vmap expansions for the mask
|
||||
)
|
||||
return causal_mask
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -406,13 +406,14 @@ class Aimv2PreTrainedModel(PreTrainedModel):
|
||||
_supports_flash_attn = True
|
||||
_supports_flex_attn = True
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
super()._init_weights(module)
|
||||
if hasattr(module, "logit_scale"):
|
||||
if isinstance(module.logit_scale, nn.Parameter):
|
||||
module.logit_scale.data.fill_(math.log(1 / 0.07))
|
||||
module.logit_scale.fill_(math.log(1 / 0.07))
|
||||
elif isinstance(module, Aimv2AttentionPoolingHead):
|
||||
module.cls_token.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.cls_token.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
|
||||
|
||||
@auto_docstring(
|
||||
|
||||
@ -449,13 +449,14 @@ class Aimv2PreTrainedModel(PreTrainedModel):
|
||||
_supports_flash_attn = True
|
||||
_supports_flex_attn = True
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
super()._init_weights(module)
|
||||
if hasattr(module, "logit_scale"):
|
||||
if isinstance(module.logit_scale, nn.Parameter):
|
||||
module.logit_scale.data.fill_(math.log(1 / 0.07))
|
||||
module.logit_scale.fill_(math.log(1 / 0.07))
|
||||
elif isinstance(module, Aimv2AttentionPoolingHead):
|
||||
module.cls_token.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.cls_token.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
|
||||
|
||||
@auto_docstring(
|
||||
|
||||
@ -302,21 +302,22 @@ class AlbertPreTrainedModel(PreTrainedModel):
|
||||
"attentions": AlbertAttention,
|
||||
}
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights."""
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
module.weight[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.zero_()
|
||||
module.weight.fill_(1.0)
|
||||
elif isinstance(module, AlbertMLMHead):
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -425,7 +426,10 @@ class AlbertModel(AlbertPreTrainedModel):
|
||||
"""
|
||||
)
|
||||
class AlbertForPreTraining(AlbertPreTrainedModel):
|
||||
_tied_weights_keys = ["predictions.decoder.bias", "predictions.decoder.weight"]
|
||||
_tied_weights_keys = {
|
||||
"predictions.decoder.weight": "albert.embeddings.word_embeddings.weight",
|
||||
"predictions.decoder.bias": "predictions.bias",
|
||||
}
|
||||
|
||||
def __init__(self, config: AlbertConfig):
|
||||
super().__init__(config)
|
||||
@ -525,7 +529,6 @@ class AlbertMLMHead(nn.Module):
|
||||
self.dense = nn.Linear(config.hidden_size, config.embedding_size)
|
||||
self.decoder = nn.Linear(config.embedding_size, config.vocab_size)
|
||||
self.activation = ACT2FN[config.hidden_act]
|
||||
self.decoder.bias = self.bias
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.dense(hidden_states)
|
||||
@ -537,14 +540,6 @@ class AlbertMLMHead(nn.Module):
|
||||
|
||||
return prediction_scores
|
||||
|
||||
def _tie_weights(self) -> None:
|
||||
# For accelerate compatibility and to not break backward compatibility
|
||||
if self.decoder.bias.device.type == "meta":
|
||||
self.decoder.bias = self.bias
|
||||
else:
|
||||
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
|
||||
self.bias = self.decoder.bias
|
||||
|
||||
|
||||
class AlbertSOPHead(nn.Module):
|
||||
def __init__(self, config: AlbertConfig):
|
||||
@ -561,7 +556,10 @@ class AlbertSOPHead(nn.Module):
|
||||
|
||||
@auto_docstring
|
||||
class AlbertForMaskedLM(AlbertPreTrainedModel):
|
||||
_tied_weights_keys = ["predictions.decoder.bias", "predictions.decoder.weight"]
|
||||
_tied_weights_keys = {
|
||||
"predictions.decoder.weight": "albert.embeddings.word_embeddings.weight",
|
||||
"predictions.decoder.bias": "predictions.bias",
|
||||
}
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
@ -823,24 +823,25 @@ class AlignPreTrainedModel(PreTrainedModel):
|
||||
input_modalities = ["image", "text"]
|
||||
supports_gradient_checkpointing = True
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module: nn.Module):
|
||||
"""Initialize the weights"""
|
||||
std = self.config.initializer_range
|
||||
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
module.weight.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
elif isinstance(module, AlignModel):
|
||||
nn.init.xavier_uniform_(module.text_projection.weight)
|
||||
module.text_projection.bias.data.zero_()
|
||||
module.temperature.data.fill_(self.config.temperature_init_value)
|
||||
module.text_projection.bias.zero_()
|
||||
module.temperature.fill_(self.config.temperature_init_value)
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
module.weight.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
module.weight[module.padding_idx].zero_()
|
||||
if isinstance(module, (nn.LayerNorm, nn.BatchNorm2d)):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.zero_()
|
||||
module.weight.fill_(1.0)
|
||||
|
||||
|
||||
@auto_docstring(
|
||||
|
||||
@ -59,6 +59,9 @@ class AlignProcessor(ProcessorMixin):
|
||||
|
||||
"""
|
||||
|
||||
attributes = ["image_processor", "tokenizer"]
|
||||
image_processor_class = "EfficientNetImageProcessor"
|
||||
tokenizer_class = ("BertTokenizer", "BertTokenizerFast")
|
||||
valid_processor_kwargs = AlignProcessorKwargs
|
||||
|
||||
def __init__(self, image_processor, tokenizer):
|
||||
|
||||
@ -770,6 +770,7 @@ class AltCLIPPreTrainedModel(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_module = []
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
factor = self.config.initializer_factor
|
||||
@ -797,23 +798,21 @@ class AltCLIPPreTrainedModel(PreTrainedModel):
|
||||
module.text_projection.weight,
|
||||
std=module.text_embed_dim**-0.5 * self.config.initializer_factor,
|
||||
)
|
||||
module.text_projection._is_hf_initialized = True
|
||||
nn.init.normal_(
|
||||
module.visual_projection.weight,
|
||||
std=module.vision_embed_dim**-0.5 * self.config.initializer_factor,
|
||||
)
|
||||
module.visual_projection._is_hf_initialized = True
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.zero_()
|
||||
module.weight.fill_(1.0)
|
||||
elif isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_factor)
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_factor)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_factor)
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_factor)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
module.weight[module.padding_idx].zero_()
|
||||
|
||||
|
||||
class AltCLIPVisionTransformer(nn.Module):
|
||||
|
||||
@ -35,6 +35,10 @@ class AltCLIPProcessor(ProcessorMixin):
|
||||
The tokenizer is a required input.
|
||||
"""
|
||||
|
||||
attributes = ["image_processor", "tokenizer"]
|
||||
image_processor_class = ("CLIPImageProcessor", "CLIPImageProcessorFast")
|
||||
tokenizer_class = ("XLMRobertaTokenizer", "XLMRobertaTokenizerFast")
|
||||
|
||||
@deprecate_kwarg(old_name="feature_extractor", version="5.0.0", new_name="image_processor")
|
||||
def __init__(self, image_processor=None, tokenizer=None):
|
||||
super().__init__(image_processor, tokenizer)
|
||||
|
||||
@ -429,7 +429,7 @@ class ApertusModel(ApertusPreTrainedModel):
|
||||
|
||||
@auto_docstring
|
||||
class ApertusForCausalLM(ApertusPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
|
||||
_tp_plan = {"lm_head": "colwise_rep"}
|
||||
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
||||
|
||||
|
||||
@ -434,7 +434,7 @@ class ArceeModel(ArceePreTrainedModel):
|
||||
|
||||
@auto_docstring(checkpoint="arcee-ai/AFM-4.5B")
|
||||
class ArceeForCausalLM(ArceePreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
|
||||
_tp_plan = {"lm_head": "colwise_rep"}
|
||||
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
||||
|
||||
|
||||
@ -585,10 +585,11 @@ class AriaTextPreTrainedModel(PreTrainedModel):
|
||||
"attentions": AriaTextAttention,
|
||||
}
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
super()._init_weights(module)
|
||||
if isinstance(module, AriaGroupedExpertsGemm):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
|
||||
|
||||
@auto_docstring
|
||||
@ -608,6 +609,7 @@ class AriaPreTrainedModel(PreTrainedModel):
|
||||
"attentions": AriaTextAttention,
|
||||
}
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
super()._init_weights(module)
|
||||
if isinstance(module, AriaProjector):
|
||||
@ -760,7 +762,7 @@ class AriaTextModel(AriaTextPreTrainedModel):
|
||||
|
||||
@auto_docstring
|
||||
class AriaTextForCausalLM(AriaTextPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
|
||||
_tp_plan = {"lm_head": "colwise_rep"}
|
||||
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
||||
|
||||
@ -1053,7 +1055,7 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
|
||||
"^multi_modal_projector": "model.multi_modal_projector",
|
||||
"^language_model.lm_head": "lm_head",
|
||||
}
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"}
|
||||
|
||||
def __init__(self, config: AriaConfig):
|
||||
super().__init__(config)
|
||||
|
||||
@ -906,6 +906,10 @@ class AriaProcessor(ProcessorMixin):
|
||||
A dictionary indicating size conversions for images.
|
||||
"""
|
||||
|
||||
attributes = ["image_processor", "tokenizer"]
|
||||
image_processor_class = "AriaImageProcessor"
|
||||
tokenizer_class = "AutoTokenizer"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_processor=None,
|
||||
@ -1187,10 +1191,11 @@ class AriaTextPreTrainedModel(PreTrainedModel):
|
||||
"attentions": AriaTextAttention,
|
||||
}
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
super()._init_weights(module)
|
||||
if isinstance(module, AriaGroupedExpertsGemm):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
|
||||
|
||||
class AriaPreTrainedModel(LlamaPreTrainedModel):
|
||||
@ -1199,6 +1204,7 @@ class AriaPreTrainedModel(LlamaPreTrainedModel):
|
||||
_can_compile_fullgraph = False # MoE models don't work with torch.compile (dynamic slicing)
|
||||
_supports_attention_backend = True
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
PreTrainedModel._init_weights(self, module)
|
||||
if isinstance(module, AriaProjector):
|
||||
@ -1216,7 +1222,7 @@ class AriaTextModel(LlamaModel):
|
||||
|
||||
|
||||
class AriaTextForCausalLM(AriaTextPreTrainedModel, LlamaForCausalLM):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
|
||||
|
||||
def __init__(self, config: AriaTextConfig):
|
||||
super().__init__(config)
|
||||
@ -1355,6 +1361,8 @@ class AriaModel(LlavaModel):
|
||||
"""
|
||||
)
|
||||
class AriaForConditionalGeneration(LlavaForConditionalGeneration):
|
||||
_tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"}
|
||||
|
||||
def get_image_features(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
|
||||
@ -67,6 +67,10 @@ class AriaProcessor(ProcessorMixin):
|
||||
A dictionary indicating size conversions for images.
|
||||
"""
|
||||
|
||||
attributes = ["image_processor", "tokenizer"]
|
||||
image_processor_class = "AriaImageProcessor"
|
||||
tokenizer_class = "AutoTokenizer"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_processor=None,
|
||||
|
||||
@ -300,23 +300,26 @@ class ASTPreTrainedModel(PreTrainedModel):
|
||||
"attentions": ASTSelfAttention,
|
||||
}
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
|
||||
"""Initialize the weights"""
|
||||
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
||||
# Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
|
||||
# `trunc_normal_cpu` not implemented in `half` issues
|
||||
module.weight.data = nn.init.trunc_normal_(
|
||||
module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range
|
||||
).to(module.weight.dtype)
|
||||
module.weight.copy_(
|
||||
nn.init.trunc_normal_(module.weight.to(torch.float32), mean=0.0, std=self.config.initializer_range).to(
|
||||
module.weight.dtype
|
||||
)
|
||||
)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.zero_()
|
||||
module.weight.fill_(1.0)
|
||||
elif isinstance(module, ASTEmbeddings):
|
||||
module.cls_token.data.zero_()
|
||||
module.position_embeddings.data.zero_()
|
||||
module.distillation_token.data.zero_()
|
||||
module.cls_token.zero_()
|
||||
module.position_embeddings.zero_()
|
||||
module.distillation_token.zero_()
|
||||
|
||||
|
||||
@auto_docstring
|
||||
|
||||
@ -223,7 +223,6 @@ CONFIG_MAPPING_NAMES = OrderedDict[str, str](
|
||||
("layoutlm", "LayoutLMConfig"),
|
||||
("layoutlmv2", "LayoutLMv2Config"),
|
||||
("layoutlmv3", "LayoutLMv3Config"),
|
||||
("layoutxlm", "LayoutLMv2Config"),
|
||||
("led", "LEDConfig"),
|
||||
("levit", "LevitConfig"),
|
||||
("lfm2", "Lfm2Config"),
|
||||
|
||||
@ -41,7 +41,6 @@ FEATURE_EXTRACTOR_MAPPING_NAMES = OrderedDict(
|
||||
("audio-spectrogram-transformer", "ASTFeatureExtractor"),
|
||||
("clap", "ClapFeatureExtractor"),
|
||||
("clvp", "ClvpFeatureExtractor"),
|
||||
("csm", "EncodecFeatureExtractor"),
|
||||
("dac", "DacFeatureExtractor"),
|
||||
("data2vec-audio", "Wav2Vec2FeatureExtractor"),
|
||||
("dia", "DiaFeatureExtractor"),
|
||||
@ -50,20 +49,14 @@ FEATURE_EXTRACTOR_MAPPING_NAMES = OrderedDict(
|
||||
("granite_speech", "GraniteSpeechFeatureExtractor"),
|
||||
("hubert", "Wav2Vec2FeatureExtractor"),
|
||||
("kyutai_speech_to_text", "KyutaiSpeechToTextFeatureExtractor"),
|
||||
("markuplm", "MarkupLMFeatureExtractor"),
|
||||
("mctct", "MCTCTFeatureExtractor"),
|
||||
("mimi", "EncodecFeatureExtractor"),
|
||||
("moonshine", "Wav2Vec2FeatureExtractor"),
|
||||
("moshi", "EncodecFeatureExtractor"),
|
||||
("musicgen", "EncodecFeatureExtractor"),
|
||||
("musicgen_melody", "MusicgenMelodyFeatureExtractor"),
|
||||
("parakeet_ctc", "ParakeetFeatureExtractor"),
|
||||
("parakeet_encoder", "ParakeetFeatureExtractor"),
|
||||
("phi4_multimodal", "Phi4MultimodalFeatureExtractor"),
|
||||
("pop2piano", "Pop2PianoFeatureExtractor"),
|
||||
("qwen2_5_omni", "WhisperFeatureExtractor"),
|
||||
("qwen2_audio", "WhisperFeatureExtractor"),
|
||||
("qwen3_omni_moe", "WhisperFeatureExtractor"),
|
||||
("seamless_m4t", "SeamlessM4TFeatureExtractor"),
|
||||
("seamless_m4t_v2", "SeamlessM4TFeatureExtractor"),
|
||||
("sew", "Wav2Vec2FeatureExtractor"),
|
||||
@ -73,7 +66,6 @@ FEATURE_EXTRACTOR_MAPPING_NAMES = OrderedDict(
|
||||
("unispeech", "Wav2Vec2FeatureExtractor"),
|
||||
("unispeech-sat", "Wav2Vec2FeatureExtractor"),
|
||||
("univnet", "UnivNetFeatureExtractor"),
|
||||
("voxtral", "WhisperFeatureExtractor"),
|
||||
("wav2vec2", "Wav2Vec2FeatureExtractor"),
|
||||
("wav2vec2-bert", "Wav2Vec2FeatureExtractor"),
|
||||
("wav2vec2-conformer", "Wav2Vec2FeatureExtractor"),
|
||||
|
||||
@ -62,9 +62,7 @@ else:
|
||||
("aimv2", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
|
||||
("aimv2_vision_model", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
|
||||
("align", ("EfficientNetImageProcessor", "EfficientNetImageProcessorFast")),
|
||||
("altclip", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
|
||||
("aria", ("AriaImageProcessor", None)),
|
||||
("aya_vision", ("GotOcr2ImageProcessor", "GotOcr2ImageProcessorFast")),
|
||||
("beit", ("BeitImageProcessor", "BeitImageProcessorFast")),
|
||||
("bit", ("BitImageProcessor", "BitImageProcessorFast")),
|
||||
("blip", ("BlipImageProcessor", "BlipImageProcessorFast")),
|
||||
@ -75,8 +73,6 @@ else:
|
||||
("clip", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
|
||||
("clipseg", ("ViTImageProcessor", "ViTImageProcessorFast")),
|
||||
("cohere2_vision", (None, "Cohere2VisionImageProcessorFast")),
|
||||
("colpali", ("SiglipImageProcessor", "SiglipImageProcessorFast")),
|
||||
("colqwen2", ("Qwen2VLImageProcessor", "Qwen2VLImageProcessorFast")),
|
||||
("conditional_detr", ("ConditionalDetrImageProcessor", "ConditionalDetrImageProcessorFast")),
|
||||
("convnext", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")),
|
||||
("convnextv2", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")),
|
||||
@ -99,10 +95,8 @@ else:
|
||||
("efficientformer", ("EfficientFormerImageProcessor", None)),
|
||||
("efficientloftr", ("EfficientLoFTRImageProcessor", "EfficientLoFTRImageProcessorFast")),
|
||||
("efficientnet", ("EfficientNetImageProcessor", "EfficientNetImageProcessorFast")),
|
||||
("emu3", ("Emu3ImageProcessor", None)),
|
||||
("eomt", ("EomtImageProcessor", "EomtImageProcessorFast")),
|
||||
("flava", ("FlavaImageProcessor", "FlavaImageProcessorFast")),
|
||||
("florence2", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
|
||||
("focalnet", ("BitImageProcessor", "BitImageProcessorFast")),
|
||||
("fuyu", ("FuyuImageProcessor", "FuyuImageProcessorFast")),
|
||||
("gemma3", ("Gemma3ImageProcessor", "Gemma3ImageProcessorFast")),
|
||||
@ -120,13 +114,11 @@ else:
|
||||
("ijepa", ("ViTImageProcessor", "ViTImageProcessorFast")),
|
||||
("imagegpt", ("ImageGPTImageProcessor", "ImageGPTImageProcessorFast")),
|
||||
("instructblip", ("BlipImageProcessor", "BlipImageProcessorFast")),
|
||||
("internvl", ("GotOcr2ImageProcessor", "GotOcr2ImageProcessorFast")),
|
||||
("janus", ("JanusImageProcessor", "JanusImageProcessorFast")),
|
||||
("kosmos-2", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
|
||||
("kosmos-2.5", ("Kosmos2_5ImageProcessor", "Kosmos2_5ImageProcessorFast")),
|
||||
("layoutlmv2", ("LayoutLMv2ImageProcessor", "LayoutLMv2ImageProcessorFast")),
|
||||
("layoutlmv3", ("LayoutLMv3ImageProcessor", "LayoutLMv3ImageProcessorFast")),
|
||||
("layoutxlm", ("LayoutLMv2ImageProcessor", "LayoutLMv2ImageProcessor")),
|
||||
("levit", ("LevitImageProcessor", "LevitImageProcessorFast")),
|
||||
("lfm2_vl", (None, "Lfm2VlImageProcessorFast")),
|
||||
("lightglue", ("LightGlueImageProcessor", "LightGlueImageProcessorFast")),
|
||||
@ -149,7 +141,6 @@ else:
|
||||
("mobilevitv2", ("MobileViTImageProcessor", "MobileViTImageProcessorFast")),
|
||||
("nat", ("ViTImageProcessor", "ViTImageProcessorFast")),
|
||||
("nougat", ("NougatImageProcessor", "NougatImageProcessorFast")),
|
||||
("omdet-turbo", ("DetrImageProcessor", "DetrImageProcessorFast")),
|
||||
("oneformer", ("OneFormerImageProcessor", "OneFormerImageProcessorFast")),
|
||||
("ovis2", ("Ovis2ImageProcessor", "Ovis2ImageProcessorFast")),
|
||||
("owlv2", ("Owlv2ImageProcessor", "Owlv2ImageProcessorFast")),
|
||||
@ -164,17 +155,14 @@ else:
|
||||
("prompt_depth_anything", ("PromptDepthAnythingImageProcessor", "PromptDepthAnythingImageProcessorFast")),
|
||||
("pvt", ("PvtImageProcessor", "PvtImageProcessorFast")),
|
||||
("pvt_v2", ("PvtImageProcessor", "PvtImageProcessorFast")),
|
||||
("qwen2_5_omni", ("Qwen2VLImageProcessor", "Qwen2VLImageProcessorFast")),
|
||||
("qwen2_5_vl", ("Qwen2VLImageProcessor", "Qwen2VLImageProcessorFast")),
|
||||
("qwen2_vl", ("Qwen2VLImageProcessor", "Qwen2VLImageProcessorFast")),
|
||||
("qwen3_omni_moe", ("Qwen2VLImageProcessor", "Qwen2VLImageProcessorFast")),
|
||||
("qwen3_vl", ("Qwen2VLImageProcessor", "Qwen2VLImageProcessorFast")),
|
||||
("regnet", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")),
|
||||
("resnet", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")),
|
||||
("rt_detr", ("RTDetrImageProcessor", "RTDetrImageProcessorFast")),
|
||||
("sam", ("SamImageProcessor", "SamImageProcessorFast")),
|
||||
("sam2", (None, "Sam2ImageProcessorFast")),
|
||||
("sam2_video", (None, "Sam2ImageProcessorFast")),
|
||||
("sam_hq", ("SamImageProcessor", "SamImageProcessorFast")),
|
||||
("segformer", ("SegformerImageProcessor", "SegformerImageProcessorFast")),
|
||||
("seggpt", ("SegGptImageProcessor", None)),
|
||||
@ -192,14 +180,12 @@ else:
|
||||
("textnet", ("TextNetImageProcessor", "TextNetImageProcessorFast")),
|
||||
("timesformer", ("VideoMAEImageProcessor", None)),
|
||||
("timm_wrapper", ("TimmWrapperImageProcessor", None)),
|
||||
("trocr", ("ViTImageProcessor", "ViTImageProcessorFast")),
|
||||
("tvlt", ("TvltImageProcessor", None)),
|
||||
("tvp", ("TvpImageProcessor", "TvpImageProcessorFast")),
|
||||
("udop", ("LayoutLMv3ImageProcessor", "LayoutLMv3ImageProcessorFast")),
|
||||
("upernet", ("SegformerImageProcessor", "SegformerImageProcessorFast")),
|
||||
("van", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")),
|
||||
("video_llama_3", ("VideoLlama3ImageProcessor", "VideoLlama3ImageProcessorFast")),
|
||||
("video_llava", ("VideoLlavaImageProcessor", None)),
|
||||
("videomae", ("VideoMAEImageProcessor", None)),
|
||||
("vilt", ("ViltImageProcessor", "ViltImageProcessorFast")),
|
||||
("vipllava", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
|
||||
@ -538,9 +524,10 @@ class AutoImageProcessor:
|
||||
)
|
||||
use_fast = False
|
||||
if use_fast:
|
||||
# Check if the fast image processor class exists
|
||||
image_processor_class_fast = get_image_processor_class_from_name(image_processor_type)
|
||||
if image_processor_class_fast is None:
|
||||
for image_processors in IMAGE_PROCESSOR_MAPPING_NAMES.values():
|
||||
if image_processor_type in image_processors:
|
||||
break
|
||||
else:
|
||||
image_processor_type = image_processor_type[:-4]
|
||||
use_fast = False
|
||||
logger.warning_once(
|
||||
|
||||
@ -107,7 +107,6 @@ PROCESSOR_MAPPING_NAMES = OrderedDict(
|
||||
("mllama", "MllamaProcessor"),
|
||||
("mm-grounding-dino", "GroundingDinoProcessor"),
|
||||
("moonshine", "Wav2Vec2Processor"),
|
||||
("omdet-turbo", "OmDetTurboProcessor"),
|
||||
("oneformer", "OneFormerProcessor"),
|
||||
("ovis2", "Ovis2Processor"),
|
||||
("owlv2", "Owlv2Processor"),
|
||||
|
||||
@ -72,7 +72,6 @@ TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]](
|
||||
),
|
||||
),
|
||||
("align", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("altclip", ("XLMRobertaTokenizer", "XLMRobertaTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("arcee", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("aria", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("aya_vision", (None, "CohereTokenizerFast" if is_tokenizers_available() else None)),
|
||||
@ -157,7 +156,6 @@ TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]](
|
||||
("codegen", ("CodeGenTokenizer", "CodeGenTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("cohere", (None, "CohereTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("cohere2", (None, "CohereTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("cohere2_vision", (None, "CohereTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("colpali", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("colqwen2", ("Qwen2Tokenizer", "Qwen2TokenizerFast" if is_tokenizers_available() else None)),
|
||||
("convbert", ("ConvBertTokenizer", "ConvBertTokenizerFast" if is_tokenizers_available() else None)),
|
||||
@ -226,7 +224,6 @@ TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]](
|
||||
),
|
||||
),
|
||||
("distilbert", ("DistilBertTokenizer", "DistilBertTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("donut", ("XLMRobertaTokenizer", "XLMRobertaTokenizerFast" if is_tokenizers_available() else None)),
|
||||
(
|
||||
"dpr",
|
||||
(
|
||||
@ -241,7 +238,6 @@ TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]](
|
||||
("ernie4_5_moe", (None, "LlamaTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("ernie_m", ("ErnieMTokenizer" if is_sentencepiece_available() else None, None)),
|
||||
("esm", ("EsmTokenizer", None)),
|
||||
("evolla", (None, "LlamaTokenizerFast" if is_tokenizers_available() else None)),
|
||||
(
|
||||
"exaone4",
|
||||
(
|
||||
@ -256,13 +252,10 @@ TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]](
|
||||
("FastSpeech2ConformerTokenizer" if is_g2p_en_available() else None, None),
|
||||
),
|
||||
("flaubert", ("FlaubertTokenizer", None)),
|
||||
("flava", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("flex_olmo", (None, "GPT2TokenizerFast" if is_tokenizers_available() else None)),
|
||||
("florence2", ("BartTokenizer", "BartTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("fnet", ("FNetTokenizer", "FNetTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("fsmt", ("FSMTTokenizer", None)),
|
||||
("funnel", ("FunnelTokenizer", "FunnelTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("fuyu", (None, "LlamaTokenizerFast" if is_tokenizers_available() else None)),
|
||||
(
|
||||
"gemma",
|
||||
(
|
||||
@ -311,7 +304,6 @@ TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]](
|
||||
("glm4_moe", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("glm4v", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("glm4v_moe", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("got_ocr2", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("gpt-sw3", ("GPTSw3Tokenizer" if is_sentencepiece_available() else None, None)),
|
||||
("gpt2", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
|
||||
("gpt_bigcode", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
|
||||
@ -322,7 +314,6 @@ TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]](
|
||||
("gptj", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
|
||||
("gptsan-japanese", ("GPTSanJapaneseTokenizer", None)),
|
||||
("granite", ("GPT2Tokenizer", None)),
|
||||
("granite_speech", ("GPT2Tokenizer", None)),
|
||||
("granitemoe", ("GPT2Tokenizer", None)),
|
||||
("granitemoehybrid", ("GPT2Tokenizer", None)),
|
||||
("granitemoeshared", ("GPT2Tokenizer", None)),
|
||||
@ -362,14 +353,11 @@ TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]](
|
||||
),
|
||||
),
|
||||
("kosmos-2.5", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("kyutai_speech_to_text", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("layoutlm", ("LayoutLMTokenizer", "LayoutLMTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("layoutlmv2", ("LayoutLMv2Tokenizer", "LayoutLMv2TokenizerFast" if is_tokenizers_available() else None)),
|
||||
("layoutlmv3", ("LayoutLMv3Tokenizer", "LayoutLMv3TokenizerFast" if is_tokenizers_available() else None)),
|
||||
("layoutxlm", ("LayoutXLMTokenizer", "LayoutXLMTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("led", ("LEDTokenizer", "LEDTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("lfm2", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("lfm2_vl", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("lilt", ("LayoutLMv3Tokenizer", "LayoutLMv3TokenizerFast" if is_tokenizers_available() else None)),
|
||||
(
|
||||
"llama",
|
||||
@ -410,7 +398,6 @@ TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]](
|
||||
("mamba", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("mamba2", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("marian", ("MarianTokenizer" if is_sentencepiece_available() else None, None)),
|
||||
("markuplm", ("MarkupLMTokenizer", "MarkupLMTokenizerFast" if is_tokenizers_available() else None)),
|
||||
(
|
||||
"mbart",
|
||||
(
|
||||
@ -497,7 +484,6 @@ TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]](
|
||||
"NllbTokenizerFast" if is_tokenizers_available() else None,
|
||||
),
|
||||
),
|
||||
("nougat", (None, "NougatTokenizerFast" if is_tokenizers_available() else None)),
|
||||
(
|
||||
"nystromformer",
|
||||
(
|
||||
@ -519,7 +505,6 @@ TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]](
|
||||
("OpenAIGPTTokenizer", "OpenAIGPTTokenizerFast" if is_tokenizers_available() else None),
|
||||
),
|
||||
("opt", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
|
||||
("ovis2", (None, "Qwen2TokenizerFast" if is_tokenizers_available() else None)),
|
||||
("owlv2", ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("owlvit", ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("paligemma", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
|
||||
@ -545,7 +530,6 @@ TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]](
|
||||
None,
|
||||
),
|
||||
),
|
||||
("perception_lm", (None, "LlamaTokenizerFast" if is_tokenizers_available() else None)),
|
||||
(
|
||||
"persimmon",
|
||||
(
|
||||
@ -555,7 +539,6 @@ TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]](
|
||||
),
|
||||
("phi", ("CodeGenTokenizer", "CodeGenTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("phi3", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("phi4_multimodal", (None, "GPT2TokenizerFast" if is_tokenizers_available() else None)),
|
||||
("phimoe", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("phobert", ("PhobertTokenizer", None)),
|
||||
("pix2struct", ("T5Tokenizer", "T5TokenizerFast" if is_tokenizers_available() else None)),
|
||||
@ -569,7 +552,6 @@ TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]](
|
||||
),
|
||||
),
|
||||
("plbart", ("PLBartTokenizer" if is_sentencepiece_available() else None, None)),
|
||||
("pop2piano", ("Pop2PianoTokenizer", None)),
|
||||
("prophetnet", ("ProphetNetTokenizer", None)),
|
||||
("qdqbert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
|
||||
(
|
||||
@ -676,7 +658,6 @@ TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]](
|
||||
),
|
||||
),
|
||||
("smollm3", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("smolvlm", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("speech_to_text", ("Speech2TextTokenizer" if is_sentencepiece_available() else None, None)),
|
||||
("speech_to_text_2", ("Speech2Text2Tokenizer", None)),
|
||||
("speecht5", ("SpeechT5Tokenizer" if is_sentencepiece_available() else None, None)),
|
||||
@ -711,7 +692,6 @@ TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]](
|
||||
("tapas", ("TapasTokenizer", None)),
|
||||
("tapex", ("TapexTokenizer", None)),
|
||||
("transfo-xl", ("TransfoXLTokenizer", None)),
|
||||
("trocr", ("XLMRobertaTokenizer", "XLMRobertaTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("tvp", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
|
||||
(
|
||||
"udop",
|
||||
@ -727,14 +707,9 @@ TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]](
|
||||
"T5TokenizerFast" if is_tokenizers_available() else None,
|
||||
),
|
||||
),
|
||||
("video_llama_3", ("Qwen2Tokenizer", "Qwen2TokenizerFast" if is_tokenizers_available() else None)),
|
||||
("video_llava", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("vilt", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("vipllava", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
|
||||
(
|
||||
"vision_text_dual_encoder",
|
||||
("PreTrainedTokenizer", "PreTrainedTokenizerFast" if is_tokenizers_available() else None),
|
||||
),
|
||||
("visual_bert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("vits", ("VitsTokenizer", None)),
|
||||
(
|
||||
@ -750,7 +725,6 @@ TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]](
|
||||
("wav2vec2-bert", ("Wav2Vec2CTCTokenizer", None)),
|
||||
("wav2vec2-conformer", ("Wav2Vec2CTCTokenizer", None)),
|
||||
("wav2vec2_phoneme", ("Wav2Vec2PhonemeCTCTokenizer", None)),
|
||||
("wav2vec2_with_lm", ("Wav2Vec2CTCTokenizer", None)),
|
||||
("whisper", ("WhisperTokenizer", "WhisperTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("xclip", ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None)),
|
||||
(
|
||||
@ -1186,7 +1160,7 @@ class AutoTokenizer:
|
||||
The configuration corresponding to the model to register.
|
||||
slow_tokenizer_class ([`PretrainedTokenizer`], *optional*):
|
||||
The slow tokenizer to register.
|
||||
fast_tokenizer_class ([`PreTrainedTokenizerFast`], *optional*):
|
||||
fast_tokenizer_class ([`PretrainedTokenizerFast`], *optional*):
|
||||
The fast tokenizer to register.
|
||||
"""
|
||||
if slow_tokenizer_class is None and fast_tokenizer_class is None:
|
||||
|
||||
@ -60,7 +60,6 @@ else:
|
||||
("qwen3_vl_moe", "Qwen3VLVideoProcessor"),
|
||||
("sam2_video", "Sam2VideoVideoProcessor"),
|
||||
("smolvlm", "SmolVLMVideoProcessor"),
|
||||
("video_llama_3", "VideoLlama3VideoProcessor"),
|
||||
("video_llava", "VideoLlavaVideoProcessor"),
|
||||
("videomae", "VideoMAEVideoProcessor"),
|
||||
("vjepa2", "VJEPA2VideoProcessor"),
|
||||
@ -292,7 +291,7 @@ class AutoVideoProcessor:
|
||||
|
||||
# Some models have different image processors, e.g. InternVL uses GotOCRImageProcessor
|
||||
# We cannot use GotOCRVideoProcessor when falling back for BC and should try to infer from config later on
|
||||
if video_processor_class_from_name(video_processor_class_inferred) is not None:
|
||||
if video_processor_class_inferred in VIDEO_PROCESSOR_MAPPING_NAMES.values():
|
||||
video_processor_class = video_processor_class_inferred
|
||||
if "AutoImageProcessor" in config_dict.get("auto_map", {}):
|
||||
image_processor_auto_map = config_dict["auto_map"]["AutoImageProcessor"]
|
||||
|
||||
@ -826,21 +826,22 @@ class AutoformerPreTrainedModel(PreTrainedModel):
|
||||
main_input_name = "past_values"
|
||||
supports_gradient_checkpointing = True
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module: nn.Module):
|
||||
std = self.config.init_std
|
||||
if isinstance(module, (nn.Linear, nn.Conv1d)):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
module.weight.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
elif isinstance(module, AutoformerSinusoidalPositionalEmbedding):
|
||||
module._init_weight()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
module.weight.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
module.weight[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.data.zero_()
|
||||
module.weight.fill_(1.0)
|
||||
module.bias.zero_()
|
||||
|
||||
# copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask
|
||||
def _update_full_mask(
|
||||
|
||||
@ -338,7 +338,7 @@ class AyaVisionForConditionalGeneration(AyaVisionPreTrainedModel, GenerationMixi
|
||||
"^multi_modal_projector": "model.multi_modal_projector",
|
||||
"^language_model.lm_head": "lm_head",
|
||||
}
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"}
|
||||
|
||||
def __init__(self, config: AyaVisionConfig):
|
||||
super().__init__(config)
|
||||
|
||||
@ -70,6 +70,10 @@ class AyaVisionProcessor(ProcessorMixin):
|
||||
in a chat into a tokenizable string.
|
||||
"""
|
||||
|
||||
attributes = ["image_processor", "tokenizer"]
|
||||
image_processor_class = "AutoImageProcessor"
|
||||
tokenizer_class = "AutoTokenizer"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_processor=None,
|
||||
|
||||
@ -1126,12 +1126,13 @@ class BambaPreTrainedModel(PreTrainedModel):
|
||||
# Note: only supports HybridMambaAttentionDynamicCache
|
||||
_is_stateful = True
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
super()._init_weights(module)
|
||||
if isinstance(module, BambaMixer):
|
||||
module.dt_bias.data.fill_(1.0)
|
||||
module.A_log.data = torch.log(torch.arange(1, module.num_heads + 1))
|
||||
module.D.data.fill_(1.0)
|
||||
module.dt_bias.fill_(1.0)
|
||||
module.A_log.copy_(torch.log(torch.arange(1, module.num_heads + 1)))
|
||||
module.D.fill_(1.0)
|
||||
|
||||
|
||||
@auto_docstring
|
||||
@ -1383,7 +1384,7 @@ class BambaModel(BambaPreTrainedModel):
|
||||
|
||||
@auto_docstring
|
||||
class BambaForCausalLM(BambaPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
|
||||
_tp_plan = {"lm_head": "colwise_rep"}
|
||||
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
||||
|
||||
|
||||
@ -800,12 +800,13 @@ class BambaPreTrainedModel(PreTrainedModel):
|
||||
# Note: only supports HybridMambaAttentionDynamicCache
|
||||
_is_stateful = True
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
super()._init_weights(module)
|
||||
if isinstance(module, BambaMixer):
|
||||
module.dt_bias.data.fill_(1.0)
|
||||
module.A_log.data = torch.log(torch.arange(1, module.num_heads + 1))
|
||||
module.D.data.fill_(1.0)
|
||||
module.dt_bias.fill_(1.0)
|
||||
module.A_log.copy_(torch.log(torch.arange(1, module.num_heads + 1)))
|
||||
module.D.fill_(1.0)
|
||||
|
||||
|
||||
@auto_docstring
|
||||
|
||||
@ -329,19 +329,20 @@ class BarkPreTrainedModel(PreTrainedModel):
|
||||
supports_gradient_checkpointing = False
|
||||
_supports_flash_attn = True
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights."""
|
||||
if isinstance(module, (nn.Linear,)):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
module.weight[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.zero_()
|
||||
module.weight.fill_(1.0)
|
||||
|
||||
def __init__(self, *inputs, **kwargs):
|
||||
super().__init__(*inputs, **kwargs)
|
||||
@ -910,6 +911,9 @@ class BarkFineModel(BarkPreTrainedModel):
|
||||
# non-causal gpt-like model with one embedding layer and one lm_head for each codebook of Encodec
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self._tied_weights_keys = {}
|
||||
for i in range(self.config.n_codes_total - self.config.n_codes_given):
|
||||
self._tied_weights_keys[f"lm_heads.{i}.weight"] = f"input_embeds_layers.{i + 1}.weight"
|
||||
|
||||
# initialize a modified non causal GPT-like model
|
||||
# note that for there is one embedding layer and one lm_head for each codebook of Encodec
|
||||
@ -1025,25 +1029,6 @@ class BarkFineModel(BarkPreTrainedModel):
|
||||
|
||||
return model_embeds
|
||||
|
||||
def _tie_weights(self):
|
||||
if getattr(self.config, "tie_word_embeddings", True):
|
||||
self._tied_weights_keys = []
|
||||
output_embeddings = self.get_output_embeddings()
|
||||
input_embeddings = self.get_input_embeddings()
|
||||
|
||||
for i in range(self.config.n_codes_total - self.config.n_codes_given):
|
||||
# self.input_embeds_layers[i + 1].weight = self.lm_heads[i].weight
|
||||
self._tie_embedding_weights(output_embeddings[i], input_embeddings[i + 1])
|
||||
self._tied_weights_keys.append(f"lm_heads.{i}.weight")
|
||||
|
||||
def tie_weights(self):
|
||||
"""
|
||||
Tie the weights between the input embeddings list and the output embeddings list.
|
||||
"""
|
||||
for module in self.modules():
|
||||
if hasattr(module, "_tie_weights"):
|
||||
module._tie_weights()
|
||||
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -1580,14 +1565,6 @@ class BarkModel(BarkPreTrainedModel, GenerationMixin):
|
||||
|
||||
return audio
|
||||
|
||||
def tie_weights(self):
|
||||
"""
|
||||
Tie the weights between the input embeddings list and the output embeddings list.
|
||||
"""
|
||||
for module in self.modules():
|
||||
if hasattr(module, "_tie_weights"):
|
||||
module._tie_weights()
|
||||
|
||||
|
||||
__all__ = [
|
||||
"BarkFineModel",
|
||||
|
||||
@ -49,6 +49,9 @@ class BarkProcessor(ProcessorMixin):
|
||||
|
||||
"""
|
||||
|
||||
tokenizer_class = "AutoTokenizer"
|
||||
attributes = ["tokenizer"]
|
||||
|
||||
preset_shape = {
|
||||
"semantic_prompt": 1, # 1D array of shape (X,)
|
||||
"coarse_prompt": 2, # 2D array of shape (2,X)
|
||||
|
||||
@ -476,19 +476,20 @@ class BartPreTrainedModel(PreTrainedModel):
|
||||
|
||||
_can_compile_fullgraph = True
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
std = self.config.init_std
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
module.weight.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
module.weight.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
module.weight[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.data.zero_()
|
||||
module.weight.fill_(1.0)
|
||||
module.bias.zero_()
|
||||
|
||||
@property
|
||||
def dummy_inputs(self):
|
||||
@ -527,7 +528,7 @@ class BartEncoder(BartPreTrainedModel):
|
||||
embed_tokens (nn.Embedding): output embedding
|
||||
"""
|
||||
|
||||
def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = None):
|
||||
def __init__(self, config: BartConfig):
|
||||
super().__init__(config)
|
||||
|
||||
self.dropout = config.dropout
|
||||
@ -538,12 +539,9 @@ class BartEncoder(BartPreTrainedModel):
|
||||
self.max_source_positions = config.max_position_embeddings
|
||||
embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
|
||||
|
||||
if embed_tokens is not None:
|
||||
self.embed_tokens = embed_tokens
|
||||
else:
|
||||
self.embed_tokens = BartScaledWordEmbedding(
|
||||
config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale
|
||||
)
|
||||
self.embed_tokens = BartScaledWordEmbedding(
|
||||
config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale
|
||||
)
|
||||
|
||||
self.embed_positions = BartLearnedPositionalEmbedding(
|
||||
config.max_position_embeddings,
|
||||
@ -674,7 +672,7 @@ class BartDecoder(BartPreTrainedModel):
|
||||
embed_tokens (nn.Embedding): output embedding
|
||||
"""
|
||||
|
||||
def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = None):
|
||||
def __init__(self, config: BartConfig):
|
||||
super().__init__(config)
|
||||
self.dropout = config.dropout
|
||||
self.layerdrop = config.decoder_layerdrop
|
||||
@ -682,12 +680,9 @@ class BartDecoder(BartPreTrainedModel):
|
||||
self.max_target_positions = config.max_position_embeddings
|
||||
embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
|
||||
|
||||
if embed_tokens is not None:
|
||||
self.embed_tokens = embed_tokens
|
||||
else:
|
||||
self.embed_tokens = BartScaledWordEmbedding(
|
||||
config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale
|
||||
)
|
||||
self.embed_tokens = BartScaledWordEmbedding(
|
||||
config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale
|
||||
)
|
||||
|
||||
self.embed_positions = BartLearnedPositionalEmbedding(
|
||||
config.max_position_embeddings,
|
||||
@ -899,7 +894,10 @@ class BartDecoder(BartPreTrainedModel):
|
||||
|
||||
@auto_docstring
|
||||
class BartModel(BartPreTrainedModel):
|
||||
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
|
||||
_tied_weights_keys = {
|
||||
"decoder.embed_tokens.weight": "shared.weight",
|
||||
"encoder.embed_tokens.weight": "shared.weight",
|
||||
}
|
||||
|
||||
def __init__(self, config: BartConfig):
|
||||
super().__init__(config)
|
||||
@ -908,24 +906,12 @@ class BartModel(BartPreTrainedModel):
|
||||
embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
|
||||
self.shared = BartScaledWordEmbedding(vocab_size, config.d_model, padding_idx, embed_scale=embed_scale)
|
||||
|
||||
self.encoder = BartEncoder(config, self.shared)
|
||||
self.decoder = BartDecoder(config, self.shared)
|
||||
self.encoder = BartEncoder(config)
|
||||
self.decoder = BartDecoder(config)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def _tie_weights(self):
|
||||
if self.config.tie_word_embeddings:
|
||||
# Some model checkpoints like "facebook/bart-large-cnn"'s embedding weight is in decoder.embed_tokens, need check here, see issue #36247
|
||||
if self.shared.weight.device == torch.device(
|
||||
"meta"
|
||||
) and self.decoder.embed_tokens.weight.device != torch.device("meta"):
|
||||
self._tie_embedding_weights(self.encoder.embed_tokens, self.decoder.embed_tokens)
|
||||
self._tie_embedding_weights(self.shared, self.decoder.embed_tokens)
|
||||
else:
|
||||
self._tie_embedding_weights(self.encoder.embed_tokens, self.shared)
|
||||
self._tie_embedding_weights(self.decoder.embed_tokens, self.shared)
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.shared
|
||||
|
||||
@ -1052,7 +1038,9 @@ class BartModel(BartPreTrainedModel):
|
||||
)
|
||||
class BartForConditionalGeneration(BartPreTrainedModel, GenerationMixin):
|
||||
base_model_prefix = "model"
|
||||
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
|
||||
_tied_weights_keys = {
|
||||
"lm_head.weight": "model.shared.weight",
|
||||
}
|
||||
_keys_to_ignore_on_load_missing = ["final_logits_bias"]
|
||||
|
||||
def __init__(self, config: BartConfig):
|
||||
@ -1086,11 +1074,6 @@ class BartForConditionalGeneration(BartPreTrainedModel, GenerationMixin):
|
||||
new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
|
||||
self.register_buffer("final_logits_bias", new_bias)
|
||||
|
||||
def _tie_weights(self):
|
||||
if self.config.tie_word_embeddings:
|
||||
self.model._tie_weights()
|
||||
self._tie_embedding_weights(self.lm_head, self.model.shared)
|
||||
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -1240,8 +1223,6 @@ class BartForConditionalGeneration(BartPreTrainedModel, GenerationMixin):
|
||||
"""
|
||||
)
|
||||
class BartForSequenceClassification(BartPreTrainedModel):
|
||||
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
|
||||
|
||||
def __init__(self, config: BartConfig, **kwargs):
|
||||
super().__init__(config, **kwargs)
|
||||
self.model = BartModel(config)
|
||||
@ -1374,8 +1355,6 @@ class BartForSequenceClassification(BartPreTrainedModel):
|
||||
|
||||
@auto_docstring
|
||||
class BartForQuestionAnswering(BartPreTrainedModel):
|
||||
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
@ -1513,7 +1492,9 @@ class BartDecoderWrapper(BartPreTrainedModel):
|
||||
"""
|
||||
)
|
||||
class BartForCausalLM(BartPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tied_weights_keys = {
|
||||
"lm_head.weight": "model.decoder.embed_tokens.weight",
|
||||
}
|
||||
|
||||
def __init__(self, config):
|
||||
config.is_decoder = True
|
||||
|
||||
@ -692,31 +692,32 @@ class BeitPreTrainedModel(PreTrainedModel):
|
||||
_keys_to_ignore_on_load_unexpected = [r".*relative_position_index.*"]
|
||||
_supports_sdpa = True
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
module.weight[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.zero_()
|
||||
module.weight.fill_(1.0)
|
||||
elif isinstance(module, BeitEmbeddings):
|
||||
module.cls_token.data.zero_()
|
||||
module.cls_token.zero_()
|
||||
if module.mask_token is not None:
|
||||
module.mask_token.data.zero_()
|
||||
module.mask_token.zero_()
|
||||
if module.position_embeddings is not None:
|
||||
module.position_embeddings.data.zero_()
|
||||
module.position_embeddings.zero_()
|
||||
elif isinstance(module, BeitRelativePositionBias):
|
||||
module.relative_position_bias_table.data.zero_()
|
||||
module.relative_position_bias_table.zero_()
|
||||
elif isinstance(module, BeitLayer):
|
||||
if module.lambda_1 is not None:
|
||||
module.lambda_1.data.fill_(self.config.layer_scale_init_value)
|
||||
module.lambda_2.data.fill_(self.config.layer_scale_init_value)
|
||||
module.lambda_1.fill_(self.config.layer_scale_init_value)
|
||||
module.lambda_2.fill_(self.config.layer_scale_init_value)
|
||||
|
||||
|
||||
@auto_docstring
|
||||
|
||||
@ -506,16 +506,9 @@ class BertLMPredictionHead(nn.Module):
|
||||
|
||||
# The output weights are the same as the input embeddings, but there is
|
||||
# an output-only bias for each token.
|
||||
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=True)
|
||||
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
||||
|
||||
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
||||
self.decoder.bias = self.bias
|
||||
|
||||
def _tie_weights(self):
|
||||
self.decoder.bias = self.bias
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.transform(hidden_states)
|
||||
hidden_states = self.decoder(hidden_states)
|
||||
@ -569,21 +562,22 @@ class BertPreTrainedModel(PreTrainedModel):
|
||||
"cross_attentions": BertCrossAttention,
|
||||
}
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
module.weight[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.zero_()
|
||||
module.weight.fill_(1.0)
|
||||
elif isinstance(module, BertLMPredictionHead):
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -770,7 +764,10 @@ class BertModel(BertPreTrainedModel):
|
||||
"""
|
||||
)
|
||||
class BertForPreTraining(BertPreTrainedModel):
|
||||
_tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"]
|
||||
_tied_weights_keys = {
|
||||
"cls.predictions.decoder.weight": "bert.embeddings.word_embeddings.weight",
|
||||
"cls.predictions.decoder.bias": "cls.predictions.bias",
|
||||
}
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
@ -864,7 +861,10 @@ class BertForPreTraining(BertPreTrainedModel):
|
||||
"""
|
||||
)
|
||||
class BertLMHeadModel(BertPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"]
|
||||
_tied_weights_keys = {
|
||||
"cls.predictions.decoder.weight": "bert.embeddings.word_embeddings.weight",
|
||||
"cls.predictions.decoder.bias": "cls.predictions.bias",
|
||||
}
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
@ -948,7 +948,10 @@ class BertLMHeadModel(BertPreTrainedModel, GenerationMixin):
|
||||
|
||||
@auto_docstring
|
||||
class BertForMaskedLM(BertPreTrainedModel):
|
||||
_tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"]
|
||||
_tied_weights_keys = {
|
||||
"cls.predictions.decoder.weight": "bert.embeddings.word_embeddings.weight",
|
||||
"cls.predictions.decoder.bias": "cls.predictions.bias",
|
||||
}
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
@ -456,21 +456,22 @@ class BertGenerationPreTrainedModel(PreTrainedModel):
|
||||
"cross_attentions": BertGenerationCrossAttention,
|
||||
}
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
module.weight[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.zero_()
|
||||
module.weight.fill_(1.0)
|
||||
elif isinstance(module, BertGenerationOnlyLMHead):
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
|
||||
|
||||
@auto_docstring(
|
||||
@ -629,20 +630,11 @@ class BertGenerationOnlyLMHead(nn.Module):
|
||||
super().__init__()
|
||||
self.decoder = nn.Linear(config.hidden_size, config.vocab_size)
|
||||
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
||||
self.decoder.bias = self.bias
|
||||
|
||||
def forward(self, hidden_states):
|
||||
logits = self.decoder(hidden_states)
|
||||
return logits
|
||||
|
||||
def _tie_weights(self):
|
||||
# For accelerate compatibility and to not break backward compatibility
|
||||
if self.decoder.bias.device.type == "meta":
|
||||
self.decoder.bias = self.bias
|
||||
else:
|
||||
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
|
||||
self.bias = self.decoder.bias
|
||||
|
||||
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
@ -650,7 +642,10 @@ class BertGenerationOnlyLMHead(nn.Module):
|
||||
"""
|
||||
)
|
||||
class BertGenerationDecoder(BertGenerationPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
|
||||
_tied_weights_keys = {
|
||||
"lm_head.decoder.weight": "bert.embeddings.word_embeddings.weight",
|
||||
"lm_head.decoder.bias": "lm_head.bias",
|
||||
}
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
@ -1464,16 +1464,9 @@ class BigBirdLMPredictionHead(nn.Module):
|
||||
|
||||
# The output weights are the same as the input embeddings, but there is
|
||||
# an output-only bias for each token.
|
||||
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=True)
|
||||
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
||||
|
||||
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
||||
self.decoder.bias = self.bias
|
||||
|
||||
def _tie_weights(self):
|
||||
self.decoder.bias = self.bias
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.transform(hidden_states)
|
||||
hidden_states = self.decoder(hidden_states)
|
||||
@ -1521,21 +1514,22 @@ class BigBirdPreTrainedModel(PreTrainedModel):
|
||||
base_model_prefix = "bert"
|
||||
supports_gradient_checkpointing = True
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
module.weight[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.zero_()
|
||||
module.weight.fill_(1.0)
|
||||
elif isinstance(module, BigBirdLMPredictionHead):
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -1899,7 +1893,10 @@ class BigBirdModel(BigBirdPreTrainedModel):
|
||||
|
||||
|
||||
class BigBirdForPreTraining(BigBirdPreTrainedModel):
|
||||
_tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
|
||||
_tied_weights_keys = {
|
||||
"cls.predictions.decoder.bias": "cls.predictions.bias",
|
||||
"cls.predictions.decoder.weight": "bert.embeddings.word_embeddings.weight",
|
||||
}
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
@ -1999,7 +1996,10 @@ class BigBirdForPreTraining(BigBirdPreTrainedModel):
|
||||
|
||||
@auto_docstring
|
||||
class BigBirdForMaskedLM(BigBirdPreTrainedModel):
|
||||
_tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
|
||||
_tied_weights_keys = {
|
||||
"cls.predictions.decoder.bias": "cls.predictions.bias",
|
||||
"cls.predictions.decoder.weight": "bert.embeddings.word_embeddings.weight",
|
||||
}
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
@ -2141,7 +2141,10 @@ class BigBirdForMaskedLM(BigBirdPreTrainedModel):
|
||||
"""
|
||||
)
|
||||
class BigBirdForCausalLM(BigBirdPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
|
||||
_tied_weights_keys = {
|
||||
"cls.predictions.decoder.bias": "cls.predictions.bias",
|
||||
"cls.predictions.decoder.weight": "bert.embeddings.word_embeddings.weight",
|
||||
}
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
@ -1539,19 +1539,20 @@ class BigBirdPegasusPreTrainedModel(PreTrainedModel):
|
||||
|
||||
_can_compile_fullgraph = True
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
std = self.config.init_std
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
module.weight.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
module.weight.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
module.weight[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.data.zero_()
|
||||
module.weight.fill_(1.0)
|
||||
module.bias.zero_()
|
||||
|
||||
@property
|
||||
def dummy_inputs(self):
|
||||
@ -1574,7 +1575,7 @@ class BigBirdPegasusEncoder(BigBirdPegasusPreTrainedModel):
|
||||
embed_tokens (nn.Embedding): output embedding
|
||||
"""
|
||||
|
||||
def __init__(self, config: BigBirdPegasusConfig, embed_tokens: Optional[nn.Embedding] = None):
|
||||
def __init__(self, config: BigBirdPegasusConfig):
|
||||
super().__init__(config)
|
||||
|
||||
self.attention_type = config.attention_type
|
||||
@ -1592,9 +1593,6 @@ class BigBirdPegasusEncoder(BigBirdPegasusPreTrainedModel):
|
||||
config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale
|
||||
)
|
||||
|
||||
if embed_tokens is not None:
|
||||
self.embed_tokens.weight = embed_tokens.weight
|
||||
|
||||
self.embed_positions = BigBirdPegasusLearnedPositionalEmbedding(
|
||||
config.max_position_embeddings,
|
||||
embed_dim,
|
||||
@ -1849,7 +1847,7 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel):
|
||||
embed_tokens (nn.Embedding): output embedding
|
||||
"""
|
||||
|
||||
def __init__(self, config: BigBirdPegasusConfig, embed_tokens: Optional[nn.Embedding] = None):
|
||||
def __init__(self, config: BigBirdPegasusConfig):
|
||||
super().__init__(config)
|
||||
self.dropout = config.dropout
|
||||
self.layerdrop = config.decoder_layerdrop
|
||||
@ -1861,9 +1859,6 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel):
|
||||
config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale
|
||||
)
|
||||
|
||||
if embed_tokens is not None:
|
||||
self.embed_tokens.weight = embed_tokens.weight
|
||||
|
||||
self.embed_positions = BigBirdPegasusLearnedPositionalEmbedding(
|
||||
config.max_position_embeddings,
|
||||
config.d_model,
|
||||
@ -2075,7 +2070,10 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel):
|
||||
|
||||
@auto_docstring
|
||||
class BigBirdPegasusModel(BigBirdPegasusPreTrainedModel):
|
||||
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
|
||||
_tied_weights_keys = {
|
||||
"encoder.embed_tokens.weight": "shared.weight",
|
||||
"decoder.embed_tokens.weight": "shared.weight",
|
||||
}
|
||||
|
||||
def __init__(self, config: BigBirdPegasusConfig):
|
||||
super().__init__(config)
|
||||
@ -2086,8 +2084,8 @@ class BigBirdPegasusModel(BigBirdPegasusPreTrainedModel):
|
||||
vocab_size, config.d_model, padding_idx, embed_scale=embed_scale
|
||||
)
|
||||
|
||||
self.encoder = BigBirdPegasusEncoder(config, self.shared)
|
||||
self.decoder = BigBirdPegasusDecoder(config, self.shared)
|
||||
self.encoder = BigBirdPegasusEncoder(config)
|
||||
self.decoder = BigBirdPegasusDecoder(config)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
@ -2100,11 +2098,6 @@ class BigBirdPegasusModel(BigBirdPegasusPreTrainedModel):
|
||||
self.encoder.embed_tokens = self.shared
|
||||
self.decoder.embed_tokens = self.shared
|
||||
|
||||
def _tie_weights(self):
|
||||
if self.config.tie_word_embeddings:
|
||||
self._tie_embedding_weights(self.encoder.embed_tokens, self.shared)
|
||||
self._tie_embedding_weights(self.decoder.embed_tokens, self.shared)
|
||||
|
||||
def get_encoder(self):
|
||||
return self.encoder
|
||||
|
||||
@ -2213,7 +2206,9 @@ class BigBirdPegasusModel(BigBirdPegasusPreTrainedModel):
|
||||
# Copied from transformers.models.bart.modeling_bart.BartForConditionalGeneration with Bart->BigBirdPegasus, BART->BIGBIRD_PEGASUS
|
||||
class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel, GenerationMixin):
|
||||
base_model_prefix = "model"
|
||||
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
|
||||
_tied_weights_keys = {
|
||||
"lm_head.weight": "model.shared.weight",
|
||||
}
|
||||
_keys_to_ignore_on_load_missing = ["final_logits_bias"]
|
||||
|
||||
def __init__(self, config: BigBirdPegasusConfig):
|
||||
@ -2247,11 +2242,6 @@ class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel, Gene
|
||||
new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
|
||||
self.register_buffer("final_logits_bias", new_bias)
|
||||
|
||||
def _tie_weights(self):
|
||||
if self.config.tie_word_embeddings:
|
||||
self.model._tie_weights()
|
||||
self._tie_embedding_weights(self.lm_head, self.model.shared)
|
||||
|
||||
@auto_docstring
|
||||
# Ignore copy
|
||||
def forward(
|
||||
@ -2374,8 +2364,6 @@ class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel, Gene
|
||||
"""
|
||||
)
|
||||
class BigBirdPegasusForSequenceClassification(BigBirdPegasusPreTrainedModel):
|
||||
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
|
||||
|
||||
def __init__(self, config: BigBirdPegasusConfig, **kwargs):
|
||||
super().__init__(config, **kwargs)
|
||||
self.model = BigBirdPegasusModel(config)
|
||||
@ -2497,8 +2485,6 @@ class BigBirdPegasusForSequenceClassification(BigBirdPegasusPreTrainedModel):
|
||||
|
||||
@auto_docstring
|
||||
class BigBirdPegasusForQuestionAnswering(BigBirdPegasusPreTrainedModel):
|
||||
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
@ -2621,8 +2607,6 @@ class BigBirdPegasusDecoderWrapper(BigBirdPegasusPreTrainedModel):
|
||||
|
||||
|
||||
class BigBirdPegasusForCausalLM(BigBirdPegasusPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
|
||||
def __init__(self, config):
|
||||
config.is_decoder = True
|
||||
config.is_encoder_decoder = False
|
||||
|
||||
@ -510,7 +510,7 @@ class BioGptModel(BioGptPreTrainedModel):
|
||||
"""
|
||||
)
|
||||
class BioGptForCausalLM(BioGptPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["output_projection.weight"]
|
||||
_tied_weights_keys = {"output_projection.weight": "biogpt.embed_tokens.weight"}
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
@ -332,7 +332,7 @@ class BioGptModel(BioGptPreTrainedModel):
|
||||
"""
|
||||
)
|
||||
class BioGptForCausalLM(BioGptPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["output_projection.weight"]
|
||||
_tied_weights_keys = {"output_projection.weight": "biogpt.embed_tokens.weight"}
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
@ -628,6 +628,7 @@ class BitPreTrainedModel(PreTrainedModel):
|
||||
main_input_name = "pixel_values"
|
||||
_no_split_modules = ["BitEmbeddings"]
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
if isinstance(module, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
|
||||
|
||||
@ -433,7 +433,7 @@ class BitNetModel(BitNetPreTrainedModel):
|
||||
|
||||
@auto_docstring
|
||||
class BitNetForCausalLM(BitNetPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
|
||||
_tp_plan = None
|
||||
_pp_plan = None
|
||||
|
||||
|
||||
@ -114,7 +114,7 @@ class BitNetModel(LlamaModel):
|
||||
|
||||
|
||||
class BitNetForCausalLM(LlamaForCausalLM):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
|
||||
_tp_plan = None
|
||||
_pp_plan = None
|
||||
|
||||
|
||||
@ -438,19 +438,20 @@ class BlenderbotPreTrainedModel(PreTrainedModel):
|
||||
|
||||
_can_compile_fullgraph = True
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
std = self.config.init_std
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
module.weight.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
module.weight.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
module.weight[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.data.zero_()
|
||||
module.weight.fill_(1.0)
|
||||
module.bias.zero_()
|
||||
|
||||
@property
|
||||
def dummy_inputs(self):
|
||||
@ -474,7 +475,7 @@ class BlenderbotEncoder(BlenderbotPreTrainedModel):
|
||||
embed_tokens (nn.Embedding): output embedding
|
||||
"""
|
||||
|
||||
def __init__(self, config: BlenderbotConfig, embed_tokens: Optional[nn.Embedding] = None):
|
||||
def __init__(self, config: BlenderbotConfig):
|
||||
super().__init__(config)
|
||||
|
||||
self.dropout = config.dropout
|
||||
@ -485,12 +486,9 @@ class BlenderbotEncoder(BlenderbotPreTrainedModel):
|
||||
self.max_source_positions = config.max_position_embeddings
|
||||
embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
|
||||
|
||||
if embed_tokens is not None:
|
||||
self.embed_tokens = embed_tokens
|
||||
else:
|
||||
self.embed_tokens = BlenderbotScaledWordEmbedding(
|
||||
config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale
|
||||
)
|
||||
self.embed_tokens = BlenderbotScaledWordEmbedding(
|
||||
config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale
|
||||
)
|
||||
|
||||
self.embed_positions = BlenderbotLearnedPositionalEmbedding(
|
||||
config.max_position_embeddings,
|
||||
@ -623,7 +621,7 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
|
||||
embed_tokens (nn.Embedding): output embedding
|
||||
"""
|
||||
|
||||
def __init__(self, config: BlenderbotConfig, embed_tokens: Optional[nn.Embedding] = None):
|
||||
def __init__(self, config: BlenderbotConfig):
|
||||
super().__init__(config)
|
||||
self.dropout = config.dropout
|
||||
self.layerdrop = config.decoder_layerdrop
|
||||
@ -631,12 +629,9 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
|
||||
self.max_target_positions = config.max_position_embeddings
|
||||
embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
|
||||
|
||||
if embed_tokens is not None:
|
||||
self.embed_tokens = embed_tokens
|
||||
else:
|
||||
self.embed_tokens = BlenderbotScaledWordEmbedding(
|
||||
config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale
|
||||
)
|
||||
self.embed_tokens = BlenderbotScaledWordEmbedding(
|
||||
config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale
|
||||
)
|
||||
|
||||
self.embed_positions = BlenderbotLearnedPositionalEmbedding(
|
||||
config.max_position_embeddings,
|
||||
@ -852,7 +847,10 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
|
||||
|
||||
@auto_docstring
|
||||
class BlenderbotModel(BlenderbotPreTrainedModel):
|
||||
_tied_weights_keys = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"]
|
||||
_tied_weights_keys = {
|
||||
"encoder.embed_tokens.weight": "shared.weight",
|
||||
"decoder.embed_tokens.weight": "shared.weight",
|
||||
}
|
||||
|
||||
def __init__(self, config: BlenderbotConfig):
|
||||
super().__init__(config)
|
||||
@ -860,8 +858,8 @@ class BlenderbotModel(BlenderbotPreTrainedModel):
|
||||
padding_idx, vocab_size = config.pad_token_id, config.vocab_size
|
||||
embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
|
||||
self.shared = BlenderbotScaledWordEmbedding(vocab_size, config.d_model, padding_idx, embed_scale=embed_scale)
|
||||
self.encoder = BlenderbotEncoder(config, self.shared)
|
||||
self.decoder = BlenderbotDecoder(config, self.shared)
|
||||
self.encoder = BlenderbotEncoder(config)
|
||||
self.decoder = BlenderbotDecoder(config)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
@ -1001,7 +999,9 @@ class BlenderbotModel(BlenderbotPreTrainedModel):
|
||||
class BlenderbotForConditionalGeneration(BlenderbotPreTrainedModel, GenerationMixin):
|
||||
base_model_prefix = "model"
|
||||
_keys_to_ignore_on_load_missing = ["final_logits_bias"]
|
||||
_tied_weights_keys = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight", "lm_head.weight"]
|
||||
_tied_weights_keys = {
|
||||
"lm_head.weight": "model.shared.weight",
|
||||
}
|
||||
|
||||
def __init__(self, config: BlenderbotConfig):
|
||||
super().__init__(config)
|
||||
@ -1184,7 +1184,9 @@ class BlenderbotDecoderWrapper(BlenderbotPreTrainedModel):
|
||||
|
||||
# Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->Blenderbot, facebook/bart-base->facebook/blenderbot-400M-distill
|
||||
class BlenderbotForCausalLM(BlenderbotPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tied_weights_keys = {
|
||||
"lm_head.weight": "model.decoder.embed_tokens.weight",
|
||||
}
|
||||
|
||||
def __init__(self, config):
|
||||
config.is_decoder = True
|
||||
|
||||
@ -431,19 +431,20 @@ class BlenderbotSmallPreTrainedModel(PreTrainedModel):
|
||||
|
||||
_can_compile_fullgraph = True
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
std = self.config.init_std
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
module.weight.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
module.weight.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
module.weight[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.data.zero_()
|
||||
module.weight.fill_(1.0)
|
||||
module.bias.zero_()
|
||||
|
||||
@property
|
||||
def dummy_inputs(self):
|
||||
@ -467,7 +468,7 @@ class BlenderbotSmallEncoder(BlenderbotSmallPreTrainedModel):
|
||||
embed_tokens (nn.Embedding): output embedding
|
||||
"""
|
||||
|
||||
def __init__(self, config: BlenderbotSmallConfig, embed_tokens: Optional[nn.Embedding] = None):
|
||||
def __init__(self, config: BlenderbotSmallConfig):
|
||||
super().__init__(config)
|
||||
|
||||
self.dropout = config.dropout
|
||||
@ -478,10 +479,7 @@ class BlenderbotSmallEncoder(BlenderbotSmallPreTrainedModel):
|
||||
self.max_source_positions = config.max_position_embeddings
|
||||
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
|
||||
|
||||
if embed_tokens is not None:
|
||||
self.embed_tokens = embed_tokens
|
||||
else:
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)
|
||||
|
||||
self.embed_positions = BlenderbotSmallLearnedPositionalEmbedding(
|
||||
config.max_position_embeddings,
|
||||
@ -612,7 +610,7 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
|
||||
embed_tokens (nn.Embedding): output embedding
|
||||
"""
|
||||
|
||||
def __init__(self, config: BlenderbotSmallConfig, embed_tokens: Optional[nn.Embedding] = None):
|
||||
def __init__(self, config: BlenderbotSmallConfig):
|
||||
super().__init__(config)
|
||||
self.dropout = config.dropout
|
||||
self.layerdrop = config.decoder_layerdrop
|
||||
@ -620,10 +618,7 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
|
||||
self.max_target_positions = config.max_position_embeddings
|
||||
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
|
||||
|
||||
if embed_tokens is not None:
|
||||
self.embed_tokens = embed_tokens
|
||||
else:
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
|
||||
|
||||
self.embed_positions = BlenderbotSmallLearnedPositionalEmbedding(
|
||||
config.max_position_embeddings,
|
||||
@ -838,7 +833,10 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
|
||||
|
||||
@auto_docstring
|
||||
class BlenderbotSmallModel(BlenderbotSmallPreTrainedModel):
|
||||
_tied_weights_keys = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"]
|
||||
_tied_weights_keys = {
|
||||
"encoder.embed_tokens.weight": "shared.weight",
|
||||
"decoder.embed_tokens.weight": "shared.weight",
|
||||
}
|
||||
|
||||
def __init__(self, config: BlenderbotSmallConfig):
|
||||
super().__init__(config)
|
||||
@ -846,8 +844,8 @@ class BlenderbotSmallModel(BlenderbotSmallPreTrainedModel):
|
||||
padding_idx, vocab_size = config.pad_token_id, config.vocab_size
|
||||
self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx)
|
||||
|
||||
self.encoder = BlenderbotSmallEncoder(config, self.shared)
|
||||
self.decoder = BlenderbotSmallDecoder(config, self.shared)
|
||||
self.encoder = BlenderbotSmallEncoder(config)
|
||||
self.decoder = BlenderbotSmallDecoder(config)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
@ -974,7 +972,9 @@ class BlenderbotSmallModel(BlenderbotSmallPreTrainedModel):
|
||||
class BlenderbotSmallForConditionalGeneration(BlenderbotSmallPreTrainedModel, GenerationMixin):
|
||||
base_model_prefix = "model"
|
||||
_keys_to_ignore_on_load_missing = ["final_logits_bias"]
|
||||
_tied_weights_keys = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight", "lm_head.weight"]
|
||||
_tied_weights_keys = {
|
||||
"lm_head.weight": "model.shared.weight",
|
||||
}
|
||||
|
||||
def __init__(self, config: BlenderbotSmallConfig):
|
||||
super().__init__(config)
|
||||
@ -1144,7 +1144,9 @@ class BlenderbotSmallDecoderWrapper(BlenderbotSmallPreTrainedModel):
|
||||
|
||||
# Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->BlenderbotSmall, facebook/bart-base->facebook/blenderbot_small-90M
|
||||
class BlenderbotSmallForCausalLM(BlenderbotSmallPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tied_weights_keys = {
|
||||
"lm_head.weight": "model.decoder.embed_tokens.weight",
|
||||
}
|
||||
|
||||
def __init__(self, config):
|
||||
config.is_decoder = True
|
||||
|
||||
@ -419,13 +419,14 @@ class BlipPreTrainedModel(PreTrainedModel):
|
||||
_no_split_modules = ["BlipEncoderLayer", "BlipTextEmbeddings"]
|
||||
_skip_keys_device_placement = ["past_key_values"]
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
factor = self.config.initializer_range
|
||||
if isinstance(module, (nn.Conv2d, nn.Embedding, nn.Linear)):
|
||||
module.weight.data.normal_(mean=0.0, std=factor)
|
||||
module.weight.normal_(mean=0.0, std=factor)
|
||||
if hasattr(module, "bias") and module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
|
||||
if isinstance(module, BlipVisionEmbeddings):
|
||||
if hasattr(self.config, "vision_config"):
|
||||
@ -443,10 +444,10 @@ class BlipPreTrainedModel(PreTrainedModel):
|
||||
)
|
||||
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.zero_()
|
||||
module.weight.fill_(1.0)
|
||||
elif isinstance(module, nn.Linear) and module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
|
||||
|
||||
class BlipEncoder(nn.Module):
|
||||
@ -797,8 +798,11 @@ class BlipModel(BlipPreTrainedModel):
|
||||
)
|
||||
class BlipForConditionalGeneration(BlipPreTrainedModel, GenerationMixin):
|
||||
config: BlipConfig
|
||||
_tied_weights_keys = ["text_decoder.cls.predictions.decoder.bias"]
|
||||
main_input_name = "pixel_values"
|
||||
_tied_weights_keys = {
|
||||
"text_decoder.cls.predictions.decoder.bias": "text_decoder.cls.predictions.bias",
|
||||
"text_decoder.cls.predictions.decoder.weight": "text_decoder.bert.embeddings.word_embeddings.weight",
|
||||
} # TODO @arthurzucker check why we need this when for other models, their subPreTrainedModel handle it themselves.
|
||||
|
||||
def __init__(self, config: BlipConfig):
|
||||
super().__init__(config)
|
||||
@ -963,7 +967,10 @@ class BlipForConditionalGeneration(BlipPreTrainedModel, GenerationMixin):
|
||||
)
|
||||
class BlipForQuestionAnswering(BlipPreTrainedModel, GenerationMixin):
|
||||
config: BlipConfig
|
||||
_tied_weights_keys = ["text_decoder.cls.predictions.decoder.bias"]
|
||||
_tied_weights_keys = {
|
||||
"text_decoder.cls.predictions.decoder.bias": "text_decoder.cls.predictions.bias",
|
||||
"text_decoder.cls.predictions.decoder.weight": "text_decoder.bert.embeddings.word_embeddings.weight",
|
||||
}
|
||||
|
||||
def __init__(self, config: BlipConfig):
|
||||
super().__init__(config)
|
||||
@ -971,7 +978,6 @@ class BlipForQuestionAnswering(BlipPreTrainedModel, GenerationMixin):
|
||||
self.vision_model = BlipVisionModel(config.vision_config)
|
||||
|
||||
self.text_encoder = BlipTextModel(config.text_config, add_pooling_layer=False)
|
||||
|
||||
self.text_decoder = BlipTextLMHeadModel(config.text_config)
|
||||
|
||||
self.decoder_pad_token_id = config.text_config.pad_token_id
|
||||
|
||||
@ -473,16 +473,9 @@ class BlipTextLMPredictionHead(nn.Module):
|
||||
|
||||
# The output weights are the same as the input embeddings, but there is
|
||||
# an output-only bias for each token.
|
||||
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=True)
|
||||
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
||||
|
||||
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
||||
self.decoder.bias = self.bias
|
||||
|
||||
def _tie_weights(self):
|
||||
self.decoder.bias = self.bias
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.transform(hidden_states)
|
||||
hidden_states = self.decoder(hidden_states)
|
||||
@ -511,15 +504,16 @@ class BlipTextPreTrainedModel(PreTrainedModel):
|
||||
base_model_prefix = "bert"
|
||||
_no_split_modules = []
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.zero_()
|
||||
module.weight.fill_(1.0)
|
||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
|
||||
|
||||
# Adapted from https://github.com/salesforce/BLIP/blob/3a29b7410476bf5f2ba0955827390eb6ea1f4f9d/models/med.py#L571
|
||||
@ -744,7 +738,10 @@ class BlipTextModel(BlipTextPreTrainedModel):
|
||||
|
||||
# Adapted from https://github.com/salesforce/BLIP/blob/main/models/med.py#L811
|
||||
class BlipTextLMHeadModel(BlipTextPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
|
||||
_tied_weights_keys = {
|
||||
"cls.predictions.decoder.bias": "cls.predictions.bias",
|
||||
"cls.predictions.decoder.weight": "bert.embeddings.word_embeddings.weight",
|
||||
}
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
@ -53,6 +53,10 @@ class BlipProcessor(ProcessorMixin):
|
||||
An instance of ['BertTokenizerFast`]. The tokenizer is a required input.
|
||||
"""
|
||||
|
||||
attributes = ["image_processor", "tokenizer"]
|
||||
image_processor_class = ("BlipImageProcessor", "BlipImageProcessorFast")
|
||||
tokenizer_class = ("BertTokenizer", "BertTokenizerFast")
|
||||
|
||||
def __init__(self, image_processor, tokenizer, **kwargs):
|
||||
tokenizer.return_token_type_ids = False
|
||||
super().__init__(image_processor, tokenizer)
|
||||
|
||||
@ -263,9 +263,7 @@ class Blip2Config(PreTrainedConfig):
|
||||
```"""
|
||||
|
||||
model_type = "blip-2"
|
||||
attribute_map = {
|
||||
"image_token_id": "image_token_index",
|
||||
}
|
||||
attribute_map = {"image_token_id": "image_token_index", "tie_words_embeddings": "use_decoder_only_language_model"}
|
||||
sub_configs = {"text_config": AutoConfig, "qformer_config": Blip2QFormerConfig, "vision_config": Blip2VisionConfig}
|
||||
|
||||
def __init__(
|
||||
|
||||
@ -409,19 +409,20 @@ class Blip2PreTrainedModel(PreTrainedModel):
|
||||
]
|
||||
_skip_keys_device_placement = "past_key_values"
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
factor = self.config.initializer_range
|
||||
|
||||
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
||||
module.weight.data.normal_(mean=0.0, std=factor)
|
||||
module.weight.normal_(mean=0.0, std=factor)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=factor)
|
||||
module.weight.normal_(mean=0.0, std=factor)
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.zero_()
|
||||
module.weight.fill_(1.0)
|
||||
elif isinstance(module, Blip2VisionEmbeddings):
|
||||
nn.init.trunc_normal_(module.position_embedding, mean=0.0, std=factor)
|
||||
nn.init.trunc_normal_(module.class_embedding, mean=0.0, std=factor)
|
||||
@ -435,7 +436,7 @@ class Blip2PreTrainedModel(PreTrainedModel):
|
||||
Blip2ForImageTextRetrieval,
|
||||
),
|
||||
):
|
||||
module.query_tokens.data.zero_()
|
||||
module.query_tokens.zero_()
|
||||
|
||||
|
||||
# Copied from transformers.models.blip.modeling_blip.BlipEncoder with Blip->Blip2
|
||||
@ -1049,10 +1050,6 @@ class Blip2Model(Blip2PreTrainedModel):
|
||||
else:
|
||||
language_model = AutoModelForSeq2SeqLM.from_config(config.text_config)
|
||||
|
||||
# Update _tied_weights_keys using the base model used.
|
||||
if language_model._tied_weights_keys is not None:
|
||||
self._tied_weights_keys = [f"language_model.{k}" for k in language_model._tied_weights_keys]
|
||||
|
||||
self.language_model = language_model
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
@ -1076,11 +1073,6 @@ class Blip2Model(Blip2PreTrainedModel):
|
||||
def get_decoder(self):
|
||||
return self.language_model.get_decoder()
|
||||
|
||||
def _tie_weights(self):
|
||||
if not self.config.use_decoder_only_language_model:
|
||||
self.language_model.encoder.embed_tokens = self.language_model.shared
|
||||
self.language_model.decoder.embed_tokens = self.language_model.shared
|
||||
|
||||
@filter_out_non_signature_kwargs()
|
||||
@auto_docstring
|
||||
def get_text_features(
|
||||
@ -1612,10 +1604,6 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
|
||||
else:
|
||||
language_model = AutoModelForSeq2SeqLM.from_config(config.text_config)
|
||||
|
||||
# Update _tied_weights_keys using the base model used.
|
||||
if language_model._tied_weights_keys is not None:
|
||||
self._tied_weights_keys = [f"language_model.{k}" for k in language_model._tied_weights_keys]
|
||||
|
||||
self.language_model = language_model
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
@ -1639,11 +1627,6 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
|
||||
def get_decoder(self):
|
||||
return self.language_model.get_decoder()
|
||||
|
||||
def _tie_weights(self):
|
||||
if not self.config.use_decoder_only_language_model:
|
||||
self.language_model.encoder.embed_tokens = self.language_model.shared
|
||||
self.language_model.decoder.embed_tokens = self.language_model.shared
|
||||
|
||||
def _preprocess_accelerate(self):
|
||||
r"""
|
||||
Some pre-processing hacks to make the model `accelerate` compatible. Check
|
||||
|
||||
@ -60,6 +60,10 @@ class Blip2Processor(ProcessorMixin):
|
||||
Number of tokens used by the Qformer as queries, should be same as in model's config.
|
||||
"""
|
||||
|
||||
attributes = ["image_processor", "tokenizer"]
|
||||
image_processor_class = ("BlipImageProcessor", "BlipImageProcessorFast")
|
||||
tokenizer_class = "AutoTokenizer"
|
||||
|
||||
def __init__(self, image_processor, tokenizer, num_query_tokens=None, **kwargs):
|
||||
tokenizer.return_token_type_ids = False
|
||||
if not hasattr(tokenizer, "image_token"):
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user