Compare commits

..

26 Commits

Author SHA1 Message Date
a9c0f9a62b Add prefix detection 2025-11-07 15:45:17 +00:00
041690ffd0 Add de-initialization mechnaism 2025-11-07 10:34:34 +00:00
c85ee385b5 Added de-duplication 2025-11-06 11:51:36 +00:00
25ac7a0c2a Added hasing mechanism (wip) 2025-11-05 21:36:24 +00:00
f6e4cc6230 Removed dead code 2025-11-05 08:43:50 +00:00
3824e5d201 Replace _free_blocks with a proper object BlockManager 2025-11-04 14:22:23 +00:00
efde9d5427 Nit in example 2025-11-04 14:18:41 +00:00
53f953cec4 Fix a bug in the CB memory calcuation 2025-11-04 13:58:23 +00:00
dd4e048e75 Reduce the number of benchmark in the CI (#42008)
Changed how benchmark cfgs are chosen
2025-11-04 14:07:17 +01:00
6ff4fabd9d Correct syntax error in trainer.md (#42001)
A comma is missing between two parameters in the signature of compute_loss function.
2025-11-04 12:36:54 +00:00
6d4450e341 Fix torch+deepspeed docker file (#41985)
* fix

* delete

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
2025-11-04 10:41:22 +00:00
aee5c2384a DOC Fix typo in argument name: pseudoquant (#41994)
The correct argument name is pseudoquantization. Since there is no error
on passing wrong arguments name (which is arguably an anti-pattern),
this is difficult for users to debug.
2025-11-04 10:48:39 +01:00
5b6c209bc5 [kernels] change import time in KernelConfig (#42004)
* change import time

* style
2025-11-04 10:26:24 +01:00
258c76e4dc Fix run slow v2: empty report when there is only one model (#42002)
fix

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
2025-11-04 06:46:21 +01:00
64397a8301 Fixed wrong padding value in OWLv2 (#41938)
* Update image_processing_owlv2_fast.py

fixed padding value

* fixed padding value

* Change padding constant value from 0.5 to 0.0

* Fixed missed padding value in modular_owlv2.py

---------

Co-authored-by: Yoni Gozlan <74535834+yonigozlan@users.noreply.github.com>
2025-11-03 18:46:28 -05:00
cd309610c0 Integrate colqwen2.5 using colqwen2 modelling code (#40600)
* adding option for 2.5

* minor - arg in conversion script

* getting started on modelling.py

* minor - shouldve been using modular

* adressing comments + fixing datatype/device _get method

* minor

* commiting suggestion

Co-authored-by: Yoni Gozlan <74535834+yonigozlan@users.noreply.github.com>

* docs + first test

* ruff fix

* minor fix

* ruff fix

* model fix

* model fix

* fine-grained check, with a hardcoded score from the original Hf implementation.

* minor ruff

* update tests values with CI hardware

* adding 2.5 to conversion script

* Apply style fixes

---------

Co-authored-by: Sahil Kabir <sahilkabir@Sahils-MacBook-Pro.local>
Co-authored-by: Yoni Gozlan <74535834+yonigozlan@users.noreply.github.com>
Co-authored-by: yonigozlan <yoni.gozlan@huggingface.co>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2025-11-03 18:31:07 -05:00
dd8f231495 fix 3 failed test cases for video_llama_3 model on Intel XPU (#41931)
* fix 3 failed test cases for video_llama_3 model on Intel XPU

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* update

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* adjust format

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* update code

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

---------

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
2025-11-03 18:18:20 +01:00
1619a3475f fix (CI): Refactor SSH runners (#41991)
* Change ssh runner type

* Add wait step to SSH runner workflow

* Rename wait step to wait2 in ssh-runner.yml

* Remove wait step from ssh-runner.yml

Removed the wait step from the SSH runner workflow.

* Update runner type for single GPU A10 instance

* Update SSH runner version to 1.90.3

* Add sha256sum to ssh-runner workflow

* Update runner type and remove unused steps
2025-11-03 18:16:32 +01:00
ff0f7d6498 More data in benchmarking (#41848)
* Reduce scope of cross-generate

* Rm generate_sall configs

* Workflow benchmarks more

* Prevent crash when FA is not installed
2025-11-03 18:05:26 +01:00
80305364e2 Move the Mi355 to regular docker (#41989)
* Move the Mi355 to regular docker

* Disable gfx950 compilation for FA on AMD
2025-11-03 16:41:06 +01:00
a623cda427 [kernels] Add Tests & CI for kernels (#41765)
* first commit

* add tests

* add kernel config

* add more tests

* add ci

* small fix

* change branch name

* update tests

* nit

* change test name

* revert jobs

* addressing review

* reenable all jobs

* address second review
2025-11-03 16:36:52 +01:00
7d5160bd7a Fix torchcodec version in quantization docker file (#41988)
check

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
2025-11-03 15:27:47 +01:00
22e39dfb31 docs: add continuous batching page (#41847)
* docs: add continuous batching page

* docs(cb): add `generate_batch` example

* docs(cb): add `opentelemtry` and `serving` section

* feat: add `TODO` note about opentelemetry dependency

* docs(cb): add supported features

* docs(cb): add unsupported features

* docs(cb): add `ContinuousBatchingManager` example

* docs(cb): x reference CB in optimizing inference
2025-11-03 15:19:30 +01:00
63fbd50fb4 fix: dict[RopeParameters] to dict[str, RopeParameters] (#41963) 2025-11-03 14:09:27 +00:00
b433ec8b50 test tensor parallel: make tests for dense model more robust (#41968)
* make test forward and backward more robust

* refactor compile part of test tensor parallel

* linting

* pass rank around instead of calling it over and over

* Run slow v2 (#41914)

* Super

* Super

* Super

* Super

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>

* Fix `detectron2` installation in docker files (#41975)

* detectron2 - part 1

* detectron2 - part 2

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>

* Fix `autoawq[kernels]` installation in quantization docker file (#41978)

fix autoawq[kernels]

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>

* add support for saving encoder only so any parakeet model can be loaded for inference (#41969)

* add support for saving encoder only so any decoder model can be loaded

Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com>

* use convolution_bias

* convert modular

* convolution_bias in convertion script

---------

Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com>
Co-authored-by: Eustache Le Bihan <eulebihan@gmail.com>
Co-authored-by: eustlb <94853470+eustlb@users.noreply.github.com>

---------

Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com>
Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com>
Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
Co-authored-by: Nithin Rao <nithinrao.koluguri@gmail.com>
Co-authored-by: Eustache Le Bihan <eulebihan@gmail.com>
Co-authored-by: eustlb <94853470+eustlb@users.noreply.github.com>
2025-11-03 13:56:26 +01:00
3c16c1ae43 Use indices as position_ids in modernebert (#41789)
* Use indices as position_ids in modernebert

* Move position_ids init to the branch
2025-11-03 12:10:24 +01:00
242 changed files with 3391 additions and 4113 deletions

View File

@ -52,7 +52,7 @@ jobs:
commit_id=$GITHUB_SHA
fi
commit_msg=$(git show -s --format=%s | cut -c1-70)
python3 benchmark_v2/run_benchmarks.py -b 32 -s 128 -n 256 --branch-name "$BRANCH_NAME" --commit-id "$commit_id" --commit-message "$commit_msg" --model-id "$MODEL_ID" --log-level INFO --push-result-to-dataset "$DATASET_ID"
python3 benchmark_v2/run_benchmarks.py -b 32 -s 128 -n 256 --level 2 --branch-name "$BRANCH_NAME" --commit-id "$commit_id" --commit-message "$commit_msg" --model-id "$MODEL_ID" --log-level INFO --push-result-to-dataset "$DATASET_ID"
env:
HF_TOKEN: ${{ secrets.HF_HUB_READ_TOKEN }}
PUSH_TO_HUB_TOKEN: ${{ secrets.PUSH_TO_HUB_TOKEN }}

View File

@ -97,7 +97,7 @@ jobs:
latest-torch-deepspeed-docker:
name: "Latest PyTorch + DeepSpeed"
runs-on:
group: aws-g4dn-2xlarge-cache
group: aws-general-8-plus
steps:
-
name: Set up Docker Buildx

View File

@ -21,7 +21,7 @@ jobs:
job: run_models_gpu
slack_report_channel: "#amd-hf-ci"
runner_group: hfc-amd-mi355
docker: huggingface/testing-rocm7.0-preview
docker: huggingface/transformers-pytorch-amd-gpu
ci_event: Scheduled CI (AMD) - mi355
report_repo_id: hf-transformers-bot/transformers-ci-dummy
secrets: inherit
@ -33,7 +33,7 @@ jobs:
job: run_pipelines_torch_gpu
slack_report_channel: "#amd-hf-ci"
runner_group: hfc-amd-mi355
docker: huggingface/testing-rocm7.0-preview
docker: huggingface/transformers-pytorch-amd-gpu
ci_event: Scheduled CI (AMD) - mi355
report_repo_id: hf-transformers-bot/transformers-ci-dummy
secrets: inherit
@ -45,7 +45,7 @@ jobs:
job: run_examples_gpu
slack_report_channel: "#amd-hf-ci"
runner_group: hfc-amd-mi355
docker: huggingface/testing-rocm7.0-preview
docker: huggingface/transformers-pytorch-amd-gpu
ci_event: Scheduled CI (AMD) - mi355
report_repo_id: hf-transformers-bot/transformers-ci-dummy
secrets: inherit

View File

@ -118,3 +118,15 @@ jobs:
report_repo_id: hf-internal-testing/transformers_daily_ci
commit_sha: ${{ github.sha }}
secrets: inherit
kernels-ci:
name: Kernels CI
uses: ./.github/workflows/self-scheduled.yml
with:
job: run_kernels_gpu
slack_report_channel: "#transformers-ci-daily-kernels"
docker: huggingface/transformers-all-latest-gpu
ci_event: Daily CI
report_repo_id: hf-internal-testing/transformers_daily_ci
commit_sha: ${{ github.sha }}
secrets: inherit

View File

@ -102,8 +102,10 @@ jobs:
working-directory: /transformers/tests
run: |
if [ "${{ inputs.job }}" = "run_models_gpu" ]; then
echo "folder_slices=$(python3 ../utils/split_model_tests.py --subdirs '${{ inputs.subdirs }}' --num_splits ${{ env.NUM_SLICES }})" >> $GITHUB_OUTPUT
echo "slice_ids=$(python3 -c 'd = list(range(${{ env.NUM_SLICES }})); print(d)')" >> $GITHUB_OUTPUT
python3 ../utils/split_model_tests.py --subdirs '${{ inputs.subdirs }}' --num_splits ${{ env.NUM_SLICES }} > folder_slices.txt
echo "folder_slices=$(cat folder_slices.txt)" >> $GITHUB_OUTPUT
python3 -c "import ast; folder_slices = ast.literal_eval(open('folder_slices.txt').read()); open('slice_ids.txt', 'w').write(str(list(range(len(folder_slices)))))"
echo "slice_ids=$(cat slice_ids.txt)" >> $GITHUB_OUTPUT
elif [ "${{ inputs.job }}" = "run_trainer_and_fsdp_gpu" ]; then
echo "folder_slices=[['trainer'], ['fsdp']]" >> $GITHUB_OUTPUT
echo "slice_ids=[0, 1]" >> $GITHUB_OUTPUT
@ -336,7 +338,7 @@ jobs:
working-directory: ${{ inputs.working-directory-prefix }}/
run: |
python3 -m pip uninstall -y deepspeed
DS_DISABLE_NINJA=1 DS_BUILD_CPU_ADAM=1 DS_BUILD_FUSED_ADAM=1 python3 -m pip install deepspeed --global-option="build_ext" --global-option="-j8" --no-cache -v --disable-pip-version-check
DS_DISABLE_NINJA=1 DS_BUILD_CPU_ADAM=1 DS_BUILD_FUSED_ADAM=1 python3 -m pip install deepspeed --no-build-isolation --config-settings="--build-option=build_ext" --config-settings="--build-option=-j8" --no-cache -v --disable-pip-version-check
# To avoid unknown test failures
- name: Pre build DeepSpeed *again* (for nightly & Past CI)
@ -346,7 +348,7 @@ jobs:
python3 -m pip uninstall -y deepspeed
rm -rf DeepSpeed
git clone https://github.com/deepspeedai/DeepSpeed && cd DeepSpeed && rm -rf build
DS_BUILD_CPU_ADAM=1 DS_BUILD_FUSED_ADAM=1 python3 -m pip install . --global-option="build_ext" --global-option="-j8" --no-cache -v --disable-pip-version-check
DS_BUILD_CPU_ADAM=1 DS_BUILD_FUSED_ADAM=1 python3 -m pip install . --no-build-isolation --config-settings="--build-option=build_ext" --config-settings="--build-option=-j8" --no-cache -v --disable-pip-version-check
- name: NVIDIA-SMI
run: |
@ -475,6 +477,70 @@ jobs:
name: ${{ env.machine_type }}_run_quantization_torch_gpu_${{ env.matrix_folders }}_test_reports
path: /transformers/reports/${{ env.machine_type }}_run_quantization_torch_gpu_${{ matrix.folders }}_test_reports
run_kernels_gpu:
if: ${{ inputs.job == 'run_kernels_gpu' }}
name: Kernel tests
strategy:
fail-fast: false
matrix:
machine_type: [aws-g5-4xlarge-cache]
runs-on:
group: '${{ matrix.machine_type }}'
container:
image: ${{ inputs.docker }}
options: --gpus all --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/
steps:
- name: Update clone
working-directory: /transformers
run: git fetch && git checkout ${{ inputs.commit_sha || github.sha }}
- name: Reinstall transformers in edit mode
working-directory: /transformers
run: python3 -m pip uninstall -y transformers && python3 -m pip install -e .[testing]
- name: Install kernels
working-directory: /transformers
run: python3 -m pip install -U kernels
- name: NVIDIA-SMI
run: nvidia-smi
- name: Environment
working-directory: /transformers
run: python3 utils/print_env.py
- name: Show installed libraries and their versions
working-directory: /transformers
run: pip freeze
- name: Set `machine_type` for report and artifact names
working-directory: /transformers
shell: bash
run: |
if [ "${{ matrix.machine_type }}" = "aws-g5-4xlarge-cache" ]; then
machine_type=single-gpu
else
machine_type=${{ matrix.machine_type }}
fi
echo "machine_type=$machine_type" >> $GITHUB_ENV
- name: Run kernel tests on GPU
working-directory: /transformers
run: |
python3 -m pytest -v --make-reports=${{ env.machine_type }}_run_kernels_gpu_test_reports tests/kernels/test_kernels.py
- name: Failure short reports
if: ${{ failure() }}
continue-on-error: true
run: cat /transformers/reports/${{ env.machine_type }}_run_kernels_gpu_test_reports/failures_short.txt
- name: "Test suite reports artifacts: ${{ env.machine_type }}_run_kernels_gpu_test_reports"
if: ${{ always() }}
uses: actions/upload-artifact@v4
with:
name: ${{ env.machine_type }}_run_kernels_gpu_test_reports
path: /transformers/reports/${{ env.machine_type }}_run_kernels_gpu_test_reports
run_extract_warnings:
# Let's only do this for the job `run_models_gpu` to simplify the (already complex) logic.
if: ${{ always() && inputs.job == 'run_models_gpu' }}
@ -527,6 +593,7 @@ jobs:
run_examples_gpu,
run_torch_cuda_extensions_gpu,
run_quantization_torch_gpu,
run_kernels_gpu,
run_extract_warnings
]
if: always() && !cancelled()

View File

@ -4,7 +4,7 @@ on:
workflow_dispatch:
inputs:
runner_type:
description: 'Type of runner to test (a10 or t4)'
description: 'Type of runner to test (a10)'
required: true
docker_image:
description: 'Name of the Docker image'
@ -36,14 +36,10 @@ jobs:
NUM_GPUS: ${{ github.event.inputs.num_gpus }}
RUNNER_TYPE: ${{ github.event.inputs.runner_type }}
run: |
if [[ "$NUM_GPUS" == "single" && "$RUNNER_TYPE" == "t4" ]]; then
echo "RUNNER=aws-g4dn-4xlarge-cache" >> $GITHUB_ENV
elif [[ "$NUM_GPUS" == "multi" && "$RUNNER_TYPE" == "t4" ]]; then
echo "RUNNER=aws-g4dn-12xlarge-cache" >> $GITHUB_ENV
elif [[ "$NUM_GPUS" == "single" && "$RUNNER_TYPE" == "a10" ]]; then
echo "RUNNER=aws-g5-4xlarge-cache" >> $GITHUB_ENV
if [[ "$NUM_GPUS" == "single" && "$RUNNER_TYPE" == "a10" ]]; then
echo "RUNNER=aws-g5-4xlarge-cache-ssh" >> $GITHUB_ENV
elif [[ "$NUM_GPUS" == "multi" && "$RUNNER_TYPE" == "a10" ]]; then
echo "RUNNER=aws-g5-12xlarge-cache" >> $GITHUB_ENV
echo "RUNNER=aws-g5-12xlarge-cache-ssh" >> $GITHUB_ENV
else
echo "RUNNER=" >> $GITHUB_ENV
fi
@ -61,8 +57,6 @@ jobs:
group: ${{ needs.get_runner.outputs.RUNNER }}
container:
image: ${{ github.event.inputs.docker_image }}
options: --gpus all --privileged --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/
steps:
- name: Update clone
working-directory: /transformers
@ -106,7 +100,7 @@ jobs:
else
echo "SLACKCHANNEL=${{ secrets.SLACK_CIFEEDBACK_CHANNEL }}" >> $GITHUB_ENV
fi
- name: Tailscale # In order to be able to SSH when a test fails
uses: huggingface/tailscale-action@main
with:

View File

@ -1,8 +1,11 @@
import hashlib
import itertools
import json
import logging
from typing import Any
from transformers.utils.import_utils import is_flash_attn_2_available
KERNELIZATION_AVAILABLE = False
try:
@ -18,6 +21,16 @@ logger = logging.getLogger(__name__)
class BenchmarkConfig:
"""Configuration for a single benchmark scenario."""
all_attn_implementations = [
("flash_attention_2", None),
("eager", None),
("sdpa", "math"),
("sdpa", "flash_attention"),
("flex_attention", None),
]
all_compiled_modes = [None, "default", "reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs"]
def __init__(
self,
warmup_iterations: int = 5,
@ -59,6 +72,13 @@ class BenchmarkConfig:
def check_validity(self, skip_validity_check: bool = False) -> None:
if skip_validity_check:
return
# Check FA is installed
if self.attn_implementation == "flash_attention_2" and not is_flash_attn_2_available():
logger.warning(
"Flash attention does not support compile mode. Defaulting to SDPA w/ flash attention backend."
)
self.attn_implementation = "sdpa"
self.sdpa_backend = "flash_attention"
# Flash attention does not support compile mode, so we turn it off # FIXME: it would be better to support it
is_fa = self.attn_implementation == "flash_attention_2"
is_fa |= self.attn_implementation == "sdpa" and self.sdpa_backend == "flash_attention"
@ -127,88 +147,68 @@ class BenchmarkConfig:
)
def cross_generate_configs(
attn_impl_and_sdpa_backend: list[tuple[str, str | None]],
compiled_mode: list[str | None],
kernelized: list[bool],
warmup_iterations: int = 5,
measurement_iterations: int = 20,
batch_size: int = 1,
sequence_length: int = 128,
num_tokens_to_generate: int = 128,
gpu_monitoring: bool = True,
def adapt_configs(
configs: list[BenchmarkConfig],
warmup_iterations: int | list[int] = 5,
measurement_iterations: int | list[int] = 20,
batch_size: int | list[int] = 1,
sequence_length: int | list[int] = 128,
num_tokens_to_generate: int | list[int] = 128,
gpu_monitoring: bool | list[bool] = True,
) -> list[BenchmarkConfig]:
# Create kwargs common to all configs
kwargs = {
"warmup_iterations": warmup_iterations,
"measurement_iterations": measurement_iterations,
"batch_size": batch_size,
"sequence_length": sequence_length,
"num_tokens_to_generate": num_tokens_to_generate,
"gpu_monitoring": gpu_monitoring,
}
# Cross-generate all combinations of attn_implementation, compiled_mode, and kernelized
configs = []
for attn_implementation, sdpa_backend in list(dict.fromkeys(attn_impl_and_sdpa_backend)):
for cm in list(dict.fromkeys(compiled_mode)):
for kernelize_on in list(dict.fromkeys(kernelized)):
config = BenchmarkConfig(
attn_implementation=attn_implementation,
sdpa_backend=sdpa_backend,
compile_mode=cm,
kernelize=kernelize_on,
**kwargs,
)
configs.append(config)
return configs
def generate_all_configs(
warmup_iterations: int = 5,
measurement_iterations: int = 20,
batch_size: int = 1,
sequence_length: int = 128,
num_tokens_to_generate: int = 128,
gpu_monitoring: bool = True,
) -> list[BenchmarkConfig]:
all_attn_implementations = [
("flash_attention_2", None),
("eager", None),
("sdpa", "math"),
("sdpa", "flash_attention"),
("flex_attention", None),
]
return cross_generate_configs(
attn_impl_and_sdpa_backend=all_attn_implementations,
compiled_mode=[None, "default", "reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs"],
kernelized=[False, KERNELIZATION_AVAILABLE],
warmup_iterations=warmup_iterations,
measurement_iterations=measurement_iterations,
batch_size=batch_size,
sequence_length=sequence_length,
num_tokens_to_generate=num_tokens_to_generate,
gpu_monitoring=gpu_monitoring,
parameters = (
x if isinstance(x, list) else [x]
for x in [
warmup_iterations,
measurement_iterations,
batch_size,
sequence_length,
num_tokens_to_generate,
gpu_monitoring,
]
)
iterator = itertools.product(*parameters)
adapted_configs = []
for warmup_iters, measurement_iters, bs, seqlen, ntok, monitor in iterator:
for config in configs:
config = config.to_dict()
config["warmup_iterations"] = warmup_iters
config["measurement_iterations"] = measurement_iters
config["batch_size"] = bs
config["sequence_length"] = seqlen
config["num_tokens_to_generate"] = ntok
config["gpu_monitoring"] = monitor
adapted_configs.append(BenchmarkConfig.from_dict(config))
return adapted_configs
def generate_main_configs(
warmup_iterations: int = 5,
measurement_iterations: int = 20,
batch_size: int = 1,
sequence_length: int = 128,
num_tokens_to_generate: int = 128,
) -> list[BenchmarkConfig]:
# Create kwargs common to all configs
kwargs = {
"warmup_iterations": warmup_iterations,
"measurement_iterations": measurement_iterations,
"batch_size": batch_size,
"sequence_length": sequence_length,
"num_tokens_to_generate": num_tokens_to_generate,
}
return [ # TODO: test max-autotune instead of default
BenchmarkConfig(attn_implementation="flex_attention", compile_mode="default", gpu_monitoring=False, **kwargs),
BenchmarkConfig(attn_implementation="flex_attention", compile_mode="default", gpu_monitoring=True, **kwargs),
BenchmarkConfig(attn_implementation="eager", compile_mode="default", gpu_monitoring=True, **kwargs),
BenchmarkConfig(attn_implementation="flash_attention_2", gpu_monitoring=True, **kwargs),
]
def get_config_by_level(level: int) -> list[BenchmarkConfig]:
configs = []
# Early return if level is greater than 3: we generate all combinations of configs, maybe even w/ all compile modes
if level >= 3:
for attn_implementation, sdpa_backend in BenchmarkConfig.all_attn_implementations:
# Usually there is not much to gain by compiling with other modes, but we allow it for level 4
compile_modes = BenchmarkConfig.all_compiled_modes if level >= 4 else [None, "default"]
for cm in compile_modes:
for kernelize_on in [False, KERNELIZATION_AVAILABLE]:
configs.append(
BenchmarkConfig(
attn_implementation=attn_implementation,
sdpa_backend=sdpa_backend,
compile_mode=cm,
kernelize=kernelize_on,
)
)
return configs
# Otherwise, we add the configs for the given level
if level >= 0:
configs.append(BenchmarkConfig(attn_implementation="flex_attention", compile_mode="default"))
if level >= 1:
configs.append(BenchmarkConfig(attn_implementation="flash_attention_2"))
configs.append(BenchmarkConfig(attn_implementation="eager", compile_mode="default"))
if level >= 2:
configs.append(BenchmarkConfig(attn_implementation="sdpa", compile_mode="default"))
configs.append(BenchmarkConfig(attn_implementation="flex_attention", compile_mode="default", kernelize=True))
configs.append(BenchmarkConfig(attn_implementation="flash_attention_2", kernelize=True))
return configs

View File

@ -23,7 +23,7 @@ import logging
import sys
import uuid
from framework.benchmark_config import BenchmarkConfig, generate_all_configs, generate_main_configs
from framework.benchmark_config import adapt_configs, get_config_by_level
from framework.benchmark_runner import BenchmarkRunner
@ -40,7 +40,14 @@ if __name__ == "__main__":
parser.add_argument("--sequence-length", "-s", type=int, nargs="+", help="Sequence length")
parser.add_argument("--num-tokens-to-generate", "-n", type=int, nargs="+", help="Number of tokens to generate")
parser.add_argument("--cross-generate", action="store_true", help="Cross-generate all combinations of configs")
parser.add_argument(
"--level",
type=int,
default=1,
help="Level of coverage for the benchmark. 0: only the main config, 1: a few important configs, 2: a config for"
" each attn implementation an option, 3: cross-generate all combinations of configs, 4: cross-generate all"
" combinations of configs w/ all compile modes",
)
parser.add_argument("--num-tokens-to-profile", "-p", type=int, default=0, help="Number of tokens to profile")
parser.add_argument("--branch-name", type=str, help="Git branch name")
@ -79,64 +86,24 @@ if __name__ == "__main__":
"At least one of the arguments --batch-size, --sequence-length, or --num-tokens-to-generate is required"
)
# If there is only one (batch_size, sequence_length, num_tokens_to_generate), we benchmark across configs
elif len(args.batch_size) * len(args.sequence_length) * len(args.num_tokens_to_generate) == 1:
if args.cross_generate:
benchmark_configs = generate_all_configs(
warmup_iterations=args.warmup,
measurement_iterations=args.iterations,
batch_size=args.batch_size[0],
sequence_length=args.sequence_length[0],
num_tokens_to_generate=args.num_tokens_to_generate[0],
gpu_monitoring=not args.no_gpu_monitoring,
)
else:
benchmark_configs = generate_main_configs(
warmup_iterations=args.warmup,
measurement_iterations=args.iterations,
batch_size=args.batch_size[0],
sequence_length=args.sequence_length[0],
num_tokens_to_generate=args.num_tokens_to_generate[0],
)
# Otherwise, we benchmark across all combinations of dimensions
else:
main_config = generate_main_configs(
warmup_iterations=args.warmup,
measurement_iterations=args.iterations,
batch_size=args.batch_size[0],
sequence_length=args.sequence_length[0],
num_tokens_to_generate=args.num_tokens_to_generate[0],
)[0]
benchmark_configs = []
for num_tokens_to_generate in args.num_tokens_to_generate:
for sequence_length in args.sequence_length:
for batch_size in args.batch_size:
cfg_dict = main_config.to_dict()
cfg_dict["batch_size"] = batch_size
cfg_dict["sequence_length"] = sequence_length
cfg_dict["num_tokens_to_generate"] = num_tokens_to_generate
cfg_dict.pop("name")
benchmark_configs.append(BenchmarkConfig.from_dict(cfg_dict))
runner = BenchmarkRunner(
logger,
args.output_dir,
args.branch_name,
args.commit_id,
args.commit_message,
# Get the configs for the given coverage level
configs = get_config_by_level(args.level)
# Adapt the configs to the given arguments
configs = adapt_configs(
configs,
args.warmup,
args.iterations,
args.batch_size,
args.sequence_length,
args.num_tokens_to_generate,
not args.no_gpu_monitoring,
)
runner = BenchmarkRunner(logger, args.output_dir, args.branch_name, args.commit_id, args.commit_message)
timestamp, results = runner.run_benchmarks(
args.model_id,
benchmark_configs,
args.num_tokens_to_profile,
pretty_print_summary=True,
args.model_id, configs, args.num_tokens_to_profile, pretty_print_summary=True
)
dataset_id = args.push_result_to_dataset
if dataset_id is not None and len(results) > 0:
runner.push_results_to_hub(
dataset_id,
results,
timestamp,
)
runner.push_results_to_hub(dataset_id, results, timestamp)

View File

@ -39,7 +39,7 @@ RUN python3 -m pip install --no-cache-dir "torchcodec==0.5"
# Install flash attention from source. Tested with commit 6387433156558135a998d5568a9d74c1778666d8
RUN git clone https://github.com/ROCm/flash-attention/ -b tridao && \
cd flash-attention && \
GPU_ARCHS="gfx942;gfx950" python setup.py install
# GPU_ARCHS builds for MI300, MI325 and MI355
GPU_ARCHS="gfx942" python setup.py install
# GPU_ARCHS builds for MI300, MI325 but not MI355: we would need to add `;gfx950` but it takes too long to build.
RUN python3 -m pip install --no-cache-dir einops

View File

@ -21,7 +21,7 @@ RUN python3 -m pip install --no-cache-dir './transformers[deepspeed-testing]' 'p
# Install latest release PyTorch
# (PyTorch must be installed before pre-compiling any DeepSpeed c++/cuda ops.)
# (https://www.deepspeed.ai/tutorials/advanced-install/#pre-install-deepspeed-ops)
RUN python3 -m pip uninstall -y torch torchvision torchaudio && python3 -m pip install --no-cache-dir -U torch==$PYTORCH torchvision torchaudio torchcodec --extra-index-url https://download.pytorch.org/whl/$CUDA
RUN python3 -m pip uninstall -y torch torchvision torchaudio torchcodec && python3 -m pip install --no-cache-dir -U torch==$PYTORCH torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/$CUDA
RUN python3 -m pip install --no-cache-dir git+https://github.com/huggingface/accelerate@main#egg=accelerate
@ -43,7 +43,7 @@ RUN python3 -m pip uninstall -y deepspeed
# This has to be run (again) inside the GPU VMs running the tests.
# The installation works here, but some tests fail, if we don't pre-build deepspeed again in the VMs running the tests.
# TODO: Find out why test fail.
RUN DS_BUILD_CPU_ADAM=1 DS_BUILD_FUSED_ADAM=1 python3 -m pip install deepspeed --global-option="build_ext" --global-option="-j8" --no-cache -v --disable-pip-version-check 2>&1
RUN DS_BUILD_CPU_ADAM=1 DS_BUILD_FUSED_ADAM=1 python3 -m pip install deepspeed --no-build-isolation --config-settings="--build-option=build_ext" --config-settings="--build-option=-j8" --no-cache -v --disable-pip-version-check 2>&1
# `kernels` may give different outputs (within 1e-5 range) even with the same model (weights) and the same inputs
RUN python3 -m pip uninstall -y kernels

View File

@ -24,7 +24,7 @@ RUN [ ${#PYTORCH} -gt 0 ] && VERSION='torch=='$PYTORCH'.*' || VERSION='torch';
RUN echo torch=$VERSION
# `torchvision` and `torchaudio` should be installed along with `torch`, especially for nightly build.
# Currently, let's just use their latest releases (when `torch` is installed with a release version)
RUN python3 -m pip install --no-cache-dir -U $VERSION torchvision torchaudio torchcodec --extra-index-url https://download.pytorch.org/whl/$CUDA
RUN python3 -m pip install --no-cache-dir -U $VERSION torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/$CUDA
RUN python3 -m pip install --no-cache-dir git+https://github.com/huggingface/accelerate@main#egg=accelerate

View File

@ -119,6 +119,8 @@
title: Tools
- local: transformers_as_backend
title: Inference server backends
- local: continuous_batching
title: Continuous Batching
title: Inference
- isExpanded: false
sections:

View File

@ -0,0 +1,194 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
# Continuous Batching
Continuous Batching (CB) is an advanced technique to optimize the inference of transformer models by dynamically grouping multiple requests into batches. This approach maximizes GPU utilization and throughput, specifically for workloads with many variable-length inputs.
We are particularly interested in having Continuous Batching in transformers for the following use cases:
- Evaluation of models on large datasets with variable-length inputs
- Generating outputs for multiple sequences for GRPO policies
CB is what makes inference engines like vLLM or SGLang efficient. That being said, transformers does not aim to be a production-ready inference engine, but a complete framework for model development. For this reason, CB is available in `transformers serve`.
If you are not familiar with some of the core concepts CB is built upon, we invite you to read the associated blog post: [Continuous Batching: Efficient Inference for Large Language Models](https://huggingface.co/blog/continuous-batching). _broken link for now_
## API Reference
## Usage Examples
The main way to use CB in transformers is via the `generate_batch` method.
Unlike `generate`, CB takes already tokenized inputs, known as input IDs. Each sequence of input IDs is represented as a list of integers, in python: `list[int]`. Since
For a more detailed example, please refer to: [examples/continuous_batching](./path/to/example)
### `generate_batch` example
We have created a `ContinuousMixin` that is inherited by the `GenerationMixin` so that all auto regressive text models support CB.
This adds the `generate_batch` method to all models that inherit from `GenerationMixin`.
You can use it as follows:
```py
import datasets
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation import GenerationConfig
model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen3-4B-Instruct-2507",
attn_implementation="spda_paged",
device_map="cuda", # if you need cuda
dtype=torch.bfloat16,
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, padding_side="left")
# prepare a batch of inputs
dataset = datasets.load_dataset("openai/gsm8k", "socratic", split="test")
dataset = dataset.select(range(args.samples))
tokenized_datasets = dataset.map(lambda x: tokenizer(x["question"]), batched=True)
simple_batch_inputs = [item["input_ids"] for item in tokenized_datasets]
generation_config = GenerationConfig(
max_new_tokens=32,
use_cuda_graph=False, # Not supported for simple version
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
do_sample=False,
max_batch_tokens=512, # max number of tokens in a batch, this is just a default value you should tune based on your hardware
)
batch_outputs = model.generate_batch(
inputs=simple_batch_inputs,
generation_config=generation_config,
)
for request_id, output in batch_outputs.items():
generated_text = tokenizer.decode(output.generated_tokens, skip_special_tokens=True)
print(f"Request {request_id} output: {generated_text}")
```
### `ContinuousBatchingManager` example
If you want more control w.r.t. how you want to schedule requests using CB, you can use the `ContinuousBatchingManager` class directly.
This is what we use in `transformers serve` because requests arrive asynchronously and we can leverage the asynchronous nature of the CB process to make things more efficient.
Under the hood, the `ContinuousBatchingManager` creates a background thread that receives inputs from a python `queue.Queue` which it uses to get requests to batch in each forward pass.
Note that the manager is thread safe!
```py
import datasets
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation import GenerationConfig
from transformers.generation.continuous_batching import RequestStatus
model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen3-4B-Instruct-2507",
attn_implementation="spda_paged",
device_map="cuda", # if you need cuda
dtype=torch.bfloat16,
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, padding_side="left")
# prepare a batch of inputs
dataset = datasets.load_dataset("openai/gsm8k", "socratic", split="test")
dataset = dataset.select(range(args.samples))
tokenized_datasets = dataset.map(lambda x: tokenizer(x["question"]), batched=True)
simple_batch_inputs = [item["input_ids"] for item in tokenized_datasets]
# initialize the manager, available method thanks to the `ContinuousMixin`
manager = model.init_continuous_batching(generation_config=generation_config)
# start the background thread
manager.start()
# this is for demonstration purposes only, in practice this is most useful to do concurrently
for i, input in enumerate(simple_batch_inputs):
request_id = manager.add_request(input_ids=input, request_id=f"request_{i}") # if you do not specify a request_id, one will be generated for you
# Can be done in an other thread
for id, request in manager.get_result():
generated_text = tokenizer.decode(request.generated_tokens, skip_special_tokens=True)
print(f"Request {id} output: {generated_text}")
# you can also get results for a specific request id
result = manager.get_result(request_id="request_5") # this is blocking and will wait for the result to be ready
# or get results for a request that is streaming
manager.add_request(
input_ids=input,
request_id="streaming_request",
stream=True,
)
for chunk in manager.request_id_iter(request_id="streaming_request"):
generated_text = tokenizer.decode(chunk.generated_tokens, skip_special_tokens=True)
print(generated_text)
# FIXME: stop iteration in `request_id_iter` when finished instead of doing it externally
if chunk.status == RequestStatus.FINISHED:
break
# stop the background thread before exiting the process
manager.stop()
```
## Supported & Unsupported Features
### Supported Features
- Dynamic scheduling of variable-length requests
- Chunked prefill
- Paged Attention Cache
- Sliding window attention
- Chat templates
### Unsupported Features
At the moment, the following features are not supported with CB. We plan to add support to the following:
- Prefix caching
- Beam search
- tool calling
The others are unplanned, but depending on community requests we might consider adding them:
- MTP (multi token prediction)
- Medusa
## Performance Considerations
## Integration with Serving
You can use CB in `transformers serve` by passing the `--continuous-batching` flag when starting the server.
## Monitoring
We have added `opentelemetry` support to Continuous Batching to help you monitor its performance in production. To enable it, you need to install the `opentelemetry` extra when installing `transformers`:
```sh
# this installs `opentelemetry-api`, `opentelemetry-sdk` and `opentelemetry-exporter-otlp`
pip install transformers[open-telemetry]
```
This will enable traces and metrics collection in CB. You will then have to setup the backend to collect and visualize the traces and metrics.

View File

@ -393,3 +393,9 @@ model = AutoModelForCausalLM.from_pretrained(
"mistralai/Mistral-7B-v0.1", quantization_config=quant_config, device_map="auto"
)
```
## Continuous Batching
When serving LLMs for inference, you may have multiple requests arriving at different times. Continuous Batching (CB) is a technique that groups incoming requests into batches to maximize GPU utilization and throughput.
See the [Continuous Batching](./continuous_batching) guide for more details on how to use CB in transformers.

View File

@ -158,6 +158,24 @@ print("Retrieval scores (query x image):")
print(scores)
```
You can also use checkpoints for `ColQwen2.5` that are **compatible with the ColQwen2 architecture**. This version of the model uses [Qwen2_5_VL](./qwen2_5_vl) as the backbone.
```python
import torch
from transformers import ColQwen2ForRetrieval, ColQwen2Processor
from transformers.utils.import_utils import is_flash_attn_2_available
model_name = "Sahil-Kabir/colqwen2.5-v0.2-hf" # An existing compatible checkpoint
model = ColQwen2ForRetrieval.from_pretrained(
model_name,
dtype=torch.bfloat16,
device_map="auto",
attn_implementation="flash_attention_2" if is_flash_attn_2_available() else "sdpa"
)
processor = ColQwen2Processor.from_pretrained(model_name)
```
## Notes
- [`~ColQwen2Processor.score_retrieval`] returns a 2D tensor where the first dimension is the number of queries and the second dimension is the number of images. A higher score indicates more similarity between the query and image.

View File

@ -149,7 +149,7 @@ The example below packs `up_proj` and `gate_proj` into a single `gate_up_proj` m
```python
class Llama4TextExperts(nn.Module):
...
self.gate_up_proj = nn.Parameter(torch.zeros(self.num_experts, self.hidden_size, 2 * self.expert_dim))
self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, 2 * self.expert_dim))
```
Batch matrix multiplication can be used in the `forward` pass to compute the output of the `gate_up_proj` module.

View File

@ -40,7 +40,7 @@ You can choose between MXFP4 and NVFP4 with `FPQuantConfig(forward_dtype="mxfp4"
A **Blackwell-generation GPU is required** to run the kernels. Runtime support for FP-Quant is implemented through the [QuTLASS](https://github.com/IST-DASLab/qutlass) library and a lightweight PyTorch interface lib [`fp_quant`](https://github.com/IST-DASLab/FP-Quant/tree/master/inference_lib). We recommend installing the former **from source** and the latter with `pip install fp_quant`.
Users **without a Blackwell-generation GPU** , can use the method with `quantization_config=FPQuantConfig(pseudoquant=True)` without having to install [QuTLASS](https://github.com/IST-DASLab/qutlass). This would provide no speedups but would fully emulate the effect of quantization.
Users **without a Blackwell-generation GPU** , can use the method with `quantization_config=FPQuantConfig(pseudoquantization=True)` without having to install [QuTLASS](https://github.com/IST-DASLab/qutlass). This would provide no speedups but would fully emulate the effect of quantization.
> [!TIP]
> Find models pre-quantized with FP-Quant in the official ISTA-DASLab [collection](https://huggingface.co/collections/ISTA-DASLab/fp-quant-6877c186103a21d3a02568ee).

View File

@ -187,7 +187,7 @@ from torch import nn
from transformers import Trainer
class CustomTrainer(Trainer):
def compute_loss(self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], return_outputs: bool = False num_items_in_batch: Optional[torch.Tensor] = None):
def compute_loss(self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], return_outputs: bool = False, num_items_in_batch: Optional[torch.Tensor] = None):
labels = inputs.pop("labels")
# forward pass
outputs = model(**inputs)

View File

@ -152,7 +152,7 @@ class ParallelInterface(MutableMapping):
```python
class Llama4TextExperts(nn.Module):
...
self.gate_up_proj = nn.Parameter(torch.zeros(self.num_experts, self.hidden_size, 2 * self.expert_dim))
self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, 2 * self.expert_dim))
```
배치 행렬 곱셈을 `forward` 패스에서 사용하여 `gate_up_proj` 모듈의 출력을 계산할 수 있습니다.

View File

@ -16,6 +16,7 @@ import argparse
import contextlib
import json
import os
import random
import time
from typing import Optional
@ -29,7 +30,6 @@ from transformers.generation import GenerationConfig
from transformers.generation.continuous_batching.requests import logger
# MODEL_ID = "Qwen/Qwen3-4B-Instruct-2507"
SLIDING_WINDOW = 0
MODEL_ID = "google/gemma-2-2b-it" if SLIDING_WINDOW > 0 else "meta-llama/Meta-Llama-3-8B"
FORCE_MAX_LENGTH = False # should be False unless you are debugging sliding window features
@ -193,6 +193,8 @@ if __name__ == "__main__":
parser.add_argument("--compile", action="store_true", help="Compile the model using torch.compile")
parser.add_argument("--samples", type=int, default=500, help="Number of samples to generate")
parser.add_argument("--add-prefix", action="store_true", help="Add a prefix to the samples")
parser.add_argument("--displayed", type=int, default=0, help="Number of samples to display")
parser.add_argument("--log-level", type=str, default="INFO")
parser.add_argument("--output-file", type=str, default=None)
@ -242,7 +244,18 @@ if __name__ == "__main__":
dataset = datasets.load_dataset("openai/gsm8k", "socratic", split="test")
dataset = dataset.select(range(args.samples))
simple_batch_inputs = [tokenizer(item["question"])["input_ids"] for item in dataset]
def random_prefix() -> str:
if not args.add_prefix:
return ""
prefixes = [
"Math and reasonning problems are very important to the world. This is a problem, and then you will find the answer.\n",
"We all know that reasonning can be taught by answering questions, often illustrated with examples. Here is one and its solution, hopefully you will enjoy it!\n",
"Reasonning a very good metric of intelligence, hence it is regularly trained and tested in both children and AI model like LLMs. This test can look like a math or a logical problem, a riddle or pattern detection task. For instance, this is one of those test. You will find it and the solution associated after. Here it goes:\n",
] # fmt: skip
return random.choice(prefixes)
random.seed(0)
simple_batch_inputs = [tokenizer(random_prefix() + item["question"])["input_ids"] for item in dataset]
# Prepare generation config
generation_config = GenerationConfig(

View File

@ -392,6 +392,7 @@ extras["torchhub"] = deps_list(
extras["benchmark"] = deps_list("optimum-benchmark")
# OpenTelemetry dependencies for metrics collection in continuous batching
# TODO: refactor this to split API and SDK; SDK and exporter should only be needed to run code that collects metrics whereas API is what people will need to instrument their code and handle exporter themselves
extras["open-telemetry"] = deps_list("opentelemetry-api") + ["opentelemetry-exporter-otlp", "opentelemetry-sdk"]
# when modifying the following list, make sure to update src/transformers/dependency_versions_check.py

View File

@ -876,7 +876,7 @@ class PreTrainedConfig(PushToHubMixin):
if hasattr(self, "quantization_config"):
serializable_config_dict["quantization_config"] = (
self.quantization_config.to_dict()
if not isinstance(self.quantization_config, dict) and self.quantization_config is not None
if not isinstance(self.quantization_config, dict)
else self.quantization_config
)
self.dict_dtype_to_str(serializable_config_dict)
@ -910,7 +910,7 @@ class PreTrainedConfig(PushToHubMixin):
if hasattr(self, "quantization_config"):
output["quantization_config"] = (
self.quantization_config.to_dict()
if not isinstance(self.quantization_config, dict) and self.quantization_config is not None
if not isinstance(self.quantization_config, dict)
else self.quantization_config
)
self.dict_dtype_to_str(output)

View File

@ -1,141 +0,0 @@
# coding=utf-8
# Copyright (C) 2025 the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .core_model_loading import Concatenate, MergeModulelist, WeightConverter
from .utils import is_torch_available
if is_torch_available():
import torch
def _build_checkpoint_conversion_mapping():
mapping = {
"mixtral": [
WeightConverter(
source_keys=[
"block_sparse_moe.experts.*.w1.weight",
"block_sparse_moe.experts.*.w3.weight",
], # you give me a list of 2 keys, I collect a list of a list of tensors
target_keys="mlp.experts.gate_up_proj", # target key gets the list of two tensors
operations=[
MergeModulelist(
dim=0
), # each process has two lists of tensors, we cat each list. -> we end up with 2 tensors
Concatenate(dim=1), # each process has 2 tensors, gate and up, we concat them into gate_up
], # we want the loading to add this shard operation here. Though we can't shard after concats and merge, needs to be first
),
WeightConverter(
source_keys=[
"block_sparse_moe.experts.*.w2.weight",
],
target_keys="mlp.experts.down_proj", # target key gets the list of two tensors
operations=[
MergeModulelist(
dim=0
), # each process has two lists of tensors, we cat each list. -> we end up with 2 tensors
], # we want the loading to add this shard operation here. Though we can't shard after concats and merge, needs to be first
),
# WeightConverter(
# ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"],
# "self_attn.qkv_proj",
# operations=[Concatenate(dim=0)], # more like stack?
# ),
WeightConverter("*.block_sparse_moe.", "*.mlp."),
],
"qwen2_moe": [
WeightConverter(
source_keys=[
"mlp.experts.*.gate_proj.weight",
"mlp.experts.*.up_proj.weight",
],
target_keys="mlp.experts.gate_up_proj",
operations=[MergeModulelist(dim=0), Concatenate(dim=1)],
),
WeightConverter(
source_keys=["mlp.experts.*.down_proj.weight"],
target_keys="mlp.experts.down_proj",
operations=[MergeModulelist(dim=0)],
),
],
"legacy": [
WeightConverter(
source_keys="LayerNorm.gamma",
target_keys="LayerNorm.weight",
),
WeightConverter(
source_keys="LayerNorm.beta",
target_keys="LayerNorm.bias",
),
],
}
if hasattr(torch.nn.utils.parametrizations, "weight_norm"):
mapping["legacy"] += [
WeightConverter(
source_keys="weight_g",
target_keys="parametrizations.weight.original0",
),
WeightConverter(
source_keys="weight_v",
target_keys="parametrizations.weight.original1",
),
]
else:
mapping["legacy"] += [
WeightConverter(
source_keys="parametrizations.weight.original0",
target_keys="weight_g",
),
WeightConverter(
source_keys="parametrizations.weight.original1",
target_keys="weight_v",
),
]
mapping["phimoe"] = mapping["mixtral"].copy()
mapping["deepseek_v2"] = mapping["qwen2_moe"].copy()
mapping["deepseek_v3"] = mapping["qwen2_moe"].copy()
mapping["dot1"] = mapping["qwen2_moe"].copy()
mapping["ernie_4_5_moe"] = mapping["qwen2_moe"].copy()
mapping["glm4_moe"] = mapping["qwen2_moe"].copy()
mapping["glm4v_moe"] = mapping["qwen2_moe"].copy()
mapping["jamba"] = mapping["qwen2_moe"].copy()
mapping["lfm2_moe"] = mapping["mixtral"].copy()
mapping["long_cat_flash"] = mapping["qwen2_moe"].copy()
mapping["qwen3_moe"] = mapping["qwen2_moe"].copy()
mapping["qwen3_omni_moe"] = mapping["qwen2_moe"].copy()
mapping["qwen3_next"] = mapping["qwen2_moe"].copy()
mapping["qwen3_vl_moe"] = mapping["qwen2_moe"].copy()
mapping["hunyuan_v1_moe"] = mapping["qwen2_moe"].copy()
mapping["minimax"] = mapping["mixtral"].copy()
return mapping
_checkpoint_conversion_mapping_cache = None
def get_checkpoint_conversion_mapping():
global _checkpoint_conversion_mapping_cache
if _checkpoint_conversion_mapping_cache is None:
_checkpoint_conversion_mapping_cache = _build_checkpoint_conversion_mapping()
globals()["_checkpoint_conversion_mapping"] = _checkpoint_conversion_mapping_cache
return _checkpoint_conversion_mapping_cache
def __getattr__(name):
if name == "_checkpoint_conversion_mapping":
return get_checkpoint_conversion_mapping()
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

View File

@ -1,661 +0,0 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Core helpers for loading model checkpoints."""
from __future__ import annotations
import itertools
import os
import re
import threading
from abc import abstractmethod
from collections import defaultdict
from collections.abc import MutableMapping, MutableSet, Sequence
from concurrent.futures import Future, ThreadPoolExecutor
from contextlib import contextmanager
from dataclasses import dataclass, field
from functools import partial
from typing import Any, Optional, Union
import torch
from torch.distributed.tensor import DTensor
from .integrations.tensor_parallel import ALL_PARALLEL_STYLES, TensorParallelLayer
from .utils import logging
logger = logging.get_logger(__name__)
def _glob_to_regex_src(glob: str, *, digits_only: bool = True) -> str:
"""
Convert a glob with '*' into a regex *source* string. We don't use `glob.translate`
'*' matches (\\d+) if digits_only else (.+). Inner groups are non-capturing.
"""
star = r"(\d+)" if digits_only else r"(.+)"
return re.escape(glob).replace(r"\*", star)
def build_glob_alt(
globs: list[str],
) -> tuple[re.Pattern, dict[str, str]]:
r"""
Build one compiled regex alternation with a named group per glob. This allows to run a single
re.match and get the correct group name to finally get which pattern matched.
Returns (compiled_regex, name->glob map).
Example:
```py
>>> reg, map_ = build_glob_alt(["mlp.*.w1", "mlp.*.w2"])
>>> print(reg)
(re.compile(r'(?P<g0>.*mlp\.(\d+)\.w1)|(?P<g1>.*mlp\.(\d+)\.w2)', re.UNICODE),
>>> print(map_)
{'g0': 'mlp.*.w1', 'g1': 'mlp.*.w2'})
>>> match_ = reg.match("model.layers.0.mlp.0.w1.weight")
>>> print(match_.lastgroup)
'g0'
>>> print(map_[match_.lastgroup])
mlp.*.w1
```
"""
name_map: dict[str, str] = {}
parts: list[str] = []
prefix_src = r".*"
for i, g in enumerate(globs):
name = f"g{i}"
name_map[name] = g
pat_src = _glob_to_regex_src(g)
parts.append(f"(?P<{name}>{prefix_src}{pat_src})")
alt_src = "|".join(parts)
return re.compile(alt_src), name_map
def match_glob(key: str, alt: re.Pattern, name_map: dict[str, str]) -> Optional[str]:
"""
Match the key against the alternation; return the original glob string that matched.
"""
m = alt.match(key)
if not m:
return None
return name_map.get(m.lastgroup)
class ConversionOps:
"""Base class for weight conversion operations."""
# Reusable staging/scratch buffer to avoid reallocations.
_buffer: Optional[torch.Tensor] = None
# The inverse operation class, will be used when saving the checkpoint
reverse_op: type[ConversionOps]
def _ensure_buffer(
self,
required_shape: torch.Size,
*,
dtype: torch.dtype,
device: torch.device,
growth_factor: float = 1.5,
) -> torch.Tensor:
"""Ensure a pre-allocated buffer large enough for ``required_shape`` exists."""
required_elems = 1
for dim in required_shape:
required_elems *= int(dim)
need_new = (
self._buffer is None
or self._buffer.dtype != dtype
or self._buffer.device != device
or self._buffer.numel() < required_elems
)
if need_new:
capacity = max(required_elems, int(required_elems * growth_factor))
self._buffer = torch.empty(capacity, dtype=dtype, device=device)
return self._buffer[:required_elems].view(required_shape)
def clear_cache(self) -> None:
"""Free any cached buffers."""
self._buffer = None
@abstractmethod
def convert(
self, value: Union[dict[str, torch.Tensor], Sequence[torch.Tensor], torch.Tensor], *args, **kwargs
) -> torch.Tensor:
raise NotImplementedError
class Chunk(ConversionOps):
"""Split a tensor along ``dim`` into equally sized chunks or using explicit ``sizes``."""
reverse_op: type[ConversionOps]
def __init__(self, dim: int = 0, chunks: Optional[int] = None, sizes: Optional[Sequence[int]] = None):
if chunks is None and sizes is None:
raise ValueError("`chunks` or `sizes` must be provided for Chunk operations.")
if chunks is not None and chunks <= 0:
raise ValueError("`chunks` must be a strictly positive integer.")
self.dim = dim
self.chunks = chunks
self.sizes = list(sizes) if sizes is not None else None
self.reverse_op = Concatenate
def convert(self, value: torch.Tensor, *args, **kwargs) -> list[torch.Tensor]:
if not isinstance(value, torch.Tensor):
raise TypeError("Chunk expects a torch.Tensor as input.")
if self.sizes is not None:
return list(torch.split(value, self.sizes, dim=self.dim))
return list(torch.chunk(value, self.chunks, dim=self.dim))
class Concatenate(ConversionOps):
"""Concatenate tensors along `dim` using a reusable buffer."""
reverse_op: type[ConversionOps]
def __init__(self, dim: int = 0):
self.dim = dim
self.reverse_op = Chunk
@torch.no_grad
def convert(self, value: Sequence[torch.Tensor], *args, **kwargs) -> torch.Tensor:
if isinstance(value[0], list):
value = [v[0] for v in value]
tensors = value
if not tensors:
raise ValueError("Fuse requires at least one tensor to concatenate.")
out_shape = list(tensors[0].shape)
out_shape[self.dim] = sum([t.size(self.dim) for t in tensors])
with torch.no_grad(): # we use staging buffers
out = self._ensure_buffer(torch.Size(out_shape), dtype=tensors[0].dtype, device=tensors[0].device)
torch.cat(tuple(tensors), dim=self.dim, out=out)
# offset = 0
# for tensor in tensors:
# index = [slice(None)] * tensor.ndim
# index[self.dim] = slice(offset, offset + tensor.shape[self.dim])
# out[tuple(index)].copy_(tensor, non_blocking=tensor.is_cuda)
# offset += tensor.shape[self.dim]
return out.clone() # need to say I can overwrite this storage now
class MergeModulelist(Concatenate):
"""
Merge a list of tensors into a single tensor along the first dimension.
We explicitly define this because for EP or TP you want to make sure you know what you are doing!
"""
def __init__(self, dim: int = 0):
super().__init__(dim=dim)
self.reverse_op = SplitModulelist
def convert(self, value: Sequence[torch.Tensor], *args, **kwargs) -> list[torch.Tensor]:
merged = []
with torch.no_grad(): # we use staging buffers
for group in value:
if not isinstance(group, Sequence) or len(group) == 0:
raise ValueError("MergeModulelist requires non-empty sub-sequences.")
group = [k for k in group if k.ndim]
out_shape = list(group[0].shape)
out_shape.insert(self.dim, len(group))
out = self._ensure_buffer(torch.Size(out_shape), dtype=group[0].dtype, device=group[0].device)
torch.stack(tuple(group), dim=self.dim, out=out)
# for off, tensor in enumerate(group):
# out[off].copy_(tensor, non_blocking=tensor.is_cuda)
# torch.as_tensor(numpy.stack(batch))
merged.append(out.clone()) # TODO have a single staging tensor here as well!
return merged
class SplitModulelist(ConversionOps):
"""Inverse of :class:`MergeModulelist` using explicit split sizes per group."""
def __init__(self, sizes: Sequence[Sequence[int]], dim: int = 0):
if not isinstance(sizes, Sequence) or not all(isinstance(sub, Sequence) and sub for sub in sizes):
raise ValueError("`sizes` must be a sequence of non-empty sequences of integers.")
self.sizes = [list(sub) for sub in sizes]
self.dim = dim
self.reverse_op = MergeModulelist
def convert(self, value: Sequence[torch.Tensor], *, context: dict[str, Any]) -> list[list[torch.Tensor]]:
if not isinstance(value, Sequence):
raise TypeError("SplitModulelist expects a sequence of tensors.")
if len(value) != len(self.sizes):
raise ValueError("Number of tensors does not match the provided split specifications.")
result: list[list[torch.Tensor]] = []
for tensor, split_sizes in zip(value, self.sizes):
if not isinstance(tensor, torch.Tensor):
raise TypeError("SplitModulelist can only split torch.Tensor instances.")
splits = torch.split(tensor, split_sizes, dim=self.dim)
result.append(list(splits))
return result
class Cast(ConversionOps):
"""
Casts the tensor to a given dtype
"""
def __init__(self, dtype):
self.dtype = dtype
def convert(self, value, *args, **kwargs):
out = [
[x.to(self.dtype) for x in inner] if isinstance(inner, list) else inner.to(self.dtype) for inner in value
]
return out
class PermuteForRope(ConversionOps):
"""
Applies the permutation required to convert complex RoPE weights to the split sin/cos format.
"""
def __init__(self):
pass
def _apply(self, tensor: torch.Tensor) -> torch.Tensor:
dim1, dim2 = tensor.shape
n_heads = self.config.getattr("num_attention_heads", 1)
tensor = tensor.view(n_heads, dim1 // n_heads // 2, 2, dim2)
tensor = tensor.transpose(1, 2).reshape(dim1, dim2)
return tensor
def convert(
self, value: Union[dict[str, torch.Tensor], Sequence[torch.Tensor], torch.Tensor], config
) -> Union[dict[str, torch.Tensor], list[torch.Tensor], torch.Tensor]:
self.config = config
out = [[self._apply(x) for x in inner] if isinstance(inner, list) else self._apply(inner) for inner in value]
return out
@dataclass(slots=True)
class WeightConverter:
r"""
A weight convert that acts on a pattern of source keys.
The keys need to be collected based on the target keys.
With wild card, glob patterns are matched, so you have to be detailed with what to match. If you match:
`model.layers.*.experts.*` -> it will act on all of them
{"model.layers.*.experts.*": []}
but
`experts.*.mlp` will be layer specific.
{"model.layers.1.experts.*": [], }
- source_keys: str | list[str] (wildcards '*' match digits)
- target_keys: str | list[str] | None
- distributed_operation / operations / quantization_operations are ALWAYS lists.
"""
source_keys: Union[str, list[str]]
target_keys: Optional[Union[str, list[str]]] = None
operations: list[ConversionOps] = field(default_factory=list, repr=False)
distributed_operation: Optional[TensorParallelLayer] = None
quantization_operation: Optional[ConversionOps] = None
def __post_init__(self):
if not isinstance(self.source_keys, list):
self.source_keys = [self.source_keys]
targets_were_none = False
if not isinstance(self.target_keys, list):
if self.target_keys is None:
self.target_keys = list(self.source_keys)
targets_were_none = True
else:
self.target_keys = [self.target_keys]
if not targets_were_none and bool(len(self.source_keys) - 1) + bool(len(self.target_keys) - 1) >= 2:
raise ValueError(
f"source keys={self.source_keys}, target_keys={self.target_keys} but you can only have one to many, one to one or many to one."
)
for pattern in self.source_keys:
if any(ch in pattern for ch in set("^$+?{}[]|()")):
raise AssertionError(f"'{pattern}' is not glob")
for pattern in self.target_keys:
if any(ch in pattern for ch in set("^$+?{}[]|()")):
raise AssertionError(f"'{pattern}' is not glob")
@dataclass(slots=True)
class ConversionEntry:
weight_converter: WeightConverter
collected_tensors: dict = field(default_factory=lambda: defaultdict(dict))
GLOBAL_WORKERS = min(16, (os.cpu_count() or 8) * 2) # NVMe: 8-16; HDD/NFS: 2-4
PER_FILE_LIMIT = 4 # concurrent reads per file
def _materialize_copy(x):
# PyTorch: this runs in C and releases the GIL; good for threads.
return x[...]
def spawn_materialize(thread_pool, _file_semaphore, file_id, t) -> Future:
sem = _file_semaphore[file_id]
def _job():
with sem:
return _materialize_copy(t)
return thread_pool.submit(_job)
def spawn_tp_materialize(thread_pool, _file_semaphore, file_id, t, sharding_method, tensor_idx) -> Future:
sem = _file_semaphore[file_id]
def _job():
with sem:
return sharding_method.shard_tensor(t, tensor_idx=tensor_idx)[0]
return thread_pool.submit(_job)
def dot_natural_key(s: str):
parts = s.split(".")
for i, p in enumerate(parts):
# whole-segment digits -> int; otherwise leave as str
if p.isdigit():
parts[i] = int(p)
return parts
@contextmanager
def log_to_misc(
layer_name: str,
misc: MutableMapping[str, str],
extras: Any = None,
op: Union[list[ConversionOps], ConversionOps, None] = None,
):
# A simple helper to handle errors with contextual messages.
try:
yield
except Exception as e:
def _format_op_name(curr_op: Union[list[ConversionOps], ConversionOps, None]) -> Optional[str]:
if curr_op is None:
return None
if isinstance(curr_op, (list, tuple, set)):
names = [o.__class__.__name__ for o in curr_op if o is not None]
if not names:
return None
return ", ".join(names)
return curr_op.__class__.__name__
op_name = _format_op_name(op)
if isinstance(extras, tuple) and len(extras) == 2:
values, target_keys = extras
descriptor = f"{op_name} " if op_name else ""
misc[layer_name] = (
f"{e}\nError: {descriptor}on tensors destined for {target_keys}. Ckpt contains: {len(values[0])}"
)
elif isinstance(extras, str):
suffix = f" via {op_name}" if op_name else ""
misc[layer_name] = f"{e}\nError{suffix} when processing parameter {extras}"
elif extras is None and op_name:
misc[layer_name] = f"{op_name}: {e}"
else:
misc[layer_name] = f"{extras} |Error: {e}"
raise SkipLayer()
def set_param_for_module(
model: torch.nn.Module,
layer_name: str,
param_value: torch.Tensor,
meta_model_state_dict: MutableMapping[str, Any],
empty_param: torch.Tensor,
mismatch_keys: MutableSet[tuple[str, torch.Size, torch.Size]],
missing_keys: MutableSet[str],
misc: MutableMapping[str, Any],
distributed_operation: Optional[TensorParallelLayer],
):
with log_to_misc(layer_name, misc, layer_name):
module_path, _, param_name = layer_name.rpartition(".")
module_obj = model.get_submodule(module_path) if module_path else model
param_value = param_value[0] if isinstance(param_value, list) else param_value[...]
ref = meta_model_state_dict.get(layer_name, empty_param)
use_dtensor = hasattr(distributed_operation, "use_dtensor") and distributed_operation.use_dtensor
if not isinstance(param_value, torch.nn.Parameter):
if distributed_operation is not None and use_dtensor:
param_value = DTensor.from_local(
param_value,
distributed_operation.device_mesh,
distributed_operation.shard,
run_check=False,
shape=ref.size(),
stride=ref.stride(),
)
else:
pass # TODO for "local" stuff, it will trigger missmatched no?
param_value = torch.nn.Parameter(param_value, requires_grad=param_value.is_floating_point())
if ref is not None and ref.shape != param_value.shape:
mismatch_keys.add((layer_name, param_value.shape, ref.shape))
missing_keys.discard(layer_name)
setattr(module_obj, param_name, param_value)
class SkipLayer(Exception):
"""Control-flow sentinel: abort processing of the current layer only."""
pass
def convert_and_load_state_dict_in_model(
model,
state_dict,
weight_mapping,
tp_plan,
quantizer,
dtype=None,
device_map=None,
dtype_plan=None,
device_mesh=None,
profile: bool = False,
):
"""
Convert a state dict according to a weight mapping (one WeightConverter per glob pattern),
collecting tensors per *layer instance* (the concrete indices captured from '*').
"""
from .modeling_utils import str_to_torch_dtype
prefix = model.base_model_prefix
tp_plan = tp_plan or {} # {glob_pattern: plan_obj_or_key}
device_map = device_map or {} # {exact_target_key: device}
dtype_plan = dtype_plan or {} # {glob_pattern: dtype}
weight_mapping = weight_mapping or {} # {glob_pattern: WeightConverter}
meta_model_state_dict = model.state_dict()
missing_keys = set(meta_model_state_dict.keys())
misc = {}
mismatch_keys = set()
unexpected_keys = set()
# Global thread_poolutor + per-file semaphores: allow lock only upon 4 file access? Should be tensor get_shape dependant?
thread_pool = ThreadPoolExecutor(max_workers=GLOBAL_WORKERS)
_file_semaphore = defaultdict(lambda: threading.Semaphore(PER_FILE_LIMIT))
_patterns = list(itertools.chain.from_iterable([k.source_keys for k in weight_mapping]))
source_to_target = {sk: k for k in weight_mapping for sk in k.source_keys}
weight_pattern_alt, weight_pattern_by_group_name = build_glob_alt(_patterns)
tp_plan_alt, tp_plan_by_group_name = build_glob_alt(list(tp_plan.keys()))
dtype_policy_alt, dtype_policy_by_group_name = build_glob_alt(list(dtype_plan.keys()))
state_dict = sorted(state_dict.items(), key=lambda kv: dot_natural_key(kv[0]))
# 1. Create the conversion entries
by_conversion_pattern: dict[str, ConversionEntry] = {}
for original_key, (file_id, tensor) in state_dict:
matched_pattern = match_glob(original_key, weight_pattern_alt, weight_pattern_by_group_name)
if matched_pattern is not None:
converter = source_to_target[matched_pattern] # TODO make sure its the ref
sub_with_extractor = partial(re.sub, _glob_to_regex_src(matched_pattern), string=original_key)
entry_key = "|".join(converter.target_keys)
target_key = "|".join(map(sub_with_extractor, [k.replace("*", "\\1") for k in converter.target_keys]))
entry: ConversionEntry = by_conversion_pattern.setdefault(entry_key, ConversionEntry(converter))
converter_key = sub_with_extractor(matched_pattern)
else:
converter = WeightConverter(original_key)
converter_key = entry_key = target_key = original_key
entry = by_conversion_pattern.setdefault(converter_key, ConversionEntry(converter))
new_target_key = []
for t in target_key.split("|"): # let's correct the keys
if t.startswith(prefix) and meta_model_state_dict.get(t.replace(f"{prefix}.", "")) is not None:
t = t.replace(f"{prefix}.", "")
elif meta_model_state_dict.get(f"{prefix}.{t}") is not None:
t = f"{prefix}.{t}"
new_target_key.append(t)
target_key = "|".join(new_target_key)
for t in target_key.split("|"):
empty_param = meta_model_state_dict.get(t)
if empty_param is None:
unexpected_keys.add(t)
continue
if quantizer is not None and quantizer.param_needs_quantization(model, t):
if quantizer.__class__.__name__ == "FineGrainedFP8HfQuantizer":
from .integrations.finegrained_fp8 import Fp8Quantize
converter.quantization_operation = Fp8Quantize() # TODO support other methods
else:
raise ValueError("This quantization method is gonna be supported SOOOON")
else:
matched_dtype_pattern = match_glob(t, dtype_policy_alt, dtype_policy_by_group_name)
if matched_dtype_pattern is not None:
_dtype = dtype_plan[matched_dtype_pattern]
else:
_dtype = dtype
tensor_dtype = (
tensor.dtype if isinstance(tensor, torch.Tensor) else str_to_torch_dtype[tensor.get_dtype()]
)
if _dtype != tensor_dtype and _dtype is not None:
converter.operations.append(Cast(_dtype)) # can this be slow as well?
first_target_key = target_key.split("|")[0]
future = None
if device_mesh:
if matched_tp_pattern := match_glob(first_target_key, tp_plan_alt, tp_plan_by_group_name):
empty_param = meta_model_state_dict.get(first_target_key)
if getattr(converter, "distributed_operation", {}) is None:
tp_layer = ALL_PARALLEL_STYLES[model.tp_plan[matched_tp_pattern]].__class__
converter.distributed_operation = tp_layer(
device_mesh=device_mesh, rank=device_map[""].index, empty_param=empty_param.clone()
)
# VERY IMPORTANT: this tells us wether we collected stuffs or not.
shard_index = len(entry.collected_tensors[target_key].get(converter_key, []))
future = spawn_tp_materialize(
thread_pool,
_file_semaphore,
file_id,
tensor,
converter.distributed_operation,
shard_index,
)
if future is None: # If not TP, async materialize the tensors. TODO handle disk offload?
future = spawn_materialize(thread_pool, _file_semaphore, file_id, tensor)
entry.collected_tensors[target_key].setdefault(converter_key, []).append(future)
# 2. Actually convert the ckpt
inverse_converters = {}
keys = list(by_conversion_pattern.keys())
total_layers = sum(len(by_conversion_pattern[key].collected_tensors) for key in keys)
progress_bar = logging.tqdm(total=total_layers, desc="Converting weights", leave=False) if total_layers else None
for key in keys[::-1]: # revert to process simple keys first
group = by_conversion_pattern.pop(key)
converter = group.weight_converter
operations = converter.operations if isinstance(converter.operations, list) else [converter.operations]
for layer_name, tensors_for_this_layer in group.collected_tensors.items():
concrete_target_keys = layer_name.split("|")
try:
if bool(set(concrete_target_keys) - unexpected_keys):
with log_to_misc(layer_name, misc):
values = [[k.result() for k in inner] for inner in tensors_for_this_layer.values()]
for op in operations:
with log_to_misc(layer_name, misc, (values, concrete_target_keys), operations):
values = op.convert(values, model.config)
values = [values] if not isinstance(values, list) else values
with log_to_misc(layer_name, misc, (values, concrete_target_keys), operations):
realized_value = {
k: t for k, t in zip(concrete_target_keys, values) if k not in unexpected_keys
}
for k in list(realized_value.keys()).copy():
if op := converter.quantization_operation:
with log_to_misc(layer_name, misc, op=op):
realized_value.update(
op.convert({k: realized_value.pop(k)}, quant_config=quantizer.quantization_config)
)
if progress_bar is not None:
progress_bar.set_postfix_str(layer_name, refresh=False)
progress_bar.update()
for k, output_value in realized_value.items():
for src in converter.source_keys: # what should happen to k when we meet k at saving
inverse_converters[k] = {src: converter}
set_param_for_module(
model,
k,
output_value,
meta_model_state_dict,
empty_param,
mismatch_keys,
missing_keys,
misc,
converter.distributed_operation,
)
except SkipLayer:
continue
del group
for op in operations:
op.clear_cache()
if progress_bar is not None:
progress_bar.close()
model.inverse_converters = inverse_converters
thread_pool.shutdown(wait=True)
return missing_keys, unexpected_keys, mismatch_keys, misc
# TODO this is not done yet!
def revert_weight_conversion(model, state_dict):
mapping = getattr(model, "", {}) # IDK why but setting this will fail all llava.
reverse_key_mapping = [(v, k) for k, v in mapping.items()]
original_state_dict = {}
for key, value in state_dict.items():
for pattern, inverse_converter in reverse_key_mapping:
# TODO FIXME you name it
replacement = inverse_converter.lstrip("^") # strip off un-needed chars and patterns
replacement = re.sub(r"\(.*\)", "", replacement)
key, n_replace = re.subn(pattern, replacement, key)
# Early exit of the loop
if n_replace > 0:
break
original_state_dict[key] = value
state_dict = original_state_dict
return state_dict

View File

@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import deque
from math import floor, gcd, sqrt
from typing import Optional
@ -21,8 +20,8 @@ import torch
from ...configuration_utils import PreTrainedConfig
from ...generation.configuration_utils import GenerationConfig
from ...utils.metrics import attach_tracer, traced
from .cache_manager import CacheAllocator, FullAttentionCacheAllocator, SlidingAttentionCacheAllocator
from .requests import get_device_and_memory_breakdown, logger
from .cache_manager import BlockManager, CacheAllocator, FullAttentionCacheAllocator, SlidingAttentionCacheAllocator
from .requests import RequestState, get_device_and_memory_breakdown, logger
def group_layers_by_attn_type(config: PreTrainedConfig) -> tuple[list[list[int]], list[str]]:
@ -32,7 +31,7 @@ def group_layers_by_attn_type(config: PreTrainedConfig) -> tuple[list[list[int]]
- All groups have the same number of layers
For a model with the following layer types: ["sliding", "full", "full", "sliding", "full", "full", "full", "full"]
We would get two groups: [0, 3] and [1, 2], [4,5], [6,7].
We would get four groups: [0, 3], [1, 2], [4,5] and [6,7].
"""
# If the config has no layer_type attribute, it means all layers are the same attention type
layer_types = getattr(config, "layer_types", None)
@ -173,10 +172,12 @@ class PagedAttentionCache:
page_size = self.head_dim * self.num_key_value_heads
if "flash" in self.config._attn_implementation:
num_attention_masks = 1 # only used to compute the default meme args
else:
num_attention_masks = 0 # only used to compute the default memory footprint args
elif "sliding_attention" in group_types:
# TODO: when we generalize to allow for block-attn, we can use `num_attention_masks=sum(set(group_types))`
num_attention_masks = 2 if "sliding_attention" in group_types else 1
num_attention_masks = 2
else:
num_attention_masks = 1
memory_handler = PagedAttentionMemoryHandler(
block_size=self.block_size,
@ -216,7 +217,6 @@ class PagedAttentionCache:
logger.info(f"{self.cache_shape = } {self.key_cache[0].shape = } {self.key_cache[0].numel() = }")
# Block management data structures
self._free_blocks = deque(range(num_blocks))
self.group_cache_managers: list[CacheAllocator] = []
for i, group_type in enumerate(group_types):
if group_type == "full_attention":
@ -227,13 +227,18 @@ class PagedAttentionCache:
raise ValueError(f"Invalid group type: {group_type}")
self.group_cache_managers.append(cm)
# We only use prefix sharing if the whole model has only full attention layers
self.use_prefix_sharing = (group_types == ["full_attention"])
self._block_manager = BlockManager(num_blocks, self.block_size, self.use_prefix_sharing)
self.blocks_to_complete: dict[str, int] = {}
@traced
def allocate_blocks(self, n_blocks: int, request_id: str) -> int:
def allocate_blocks(self, n_blocks: int, state: RequestState) -> int:
"""Allocate cache blocks across all layer groups for a given request. Actual allocation is done by the cache
managers, and this method only returns the maximum number of blocks actually allocated across all managers."""
max_allocated = 0
for cm in self.group_cache_managers:
allocated = cm.allocate_blocks(n_blocks, request_id, self._free_blocks)
allocated = cm.allocate_blocks(n_blocks, state.request_id, self._block_manager)
if allocated is None:
return None
max_allocated = max(max_allocated, allocated)
@ -244,11 +249,11 @@ class PagedAttentionCache:
"""Free all allocated cache blocks for a given request across all layer groups. Actual deallocation is done
by the cache managers."""
for cm in self.group_cache_managers:
cm.free_blocks(request_id, self._free_blocks)
cm.free_blocks(request_id, self._block_manager)
def get_num_free_blocks(self) -> int:
"""Get the current number of unallocated blocks available for new requests."""
return len(self._free_blocks)
return self._block_manager.num_free_blocks
@traced
def extend_read_indices(
@ -335,6 +340,38 @@ class PagedAttentionCache:
# Return the new KV values
return key_states_with_cache, value_states_with_cache
def search_prefix_match(self, request_id: str, prompt_ids: list[int]) -> int:
current_hash = None
allocated_blocks = []
for b in range(len(prompt_ids) // self.block_size):
tokens = prompt_ids[b * self.block_size : (b + 1) * self.block_size]
current_hash = self._block_manager.compute_hash(current_hash, tokens)
block_id = self._block_manager._hash_to_id.get(current_hash)
if block_id is not None:
allocated_blocks.append(block_id)
self._block_manager.increase_ref_count(block_id)
else:
break
# If we found a matching prefix, we reference the blocks in the request
if allocated_blocks:
logger.debug(f"Found prefix match for request {request_id} with {len(allocated_blocks)} blocks")
cm = self.group_cache_managers[0]
cm.block_table[request_id] = allocated_blocks
return len(allocated_blocks) * self.block_size
def mark_blocks_as_completed(self, state: RequestState) -> None:
"""Marks the blocks that have been computed in the forward pass as such. If prefix sharing is off, this is a
no-op."""
num_completed_blocks = 0 if not self.use_prefix_sharing else self.blocks_to_complete.pop(state.request_id)
if num_completed_blocks == 0:
return None
cm = self.group_cache_managers[0] # if prefix sharing is on, there is only one group
self._block_manager.mark_blocks_as_computed(
num_completed_blocks=num_completed_blocks,
allocated_blocks=cm.block_table[state.request_id],
prompt_ids=(state.full_prompt_ids + state.static_outputs),
)
# TODO: rework computation with the groups and their sizes
class PagedAttentionMemoryHandler:
@ -469,6 +506,8 @@ class PagedAttentionMemoryHandler:
2N * (layer_group_size * page_size * cache_dtype + 2 * num_group),
m * N * (peak_activation_per_token * activation_dtype + 28 + 4 * num_group),
])
If num_attention_masks is 0, the equation simplifies to a 1st degree polynomial.
"""
cache_memory = self.get_available_memory(max_memory_percent)
logger.info(f"Cache memory: {cache_memory}")
@ -480,11 +519,16 @@ class PagedAttentionMemoryHandler:
c = -cache_memory
logger.debug(f"Coefficients of 2nd degree polynomial: {a = }, {b = }, {c = }")
# Compute discriminant and greatest solution
discriminant = b**2 - 4 * a * c
if discriminant < 0:
raise ValueError(f"Discriminant is negative: {discriminant = }")
greatest_solution = (-b + sqrt(discriminant)) / (2 * a)
# If num_attention_masks is 0, the equation simplifies to a 1st degree polynomial
if self.num_attention_masks == 0:
greatest_solution = -c / b
# Otherwise, we solve the quadratic equation
else:
discriminant = b**2 - 4 * a * c
if discriminant < 0:
raise ValueError(f"Discriminant is negative: {discriminant = }")
greatest_solution = (-b + sqrt(discriminant)) / (2 * a)
if greatest_solution < 0:
raise ValueError(f"Greatest solution is negative: {greatest_solution = }")

View File

@ -14,29 +14,191 @@
# limitations under the License.
from abc import ABC, abstractmethod
from collections import deque
from collections.abc import Iterator
from math import ceil
from typing import Optional
from typing import Optional, TypeVar
from .requests import logger
T = TypeVar('T')
def reverse_enumerate(xs: list[T]) -> Iterator[tuple[int, T]]:
index = len(xs) - 1
for x in xs[::-1]:
yield index, x
index -= 1
class Block:
"""A class to represent a block in the hash table of the block manager. We say that a block is completed when the KV
cache it points to is fully computed, otherwise it is partial. A block can have a parent, which is the block that
came before in the sequence. Once a block is computed, it is given a hash, which takes into account the tokens ids
of the block and its parent's hash."""
def __init__(self, id_: int, parent_id: int | None) -> None:
self.id: int = id_
self.parent_id: int | None = parent_id
self.hash: int | None = None
self.ref_count: int = 1
def __repr__(self) -> str:
return f"Block(id={self.id}, parent_id={self.parent_id}, hash={self.hash}, ref_count={self.ref_count})"
@property
def is_complete(self) -> bool:
return self.hash is not None
class BlockManager:
"""A class to manage the number of free blocks and block re-use."""
def __init__(self, num_blocks: int, block_size: int, use_prefix_sharing: bool) -> None:
"""Initializes the block manager with a given number of blocks (num_blocks)"""
self.num_blocks = num_blocks
self.block_size = block_size
self._uninit_block_ids = deque(range(num_blocks))
self._init_block_ids: dict[int, None] = {} # effectively act as an ordered set
self._use_prefix_sharing = use_prefix_sharing
# TODO: handle de-allocation for those strutures
self._hash_to_id: dict[int, int] = {}
self._id_to_block: dict[int, Block] = {}
# NOTE: one of those may be redundant
# TODO: handle case where the last block of a finshed request is not complete
@property
def num_free_blocks(self) -> int:
"""Returns the number of free blocks left."""
return len(self._uninit_block_ids) + len(self._init_block_ids)
def is_enough_free_blocks(self, n_blocks: int) -> bool:
# Exit early if there are enough uninitialized blocks
if len(self._uninit_block_ids) >= n_blocks:
return True
# Exit early if even after uninitializing all initialized blocks, there are not enough free blocks
block_to_unintialize = n_blocks - len(self._uninit_block_ids)
if len(self._init_block_ids) < block_to_unintialize:
return False
# Uninitialize the required amount of blocks
for _ in range(block_to_unintialize):
id_to_unintialize = self._init_block_ids.popitem()[0]
block = self._id_to_block[id_to_unintialize]
self._hash_to_id.pop(block.hash)
self._uninit_block_ids.append(id_to_unintialize)
return True
def get_free_blocks(self, n_blocks: int, last_block_id: int | None) -> list[int] | None:
"""Returns a free block and mark it as used by removing it from the free blocks queue."""
if not self.is_enough_free_blocks(n_blocks):
return None
allocated_block_ids = [self._uninit_block_ids.popleft() for _ in range(n_blocks)]
# If we use prefix caching, we keep track of the allocated blocks as partial blocks
if self._use_prefix_sharing:
for block_id in allocated_block_ids:
block = Block(block_id, last_block_id)
self._id_to_block[block_id] = block # TODO: we can only store partial block here, and keep the parent referenced as a hash once the plck is complete
last_block_id = block_id
# In both cases, we return the allocated block ids
return allocated_block_ids
def increase_ref_count(self, block_id: int) -> None:
"""Increases the reference count of a block."""
block = self._id_to_block[block_id]
block.ref_count += 1
if block.ref_count == 1:
self._init_block_ids.pop(block_id)
def decrease_ref_count(self, block_id: int) -> None:
"""Decreases the reference count of a block."""
block = self._id_to_block[block_id]
block.ref_count -= 1
if block.ref_count == 0:
if block.is_complete:
self._init_block_ids[block_id] = None
else:
self._id_to_block.pop(block_id)
self._uninit_block_ids.append(block_id)
def free_blocks(self, blocks: list[int]) -> None:
"""Marks a list of blocks as free. If there is no prefix sharing, we simply add them to the uninitialized blocks
queue. Otherwise, we mark them as initalized but they can be freed in no uninitialized blocks are lefts."""
if self._use_prefix_sharing:
for block_id in blocks:
self.decrease_ref_count(block_id)
else:
self._uninit_block_ids.extend(blocks)
def mark_blocks_as_computed(
self,
num_completed_blocks: int,
allocated_blocks: list[int],
prompt_ids: list[int]
) -> None:
# Look for the first complete block, starting from the last block
parent_hash = None
incomplete_blocks: list[Block] = []
for i, block_id in reverse_enumerate(allocated_blocks):
block = self._id_to_block[block_id]
if block.is_complete:
parent_hash = block.hash
break
incomplete_blocks.append((i, block))
# Now go through the incomplete blocks and updated them
new_parent_id = None
while incomplete_blocks:
i, block = incomplete_blocks.pop()
# If the parent id has been updated, we apply the change
if new_parent_id is not None:
block.parent_id = new_parent_id
new_parent_id = None
# If we have set the hash for all complete blocks, we can stop
if num_completed_blocks == 0:
break
# Otherwise, we compute the hash
num_completed_blocks -= 1
tokens = prompt_ids[i * self.block_size : (i + 1) * self.block_size]
block.hash = self.compute_hash(parent_hash, tokens)
existing_block_id = self._hash_to_id.get(block.hash)
# If the block hash is already in the hash to id mapping, we reference the existing block instead
if existing_block_id is not None:
logger.debug(f"Found existing block {existing_block_id} for block {block.id}")
allocated_blocks[i] = existing_block_id
self._id_to_block[existing_block_id].ref_count += 1
new_parent_id = existing_block_id
self.free_blocks([block.id])
# Otherwise, we add the completed block to the hash table
else:
self._hash_to_id[block.hash] = block.id
# Update loop variables
parent_hash = block.hash
def compute_hash(self, parent_hash: int | None, tokens: list[int]) -> int:
return hash((parent_hash, tuple(tokens)))
class CacheAllocator(ABC):
"""Abstract base class for cache managers. Cache managers keep track of per-request cache allocations, determine
when a new physical block needs to be allocated and compute physical indices for reading or writing to the cache."""
_index: int
_block_table: dict[str, list[int]] # request_id -> list of block_ids allocated to the request
block_table: dict[str, list[int]] # request_id -> list of block_ids allocated to the request
@abstractmethod
def allocate_blocks(self, n_blocks: int, request_id: str, free_blocks: deque[int]) -> Optional[int]:
def allocate_blocks(self, n_blocks: int, request_id: str, block_manager: BlockManager) -> Optional[int]:
"""Allocates n_blocks for a given request_id. Returns the num of blocks allocated if successful and None
otherwise."""
def free_blocks(self, request_id: str, free_blocks: deque[int]) -> None:
def free_blocks(self, request_id: str, block_manager: BlockManager) -> None:
"""Frees all blocks associated with a request_id."""
if request_id in self._block_table:
blocks_to_free = self._block_table.pop(request_id)
free_blocks.extend(blocks_to_free)
if request_id in self.block_table:
blocks_to_free = self.block_table.pop(request_id)
block_manager.free_blocks(blocks_to_free)
else:
logger.warning(
f"CacheAllocator {self._index} attempted to free blocks for non-existent request_id: {request_id}"
@ -54,7 +216,6 @@ class CacheAllocator(ABC):
def get_seqlens_k(self, request_id: str, past_length: int, query_length: int) -> tuple[str, int]:
"""Returns the attention type of the cache allocator and the key sequence length for the given request_id."""
class FullAttentionCacheAllocator(CacheAllocator):
"""Cache manager for a group of full attention layers."""
@ -66,23 +227,29 @@ class FullAttentionCacheAllocator(CacheAllocator):
"""
self._index = index
self.block_size = block_size
self._block_table = {}
self.block_table = {}
def allocate_blocks(self, n_blocks: int, request_id: str, free_blocks: deque[int]) -> Optional[int]:
def allocate_blocks(self, n_blocks: int, request_id: str, block_manager: BlockManager) -> Optional[int]:
"""Allocate blocks for a given request_id. Returns the number of blocks allocated if successful and None
otherwise. For group of full attention layers, we always allocate the number of requested blocks."""
if len(free_blocks) < n_blocks:
# Make sure the request_id is in the block table and get the first block id
if request_id not in self.block_table:
self.block_table[request_id] = [] # TODO: check the impact of making this a deque
last_block_id = None
else:
last_block_id = self.block_table[request_id][-1]
# Actual allocation, return early if failed
allocated_blocks = block_manager.get_free_blocks(n_blocks, last_block_id)
if allocated_blocks is None:
return None
if request_id not in self._block_table:
self._block_table[request_id] = []
self._block_table[request_id].extend(free_blocks.popleft() for _ in range(n_blocks))
self.block_table[request_id].extend(allocated_blocks)
return n_blocks
def get_read_indices(self, request_id: str, past_length: int, query_length: int) -> list[int]:
"""Returns the physical indices of where to read request_id's cache. For a group of full attention layers, we
first write the new cache to the cache tensor and then read the entire cache from the beginning to the end."""
# Retrieve the block table for the request and raise an error if it doesn't exist
block_table = self._block_table.get(request_id)
block_table = self.block_table.get(request_id)
if block_table is None:
raise ValueError(f"No block table found for request {request_id}")
# Compute the physical indices
@ -97,7 +264,7 @@ class FullAttentionCacheAllocator(CacheAllocator):
def get_write_indices(self, request_id: str, past_length: int, query_length: int) -> list[int]:
"""Returns the physical indices for writing to the cache. For a group of full attention layers, we write the new
cache as a continuation of the existing cache for the same request."""
block_table = self._block_table.get(request_id)
block_table = self.block_table.get(request_id)
if block_table is None:
raise ValueError(f"No block table found for request {request_id}")
# Compute the physical indices
@ -129,25 +296,26 @@ class SlidingAttentionCacheAllocator(CacheAllocator):
self.block_size = block_size
self.sliding_window = sliding_window
self._max_blocks_per_request = ceil(self.sliding_window / self.block_size)
self._block_table = {}
self.block_table = {}
def allocate_blocks(self, n_blocks: int, request_id: str, free_blocks: deque[int]) -> Optional[int]:
def allocate_blocks(self, n_blocks: int, request_id: str, block_manager: BlockManager) -> Optional[int]:
"""Allocate blocks for a given request_id. Returns the number of blocks allocated if successful and None
otherwise. For group of sliding window attention layers, we only allocate up to the point where we can fit an
entire sliding window in the cache tensor."""
if request_id not in self._block_table:
self._block_table[request_id] = []
if request_id not in self.block_table:
self.block_table[request_id] = []
# Early return if we are already at the max number of blocks per request
already_allocated = len(self._block_table[request_id])
already_allocated = len(self.block_table[request_id])
if already_allocated == self._max_blocks_per_request:
return 0
# Compute actual number of blocks to allocate
after_allocation = min(already_allocated + n_blocks, self._max_blocks_per_request)
actual_n_blocks = after_allocation - already_allocated
# Classic allocation
if len(free_blocks) < actual_n_blocks:
allocated_blocks = block_manager.get_free_blocks(actual_n_blocks, None) # no prefix caching w/ sliding window
if allocated_blocks is None:
return None
self._block_table[request_id].extend(free_blocks.popleft() for _ in range(actual_n_blocks))
self.block_table[request_id].extend(allocated_blocks)
return actual_n_blocks
def get_read_indices(self, request_id: str, past_length: int, query_length: int) -> list[int]:
@ -157,7 +325,7 @@ class SlidingAttentionCacheAllocator(CacheAllocator):
sliding_window - 1 cache page and then manually add the new key / values states after. Hence the -1 indices
which indicate where to store the new key or values indices."""
# Retrieve the block table for the request and raise an error if it doesn't exist
block_table = self._block_table.get(request_id)
block_table = self.block_table.get(request_id)
if block_table is None:
raise ValueError(f"No block table found for request {request_id}")
# Apply sliding window
@ -178,7 +346,7 @@ class SlidingAttentionCacheAllocator(CacheAllocator):
sliding window attention layers, we write the new cache in rolling-buffer kind of way: if we reach the end of
the allocated physical cache, we start writing from the beginning of the physical cache again."""
# Retrieve the block table for the request and raise an error if it doesn't exist
block_table = self._block_table.get(request_id)
block_table = self.block_table.get(request_id)
if block_table is None:
raise ValueError(f"No block table found for request {request_id}")
# Apply sliding window

View File

@ -21,7 +21,7 @@ from functools import partial
from itertools import count
from math import ceil
from time import perf_counter
from typing import Optional, Union
from typing import Optional
import torch
from torch import nn
@ -446,10 +446,7 @@ class ContinuousBatchProcessor:
cumulative_seqlens_q = [0]
logits_indices = []
if isinstance(self.cumulative_seqlens_k, dict):
cumulative_seqlens_k = {layer_type: [0] for layer_type in self.cumulative_seqlens_k}
else:
cumulative_seqlens_k = [0]
cumulative_seqlens_k = {layer_type: [0] for layer_type in self.cumulative_seqlens_k}
read_index = [[] for _ in range(self.cache.num_groups)]
write_index = [[] for _ in range(self.cache.num_groups)]
@ -498,10 +495,7 @@ class ContinuousBatchProcessor:
self.metrics.record_kv_cache_memory_metrics(self.cache)
if logger.isEnabledFor(logging.DEBUG):
if isinstance(self.cumulative_seqlens_k, dict):
ck = max(cumulative_seqlens_k[layer_type][-1] for layer_type in self.cumulative_seqlens_k)
else:
ck = cumulative_seqlens_k[-1]
ck = max(cumulative_seqlens_k[layer_type][-1] for layer_type in self.cumulative_seqlens_k)
logger.debug(
f"Scheduled: {len(self.requests_in_batch)}, Waiting: {len(self.scheduler.waiting_requests)}, "
f"Active: {len(self.scheduler.active_requests)}. cum Q: {cumulative_seqlens_q[-1]}. "
@ -517,7 +511,7 @@ class ContinuousBatchProcessor:
read_index: list[list[int]],
write_index: list[list[int]],
cumulative_seqlens_q: list[int],
cumulative_seqlens_k: Union[list[int], dict[str, list[int]]],
cumulative_seqlens_k: dict[str, list[int]],
logits_indices: list[int],
) -> None:
"""Builds the actual tensors for the current batch, by modifying the already allocated tensors in place."""
@ -561,9 +555,7 @@ class ContinuousBatchProcessor:
@traced
def _maybe_send_output(self, state: RequestState) -> None:
"""Send output to the queue based on streaming mode and request state."""
if state.streaming:
self.output_queue.put(state.to_generation_output())
elif state.status == RequestStatus.FINISHED:
if state.streaming or state.status == RequestStatus.FINISHED:
self.output_queue.put(state.to_generation_output())
@traced
@ -571,17 +563,30 @@ class ContinuousBatchProcessor:
"""Update request states based on generated tokens."""
out_tokens = self._sync()
for i, state in enumerate(self.requests_in_batch):
# If the request has no remaining prompt ids, it means prefill has already or just finished
if len(state.remaining_prompt_ids) == 0:
self.metrics.record_ttft_metric(state.created_time, state.request_id)
state.status = RequestStatus.DECODING
token = out_tokens[self.logits_indices[i]]
state.prompt_ids = [token]
if state.update_with_token(token):
# Update the request and stop if it is complete
is_finished = state.update_and_check_completion(token)
# We mark the completed blocks as such
self.cache.mark_blocks_as_completed(state)
if is_finished:
self.metrics.record_request_completion(state.created_time, state.request_id)
self.scheduler.finish_request(state.request_id, evict_from_cache=(not self.manual_eviction))
self._maybe_send_output(state)
# Otherwise, the request is still prefilling, but the prefill has been split
elif state.status == RequestStatus.PREFILLING_SPLIT:
state.status = RequestStatus.SPLIT_PENDING_REMAINDER
# DEBUG: there is a dangling if, but idk if it ever happens. Adding an error to catch it.
else:
raise ValueError(f"Request {state.request_id} is in an unexpected state: {state.status}")
# We error out if the cache is full
if self.cache.get_num_free_blocks() == 0:
raise ValueError("No more free blocks")
@ -799,7 +804,6 @@ class ContinuousBatchingManager:
logger.warning("Manager thread is already running.")
return
self._result_queue = queue.Queue()
self._generation_thread = threading.Thread(target=self._run_generation_loop)
self._generation_thread.start()
@ -919,6 +923,7 @@ class ContinuousBatchingManager:
if result is not None:
yield result
# FIXME: stop iteration when request status is finished?
def request_id_iter(self, request_id: str) -> Generator[GenerationOutput]:
"""Iterate over results matching a specific request id as they become available."""
request_cancelled = False
@ -930,20 +935,6 @@ class ContinuousBatchingManager:
request_cancelled = self.batch_processor.scheduler.request_is_cancelled(request_id)
@traced
def warmup(self, batch_processor: ContinuousBatchProcessor) -> None:
stream = torch.cuda.Stream(device=self.model.device)
stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(stream):
# Warmup the model with a dummy forward pass
self._generation_step(batch_processor)
torch.cuda.current_stream().wait_stream(stream)
self.graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.graph, stream=stream):
self._generation_step(batch_processor)
@traced
# @torch.compile
def _generation_step(self) -> None:
"""Perform a single generation step. This is cuda graphed"""
self.batch_processor._generation_step(self.model, self.logit_processor, self.do_sample)

View File

@ -105,10 +105,10 @@ class RequestState:
error (Optional[str]): Any error message associated with the request. When None, has had no error yet.
"""
# Required fields
# Required fields # TODO: come up with better names / not sure prompt_ids and such are not redundant
request_id: str
full_prompt_ids: Optional[list[int]] = None # Full initial prompt
prompt_ids: Optional[list[int]] = None # Tokens IDs currently being processed (initial + generated)
prompt_ids: Optional[list[int]] = None # Tokens IDs currently being processed
remaining_prompt_ids: list[int] = field(default_factory=list) # For split requests, prefill left to process
static_outputs: list[int] = field(default_factory=list) # Generated tokens
allocated_blocks: int = 0 # Number of blocks allocated to the request
@ -153,7 +153,7 @@ class RequestState:
# TODO: this logic seems one token off, check it out
@traced
def update_with_token(self, token_id: int) -> bool:
def update_and_check_completion(self, token_id: int) -> bool:
"""Update the request with a newly generated token and check for completion.
Args:

View File

@ -104,7 +104,7 @@ class Scheduler(ABC):
)
@traced
def _allocate_blocks_if_needed(self, state: RequestState, len_next_tokens: int) -> bool:
def _allocate_blocks_if_needed(self, state: RequestState) -> bool:
"""Allocate additional cache blocks for a request if the currently allocated blocks are insufficient to
accommodate the next tokens. It calculates how many blocks are needed based on the request's current
cache occupancy and the number of tokens to be processed. The allocation itself is done by the CacheAllocator
@ -113,10 +113,11 @@ class Scheduler(ABC):
# 1. we check that the occupancy is less than the requested length
# 2. we allocate enough blocks to cover the requested length
current_len = state.current_len()
len_next_tokens = len(state.prompt_ids)
occupancy = state.allocated_blocks * self.cache.block_size - current_len
if occupancy < len_next_tokens or state.allocated_blocks == 0:
blocks_needed = ((len_next_tokens - occupancy + 1) // self.cache.block_size) + 1
allocated = self.cache.allocate_blocks(blocks_needed, state.request_id)
allocated = self.cache.allocate_blocks(blocks_needed, state)
if allocated is None:
return False
state.allocated_blocks += allocated
@ -125,11 +126,25 @@ class Scheduler(ABC):
@traced(span_name="prepare_request")
def _prepare_request_for_processing(
self, state: RequestState, token_budget: int, request_ids_to_remove_from_waiting: set[str]
):
) -> None:
"""Prepares a request for processing in the current batch."""
request_tokens = (
state.remaining_prompt_ids if state.status == RequestStatus.SPLIT_PENDING_REMAINDER else state.prompt_ids
)
# If prefix sharing is enabled, we look for a prefix match and split the request if found
if self.cache.use_prefix_sharing and state.status == RequestStatus.PENDING:
prefill_length = self.cache.search_prefix_match(state.request_id, state.prompt_ids)
if prefill_length > 0:
self.active_requests[state.request_id] = state
state.remaining_prompt_ids = state.prompt_ids[prefill_length:]
state.prompt_ids = state.prompt_ids[prefill_length:]
request_ids_to_remove_from_waiting.add(state.request_id)
state.status = RequestStatus.SPLIT_PENDING_REMAINDER
# If the request has a split prefill, the tokens to process are the remaining prompt ids
if state.status == RequestStatus.SPLIT_PENDING_REMAINDER:
request_tokens = state.remaining_prompt_ids
# Otherwise, the tokens to process are the prompt ids, which are the full prompt or the last predicted tokens
else:
request_tokens = state.prompt_ids
if len(request_tokens) < token_budget:
# Can process the entire prompt/remainder
if state.status == RequestStatus.PENDING:
@ -152,6 +167,7 @@ class Scheduler(ABC):
state.prompt_ids = request_tokens[:token_budget]
# TODO: further common-ize the two classes
@attach_tracer()
class FIFOScheduler(Scheduler):
"""This scheduler processes requests in the order they arrive, meaning decoding requests has priority over
@ -195,30 +211,31 @@ class FIFOScheduler(Scheduler):
self._prepare_request_for_processing(state, token_budget, request_ids_to_remove_from_waiting)
request_len = len(state.prompt_ids)
if not self._allocate_blocks_if_needed(
state, len(state.prompt_ids)
): # don't schedule if we can't allocate blocks
if len(self.cache._free_blocks) == 0:
# If we can't allocate blocks, do not schedule the request and break if the cache is full
if not self._allocate_blocks_if_needed(state):
if self.cache.get_num_free_blocks() == 0:
break
continue
@traced
def _add_to_scheduled_requests(state: RequestState):
scheduled_requests.append(state)
_add_to_scheduled_requests(state)
# Add the request to the scheduled requests
scheduled_requests.append(state)
# Update the token budget
token_budget -= request_len
# If using prefix sharing, we make note of the blocks that will be computed in the forward pass
if self.cache.use_prefix_sharing:
tokens_in_current_block = state.current_len() % self.cache.block_size
tokens_after_forward = tokens_in_current_block + request_len
computed_blocks = tokens_after_forward // self.cache.block_size
self.cache.blocks_to_complete[state.request_id] = computed_blocks
@traced
def _remove_from_waiting_requests(state: RequestState):
req_id = state.request_id
if req_id in self.waiting_requests:
del self.waiting_requests[req_id]
request_ids_to_remove_from_waiting.add(req_id)
_remove_from_waiting_requests(state)
# Remove the request from the waiting queue and mark it as removed
req_id = state.request_id
was_waiting = self.waiting_requests.pop(req_id, None) is not None
if was_waiting:
request_ids_to_remove_from_waiting.add(req_id)
# Early exit of the loop if we have no token budget left
if token_budget == 0:
break
@ -249,6 +266,7 @@ class PrefillFirstScheduler(Scheduler):
elif state.status == RequestStatus.DECODING:
second_priority_states.append(state)
# Add waiting requests to second priority
for req_id in self.waiting_requests_order:
second_priority_states.append(self.waiting_requests[req_id])
@ -259,30 +277,31 @@ class PrefillFirstScheduler(Scheduler):
for state in candidates:
self._prepare_request_for_processing(state, token_budget, request_ids_to_remove_from_waiting)
request_len = len(state.prompt_ids)
if not self._allocate_blocks_if_needed(
state, len(state.prompt_ids)
): # don't schedule if we can't allocate blocks
if len(self.cache._free_blocks) == 0:
# If we can't allocate blocks, do not schedule the request and break if the cache is full
if not self._allocate_blocks_if_needed(state):
if self.cache.get_num_free_blocks() == 0:
break
continue
@traced
def _add_to_scheduled_requests(state: RequestState):
scheduled_requests.append(state)
_add_to_scheduled_requests(state)
# Add the request to the scheduled requests
scheduled_requests.append(state)
# Update the token budget
token_budget -= request_len
# If using prefix sharing, we make note of the blocks that will be computed in the forward pass
if self.cache.use_prefix_sharing:
tokens_in_current_block = state.current_len() % self.cache.block_size
tokens_after_forward = tokens_in_current_block + request_len
computed_blocks = tokens_after_forward // self.cache.block_size
self.cache.blocks_to_complete[state.request_id] = computed_blocks
@traced
def _remove_from_waiting_requests(state: RequestState):
req_id = state.request_id
if req_id in self.waiting_requests:
del self.waiting_requests[req_id]
request_ids_to_remove_from_waiting.add(req_id)
_remove_from_waiting_requests(state)
# Remove the request from the waiting queue and mark it as removed
req_id = state.request_id
if req_id in self.waiting_requests:
del self.waiting_requests[req_id]
request_ids_to_remove_from_waiting.add(req_id)
# Early exit of the loop if we have no token budget left
if token_budget == 0:
break

View File

@ -1635,12 +1635,7 @@ class GenerationMixin(ContinuousMixin):
# TransformersKwargs are model-agnostic attention and generation arguments such as 'output_attentions'
for key, value in model_kwargs.items():
if (
value is not None
and key not in model_args
and key not in TransformersKwargs.__optional_keys__
and key != "debug_io"
):
if value is not None and key not in model_args and key not in TransformersKwargs.__optional_keys__:
unused_model_args.append(key)
if unused_model_args:

View File

@ -512,8 +512,10 @@ def accelerate_disk_offload(
checkpoint_files,
device_map,
checkpoint_keys,
key_renaming_mapping,
sharded_metadata,
dtype,
reverse_key_renaming_mapping,
):
disk_only_shard_files = []
if disk_offload_folder is not None:
@ -532,13 +534,19 @@ def accelerate_disk_offload(
weight_map = dict.fromkeys(checkpoint_keys, checkpoint_files[0])
else:
folder = os.path.sep.join(checkpoint_files[0].split(os.path.sep)[:-1])
# Fix the weight map keys according to the key mapping
weight_map = {
key_renaming_mapping[k]: v
for k, v in sharded_metadata["weight_map"].items()
if k in key_renaming_mapping
}
weight_map = {k: os.path.join(folder, v) for k, v in weight_map.items()}
# Find potential checkpoints containing only offloaded weights
disk_only_shard_files = get_disk_only_shard_files(device_map, weight_map)
disk_offload_index = {
name: {
"safetensors_file": file,
"weight_name": name,
"weight_name": reverse_key_renaming_mapping[name],
"dtype": str_dtype,
}
for name, file in weight_map.items()

View File

@ -1,4 +1,5 @@
import inspect
from copy import deepcopy
from inspect import signature
from ..utils import (
@ -23,6 +24,7 @@ if is_accelerate_available():
import accelerate
from accelerate import init_empty_weights
from accelerate.hooks import add_hook_to_module, remove_hook_from_module
from accelerate.utils import find_tied_parameters
logger = logging.get_logger(__name__)
@ -149,6 +151,52 @@ def replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name
return model
def get_keys_to_not_convert(model):
r"""
An utility function to get the key of the module to keep in full precision if any For example for CausalLM modules
we may want to keep the lm_head in full precision for numerical stability reasons. For other architectures, we want
to keep the tied weights of the model. The function will return a list of the keys of the modules to not convert in
int8.
Parameters:
model (`torch.nn.Module`):
Input model
"""
# Create a copy of the model and tie the weights, then
# check if it contains tied weights
tied_model = deepcopy(model) # this has 0 cost since it is done inside `init_empty_weights` context manager`
tied_model.tie_weights()
tied_params = find_tied_parameters(tied_model)
tied_keys = sum(tied_params, [])
has_tied_params = len(tied_keys) > 0
# If there is not tied weights, we want to keep the lm_headoutput_embedding) in full precision
if not has_tied_params:
output_emb = model.get_output_embeddings()
if output_emb is not None:
list_last_module = [name for name, module in model.named_modules() if id(module) == id(output_emb)]
return list_last_module
# otherwise, no tied weights, no output embedding defined, simply keep the last module in full precision
list_modules = list(model.named_parameters())
list_last_module = [list_modules[-1][0]]
# add last module together with tied weights
intersection = set(list_last_module) - set(tied_keys)
list_untouched = list(set(tied_keys)) + list(intersection)
# remove ".weight" from the keys
names_to_remove = [".weight", ".bias"]
filtered_module_names = []
for name in list_untouched:
for name_to_remove in names_to_remove:
if name_to_remove in name:
name = name.replace(name_to_remove, "")
filtered_module_names.append(name)
return filtered_module_names
# Copied from PEFT: https://github.com/huggingface/peft/blob/47b3712898539569c02ec5b3ed4a6c36811331a1/src/peft/utils/integrations.py#L41
def dequantize_bnb_weight(weight: "torch.nn.Parameter", dtype: "torch.dtype", state=None):
"""

View File

@ -13,11 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from collections.abc import Sequence
from typing import Any, Optional, Union
from typing import Optional
from ..core_model_loading import ConversionOps
from ..utils import is_accelerate_available, is_torch_accelerator_available, is_torch_available, logging
@ -33,18 +30,6 @@ if is_accelerate_available():
logger = logging.get_logger(__name__)
try:
_FP8_DTYPE = torch.float8_e4m3fn
_FP8_MIN = torch.finfo(_FP8_DTYPE).min
_FP8_MAX = torch.finfo(_FP8_DTYPE).max
_FP8_IS_INT = False
except AttributeError:
_FP8_DTYPE = torch.int8
_FP8_MIN, _FP8_MAX = -127, 127
_FP8_IS_INT = True
logger.warning_once(
"torch.float8_e4m3fn not available; falling back to int8 emulation for Fp8Quantize operations."
)
# Copied from https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/inference/kernel.py
@ -347,12 +332,6 @@ class FP8Linear(nn.Linear):
if self.weight.element_size() > 1:
return F.linear(input, self.weight, self.bias)
else:
if isinstance(self.weight, torch.distributed.tensor.DTensor):
weight = self.weight._local_tensor.contiguous()
scale_inv = self.weight_scale_inv._local_tensor.contiguous()
else:
weight = self.weight.contiguous()
scale_inv = self.weight_scale_inv.contiguous()
# Context manager used to switch among the available accelerators
device_type = torch.accelerator.current_accelerator().type if is_torch_accelerator_available() else "cuda"
torch_accelerator_module = getattr(torch, device_type, torch.cuda)
@ -360,9 +339,9 @@ class FP8Linear(nn.Linear):
qinput, scale = act_quant(input, self.block_size[1])
output = w8a8_block_fp8_matmul_triton(
qinput,
weight,
self.weight,
scale,
scale_inv,
self.weight_scale_inv,
self.block_size,
output_dtype=input.dtype,
)
@ -371,124 +350,9 @@ class FP8Linear(nn.Linear):
torch_accelerator_module.synchronize()
if self.bias is not None:
output = output + self.bias
output = torch.nan_to_num(output, nan=0.0)
return output.to(dtype=input.dtype)
def _ceil_div(a, b):
return (a + b - 1) // b
class FP8Expert(nn.Module):
dtype = torch.float8_e4m3fn
def __init__(self, config, block_size, device):
super().__init__()
from ..activations import ACT2FN
self.block_size = block_size
self.num_experts = config.num_local_experts
self.hidden_dim = config.hidden_size
self.intermediate_dim = config.intermediate_size
Wg_out, Wg_in = 2 * self.intermediate_dim, self.hidden_dim
Wd_out, Wd_in = self.hidden_dim, self.intermediate_dim
self.gate_up_proj = nn.Parameter(
torch.zeros(self.num_experts, Wg_out, Wg_in, dtype=FP8Expert.dtype, device=device)
)
self.down_proj = nn.Parameter(
torch.zeros(self.num_experts, Wd_out, Wd_in, dtype=FP8Expert.dtype, device=device)
)
# Create inverse scale tiles only when using 1-byte types (fp8)
if self.gate_up_proj.element_size() == 1:
bo, bi = self.block_size
# gate_up tiles: ceil(Wg_out/bo) x ceil(Wg_in/bi)
gu_scale_o = _ceil_div(Wg_out, bo)
gu_scale_i = _ceil_div(Wg_in, bi)
self.gate_up_proj_scales_inv = nn.Parameter(
torch.zeros(self.num_experts, gu_scale_o, gu_scale_i, dtype=torch.float32, device=device)
)
# down tiles: ceil(Wd_out/bo) x ceil(Wd_in/bi)
dp_scale_o = _ceil_div(Wd_out, bo)
dp_scale_i = _ceil_div(Wd_in, bi)
self.down_proj_scales_inv = nn.Parameter(
torch.zeros(self.num_experts, dp_scale_o, dp_scale_i, dtype=torch.float32, device=device)
)
else:
# Match FP8Linear behavior when not using 1-byte weights
self.register_parameter("gate_up_proj_scale_inv", None)
self.register_parameter("down_proj_scale_inv", None)
# (Optional) bias per projection — many MoEs omit bias; keep None to match your FP8Linear default
self.register_parameter("gate_up_bias", None)
self.register_parameter("down_bias", None)
# Activation used in the MLP (same as your config / ACT2FN)
# Keep a handle here; actual usage happens in forward of your MoE block
self.act_fn = ACT2FN[config.hidden_act]
def forward(
self,
hidden_states: torch.Tensor,
top_k_index: torch.Tensor,
top_k_weights: torch.Tensor,
) -> torch.Tensor:
final_hidden_states = torch.zeros_like(hidden_states)
num_experts = top_k_weights.shape[1]
with torch.no_grad():
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1)
expert_mask = expert_mask.permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
for expert_idx in expert_hit:
expert_idx = expert_idx[0]
if expert_idx == num_experts:
continue
_, token_idx = torch.where(expert_mask[expert_idx])
current_state = hidden_states.index_select(0, token_idx)
gate, up = self.linear(
current_state, self.gate_up_proj[expert_idx], self.gate_up_proj_scales_inv[expert_idx]
).chunk(2, dim=-1)
current_hidden_states = self.act_fn(gate) * up
current_hidden_states = self.linear(
current_hidden_states, self.down_proj[expert_idx], self.down_proj_scales_inv[expert_idx]
)
routing_weights = top_k_weights[token_idx, expert_idx].unsqueeze(-1)
current_hidden_states = current_hidden_states * routing_weights.to(current_hidden_states.dtype)
final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))
return final_hidden_states
def linear(self, input: torch.Tensor, weight: torch.Tensor, weight_scale_inv: torch.Tensor) -> torch.Tensor:
if weight.element_size() > 1:
return F.linear(input, weight, None)
else:
# Context manager used to switch among the available accelerators
device_type = torch.accelerator.current_accelerator().type if is_torch_accelerator_available() else "cuda"
torch_accelerator_module = getattr(torch, device_type, torch.cuda)
with torch_accelerator_module.device(input.device):
qinput, scale = act_quant(input, self.block_size[1])
output = w8a8_block_fp8_matmul_triton(
qinput,
weight,
scale,
weight_scale_inv,
self.block_size,
output_dtype=input.dtype,
)
# Blocks the CPU until all accelerator operations on the specified device are complete. It is used to ensure that the results of the
# preceding operations are ready before proceeding
torch_accelerator_module.synchronize()
return output.to(dtype=input.dtype)
# TODO: we do need this.... but not recursive...
def _replace_with_fp8_linear(
model,
tp_plan=None,
@ -497,48 +361,40 @@ def _replace_with_fp8_linear(
quantization_config=None,
has_been_replaced=False,
):
iterator = list(model.named_parameters()).copy()
for name, empty_tensor in iterator:
current_key_name = name
name = name.rsplit(".", 1)[0] if "." in name else name
module = model.get_submodule(name)
"""Replace Linear layers with FP8Linear."""
if current_key_name is None:
current_key_name = []
current_key_name_str = re.sub(r"\d+", "*", current_key_name)
if not any(key in current_key_name_str for key in (modules_to_not_convert or [])):
with init_empty_weights():
if (
"gate_up_proj" in current_key_name
or "down_proj" in current_key_name
and "experts" in current_key_name
): # Experts!
in_features = empty_tensor.size(-2)
out_features = empty_tensor.size(-1)
model.set_submodule(
name,
FP8Expert(
config=model.config,
block_size=quantization_config.weight_block_size,
device=empty_tensor.device,
),
)
for name, module in model.named_children():
current_key_name.append(name)
elif isinstance(module, nn.Linear):
in_features = module.in_features
out_features = module.out_features
model.set_submodule(
name,
FP8Linear(
in_features=in_features,
out_features=out_features,
bias=module.bias is not None,
device=module.weight.device,
dtype=module.weight.dtype,
activation_scheme=quantization_config.activation_scheme,
block_size=quantization_config.weight_block_size,
),
if isinstance(module, nn.Linear) and name not in (modules_to_not_convert or []):
current_key_name_str = ".".join(current_key_name)
if not any(key in current_key_name_str for key in (modules_to_not_convert or [])):
with init_empty_weights():
model._modules[name] = FP8Linear(
in_features=module.in_features,
out_features=module.out_features,
bias=module.bias is not None,
device=module.weight.device,
dtype=module.weight.dtype,
activation_scheme=quantization_config.activation_scheme,
block_size=quantization_config.weight_block_size,
)
has_been_replaced = True
# when changing a layer the TP PLAN for that layer should be updated. TODO
has_been_replaced = True
# when changing a layer the TP PLAN for that layer should be updated. TODO
if len(list(module.children())) > 0:
_, has_been_replaced = _replace_with_fp8_linear(
module,
tp_plan,
modules_to_not_convert,
current_key_name,
quantization_config,
has_been_replaced=has_been_replaced,
)
current_key_name.pop(-1)
return model, has_been_replaced
@ -549,7 +405,7 @@ def replace_with_fp8_linear(
quantization_config=None,
):
"""Helper function to replace model layers with FP8 versions."""
modules_to_not_convert += ["lm_head"]
modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert
if quantization_config.modules_to_not_convert is not None:
modules_to_not_convert.extend(quantization_config.modules_to_not_convert)
@ -568,133 +424,3 @@ def replace_with_fp8_linear(
)
return model
class QuantizationOp(ConversionOps):
"""Base class for quantization operations."""
pass
class Fp8Quantize(QuantizationOp):
"""
A quantization operation that creates two tensors, weight and scale out of a weight.
"""
reverse_op: type[ConversionOps]
def __init__(self, block_size: Optional[tuple[int, int]] = None):
self.block_size = block_size
self.reverse_op = Fp8Dequantize
def convert(self, input_dict: torch.Tensor, *, quant_config: dict[str, Any]) -> dict[str, torch.Tensor]:
# Unpack single key/value (value may be wrapped in a list)
target_keys, value = tuple(input_dict.items())[0]
value = value[0] if isinstance(value, list) else value
# Resolve block size (support dict-like or attr-like quant_config)
block_size = None
if quant_config is not None:
if isinstance(quant_config, dict):
block_size = quant_config.get("weight_block_size")
else:
block_size = getattr(quant_config, "weight_block_size", None)
if block_size is None:
block_size = (value.shape[-2], value.shape[-1])
block_m, block_n = block_size
rows, cols = value.shape[-2], value.shape[-1]
# Enforce exact tiling like your original
if rows % block_m != 0 or cols % block_n != 0:
raise ValueError(
f"Matrix dimensions ({rows}, {cols}) must be divisible by block sizes ({block_m}, {block_n}). for {target_keys}"
)
# Leading dims can be empty (2D) or include num_experts/... (3D+)
leading_shape = value.shape[:-2]
rows_tiles = rows // block_m
cols_tiles = cols // block_n
original_shape = value.shape
value_fp32 = value.to(torch.float32)
# Reshape to (..., rows_tiles, block_m, cols_tiles, block_n)
reshaped = value_fp32.reshape(*leading_shape, rows_tiles, block_m, cols_tiles, block_n)
# Per-tile max-abs over the block dims
# dims: block_m is at -3, block_n is at -1 after the reshape
max_abs = reshaped.abs().amax(dim=(-3, -1))
safe_max_abs = torch.where(max_abs > 0, max_abs, torch.ones_like(max_abs))
# Tile scale (we store inverse scale like your Linear: weight_scale_inv)
scales = _FP8_MAX / safe_max_abs
scales = torch.where(max_abs > 0, scales, torch.ones_like(scales)) # keep zeros stable
# Broadcast scales back over the block dims and quantize
# max_abs/scales shape: (..., rows_tiles, cols_tiles)
scales_broadcast = scales.unsqueeze(-1).unsqueeze(-3) # -> (..., rows_tiles, 1, cols_tiles, 1)
scaled = reshaped * scales_broadcast
if _FP8_IS_INT:
quantized = torch.clamp(scaled.round(), min=_FP8_MIN, max=_FP8_MAX).to(_FP8_DTYPE)
else:
quantized = torch.clamp(scaled, min=_FP8_MIN, max=_FP8_MAX).to(_FP8_DTYPE)
quantized = quantized.reshape(original_shape)
inv_scales = (1.0 / scales).to(torch.float32) # shape: (*leading, rows_tiles, cols_tiles)
if target_keys.endswith("weight"):
scale_key = target_keys.rsplit(".", 1)[0] + ".weight_scale_inv"
else:
scale_key = target_keys + "_scales_inv"
# Return both quantized weights and per-tile inverse scales (keeps leading dims, e.g., num_experts)
return {
target_keys: quantized,
scale_key: inv_scales,
}
class Fp8Dequantize(QuantizationOp):
"""Inverse operation of :class:`Fp8Quantize`. Takes a pair (weight, scale) and reconstructs the fp32 tensor."""
def __init__(self, block_size: Optional[tuple[int, int]] = None):
self.block_size = block_size
self.reverse_op = Fp8Quantize
def convert(
self,
value: Union[Sequence[torch.Tensor], dict[str, torch.Tensor]],
*,
context: dict[str, Any],
) -> torch.Tensor:
if isinstance(value, dict):
tensors = list(value.values())
else:
tensors = list(value) if isinstance(value, Sequence) else [value]
if len(tensors) != 2:
raise ValueError("Fp8Dequantize expects exactly two tensors: quantized weights and scales.")
quantized, scales = tensors
if not isinstance(quantized, torch.Tensor) or not isinstance(scales, torch.Tensor):
raise TypeError("Fp8Dequantize expects tensors as inputs.")
quantized_fp32 = quantized.to(torch.float32)
rows, cols = quantized_fp32.shape[-2:]
block_size = self.block_size
if block_size is None:
quant_config = context.get("quantization_config")
block_size = getattr(quant_config, "weight_block_size", None)
if block_size is None:
block_size = (rows, cols)
block_m, block_n = block_size
if rows % block_m != 0 or cols % block_n != 0:
raise ValueError(
f"Matrix dimensions ({rows}, {cols}) must be divisible by block sizes ({block_m}, {block_n})."
)
reshaped = quantized_fp32.reshape(-1, rows // block_m, block_m, cols // block_n, block_n)
expanded_scales = scales.to(torch.float32).reshape(-1, rows // block_m, cols // block_n)
expanded_scales = expanded_scales.unsqueeze(-1).unsqueeze(2)
dequantized = reshaped * expanded_scales
return dequantized.reshape(quantized_fp32.shape)

View File

@ -51,10 +51,13 @@ try:
)
},
"RMSNorm": {
"cuda": LayerRepository(
repo_id="kernels-community/liger_kernels",
layer_name="LigerRMSNorm",
),
"cuda": {
Mode.INFERENCE: LayerRepository(
repo_id="kernels-community/liger_kernels",
layer_name="LigerRMSNorm",
# revision="pure-layer-test",
),
},
"rocm": {
Mode.INFERENCE: LayerRepository(
repo_id="kernels-community/liger_kernels",

View File

@ -236,7 +236,7 @@ class PeftAdapterMixin:
**adapter_kwargs,
)
peft_config.inference_mode = not is_trainable
# TODO: WE NEED TOO APPLY OUR DYNAMIC WEIGHT CONVERSION AT SOME POINT HERE!
# Create and add fresh new adapters into the model.
inject_adapter_in_model(peft_config, self, adapter_name, **peft_load_kwargs)

View File

@ -18,7 +18,6 @@ import operator
import os
import re
from functools import partial, reduce
from typing import Optional
import torch
import torch.distributed as dist
@ -307,7 +306,7 @@ def repack_weights(
return final_ordered_tensor
def get_tensor_shard(param, empty_param, device_mesh, rank, dim, tensor_idx: Optional[int] = None):
def get_tensor_shard(param, empty_param, device_mesh, rank, dim):
"""
Generalized tensor sharding across a multi-dimensional device mesh.
Extract only the fraction of the parameter owned by the given `rank` when the parameter would have gone sharding at provided `dim`.
@ -359,57 +358,32 @@ def get_tensor_shard(param, empty_param, device_mesh, rank, dim, tensor_idx: Opt
rank (int): Global rank of the current process/device.
dim (int): Dimension along which to shard the tensor.
"""
param_dim = empty_param.ndim
param_dim = empty_param.dim()
if dim < 0:
dim = param_dim + dim
if dim >= param_dim:
raise ValueError(f"dim {dim} is out of bounds for tensor of dimension {param_dim}")
# Flatten the mesh to get the total number of devices
mesh_shape = device_mesh.shape
world_size = reduce(operator.mul, mesh_shape)
if dim < 0:
dim = param_dim + dim
if empty_param.dim() == 3 and dim == 1 and len(param.get_shape()) == 2:
dim = 0
elif empty_param.dim() == 3 and dim == 2 and len(param.get_shape()) == 2:
dim = 0
shard_size = math.ceil(empty_param.size(dim) / world_size)
start = rank * shard_size
end = min(start + shard_size, empty_param.size(dim))
if dim >= param_dim:
raise ValueError(f"dim {dim} is out of bounds for tensor of dimension {param_dim}")
if rank >= world_size:
raise ValueError(f"Rank {rank} is out of bounds for mesh size {world_size}")
# we have the full tensor not 1 part of it.
# in that case, we just assume that the weight was properly saved
# and thus because we TP if the layer is colwise it should not use this. Layer should be packed_colwise
# to inform that it needs to read form a packed tensor. It will also take care of the module list thingy.
# here we take care of potential chunking / layer split / layer chunking.
# The only "hard" case is? if we collect q,k,v -> merge it into qkv. In that case
# actually we still shard dim=0 does not change
# so only case is if the dim of the empty param is 3 and the shard dim is 0 -> we put the
# tensor on a certain device (with the input tensor_index)
dimensions = param.get_shape()
shard_size = math.ceil(empty_param.shape[dim] / world_size)
start = rank * shard_size
if empty_param.dim() == 3 and dim == 0 and len(param.get_shape()) == 2:
# special case we don't "shard" just send this entire tensor to the correct rank.
if start <= tensor_idx < end:
# this tensor does need to be materialized on this device:
return param[:]
else:
return torch.empty([], dtype=torch.int64, device=rank)
slice_indices = [slice(None)] * len(param.get_shape())
if start < param.get_shape()[dim]:
# Construct slicing index dynamically
end = min(start + shard_size, empty_param.shape[dim])
slice_indices = [slice(None)] * param_dim
if start < empty_param.shape[dim]:
slice_indices[dim] = slice(start, end)
param = param[tuple(slice_indices)]
if isinstance(param, list): # TODO handle the modulelist case!
param = [p[:] for p in param]
return param
return param[tuple(slice_indices)]
dimensions = list(param.shape)
dimensions[dim] = 0
return torch.empty(tuple(dimensions), dtype=torch.int64) # empty allocates memory....
return torch.empty(tuple(dimensions), dtype=torch.int64)
def distribute_module(
@ -436,19 +410,6 @@ class TensorParallelLayer:
"""
use_dtensor = True
device_mesh = None
rank = None
# Used to compare the shape of the original tensor
empty_param = None
# Used to init the corresponding DTensor
shard = None
def __init__(self, device_mesh=None, rank=None, empty_param=None):
self.rank = rank
self.device_mesh = device_mesh
self.empty_param = empty_param
@staticmethod
def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh): ...
@ -478,12 +439,12 @@ class GatherParallel(TensorParallelLayer):
def __init__(
self,
*,
input_layouts: Placement | None = None,
output_layouts: Placement | None = None,
use_local_output: bool = True,
**kwargs,
):
super().__init__(**kwargs)
super().__init__()
self.input_layouts = (input_layouts or Replicate(),)
self.output_layouts = output_layouts
self.desired_input_layouts = (Replicate(),)
@ -504,21 +465,6 @@ class GatherParallel(TensorParallelLayer):
dist.all_reduce(outputs[0], op=dist.ReduceOp.SUM, async_op=False)
return outputs
def shard_tensor(
self,
param,
param_type=None,
param_casting_dtype=None,
to_contiguous=None,
rank=None,
device_mesh=None,
tensor_idx=None,
):
shard = [Replicate()]
parameter = param[...]
self.shard = shard
return parameter, shard
def prepare_module_tp(self, module: nn.Module, device_mesh) -> nn.Module:
distribute_module(
module,
@ -547,23 +493,6 @@ class IsolatedParallel(TensorParallelLayer):
# TODO: figure out dynamo support for instance method and switch this to instance method
return outputs
def shard_tensor(
self,
param,
param_type=None,
param_casting_dtype=None,
to_contiguous=None,
rank=None,
device_mesh=None,
tensor_idx=None,
):
mesh = device_mesh or self.device_mesh
parameter = param[...]
if mesh is not None:
parameter = parameter / mesh.size()
self.shard = None
return parameter, None
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
param = param[...].to(param_casting_dtype)
if to_contiguous:
@ -586,8 +515,8 @@ class ReplicateParallel(TensorParallelLayer):
This class is used to replicate computation in a TP layer (used in SP regions when we don't use sequence parallelism for example)
"""
def __init__(self, use_dtensor=True, use_local_output=True, **kwargs):
super().__init__(**kwargs)
def __init__(self, *, use_dtensor=True, use_local_output=True):
super().__init__()
self.input_layouts = (Replicate(),)
self.output_layouts = (Replicate(),)
self.desired_input_layouts = (Replicate(),)
@ -608,33 +537,12 @@ class ReplicateParallel(TensorParallelLayer):
def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
return outputs.to_local() if use_local_output and isinstance(outputs, DTensor) else outputs
def shard_tensor(
self,
param,
param_type=None,
param_casting_dtype=None,
to_contiguous=None,
rank=None,
device_mesh=None,
tensor_idx=None,
):
parameter = param[...]
shard = [Replicate()]
self.shard = shard
return parameter, shard
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
parameter, shard = self.shard_tensor(
param,
param_type=param_type,
param_casting_dtype=param_casting_dtype,
to_contiguous=to_contiguous,
rank=rank,
device_mesh=device_mesh,
)
if self.use_dtensor:
parameter = DTensor.from_local(parameter, device_mesh, shard, run_check=False)
return parameter
param = param[...].to(param_casting_dtype)
if to_contiguous:
param = param.contiguous()
param = DTensor.from_local(param, device_mesh, [Replicate()], run_check=False)
return param
class ColwiseParallel(TensorParallelLayer):
@ -644,13 +552,13 @@ class ColwiseParallel(TensorParallelLayer):
def __init__(
self,
*,
input_layouts: Placement | None = None,
output_layouts: Placement | None = None,
use_local_output: bool = True,
use_dtensor=True,
**kwargs,
):
super().__init__(**kwargs)
super().__init__()
self.input_layouts = (input_layouts or Replicate(),)
self.output_layouts = (output_layouts or Shard(-1),)
self.desired_input_layouts = (Replicate(),)
@ -670,24 +578,17 @@ class ColwiseParallel(TensorParallelLayer):
input_tensor = input_tensor.redistribute(placements=desired_input_layouts, async_op=False)
return input_tensor
def shard_tensor(self, param, param_type=None, tensor_idx=None):
device_mesh = self.device_mesh
empty_param = self.empty_param
rank = self.rank
if param_type == "bias":
parameter = get_tensor_shard(param, empty_param, device_mesh, rank, -1, tensor_idx)
shard = [Shard(-1)]
else:
shard = [Shard(-2)]
parameter = get_tensor_shard(param, empty_param, device_mesh, rank, -2, tensor_idx)
self.shard = shard
return parameter, shard
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
# colwise shard weight/bias to Shard(0), weight be Shard(-2) (0 if you have 1 dim only)
# means Colwise as Linear is input * weight^T + bias, where
# weight would become Shard(1)
parameter, shard = self.shard_tensor(param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh)
if param_type == "bias":
parameter = get_tensor_shard(param, empty_param, device_mesh, rank, -1)
shard = [Shard(-1)]
else:
shard = [Shard(-2)]
parameter = get_tensor_shard(param, empty_param, device_mesh, rank, -2)
parameter = parameter.to(param_casting_dtype)
if to_contiguous:
parameter = parameter.contiguous()
@ -707,26 +608,6 @@ class ColwiseParallel(TensorParallelLayer):
class PackedColwiseParallel(ColwiseParallel):
def shard_tensor(
self,
param,
param_type=None,
param_casting_dtype=None,
to_contiguous=None,
rank=None,
device_mesh=None,
tensor_idx=None,
):
device_mesh = device_mesh or self.device_mesh
empty_param = self.empty_param
rank = rank if rank is not None else self.rank
return get_packed_weights(param, empty_param, device_mesh, rank, -2), [Shard(-2)]
def create_nn_parameter(
self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh
):
return nn.Parameter(param, requires_grad=param.is_floating_point())
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
# colwise shard weight/bias to Shard(0), weight be Shard(-2) (0 if you have 1 dim only)
# means Colwise as Linear is input * weight^T + bias, where
@ -761,40 +642,18 @@ class RowwiseParallel(TensorParallelLayer):
def __init__(
self,
*,
input_layouts: Placement | None = None,
output_layouts: Placement | None = None,
use_local_output: bool = True,
use_dtensor=True,
**kwargs,
):
super().__init__(**kwargs)
super().__init__()
self.input_layouts = (input_layouts or Shard(-1),)
self.output_layouts = (output_layouts or Replicate(),)
self.use_local_output = use_local_output
self.use_dtensor = use_dtensor
def shard_tensor(
self,
param,
param_type=None,
param_casting_dtype=None,
to_contiguous=None,
rank=None,
device_mesh=None,
tensor_idx=None,
):
device_mesh = device_mesh or self.device_mesh
empty_param = self.empty_param
rank = rank if rank is not None else self.rank
if param_type == "bias":
shard = [Replicate()]
parameter = param[:]
else:
parameter = get_tensor_shard(param, empty_param, device_mesh, rank, -1, tensor_idx=tensor_idx)
shard = [Shard(-1)]
self.shard = shard
return parameter, shard
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
# Rowwise shard weight to Shard(1), bias to Replicate(), weight be Shard(1)
# means Rowwise as nn.Linear is input * weight^T + bias, where
@ -866,21 +725,6 @@ class RowwiseParallel(TensorParallelLayer):
class PackedRowwiseParallel(RowwiseParallel):
def shard_tensor(
self,
param,
param_type=None,
param_casting_dtype=None,
to_contiguous=None,
rank=None,
device_mesh=None,
tensor_idx=None,
):
device_mesh = device_mesh or self.device_mesh
empty_param = self.empty_param
rank = rank if rank is not None else self.rank
return get_packed_weights(param, empty_param, device_mesh, rank, -1), [Shard(-1)]
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
# colwise shard weight/bias to Shard(0), weight be Shard(-2) (0 if you have 1 dim only)
# means Colwise as Linear is input * weight^T + bias, where
@ -939,8 +783,8 @@ class SequenceParallel(TensorParallelLayer):
to ensure that they are replicated.
"""
def __init__(self, sequence_dim: int = 1, use_local_output: bool = False, use_dtensor=False, **kwargs):
super().__init__(**kwargs)
def __init__(self, *, sequence_dim: int = 1, use_local_output: bool = False, use_dtensor=False):
super().__init__()
self.input_layouts = (Replicate(),)
self.desired_input_layouts = (Shard(1),)
self.output_layouts = (Replicate(),)
@ -949,21 +793,6 @@ class SequenceParallel(TensorParallelLayer):
self.sequence_sharding = (Shard(sequence_dim),)
self.use_local_output = use_local_output
def shard_tensor(
self,
param,
param_type=None,
param_casting_dtype=None,
to_contiguous=None,
rank=None,
device_mesh=None,
tensor_idx=None,
):
parameter = param[...]
shard = [Replicate()]
self.shard = shard
return parameter, shard
@staticmethod
def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh):
input_tensor = inputs[0]
@ -998,34 +827,10 @@ class GroupedGemmParallel(TensorParallelLayer):
Applies Expert Parallelism to MoE experts by loading the correct experts on each device.
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
def __init__(self):
super().__init__()
self.use_dtensor = False
def shard_tensor(
self,
param,
param_type=None,
param_casting_dtype=None,
to_contiguous=None,
rank=None,
device_mesh=None,
tensor_idx=None,
):
empty_param = self.empty_param
ep_rank = self.rank
device_mesh = self.device_mesh
global_num_experts = empty_param.shape[0]
if global_num_experts % device_mesh.size() != 0:
raise ValueError(
f"Global number of experts must be divisible by number of devices: {global_num_experts} % {device_mesh.size()} != 0"
)
local_num_experts = global_num_experts // device_mesh.size()
parameter = param[ep_rank * local_num_experts : (ep_rank + 1) * local_num_experts]
self.shard = None
return parameter, None
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
ep_rank = rank
global_num_experts = empty_param.shape[0]
@ -1046,8 +851,8 @@ class RouterParallel(TensorParallelLayer):
"""
def __init__(self, *args, **kwargs):
super().__init__(**kwargs)
self.args = args
self.kwargs = kwargs
self.use_dtensor = False
@staticmethod
@ -1112,20 +917,6 @@ class RouterParallel(TensorParallelLayer):
) # masking class for one hot
return router_scores, router_indices
def shard_tensor(
self,
param,
param_type=None,
param_casting_dtype=None,
to_contiguous=None,
rank=None,
device_mesh=None,
tensor_idx=None,
):
parameter = param[...]
self.shard = None
return parameter, None
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
# TODO: i'd like for this to be the default
param = param[...].to(param_casting_dtype)
@ -1268,9 +1059,6 @@ def shard_and_distribute_module(
if current_shard_plan is not None:
try:
tp_layer = ALL_PARALLEL_STYLES[current_shard_plan]
tp_layer.empty_param = empty_param
tp_layer.device_mesh = device_mesh
tp_layer.rank = rank
param = tp_layer.partition_tensor(
param, empty_param, param_type, param_casting_dtype, is_contiguous, rank, device_mesh
)

View File

@ -23,11 +23,10 @@ import json
import os
import re
import sys
import time
import warnings
from abc import abstractmethod
from collections import defaultdict
from collections.abc import Callable, Sequence
from collections.abc import Callable
from concurrent.futures import ThreadPoolExecutor, as_completed
from contextlib import contextmanager
from enum import Enum
@ -46,17 +45,17 @@ from torch.distributions import constraints
from torch.utils.checkpoint import checkpoint
from .configuration_utils import PreTrainedConfig
from .conversion_mapping import get_checkpoint_conversion_mapping
from .core_model_loading import WeightConverter, convert_and_load_state_dict_in_model, revert_weight_conversion
from .distributed import DistributedConfig
from .dynamic_module_utils import custom_object_save
from .generation import CompileConfig, GenerationConfig
from .integrations import PeftAdapterMixin, deepspeed_config, is_deepspeed_zero3_enabled, is_fsdp_enabled
from .integrations.accelerate import (
_get_device_map,
accelerate_disk_offload,
accelerate_dispatch,
check_and_set_device_map,
expand_device_map,
find_tied_parameters,
init_empty_weights,
)
from .integrations.deepspeed import _load_state_dict_into_zero3_model
@ -123,7 +122,6 @@ from .utils.import_utils import (
is_sagemaker_mp_enabled,
is_tracing,
)
from .utils.loading_report import log_state_dict_report
from .utils.quantization_config import QuantizationMethod
@ -132,6 +130,7 @@ if is_accelerate_available():
from accelerate.utils import (
extract_model_from_parallel,
offload_weight,
save_offload_index,
)
from accelerate.utils.modeling import get_state_dict_from_offload
@ -697,6 +696,82 @@ def _load_state_dict_into_meta_model(
return disk_offload_index
def load_shard_file(args):
(
shard_file,
state_dict,
disk_only_shard_files,
is_quantized,
device_map,
hf_quantizer,
key_renaming_mapping,
weights_only,
model,
reverse_key_renaming_mapping,
disk_offload_folder,
disk_offload_index,
device_mesh,
) = args
# Skip the load for shards that only contain disk-offloaded weights
if shard_file in disk_only_shard_files:
return [], disk_offload_index
map_location = "cpu"
if shard_file.endswith(".safetensors") and not (is_deepspeed_zero3_enabled() and not is_quantized):
map_location = "meta"
# If shard_file is "", we use the existing state_dict instead of loading it
if shard_file != "":
state_dict = load_state_dict(
shard_file, is_quantized=is_quantized, map_location=map_location, weights_only=weights_only
)
# Fix the key names
state_dict = {key_renaming_mapping[k]: v for k, v in state_dict.items() if k in key_renaming_mapping}
error_msgs = []
if is_deepspeed_zero3_enabled() and not is_quantized:
error_msgs += _load_state_dict_into_zero3_model(model, state_dict)
# Skip it with fsdp on ranks other than 0
elif not (is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized):
disk_offload_index = _load_state_dict_into_meta_model(
model,
state_dict,
shard_file,
reverse_key_renaming_mapping,
device_map=device_map,
disk_offload_folder=disk_offload_folder,
disk_offload_index=disk_offload_index,
hf_quantizer=hf_quantizer,
device_mesh=device_mesh,
)
return error_msgs, disk_offload_index
def load_shard_files_with_threadpool(args_list):
num_workers = int(os.environ.get("HF_PARALLEL_LOADING_WORKERS", "8"))
# Do not spawn anymore workers than you need
num_workers = min(len(args_list), num_workers)
logger.info(f"Loading model weights in parallel with {num_workers} workers...")
error_msgs = []
with ThreadPoolExecutor(max_workers=num_workers) as executor:
with logging.tqdm(total=len(args_list), desc="Loading checkpoint shards") as pbar:
futures = [executor.submit(load_shard_file, arg) for arg in args_list]
for future in as_completed(futures):
_error_msgs, disk_offload_index = future.result()
error_msgs += _error_msgs
pbar.update(1)
return error_msgs, disk_offload_index
def _add_variant(weights_name: str, variant: Optional[str] = None) -> str:
if variant is not None:
@ -1099,6 +1174,104 @@ def _get_dtype(
return config, dtype, dtype_orig
def _find_missing_and_unexpected_keys(
model: "PreTrainedModel",
original_checkpoint_keys: list[str],
checkpoint_keys: list[str],
loading_base_model_from_task_state_dict: bool,
hf_quantizer: Optional[HfQuantizer],
) -> tuple[list[str], list[str]]:
"""Find missing keys (keys that are part of the model parameters but were NOT found in the loaded state dict keys) and unexpected keys
(keys found in the loaded state dict keys, but that are NOT part of the model parameters)
"""
prefix = model.base_model_prefix
# Compute expected keys, i.e. keys that the full model expects
expected_keys = list(model.state_dict().keys())
if hf_quantizer is not None:
expected_keys = hf_quantizer.update_expected_keys(model, expected_keys, checkpoint_keys)
# Adjust prefix of the keys to make them match loaded keys before removing them
missing_keys = sorted(set(expected_keys) - set(checkpoint_keys))
unexpected_keys = set(checkpoint_keys) - set(expected_keys)
# If a module has the same name under the base and task specific model, we have to re-add it to unexpected keys
if loading_base_model_from_task_state_dict:
task_specific_keys = [k for k in original_checkpoint_keys if not k.startswith(f"{prefix}.")]
unexpected_keys.update(task_specific_keys)
# Remove nonpersistent buffers from unexpected keys: they are not in the expected keys (model state dict), but
# may be in the loaded keys. Note that removing all buffers does the job, as they were part of the expected keys anyway
model_buffers = {n for n, _ in model.named_buffers()}
unexpected_keys = sorted(unexpected_keys - model_buffers)
tied_params = find_tied_parameters(model)
for group in tied_params:
missing_in_group = [k for k in missing_keys if k in group]
if len(missing_in_group) > 0 and len(missing_in_group) < len(group):
missing_keys = [k for k in missing_keys if k not in missing_in_group]
if hf_quantizer is not None:
missing_keys = hf_quantizer.update_missing_keys(model, missing_keys, prefix)
unexpected_keys = hf_quantizer.update_unexpected_keys(model, unexpected_keys)
return missing_keys, unexpected_keys
def _find_mismatched_keys(
model: "PreTrainedModel",
state_dict: Optional[dict],
checkpoint_files: Optional[list[str]],
ignore_mismatched_sizes: bool,
keys_to_rename_mapping: dict[str, str],
is_quantized: bool,
weights_only: bool,
) -> tuple[list[str], list[tuple[int, int]]]:
"""
Find potential shape mismatch between the different state dicts and the model parameters, but only if `ignore_mismatched_sizes`
is True. Otherwise, return immediately and any shape mismatch that may exist will be raised later on. This avoids checking
every parameter in advance, as shape mismatch are extremely rare in practice. If we want to ignore them however, we do
need to check in advance as we need to know which parameters we need to move back from meta to cpu, and initialize
correctly. Indeed, as our model initialization takes place at the module level, and not the weight level, in the
case of a sharded checkpoint we cannot correctly initialize the weights according to `model._init_weights()` if we perform
this check on each state dict at loading time (after the first loaded checkpoint, there are no way to initialize only the
mismatched weights if any, without overwriting the previously loaded weights as well because all the module will be
initialized, not only the weights that are mismatched).
"""
# An error will be raised later on anyway if there is a mismatch - this avoids running the rest of this function
# if there are no mismatch (which is almost always the case)
if not ignore_mismatched_sizes:
return [], []
if state_dict is not None:
checkpoint_files = [""]
model_state_dict = model.state_dict()
mismatched_keys = []
mismatched_shapes = []
for shard_file in checkpoint_files:
# If shard_file is "", we use the existing state_dict instead of loading it
if shard_file != "":
state_dict = load_state_dict(
shard_file, is_quantized=is_quantized, map_location="meta", weights_only=weights_only
)
# Fix the key names
new_state_dict = {keys_to_rename_mapping[k]: v for k, v in state_dict.items() if k in keys_to_rename_mapping}
for key, tensor in new_state_dict.items():
if key in model_state_dict and tensor.shape != model_state_dict[key].shape:
# This skips size mismatches for 4-bit weights. Two 4-bit values share an 8-bit container, causing size differences.
# Without matching with module type or parameter type it seems like a practical way to detect valid 4bit weights.
if not (
is_quantized and tensor.shape[-1] == 1 and tensor.numel() * 2 == model_state_dict[key].numel()
):
mismatched_keys.append(key)
mismatched_shapes.append((tensor.shape, model_state_dict[key].shape))
return mismatched_keys, mismatched_shapes
class PipelineParallel(Enum):
inputs = 0
outputs = 1
@ -1504,8 +1677,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
# to also prevent bfloat16 casting, use the _keep_in_fp32_modules_strict flag
_keep_in_fp32_modules_strict = None
dtype_plan: Optional[dict[str, torch.dtype]] = None
# a list of `re` patterns of `state_dict` keys that should be removed from the list of missing
# keys we find (keys inside the model but not in the checkpoint) and avoid unnecessary warnings.
_keys_to_ignore_on_load_missing = None
@ -1670,18 +1841,11 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
self.name_or_path = config.name_or_path
self.warnings_issued = {}
self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None
# Overwrite the class attribute to make it an instance attribute, so models like
# `InstructBlipForConditionalGeneration` can dynamically update it without modifying the class attribute
# when a different component (e.g. language_model) is used.
self._keep_in_fp32_modules = copy.copy(self.__class__._keep_in_fp32_modules)
self._keep_in_fp32_modules_strict = copy.copy(self.__class__._keep_in_fp32_modules_strict)
self.dtype_plan = {}
if isinstance(self._keep_in_fp32_modules, list):
self.dtype_plan.update(dict.fromkeys(self._keep_in_fp32_modules, torch.float32))
if isinstance(self._keep_in_fp32_modules_strict, list):
self.dtype_plan.update(dict.fromkeys(self._keep_in_fp32_modules_strict, torch.float32))
self._no_split_modules = self._no_split_modules or []
_CAN_RECORD_REGISTRY[str(self.__class__)] = self._can_record_outputs # added for executorch support only
@ -1697,6 +1861,31 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
self.init_weights()
self._backward_compatibility_gradient_checkpointing()
# Make sure the modules correctly exist if the flag is active
if self._keep_in_fp32_modules is not None or self._keep_in_fp32_modules_strict is not None:
all_parameters = {name for name, _ in self.named_parameters() if len(name) > 0}
unique_module_names = set()
# Get all unique module names in the module graph, without the prefixes
for param in all_parameters:
unique_module_names.update(
[name for name in param.split(".") if not name.isnumeric() and name not in ["weight", "bias"]]
)
# Check that every module in the keep_in_fp32 list is part of the module graph
if self._keep_in_fp32_modules is not None:
for module in self._keep_in_fp32_modules:
if module not in unique_module_names:
raise ValueError(
f"{module} was specified in the `_keep_in_fp32_modules` list, but is not part of the modules in"
f" {self.__class__.__name__}"
)
if self._keep_in_fp32_modules_strict is not None:
for module in self._keep_in_fp32_modules_strict:
if module not in unique_module_names:
raise ValueError(
f"{module} was specified in the `_keep_in_fp32_modules_strict` list, but is not part of the modules in"
f" {self.__class__.__name__}"
)
self._tp_plan, self._ep_plan, self._pp_plan = {}, {}, {}
# If current model is a base model, attach `base_model_tp_plan` and `base_model_pp_plan` from config
@ -2443,41 +2632,34 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
`nn.Parameter`, this method should also be overridden in order to initialize it correctly.
"""
if hasattr(self.config, "initializer_range"):
std = self.config.initializer_range or 0.02
std = self.config.initializer_range
else:
# 0.02 is the standard default value across the library
std = getattr(self.config.get_text_config(), "initializer_range", 0.02)
try:
if isinstance(
module, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.ConvTranspose1d, nn.ConvTranspose2d)
):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.Parameter):
module.data.normal_(mean=0.0, std=std)
elif isinstance(module, nn.MultiheadAttention):
# This uses torch's original init
module._reset_parameters()
# We cannot use `isinstance` on the RMSNorms or LayerNorms, as they usually are custom modules which change names
# between modelings (because they are prefixed with the model name)
elif (
isinstance(module, (nn.GroupNorm, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d))
or "LayerNorm" in module.__class__.__name__
or "RMSNorm" in module.__class__.__name__
):
# Norms can exist without weights (in which case they are None from torch primitives)
if hasattr(module, "weight") and module.weight is not None:
module.weight.data.fill_(1.0)
if hasattr(module, "bias") and module.bias is not None:
module.bias.data.zero_()
except Exception as e:
logger.warning(f"Failed to init: {str(e)}")
if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.ConvTranspose1d, nn.ConvTranspose2d)):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.MultiheadAttention):
# This uses torch's original init
module._reset_parameters()
# We cannot use `isinstance` on the RMSNorms or LayerNorms, as they usually are custom modules which change names
# between modelings (because they are prefixed with the model name)
elif (
isinstance(module, (nn.GroupNorm, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d))
or "LayerNorm" in module.__class__.__name__
or "RMSNorm" in module.__class__.__name__
):
# Norms can exist without weights (in which case they are None from torch primitives)
if hasattr(module, "weight") and module.weight is not None:
module.weight.data.fill_(1.0)
if hasattr(module, "bias") and module.bias is not None:
module.bias.data.zero_()
def _initialize_weights(self, module):
"""
@ -2512,12 +2694,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
else:
module.smart_apply(fn)
fn(self)
if not isinstance(self, nn.Parameter):
for name, param in self.named_parameters(recurse=False):
if param is None:
continue
fn(param)
return self
torch.nn.Module.smart_apply = smart_apply
@ -3281,7 +3457,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
variant: Optional[str] = None,
token: Optional[Union[str, bool]] = None,
save_peft_format: bool = True,
save_original_format: bool = False,
**kwargs,
):
"""
@ -3330,10 +3505,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
For backward compatibility with PEFT library, in case adapter weights are attached to the model, all
keys of the state dict of adapters needs to be prepended with `base_model.model`. Advanced users can
disable this behaviours by setting `save_peft_format` to `False`.
save_original_format (`bool`, *optional*, defaults to `True`):
For backward compatibility with the previous versions of `transfomers` you can save the checkpoint with
its reverse mapping. The reverse mapping needs to exists even if the model was loaded from a None legacy
checkpoint.
kwargs (`dict[str, Any]`, *optional*):
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
"""
@ -3473,18 +3644,24 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
module_map[name + f".{key}"] = module
state_dict = model_to_save.state_dict()
if (
any(
allowed_name in class_name.__name__.lower()
for class_name in self.__class__.__mro__[:-1]
for allowed_name in VLMS
)
or save_original_format
if any(
allowed_name in class_name.__name__.lower()
for class_name in self.__class__.__mro__[:-1]
for allowed_name in VLMS
):
# MEGA BIG TODO HERE: self._conversion_ops needs to be used to save the final ckpt
# using what was loaded. Actually self._conversion_ops wont work because we need it
# even if the files are not legacy -> thus no conversion happened
state_dict = revert_weight_conversion(self, state_dict)
reverse_key_mapping = {v: k for k, v in self._checkpoint_conversion_mapping.items()}
original_state_dict = {}
for key, value in state_dict.items():
for pattern, replacement in reverse_key_mapping.items():
replacement = replacement.lstrip("^") # strip off un-needed chars and patterns
replacement = re.sub(r"\(.*\)", "", replacement)
key, n_replace = re.subn(pattern, replacement, key)
# Early exit of the loop
if n_replace > 0:
break
original_state_dict[key] = value
state_dict = original_state_dict
# Translate state_dict from smp to hf if saving with smp >= 1.10
if IS_SAGEMAKER_MP_POST_1_10:
@ -3652,8 +3829,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
if safe_serialization:
# At some point we will need to deal better with save_function (used for TPU and other distributed
# joyfulness), but for now this enough. # TODO: we should def parallelize this we are otherwise just waiting
# too much before scheduling the next write when its in a different file
# joyfulness), but for now this enough.
safe_save_file(shard, os.path.join(save_directory, shard_file), metadata=metadata)
else:
save_function(shard, os.path.join(save_directory, shard_file))
@ -4219,13 +4395,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
config, quantization_config, dtype, device_map, weights_only, user_agent
)
weight_conversions: Optional[list[WeightConverter]] = None
model_type = getattr(config, "model_type", None)
if model_type is not None:
weight_conversions = get_checkpoint_conversion_mapping().get(model_type)
if weight_conversions is None:
weight_conversions = get_checkpoint_conversion_mapping()["legacy"]
if gguf_file:
if hf_quantizer is not None:
raise ValueError(
@ -4281,6 +4450,11 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
# Let's make sure we don't run the init function of buffer modules
model = cls(config, *model_args, **model_kwargs)
# Potentially upcast some modules to avoid loosing precision
model.upcast_modules_in_fp32(hf_quantizer, dtype)
# Make sure to tie the weights correctly
model.tie_weights()
# make sure we use the model's config since the __init__ call might have copied it
config = model.config
@ -4288,7 +4462,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
hf_quantizer.preprocess_model(
model=model,
device_map=device_map,
keep_in_fp32_modules=model._keep_in_fp32_modules, # TODO prob no longer needed?
keep_in_fp32_modules=model._keep_in_fp32_modules,
config=config,
checkpoint_files=checkpoint_files,
use_kernels=use_kernels,
@ -4320,15 +4494,15 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
device_mesh=device_mesh,
key_mapping=key_mapping,
weights_only=weights_only,
weight_mapping=weight_conversions,
)
model.tie_weights() # make sure token embedding weights are still tied if needed
model.eval() # Set model in evaluation mode to deactivate DropOut modules by default
model.set_use_kernels(use_kernels, kernel_config)
# If it is a model with generation capabilities, attempt to load generation files (generation config,
# custom generate function)
if model.can_generate() and hasattr(model, "adjust_generation_fn") and trust_remote_code:
if model.can_generate() and hasattr(model, "adjust_generation_fn"):
model.adjust_generation_fn(
generation_config,
from_auto_class,
@ -4339,16 +4513,17 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
**kwargs,
)
# for device_map="auto" : dispatch model with hooks on all devices if necessary
# for device_map="auto" : dispatch model with hooks on all devices if necessary (not needed with a tp_plan, so we skip it as it slightly
# harm performances).
if device_map is not None and device_mesh is None:
accelerate_dispatch(model, hf_quantizer, device_map, offload_folder, offload_index, offload_buffers)
if hf_quantizer is not None:
model.hf_quantizer = hf_quantizer
hf_quantizer.postprocess_model(model, config=config) # usually a no-op but sometimes needed
hf_quantizer.postprocess_model(model, config=config) # usually a no-op
if _adapter_model_path is not None:
adapter_kwargs["key_mapping"] = weight_conversions # TODO: Dynamic weight loader for adapters
adapter_kwargs["key_mapping"] = key_mapping # TODO: Dynamic weight loader for adapters
model.load_adapter(
_adapter_model_path,
adapter_name=adapter_name,
@ -4366,6 +4541,107 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
return model, loading_info
return model
@staticmethod
def _fix_state_dict_key_on_load(key: str) -> tuple[str, bool]:
"""Replace legacy parameter names with their modern equivalents. E.g. beta -> bias, gamma -> weight."""
# Rename LayerNorm beta & gamma params for some early models ported from Tensorflow (e.g. Bert)
# This rename is logged.
if key.endswith("LayerNorm.beta"):
return key.replace("LayerNorm.beta", "LayerNorm.bias"), True
if key.endswith("LayerNorm.gamma"):
return key.replace("LayerNorm.gamma", "LayerNorm.weight"), True
# Rename weight norm parametrizations to match changes across torch versions.
# Impacts a number of speech/wav2vec models. e.g. Hubert, Wav2Vec2, and others.
# This rename is not logged.
if hasattr(nn.utils.parametrizations, "weight_norm"):
if key.endswith("weight_g"):
return key.replace("weight_g", "parametrizations.weight.original0"), True
if key.endswith("weight_v"):
return key.replace("weight_v", "parametrizations.weight.original1"), True
else:
if key.endswith("parametrizations.weight.original0"):
return key.replace("parametrizations.weight.original0", "weight_g"), True
if key.endswith("parametrizations.weight.original1"):
return key.replace("parametrizations.weight.original1", "weight_v"), True
return key, False
def _get_key_renaming_mapping(
self,
checkpoint_keys: list[str],
key_mapping: Optional[dict[str, str]] = None,
loading_base_model_from_task_state_dict: bool = False,
loading_task_model_from_base_state_dict: bool = False,
):
"""
Compute a mapping between the serialized keys on disk `checkpoint_keys`, and the keys that the model
that we are loading expects. This is the single entry point for key renaming that will be used during
loading.
Log if any parameters have been renamed.
"""
prefix = self.base_model_prefix
_prefix = f"{prefix}."
if loading_task_model_from_base_state_dict:
task_specific_expected_keys, base_model_keys = [], []
for key in self.state_dict():
if key.startswith(_prefix):
base_model_keys.append(key[len(_prefix) :])
else:
task_specific_expected_keys.append(key)
renamed_keys = {}
key_renaming_mapping = {}
for key in checkpoint_keys:
# Class specific rename
new_key, has_changed = self._fix_state_dict_key_on_load(key)
# Optionally map the key according to `key_mapping`
if key_mapping is not None:
for pattern, replacement in key_mapping.items():
new_key, n_replace = re.subn(pattern, replacement, new_key)
# Early exit of the loop
if n_replace > 0:
has_changed = True
break
# In this case, we need to add the prefix to the keys, to match them to the expected keys
if loading_task_model_from_base_state_dict:
# small sanity check: if we find a key that is only part of the task-specific keys, we raise
# (if it's also part of the base model, we do not raise and assume it comes from there)
if new_key in task_specific_expected_keys and new_key not in base_model_keys:
raise ValueError(
"The state dictionary of the model you are trying to load is corrupted. Are you sure it was "
"properly saved?"
)
new_key = ".".join([prefix, new_key])
# In this case we need to remove the prefix from the key to match them to the expected keys, and use
# only the keys starting with the prefix
elif loading_base_model_from_task_state_dict:
if not new_key.startswith(_prefix):
continue
new_key = new_key[len(_prefix) :]
key_renaming_mapping[key] = new_key
# track gamma/beta rename for logging
if has_changed:
if key.endswith("LayerNorm.gamma"):
renamed_keys["LayerNorm.gamma"] = (key, new_key)
elif key.endswith("LayerNorm.beta"):
renamed_keys["LayerNorm.beta"] = (key, new_key)
if renamed_keys:
warning_msg = f"A pretrained model of type `{self.__class__.__name__}` "
warning_msg += "contains parameters that have been renamed internally (a few are listed below but more are present in the model):\n"
for old_key, new_key in renamed_keys.values():
warning_msg += f"* `{old_key}` -> `{new_key}`\n"
warning_msg += "If you are using a model from the Hub, consider submitting a PR to adjust these weights and help future users."
logger.info_once(warning_msg)
return key_renaming_mapping
@staticmethod
def _fix_state_dict_key_on_save(key) -> tuple[str, bool]:
"""
@ -4397,16 +4673,97 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
device_mesh: Optional["torch.distributed.device_mesh.DeviceMesh"] = None,
key_mapping: Optional[dict[str, str]] = None,
weights_only: bool = True,
weight_mapping: Optional[Sequence[WeightConverter]] = None,
):
# TODO: we should only be calling hf_quantizer.skip_placement or something like that
is_quantized = hf_quantizer is not None
is_hqq_or_quark = is_quantized and hf_quantizer.quantization_config.quant_method in {
QuantizationMethod.HQQ,
QuantizationMethod.QUARK,
}
# Model's definition arriving here is final (TP hooks added, quantized layers replaces)
# Get all the keys of the state dicts that we have to initialize the model with
if sharded_metadata is not None:
original_checkpoint_keys = sharded_metadata["all_checkpoint_keys"]
elif state_dict is not None:
original_checkpoint_keys = list(state_dict.keys())
else:
original_checkpoint_keys = list(
load_state_dict(checkpoint_files[0], map_location="meta", weights_only=weights_only).keys()
)
# Check if we are in a special state, i.e. loading from a state dict coming from a different architecture
prefix = model.base_model_prefix
has_prefix_module = any(s.startswith(prefix) for s in original_checkpoint_keys) if len(prefix) > 0 else False
expects_prefix_module = hasattr(model, prefix) if len(prefix) > 0 else False
loading_task_model_from_base_state_dict = not has_prefix_module and expects_prefix_module
loading_base_model_from_task_state_dict = has_prefix_module and not expects_prefix_module
# Find the key names that the model expects from the serialized keys
key_renaming_mapping = model._get_key_renaming_mapping(
original_checkpoint_keys,
key_mapping,
loading_base_model_from_task_state_dict,
loading_task_model_from_base_state_dict,
)
checkpoint_keys = list(key_renaming_mapping.values())
# Find missing and unexpected keys from the state dict
missing_keys, unexpected_keys = _find_missing_and_unexpected_keys(
model, original_checkpoint_keys, checkpoint_keys, loading_base_model_from_task_state_dict, hf_quantizer
)
# Find all the keys with shape mismatch (if we ignore the mismatch, the weights need to be newly initialized the
# same way as missing keys)
mismatched_keys, mismatched_shapes = _find_mismatched_keys(
model,
state_dict,
checkpoint_files,
ignore_mismatched_sizes,
key_renaming_mapping,
is_quantized,
weights_only,
)
# We need to update both the mapping and the list of checkpoint keys to remove the mismatched and unexpected ones
key_renaming_mapping = {
k: v for k, v in key_renaming_mapping.items() if v not in mismatched_keys and v not in unexpected_keys
}
checkpoint_keys = list(key_renaming_mapping.values())
# Move missing (and potentially mismatched) keys back to cpu from meta device (because they won't be moved when
# loading the weights as they are not in the loaded state dict)
model._move_missing_keys_from_meta_to_cpu(missing_keys + mismatched_keys, dtype, hf_quantizer)
# correctly initialize the missing (and potentially mismatched) keys
model._initialize_missing_keys(missing_keys + mismatched_keys, is_quantized)
# Get reverse key mapping
reverse_key_renaming_mapping = {v: k for k, v in key_renaming_mapping.items()}
is_offloaded_safetensors = False
# This offload index if for params explicitly on the "disk" in the device_map
disk_offload_index = None
disk_only_shard_files = []
# Prepare parameters offloading if needed
if device_map is not None and "disk" in device_map.values():
disk_offload_index, disk_only_shard_files, is_offloaded_safetensors = accelerate_disk_offload(
disk_offload_folder,
checkpoint_files,
device_map,
checkpoint_keys,
key_renaming_mapping,
sharded_metadata,
dtype,
reverse_key_renaming_mapping,
)
# To be able to iterate, even if we don't use it if the state_dict is already provided
elif state_dict is not None:
checkpoint_files = [""]
# Compute expected model keys
expected_keys = list(model.state_dict().keys())
if hf_quantizer is not None:
expected_keys = hf_quantizer.update_expected_keys(model, expected_keys, checkpoint_keys)
if logger.level >= logging.WARNING:
verify_tp_plan(expected_keys, getattr(model, "_tp_plan", None))
@ -4415,84 +4772,46 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
expanded_device_map = expand_device_map(device_map, expected_keys)
caching_allocator_warmup(model, expanded_device_map, hf_quantizer)
# Now we read all the files to get a pointer on each physical weights
merged_state_dict = {}
all_pointer = set()
if device_map is None:
device_map = {"": "cpu"}
keys = sorted(device_map.keys(), key=len, reverse=True)
tp_plan = getattr(model, "_tp_plan", None)
error_msgs = []
misc = {}
if is_deepspeed_zero3_enabled() and not is_quantized:
error_msgs += _load_state_dict_into_zero3_model(model, state_dict)
else:
if checkpoint_files is not None:
pattern = re.compile(r"(" + "|".join(map(re.escape, keys)) + r")")
if sharded_metadata is None:
k_v_iterator = dict.fromkeys(
safe_open(checkpoint_files[0], framework="pt").keys(), "model.safetensors"
).items()
else:
k_v_iterator = sharded_metadata["weight_map"].items()
for k, v in k_v_iterator:
match = pattern.match(k)
if match and match.group(1) != "":
device = device_map[match.group(1)]
else:
device = device_map.get("", "cpu")
if isinstance(device, torch.device):
device = device.index # safetensors only
if device == "disk":
device = "cpu" # we read to cpu to then write to disk
file_pointer = safe_open(
os.path.join(checkpoint_files[0].rsplit("/", 1)[0], v), framework="pt", device=device
)
all_pointer.add(file_pointer)
merged_state_dict[k] = (v, file_pointer.get_slice(k)) # don't materialize yet
elif state_dict is not None:
merged_state_dict = {k: ("", v) for k, v in state_dict.items()}
else:
raise ValueError("Neither a state dict nor checkpoint files were found.")
start = time.perf_counter()
missing_keys, unexpected_keys, mismatched_keys, misc = convert_and_load_state_dict_in_model(
model,
merged_state_dict,
weight_mapping,
tp_plan,
hf_quantizer,
dtype,
# Prepare and compatabilize arguments for serial and parallel shard loading
args_list = [
(
shard_file,
state_dict,
disk_only_shard_files,
is_quantized,
device_map,
model.dtype_plan,
device_mesh=device_mesh,
hf_quantizer,
key_renaming_mapping,
weights_only,
model,
reverse_key_renaming_mapping,
disk_offload_folder,
disk_offload_index,
device_mesh,
)
end = time.perf_counter()
for shard_file in checkpoint_files
]
for k in all_pointer: # finally close all opened file pointeres
k.__exit__(None, None, None)
error_msgs = []
new_state_dict = model.state_dict()
if (
os.environ.get("HF_ENABLE_PARALLEL_LOADING", "").upper() in ENV_VARS_TRUE_VALUES
and not is_deepspeed_zero3_enabled()
):
_error_msgs, disk_offload_index = load_shard_files_with_threadpool(args_list)
error_msgs += _error_msgs
else:
if len(args_list) > 1:
args_list = logging.tqdm(args_list, desc="Loading checkpoint shards")
#!!!!!!!!!!!!!!!!!!!!!!! POST PROCESS!!!!!!!!!!!!!!!!!!
# Check if we are in a special state, i.e. loading from a state dict coming from a different architecture
prefix = model.base_model_prefix
has_prefix_module = any(s.startswith(prefix) for s in new_state_dict.keys()) if len(prefix) > 0 else False
expects_prefix_module = hasattr(model, prefix) if len(prefix) > 0 else False
loading_task_model_from_base_state_dict = not has_prefix_module and expects_prefix_module
for args in args_list:
_error_msgs, disk_offload_index = load_shard_file(args)
error_msgs += _error_msgs
# TODO last TODO here is to tie the weights once and only. If they are missing and False, and if true
# TODO TODO TODO
# Move missing (and potentially mismatched) keys back to cpu from meta device (because they won't be moved when
# loading the weights as they are not in the loaded state dict)
miss_and_mismatched = missing_keys | {k[0] for k in mismatched_keys}
model._move_missing_keys_from_meta_to_cpu(miss_and_mismatched, dtype, hf_quantizer)
# correctly initialize the missing (and potentially mismatched) keys
model._initialize_missing_keys(miss_and_mismatched, is_quantized)
# Save offloaded index if needed
if disk_offload_index is not None and len(disk_offload_index) > 0 and not is_offloaded_safetensors:
save_offload_index(disk_offload_index, disk_offload_folder)
disk_offload_index = None
# Post-processing for tensor parallelism
if device_mesh is not None:
@ -4500,7 +4819,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
tp_device = list(device_map.values())[0]
# This is needed for the RotaryEmbedding, which was not initialized on the correct device as it is
# not part of the state_dict (persistent=False)
for buffer in model.buffers(): # TODO to avaoid this buffer could be added to the ckpt
for buffer in model.buffers():
if buffer.device != tp_device:
buffer.data = buffer.to(tp_device)
@ -4527,24 +4846,52 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
device_mesh,
)
# Remove tied weights keys and etc
# Remove potential model-specific exceptions from the warnings
missing_keys, unexpected_keys = model._adjust_missing_and_unexpected_keys(
missing_keys, unexpected_keys, loading_task_model_from_base_state_dict, model
missing_keys, unexpected_keys, loading_task_model_from_base_state_dict
)
logger.warning(f"Loading the checkpoint files into the model took {end - start}")
log_state_dict_report(
model=model,
pretrained_model_name_or_path=pretrained_model_name_or_path,
logger=logger,
error_msgs=error_msgs,
unexpected_keys=unexpected_keys,
missing_keys=missing_keys,
mismatched_keys=mismatched_keys,
mismatched_shapes=mismatched_keys,
misc=misc,
ignore_mismatched_sizes=ignore_mismatched_sizes,
)
disk_offload_index = None
# TODO: separate this in another function: it's not core....
# All potential warnings/infos
if len(error_msgs) > 0:
error_msg = "\n\t".join(error_msgs)
if "size mismatch" in error_msg:
error_msg += (
"\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
)
raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
if len(unexpected_keys) > 0:
archs = [] if model.config.architectures is None else model.config.architectures
warner = logger.warning if model.__class__.__name__ in archs else logger.info
warner(
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
f" initializing {model.__class__.__name__}: {update_key_name(unexpected_keys)}\n- This IS expected if you are"
f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or"
" with another architecture (e.g. initializing a BertForSequenceClassification model from a"
" BertForPreTraining model).\n- This IS NOT expected if you are initializing"
f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly identical"
" (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
)
if len(missing_keys) > 0:
logger.warning(
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
f" {pretrained_model_name_or_path} and are newly initialized: {update_key_name(missing_keys)}\nYou should probably"
" TRAIN this model on a down-stream task to be able to use it for predictions and inference."
)
if len(mismatched_keys) > 0:
mismatched_warning = "\n".join(
[
f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
for key, (shape1, shape2) in zip(mismatched_keys, mismatched_shapes)
]
)
logger.warning(
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able"
" to use it for predictions and inference."
)
return model, missing_keys, unexpected_keys, mismatched_keys, disk_offload_index, error_msgs
def retrieve_modules_from_names(self, names, add_prefix=False, remove_prefix=False):
@ -4753,6 +5100,8 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
value = torch.empty_like(param, dtype=dtype, device="cpu")
if not is_quantized or not hf_quantizer.param_needs_quantization(self, key):
_load_parameter_into_model(self, key, value)
else:
hf_quantizer.create_quantized_param(self, value, key, "cpu")
def _initialize_missing_keys(self, missing_keys: list[str], is_quantized: bool) -> None:
"""Initialize the missing keys (keys that are part of the model parameters, but were NOT found in the loaded state dicts), according to
@ -4802,23 +5151,16 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
self.initialize_weights()
def _adjust_missing_and_unexpected_keys(
self, missing_keys: set[str], unexpected_keys: set[str], loading_task_model_from_base_state_dict: bool, model
) -> tuple[set[str], set[str]]:
self, missing_keys: list[str], unexpected_keys: list[str], loading_task_model_from_base_state_dict: bool
) -> tuple[list[str], list[str]]:
"""Adjust the `missing_keys` and `unexpected_keys` based on current model's exception rules, to avoid
raising unneeded warnings/errors.
"""
# Old checkpoints may have keys for rotary_emb.inv_freq forach layer, however we moved this buffer to the main model
# Old checkpoints may have keys for rotary_emb.inv_freq for each layer, however we moved this buffer to the main model
# (so the buffer name has changed). Remove them in such a case. This is another exception that was not added to
# `_keys_to_ignore_on_load_unexpected` as it touches many models -> we add it manually to the existing patterns
has_inv_freq_buffers = any(buffer.endswith("rotary_emb.inv_freq") for buffer, _ in self.named_buffers())
additional_unexpected_patterns = [r"rotary_emb\.inv_freq"] if has_inv_freq_buffers else []
tied_param_names = "|".join(model._tied_weights_keys or [])
if tied_param_names:
model.tie_weights()
if model.config.tie_word_embeddings:
for k in missing_keys.copy():
if re.match(tied_param_names, k):
missing_keys.discard(k)
missing_patterns = self._keys_to_ignore_on_load_missing or []
unexpected_patterns = (self._keys_to_ignore_on_load_unexpected or []) + additional_unexpected_patterns
@ -4830,17 +5172,17 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
# Clean-up missing keys
if ignore_missing_regex is not None:
missing_keys = {key for key in missing_keys if ignore_missing_regex.search(key) is None}
missing_keys = [key for key in missing_keys if ignore_missing_regex.search(key) is None]
# Clean-up unexpected keys
if ignore_unexpected_regex is not None:
unexpected_keys = {key for key in unexpected_keys if ignore_unexpected_regex.search(key) is None}
unexpected_keys = [key for key in unexpected_keys if ignore_unexpected_regex.search(key) is None]
# Note: only the unexpected keys should remove the added prefix here, to correctly display the original name
# in the warnings. For missing keys, we should show the prefix in the warning as it's part of the final model
if loading_task_model_from_base_state_dict:
_prefix = f"{self.base_model_prefix}."
unexpected_keys = {k.removeprefix(_prefix) for k in unexpected_keys}
unexpected_keys = [k.removeprefix(_prefix) for k in unexpected_keys]
return missing_keys, unexpected_keys
@ -4877,6 +5219,35 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
def eval(self):
return self.train(False)
def upcast_modules_in_fp32(self, hf_quantizer: HfQuantizer | None, dtype: torch.dtype) -> None:
"""
Upcast modules defined in `_keep_in_fp32_modules` and `_keep_in_fp32_modules_strict` in fp32, if
`dtype` is different than fp32.
"""
# If the dtype is already fp32, we can skip
if dtype == torch.float32:
return
keep_in_fp32_modules = []
# The _keep_in_fp32_modules flag is only used to avoid bf16 -> fp16 casting precision issues. It was introduced
# in case of force loading a model that should stay bf16 in fp16 (which includes a few quantizers as this is a pre-processing
# step for e.g. bitsandbytes). See https://github.com/huggingface/transformers/issues/20287 for details.
if self._keep_in_fp32_modules is not None and (
dtype == torch.float16 or getattr(hf_quantizer, "use_keep_in_fp32_modules", False)
):
keep_in_fp32_modules.extend(self._keep_in_fp32_modules)
if self._keep_in_fp32_modules_strict is not None and (dtype == torch.float16 or dtype == torch.bfloat16):
keep_in_fp32_modules.extend(self._keep_in_fp32_modules_strict)
if len(keep_in_fp32_modules) > 0:
# We need to match exact layers, so we add either `.` on each side, or start/end of string
keep_in_fp32_regex = re.compile("|".join([rf"((^|\.){module}($|\.))" for module in keep_in_fp32_modules]))
for name, param in self.named_parameters():
if keep_in_fp32_regex.search(name):
# param = param.to(torch.float32) does not work here as only in the local scope.
param.data = param.data.to(torch.float32)
PreTrainedModel.push_to_hub = copy_func(PreTrainedModel.push_to_hub)
if PreTrainedModel.push_to_hub.__doc__ is not None:

View File

@ -136,7 +136,7 @@ class ArceeConfig(PreTrainedConfig):
bos_token_id: Optional[int] = 128000,
eos_token_id: Optional[int] = 128001,
tie_word_embeddings: Optional[bool] = False,
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
attention_bias: Optional[bool] = False,
attention_dropout: Optional[float] = 0.0,
mlp_bias: Optional[bool] = False,

View File

@ -137,7 +137,7 @@ class ArceeConfig(LlamaConfig):
bos_token_id: Optional[int] = 128000,
eos_token_id: Optional[int] = 128001,
tie_word_embeddings: Optional[bool] = False,
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
attention_bias: Optional[bool] = False,
attention_dropout: Optional[float] = 0.0,
mlp_bias: Optional[bool] = False,

View File

@ -134,7 +134,7 @@ class AriaTextConfig(PreTrainedConfig):
eos_token_id: Optional[int] = 2,
pretraining_tp: Optional[int] = 1,
tie_word_embeddings: Optional[bool] = False,
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
attention_bias: Optional[bool] = False,
attention_dropout: Optional[float] = 0.0,
mlp_bias: Optional[bool] = False,

View File

@ -117,7 +117,7 @@ class BitNetConfig(PreTrainedConfig):
tie_word_embeddings: Optional[bool] = False,
attention_bias: Optional[bool] = False,
attention_dropout: Optional[str] = 0.0,
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
**kwargs,
):
self.vocab_size = vocab_size

View File

@ -44,7 +44,7 @@ class BltLocalEncoderConfig(PreTrainedConfig):
rms_norm_eps: Optional[float] = 1e-5,
dropout: Optional[float] = 0.0,
max_position_embeddings: Optional[int] = 24576,
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
hidden_act: Optional[str] = "silu",
intermediate_size: Optional[int] = 2816,
initializer_range: Optional[float] = 0.02,
@ -99,7 +99,7 @@ class BltLocalDecoderConfig(PreTrainedConfig):
rms_norm_eps: Optional[float] = 1e-5,
dropout: Optional[float] = 0.0,
max_position_embeddings: Optional[int] = 24576,
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
hidden_act: Optional[str] = "silu",
intermediate_size: Optional[int] = 2816,
initializer_range: Optional[float] = 0.02,
@ -150,7 +150,7 @@ class BltGlobalTransformerConfig(PreTrainedConfig):
rms_norm_eps: Optional[float] = 1e-5,
dropout: Optional[float] = 0.0,
max_position_embeddings: Optional[int] = 4096,
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
hidden_act: Optional[str] = "silu",
intermediate_size: Optional[int] = 5632,
initializer_range: Optional[float] = 0.02,
@ -231,7 +231,7 @@ class BltPatcherConfig(PreTrainedConfig):
rms_norm_eps: Optional[float] = 1e-5,
dropout: Optional[float] = 0.0,
intermediate_size: Optional[int] = 2048,
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
initializer_range: Optional[float] = 0.02,
**kwargs,
):
@ -356,7 +356,7 @@ class BltConfig(PreTrainedConfig):
global_config: Optional[dict] = None,
tie_word_embeddings: Optional[bool] = False,
initializer_range: Optional[float] = 0.02,
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
**kwargs,
):
# Basic model configuration

View File

@ -203,7 +203,7 @@ class ChameleonConfig(PreTrainedConfig):
bos_token_id: Optional[int] = 1,
eos_token_id: Optional[int] = 2,
tie_word_embeddings: Optional[bool] = False,
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
attention_bias: Optional[int] = False,
attention_dropout: Optional[float] = 0.0,
model_parallel_size: Optional[int] = 1,

View File

@ -139,7 +139,7 @@ class CohereConfig(PreTrainedConfig):
bos_token_id: Optional[int] = 5,
eos_token_id: Optional[int] = 255001,
tie_word_embeddings: Optional[bool] = True,
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
attention_bias: Optional[bool] = False,
attention_dropout: Optional[float] = 0.0,
use_qk_norm: Optional[bool] = False,

View File

@ -138,7 +138,7 @@ class Cohere2Config(PreTrainedConfig):
bos_token_id: Optional[int] = 5,
eos_token_id: Optional[int] = 255001,
tie_word_embeddings: Optional[bool] = True,
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
attention_bias: Optional[bool] = False,
attention_dropout: Optional[float] = 0.0,
sliding_window: Optional[int] = 4096,

View File

@ -162,7 +162,7 @@ class Cohere2Config(PreTrainedConfig):
bos_token_id: Optional[int] = 5,
eos_token_id: Optional[int] = 255001,
tie_word_embeddings: Optional[bool] = True,
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
attention_bias: Optional[bool] = False,
attention_dropout: Optional[float] = 0.0,
sliding_window: Optional[int] = 4096,

View File

@ -39,9 +39,10 @@ from typing import Any, Optional
import torch
from huggingface_hub import snapshot_download
from peft import PeftModel
from safetensors import safe_open
from transformers import AutoConfig
from transformers import AutoConfig, AutoModel
from transformers.models.colqwen2 import ColQwen2ForRetrieval
from transformers.models.colqwen2.configuration_colqwen2 import ColQwen2Config
from transformers.utils import logging
@ -69,7 +70,7 @@ def load_original_state_dict(model_id: str, revision: Optional[str] = None) -> d
original_state_dict[key] = f.get_tensor(key)
# Some weights are tied, so `lm.head`` is not saved. Let's clone to load state dict.
if "lm_head.weight" not in original_state_dict:
if "lm_head.weight" not in original_state_dict and "model.embed_tokens.weight" in original_state_dict:
original_state_dict["lm_head.weight"] = original_state_dict["model.embed_tokens.weight"].clone()
return original_state_dict
@ -124,7 +125,21 @@ def convert_colqwen2_weights_to_hf(
config.is_composition = False
# Load the untrained model
model = ColQwen2ForRetrieval(config=config).to("cpu").eval()
vlm_name_or_path = getattr(config.vlm_config, "_name_or_path", None)
if vlm_name_or_path and "2.5" in str(vlm_name_or_path):
print(
"Detected colqwen2.5 adapters in vlm_config; loading base model %s and merging PEFT weights."
% vlm_name_or_path
)
base_model = AutoModel.from_pretrained(
vlm_name_or_path,
device_map="cpu",
trust_remote_code=True,
)
peft_model = PeftModel.from_pretrained(base_model, model_id)
model = peft_model.merge_and_unload()
else:
model = ColQwen2ForRetrieval(config=config).to("cpu").eval()
print("Created model with new config and randomly initialized weights")
# NOTE: The new model was initialized with float32 weights. We need to convert it to the desired precision.
@ -201,6 +216,7 @@ if __name__ == "__main__":
help="Name or path of the original VLM backbone model",
default=None,
)
args = parser.parse_args()
convert_colqwen2_weights_to_hf(

View File

@ -172,7 +172,6 @@ class ColQwen2ForRetrieval(ColQwen2PreTrainedModel):
inputs_embeds = self.vlm.language_model.embed_tokens(input_ids)
if pixel_values is not None:
pixel_values = pixel_values.type(self.vlm.visual.get_dtype())
image_embeds = self.vlm.visual(pixel_values, grid_thw=image_grid_thw)
image_mask = (
(input_ids == self.config.vlm_config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds)

View File

@ -359,7 +359,6 @@ class ColQwen2ForRetrieval(ColPaliForRetrieval):
inputs_embeds = self.vlm.language_model.embed_tokens(input_ids)
if pixel_values is not None:
pixel_values = pixel_values.type(self.vlm.visual.get_dtype())
image_embeds = self.vlm.visual(pixel_values, grid_thw=image_grid_thw)
image_mask = (
(input_ids == self.config.vlm_config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds)

View File

@ -122,7 +122,7 @@ class CsmDepthDecoderConfig(PreTrainedConfig):
pad_token_id: Optional[int] = None,
bos_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
attention_bias: Optional[bool] = False,
attention_dropout: Optional[float] = 0.0,
mlp_bias: Optional[bool] = False,
@ -291,7 +291,7 @@ class CsmConfig(PreTrainedConfig):
eos_token_id: Optional[int] = None,
audio_token_id: Optional[int] = 128002,
audio_eos_token_id: Optional[int] = 128003,
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
attention_bias: Optional[bool] = False,
attention_dropout: Optional[float] = 0.0,
mlp_bias: Optional[bool] = False,

View File

@ -189,7 +189,7 @@ class DbrxConfig(PreTrainedConfig):
use_cache: Optional[bool] = True,
initializer_range: Optional[float] = 0.02,
output_router_logits: Optional[bool] = False,
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
**kwargs: Any,
):
if attn_config is None:

View File

@ -394,11 +394,10 @@ class DecisionTransformerGPT2PreTrainedModel(PreTrainedModel):
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
#
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
if isinstance(module, PreTrainedModel):
for name, p in module.named_parameters():
if "c_proj" in name and "weight" in name:
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer)))
for name, p in module.named_parameters():
if "c_proj" in name and "weight" in name:
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer)))
class DecisionTransformerGPT2Model(DecisionTransformerGPT2PreTrainedModel):

View File

@ -127,7 +127,8 @@ class DeepseekV2Config(PreTrainedConfig):
"layers.*.self_attn.q_b_proj": "colwise",
"layers.*.self_attn.kv_b_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.gate_up_proj": "colwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}
base_model_pp_plan = {
@ -153,7 +154,7 @@ class DeepseekV2Config(PreTrainedConfig):
bos_token_id: Optional[int] = 1,
eos_token_id: Optional[int] = 2,
tie_word_embeddings: Optional[bool] = False,
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
attention_bias: Optional[bool] = False,
attention_dropout: Optional[float] = 0.0,
mlp_bias: Optional[bool] = False,

View File

@ -42,43 +42,37 @@ from ...utils.generic import check_model_inputs
from .configuration_deepseek_v2 import DeepseekV2Config
class DeepseekV2Experts(nn.Module):
"""Collection of expert weights stored as 3D tensors."""
class DeepseekV2Experts(nn.ModuleList):
"""
ModuleList of experts.
"""
def __init__(self, config):
super().__init__()
self.num_experts = config.n_routed_experts
self.hidden_dim = config.hidden_size
self.intermediate_dim = config.moe_intermediate_size
self.gate_up_proj = nn.Parameter(torch.zeros(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
self.down_proj = nn.Parameter(torch.zeros(self.num_experts, self.hidden_dim, self.intermediate_dim))
self.act_fn = ACT2FN[config.hidden_act]
for _ in range(config.n_routed_experts):
self.append(DeepseekV2MLP(config, intermediate_size=config.moe_intermediate_size))
def forward(
self,
hidden_states: torch.Tensor,
top_k_index: torch.Tensor,
top_k_weights: torch.Tensor,
self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor
) -> torch.Tensor:
"""
Args:
hidden_states: (batch_size * sequence_length, hidden_dim)
selected_experts: (batch_size * sequence_length, top_k)
routing_weights: (batch_size * sequence_length, top_k)
Returns:
(batch_size * sequence_length, hidden_dim)
"""
final_hidden_states = torch.zeros_like(hidden_states)
num_experts = top_k_weights.shape[1]
with torch.no_grad():
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1)
expert_mask = expert_mask.permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
for expert_idx in expert_hit:
expert_idx = expert_idx[0]
if expert_idx == num_experts:
continue
_, token_idx = torch.where(expert_mask[expert_idx])
current_state = hidden_states[token_idx]
gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
current_hidden_states = self.act_fn(gate) * up
current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
current_hidden_states = current_hidden_states * top_k_weights[token_idx, expert_idx, None]
final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1])
current_hidden_states = self[expert_idx](current_state) * top_k_weights[top_x, idx, None]
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
return final_hidden_states
@ -117,7 +111,6 @@ class DeepseekV2Moe(nn.Module):
topk_weight, topk_idx = torch.topk(tmp_scores, k=self.top_k, dim=-1, sorted=False)
topk_weight = topk_weight * self.routed_scaling_factor
topk_weight = torch.zeros_like(router_logits).scatter_(1, topk_idx, topk_weight)
return topk_idx, topk_weight
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:

View File

@ -142,7 +142,8 @@ class DeepseekV2Config(LlamaConfig):
"layers.*.self_attn.q_b_proj": "colwise",
"layers.*.self_attn.kv_b_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.gate_up_proj": "colwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}
@ -166,7 +167,7 @@ class DeepseekV2Config(LlamaConfig):
bos_token_id: Optional[int] = 1,
eos_token_id: Optional[int] = 2,
tie_word_embeddings: Optional[bool] = False,
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
attention_bias: Optional[bool] = False,
attention_dropout: Optional[float] = 0.0,
mlp_bias: Optional[bool] = False,
@ -223,10 +224,12 @@ def apply_rotary_emb(
return xq_out, xk_out
class DeepseekV2Experts(Qwen2MoeExperts):
class DeepseekV2Experts(Qwen2MoeExperts, nn.ModuleList):
def __init__(self, config):
super().__init__(config)
nn.ModuleList.__init__(self)
self.num_experts = config.n_routed_experts
for _ in range(config.n_routed_experts):
self.append(DeepseekV2MLP(config, intermediate_size=config.moe_intermediate_size))
class DeepseekV2Moe(nn.Module):
@ -264,7 +267,6 @@ class DeepseekV2Moe(nn.Module):
topk_weight, topk_idx = torch.topk(tmp_scores, k=self.top_k, dim=-1, sorted=False)
topk_weight = topk_weight * self.routed_scaling_factor
topk_weight = torch.zeros_like(router_logits).scatter_(1, topk_idx, topk_weight)
return topk_idx, topk_weight
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:

View File

@ -186,7 +186,7 @@ class DeepseekV3Config(PreTrainedConfig):
eos_token_id: Optional[int] = 1,
pretraining_tp: Optional[int] = 1,
tie_word_embeddings: Optional[bool] = False,
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
rope_interleave: Optional[bool] = True,
attention_bias: Optional[bool] = False,
attention_dropout: Optional[float] = 0.0,

View File

@ -149,43 +149,37 @@ class DeepseekV3TopkRouter(nn.Module):
return router_logits
class DeepseekV3NaiveMoe(nn.Module):
"""Collection of expert weights stored as 3D tensors."""
class DeepseekV3NaiveMoe(nn.ModuleList):
"""
ModuleList of experts.
"""
def __init__(self, config):
super().__init__()
self.num_experts = config.num_local_experts
self.hidden_dim = config.hidden_size
self.intermediate_dim = config.intermediate_size
self.gate_up_proj = nn.Parameter(torch.zeros(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
self.down_proj = nn.Parameter(torch.zeros(self.num_experts, self.hidden_dim, self.intermediate_dim))
self.act_fn = ACT2FN[config.hidden_act]
for _ in range(self.num_experts):
self.append(DeepseekV3MLP(config, intermediate_size=config.moe_intermediate_size))
def forward(
self,
hidden_states: torch.Tensor,
top_k_index: torch.Tensor,
top_k_weights: torch.Tensor,
self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor
) -> torch.Tensor:
"""
Args:
hidden_states: (batch_size * sequence_length, hidden_dim)
selected_experts: (batch_size * sequence_length, top_k)
routing_weights: (batch_size * sequence_length, top_k)
Returns:
(batch_size * sequence_length, hidden_dim)
"""
final_hidden_states = torch.zeros_like(hidden_states)
num_experts = top_k_weights.shape[1]
with torch.no_grad():
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1)
expert_mask = expert_mask.permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
for expert_idx in expert_hit:
expert_idx = expert_idx[0]
if expert_idx == num_experts:
continue
_, token_idx = torch.where(expert_mask[expert_idx])
current_state = hidden_states[token_idx]
gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
current_hidden_states = self.act_fn(gate) * up
current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
current_hidden_states = current_hidden_states * top_k_weights[token_idx, expert_idx, None]
final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1])
current_hidden_states = self[expert_idx](current_state) * top_k_weights[top_x, idx, None]
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
return final_hidden_states

View File

@ -102,10 +102,12 @@ class DeepseekV3TopkRouter(nn.Module):
return router_logits
class DeepseekV3NaiveMoe(MixtralExperts):
class DeepseekV3NaiveMoe(MixtralExperts, nn.ModuleList):
def __init__(self, config):
super().__init__(config)
nn.ModuleList.__init__(self)
self.num_experts = config.num_local_experts
for _ in range(self.num_experts):
self.append(DeepseekV3MLP(config, intermediate_size=config.moe_intermediate_size))
class DeepseekV3MoE(nn.Module):

View File

@ -118,7 +118,7 @@ class DiffLlamaConfig(PreTrainedConfig):
bos_token_id: Optional[int] = 1,
eos_token_id: Optional[int] = 2,
tie_word_embeddings: Optional[bool] = False,
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
attention_bias: Optional[bool] = False,
attention_dropout: Optional[float] = 0.0,
lambda_std_dev: Optional[float] = 0.1,

View File

@ -148,7 +148,7 @@ class DogeConfig(PreTrainedConfig):
use_cache: Optional[bool] = True,
tie_word_embeddings: Optional[bool] = False,
max_position_embeddings: Optional[int] = 2048,
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
num_attention_heads: Optional[int] = 8,
num_key_value_heads: Optional[int] = None,
attention_bias: Optional[bool] = False,

View File

@ -176,7 +176,7 @@ class DogeConfig(PreTrainedConfig):
use_cache: Optional[bool] = True,
tie_word_embeddings: Optional[bool] = False,
max_position_embeddings: Optional[int] = 2048,
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
num_attention_heads: Optional[int] = 8,
num_key_value_heads: Optional[int] = None,
attention_bias: Optional[bool] = False,

View File

@ -159,7 +159,7 @@ class Dots1Config(PreTrainedConfig):
rms_norm_eps: Optional[int] = 1e-6,
use_cache: Optional[bool] = True,
tie_word_embeddings: Optional[bool] = False,
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
attention_bias: Optional[bool] = False,
attention_dropout: Optional[float] = 0.0,
routed_scaling_factor: Optional[float] = 1.0,

View File

@ -305,43 +305,37 @@ class Dots1TopkRouter(nn.Module):
return router_logits
class Dots1NaiveMoe(nn.Module):
"""Collection of expert weights stored as 3D tensors."""
class Dots1NaiveMoe(nn.ModuleList):
"""
ModuleList of experts.
"""
def __init__(self, config):
super().__init__()
self.num_experts = config.num_local_experts
self.hidden_dim = config.hidden_size
self.intermediate_dim = config.intermediate_size
self.gate_up_proj = nn.Parameter(torch.zeros(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
self.down_proj = nn.Parameter(torch.zeros(self.num_experts, self.hidden_dim, self.intermediate_dim))
self.act_fn = ACT2FN[config.hidden_act]
for _ in range(self.num_experts):
self.append(Dots1MLP(config, intermediate_size=config.moe_intermediate_size))
def forward(
self,
hidden_states: torch.Tensor,
top_k_index: torch.Tensor,
top_k_weights: torch.Tensor,
self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor
) -> torch.Tensor:
"""
Args:
hidden_states: (batch_size * sequence_length, hidden_dim)
selected_experts: (batch_size * sequence_length, top_k)
routing_weights: (batch_size * sequence_length, top_k)
Returns:
(batch_size * sequence_length, hidden_dim)
"""
final_hidden_states = torch.zeros_like(hidden_states)
num_experts = top_k_weights.shape[1]
with torch.no_grad():
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1)
expert_mask = expert_mask.permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
for expert_idx in expert_hit:
expert_idx = expert_idx[0]
if expert_idx == num_experts:
continue
_, token_idx = torch.where(expert_mask[expert_idx])
current_state = hidden_states[token_idx]
gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
current_hidden_states = self.act_fn(gate) * up
current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
current_hidden_states = current_hidden_states * top_k_weights[token_idx, expert_idx, None]
final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1])
current_hidden_states = self[expert_idx](current_state) * top_k_weights[top_x, idx, None]
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
return final_hidden_states

View File

@ -125,7 +125,7 @@ class Ernie4_5Config(PreTrainedConfig):
bos_token_id: Optional[int] = 1,
eos_token_id: Optional[int] = 2,
tie_word_embeddings: Optional[bool] = True,
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
use_bias: Optional[bool] = False,
head_dim: Optional[int] = 128,
**kwargs,

View File

@ -161,7 +161,7 @@ class Ernie4_5_MoeConfig(PreTrainedConfig):
rms_norm_eps: Optional[int] = 1e-5,
use_cache: Optional[bool] = True,
tie_word_embeddings: Optional[bool] = True,
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
use_bias: Optional[int] = False,
moe_intermediate_size: Optional[int] = 1536,
moe_k: Optional[int] = 6,

View File

@ -315,95 +315,62 @@ class Ernie4_5_MoeStatics(nn.Module):
return hidden_states + self.e_score_correction_bias.squeeze()
class Ernie4_5_MoeExperts(nn.Module):
class Ernie4_5_MoeExperts(nn.ModuleList):
def __init__(self, config):
super().__init__()
self.num_experts = config.moe_num_experts
self.hidden_dim = config.hidden_size
self.intermediate_dim = config.moe_intermediate_size
self.use_bias = config.use_bias
self.act_fn = ACT2FN[config.hidden_act]
self.gate_up_proj = nn.Parameter(torch.zeros(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
self.down_proj = nn.Parameter(torch.zeros(self.num_experts, self.hidden_dim, self.intermediate_dim))
if self.use_bias:
self.gate_up_proj_bias = nn.Parameter(torch.zeros(self.num_experts, 2 * self.intermediate_dim))
self.down_proj_bias = nn.Parameter(torch.zeros(self.num_experts, self.hidden_dim))
else:
self.gate_up_proj_bias = None
self.down_proj_bias = None
for _ in range(self.num_experts):
self.append(Ernie4_5_MoeMLP(config, config.moe_intermediate_size))
def forward(
self, hidden_states: torch.Tensor, selected_experts: torch.Tensor, routing_weights: torch.Tensor
) -> torch.Tensor:
final_hidden_states = torch.zeros_like(hidden_states)
if selected_experts.numel() == 0:
return final_hidden_states
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
for expert_idx in expert_hit:
expert_idx = int(expert_idx.item())
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1])
gate_inputs = F.linear(
current_state,
self.gate_up_proj[expert_idx],
None if self.gate_up_proj_bias is None else self.gate_up_proj_bias[expert_idx],
)
gate, up = gate_inputs.chunk(2, dim=-1)
current_hidden_states = self.act_fn(gate) * up
current_hidden_states = F.linear(
current_hidden_states,
self.down_proj[expert_idx],
None if self.down_proj_bias is None else self.down_proj_bias[expert_idx],
)
current_hidden_states = current_hidden_states * routing_weights[top_x, idx, None]
current_hidden_states = self[expert_idx](current_state) * routing_weights[top_x, idx, None]
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
return final_hidden_states
class Ernie4_5_MoeTopKRouter(nn.Module):
def __init__(self, config):
super().__init__()
self.weight = nn.Parameter(torch.zeros(config.moe_num_experts, config.hidden_size, dtype=torch.float32))
self.moe_statics = Ernie4_5_MoeStatics(config)
self.top_k = config.moe_k
self.norm_min = config.moe_norm_min
def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
device_type = (
hidden_states.device.type
if isinstance(hidden_states.device.type, str) and hidden_states.device.type != "mps"
else "cpu"
)
with torch.autocast(device_type=device_type, enabled=False): # Force float32
router_logits = F.linear(hidden_states.float(), self.weight)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
_, selected_experts = torch.topk(self.moe_statics(routing_weights), self.top_k, dim=-1)
routing_weights = torch.gather(routing_weights, dim=-1, index=selected_experts)
routing_weights = routing_weights / torch.clamp(
routing_weights.sum(dim=-1, keepdim=True), min=self.norm_min
)
routing_weights = routing_weights.to(router_logits.dtype)
return routing_weights, selected_experts
class Ernie4_5_MoeSparseMoeBlock(nn.Module):
def __init__(self, config):
super().__init__()
self.hidden_dim = config.hidden_size
self.num_experts = config.moe_num_experts
self.top_k = config.moe_k
self.gate = Ernie4_5_MoeTopKRouter(config)
self.norm_min = config.moe_norm_min
self.gate = nn.Linear(config.hidden_size, config.moe_num_experts, bias=False, dtype=torch.float32)
self.moe_statics = Ernie4_5_MoeStatics(config)
self.experts = Ernie4_5_MoeExperts(config)
self.shared_experts = None
if config.moe_num_shared_experts > 0:
self.shared_experts = Ernie4_5_MoeMLP(config, config.moe_intermediate_size * config.moe_num_shared_experts)
def route_tokens_to_experts(self, hidden_states):
device_type = (
hidden_states.device.type
if isinstance(hidden_states.device.type, str) and hidden_states.device.type != "mps"
else "cpu"
)
with torch.autocast(device_type=device_type, enabled=False): # Force float32
router_logits = self.gate(hidden_states.float())
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
_, selected_experts = torch.topk(self.moe_statics(routing_weights), self.top_k, dim=-1)
routing_weights = torch.gather(routing_weights, dim=-1, index=selected_experts)
routing_weights = routing_weights / torch.clamp(
routing_weights.sum(dim=-1, keepdim=True), min=self.norm_min
)
routing_weights = routing_weights.to(router_logits.dtype)
return selected_experts, routing_weights
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, sequence_length, _ = hidden_states.shape
hidden_states = hidden_states.view(-1, self.hidden_dim)
@ -411,14 +378,14 @@ class Ernie4_5_MoeSparseMoeBlock(nn.Module):
if self.shared_experts is not None:
shared_output = self.shared_experts(hidden_states)
routing_weights, selected_experts = self.gate(hidden_states)
selected_experts, routing_weights = self.route_tokens_to_experts(hidden_states)
final_hidden_states = self.experts(hidden_states, selected_experts, routing_weights)
if self.shared_experts is not None:
final_hidden_states = final_hidden_states + shared_output
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, self.hidden_dim)
return final_hidden_states.to(hidden_states.dtype)
return final_hidden_states
class Ernie4_5_MoeDecoderLayer(GradientCheckpointingLayer):
@ -487,11 +454,11 @@ class Ernie4_5_MoePreTrainedModel(PreTrainedModel):
_can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
_supports_attention_backend = True
_can_record_outputs = {
"router_logits": OutputRecorder(Ernie4_5_MoeTopKRouter, layer_name="mlp.gate", index=0),
"router_logits": OutputRecorder(nn.Linear, layer_name="mlp.gate", index=0),
"hidden_states": Ernie4_5_MoeDecoderLayer,
"attentions": Ernie4_5_MoeAttention,
}
_keep_in_fp32_modules_strict = ["gate.weight", "moe_statics"]
_keep_in_fp32_modules_strict = ["gate", "moe_statics"]
# Not supporting multi-token prediction (MTP) atm
_keys_to_ignore_on_load_unexpected = ["mtp"]

View File

@ -19,7 +19,6 @@ import torch
import torch.nn.functional as F
from torch import nn
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...masking_utils import create_causal_mask
from ...modeling_outputs import MoeModelOutputWithPast
@ -97,95 +96,62 @@ class Ernie4_5_MoeStatics(nn.Module):
return hidden_states + self.e_score_correction_bias.squeeze()
class Ernie4_5_MoeExperts(nn.Module):
class Ernie4_5_MoeExperts(nn.ModuleList):
def __init__(self, config):
super().__init__()
self.num_experts = config.moe_num_experts
self.hidden_dim = config.hidden_size
self.intermediate_dim = config.moe_intermediate_size
self.use_bias = config.use_bias
self.act_fn = ACT2FN[config.hidden_act]
self.gate_up_proj = nn.Parameter(torch.zeros(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
self.down_proj = nn.Parameter(torch.zeros(self.num_experts, self.hidden_dim, self.intermediate_dim))
if self.use_bias:
self.gate_up_proj_bias = nn.Parameter(torch.zeros(self.num_experts, 2 * self.intermediate_dim))
self.down_proj_bias = nn.Parameter(torch.zeros(self.num_experts, self.hidden_dim))
else:
self.gate_up_proj_bias = None
self.down_proj_bias = None
for _ in range(self.num_experts):
self.append(Ernie4_5_MoeMLP(config, config.moe_intermediate_size))
def forward(
self, hidden_states: torch.Tensor, selected_experts: torch.Tensor, routing_weights: torch.Tensor
) -> torch.Tensor:
final_hidden_states = torch.zeros_like(hidden_states)
if selected_experts.numel() == 0:
return final_hidden_states
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
for expert_idx in expert_hit:
expert_idx = int(expert_idx.item())
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1])
gate_inputs = F.linear(
current_state,
self.gate_up_proj[expert_idx],
None if self.gate_up_proj_bias is None else self.gate_up_proj_bias[expert_idx],
)
gate, up = gate_inputs.chunk(2, dim=-1)
current_hidden_states = self.act_fn(gate) * up
current_hidden_states = F.linear(
current_hidden_states,
self.down_proj[expert_idx],
None if self.down_proj_bias is None else self.down_proj_bias[expert_idx],
)
current_hidden_states = current_hidden_states * routing_weights[top_x, idx, None]
current_hidden_states = self[expert_idx](current_state) * routing_weights[top_x, idx, None]
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
return final_hidden_states
class Ernie4_5_MoeTopKRouter(nn.Module):
def __init__(self, config):
super().__init__()
self.weight = nn.Parameter(torch.zeros(config.hidden_size, config.moe_num_experts, dtype=torch.float32))
self.moe_statics = Ernie4_5_MoeStatics(config)
self.top_k = config.moe_k
self.norm_min = config.moe_norm_min
def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
device_type = (
hidden_states.device.type
if isinstance(hidden_states.device.type, str) and hidden_states.device.type != "mps"
else "cpu"
)
with torch.autocast(device_type=device_type, enabled=False): # Force float32
router_logits = F.linear(hidden_states.float(), self.weight)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
_, selected_experts = torch.topk(self.moe_statics(routing_weights), self.top_k, dim=-1)
routing_weights = torch.gather(routing_weights, dim=-1, index=selected_experts)
routing_weights = routing_weights / torch.clamp(
routing_weights.sum(dim=-1, keepdim=True), min=self.norm_min
)
routing_weights = routing_weights.to(router_logits.dtype)
return routing_weights, selected_experts
class Ernie4_5_MoeSparseMoeBlock(nn.Module):
def __init__(self, config):
super().__init__()
self.hidden_dim = config.hidden_size
self.num_experts = config.moe_num_experts
self.top_k = config.moe_k
self.gate = Ernie4_5_MoeTopKRouter(config)
self.norm_min = config.moe_norm_min
self.gate = nn.Linear(config.hidden_size, config.moe_num_experts, bias=False, dtype=torch.float32)
self.moe_statics = Ernie4_5_MoeStatics(config)
self.experts = Ernie4_5_MoeExperts(config)
self.shared_experts = None
if config.moe_num_shared_experts > 0:
self.shared_experts = Ernie4_5_MoeMLP(config, config.moe_intermediate_size * config.moe_num_shared_experts)
def route_tokens_to_experts(self, hidden_states):
device_type = (
hidden_states.device.type
if isinstance(hidden_states.device.type, str) and hidden_states.device.type != "mps"
else "cpu"
)
with torch.autocast(device_type=device_type, enabled=False): # Force float32
router_logits = self.gate(hidden_states.float())
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
_, selected_experts = torch.topk(self.moe_statics(routing_weights), self.top_k, dim=-1)
routing_weights = torch.gather(routing_weights, dim=-1, index=selected_experts)
routing_weights = routing_weights / torch.clamp(
routing_weights.sum(dim=-1, keepdim=True), min=self.norm_min
)
routing_weights = routing_weights.to(router_logits.dtype)
return selected_experts, routing_weights
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, sequence_length, _ = hidden_states.shape
hidden_states = hidden_states.view(-1, self.hidden_dim)
@ -193,7 +159,7 @@ class Ernie4_5_MoeSparseMoeBlock(nn.Module):
if self.shared_experts is not None:
shared_output = self.shared_experts(hidden_states)
routing_weights, selected_experts = self.gate(hidden_states)
selected_experts, routing_weights = self.route_tokens_to_experts(hidden_states)
final_hidden_states = self.experts(hidden_states, selected_experts, routing_weights)
if self.shared_experts is not None:
@ -227,11 +193,11 @@ class Ernie4_5_MoeDecoderLayer(Qwen3MoeDecoderLayer):
class Ernie4_5_MoePreTrainedModel(MixtralPreTrainedModel):
config: Ernie4_5_MoeConfig
_no_split_modules = ["Ernie4_5_MoeDecoderLayer"]
_keep_in_fp32_modules_strict = ["router"]
_keep_in_fp32_modules_strict = ["gate", "moe_statics"]
# Not supporting multi-token prediction (MTP) atm
_keys_to_ignore_on_load_unexpected = ["mtp"]
_can_record_outputs = {
"router_logits": OutputRecorder(Ernie4_5_MoeTopKRouter, layer_name="mlp.router", index=0),
"router_logits": OutputRecorder(nn.Linear, layer_name="mlp.gate", index=0),
"hidden_states": Ernie4_5_MoeDecoderLayer,
"attentions": Ernie4_5_MoeAttention,
}

View File

@ -203,7 +203,7 @@ class EvollaConfig(PreTrainedConfig):
hidden_act: Optional[str] = "silu", # llama activation function
max_position_embeddings: Optional[int] = 8192, # llama rope max length
rms_norm_eps: Optional[int] = 1e-05,
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
attention_bias: Optional[bool] = False,
attention_dropout: Optional[float] = 0.0,
mlp_bias: Optional[bool] = False,

View File

@ -143,7 +143,7 @@ class Exaone4Config(PreTrainedConfig):
bos_token_id: Optional[int] = 0,
eos_token_id: Optional[int] = 2,
tie_word_embeddings: Optional[bool] = False,
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
attention_dropout: Optional[float] = 0.0,
sliding_window: Optional[int] = 4096,
sliding_window_pattern: Optional[int] = 4,

View File

@ -176,7 +176,7 @@ class Exaone4Config(PreTrainedConfig):
bos_token_id: Optional[int] = 0,
eos_token_id: Optional[int] = 2,
tie_word_embeddings: Optional[bool] = False,
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
attention_dropout: Optional[float] = 0.0,
sliding_window: Optional[int] = 4096,
sliding_window_pattern: Optional[int] = 4,

View File

@ -128,7 +128,7 @@ class FalconConfig(PreTrainedConfig):
parallel_attn: Optional[bool] = True,
bias: Optional[bool] = False,
max_position_embeddings: Optional[int] = 2048,
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
bos_token_id: Optional[int] = 11,
eos_token_id: Optional[int] = 11,
ffn_hidden_size: Optional[int] = None,

View File

@ -164,7 +164,7 @@ class FalconH1Config(PreTrainedConfig):
mamba_norm_before_gate: Optional[bool] = True,
mamba_rms_norm: Optional[bool] = False,
projectors_bias: Optional[bool] = False,
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
lm_head_multiplier: Optional[float] = 1.0,
embedding_multiplier: Optional[float] = 1.0,
mlp_multipliers: Optional[int] = None,

View File

@ -1196,20 +1196,19 @@ class FalconH1PreTrainedModel(PreTrainedModel):
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Module):
for name, param in module.named_parameters(recurse=True):
if not param.requires_grad:
continue
if "layernorm" in name.lower() and "weight" in name:
# LayerNorm weights usually initialized to 1
param.data.fill_(1.0)
elif "bias" in name:
param.data.zero_()
else:
try:
param.data.normal_(mean=0.0, std=std)
except Exception as e:
print(f"Skipping init for {name} due to error: {e}")
for name, param in module.named_parameters(recurse=True):
if not param.requires_grad:
continue
if "layernorm" in name.lower() and "weight" in name:
# LayerNorm weights usually initialized to 1
param.data.fill_(1.0)
elif "bias" in name:
param.data.zero_()
else:
try:
param.data.normal_(mean=0.0, std=std)
except Exception as e:
print(f"Skipping init for {name} due to error: {e}")
def compute_mup_vector(config):

View File

@ -922,20 +922,19 @@ class FalconH1PreTrainedModel(PreTrainedModel):
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Module):
for name, param in module.named_parameters(recurse=True):
if not param.requires_grad:
continue
if "layernorm" in name.lower() and "weight" in name:
# LayerNorm weights usually initialized to 1
param.data.fill_(1.0)
elif "bias" in name:
param.data.zero_()
else:
try:
param.data.normal_(mean=0.0, std=std)
except Exception as e:
print(f"Skipping init for {name} due to error: {e}")
for name, param in module.named_parameters(recurse=True):
if not param.requires_grad:
continue
if "layernorm" in name.lower() and "weight" in name:
# LayerNorm weights usually initialized to 1
param.data.fill_(1.0)
elif "bias" in name:
param.data.zero_()
else:
try:
param.data.normal_(mean=0.0, std=std)
except Exception as e:
print(f"Skipping init for {name} due to error: {e}")
def compute_mup_vector(config):

View File

@ -109,7 +109,6 @@ class FlexOlmoConfig(PreTrainedConfig):
model_type = "flex_olmo"
keys_to_ignore_at_inference = ["past_key_values"]
attribute_map = {"num_local_experts": "num_experts"}
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
"layers.*.self_attn.k_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
@ -142,7 +141,7 @@ class FlexOlmoConfig(PreTrainedConfig):
bos_token_id: Optional[int] = None,
eos_token_id: Optional[int] = 100257,
tie_word_embeddings: Optional[bool] = False,
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
attention_bias: Optional[bool] = False,
attention_dropout: Optional[float] = 0.0,
num_experts_per_tok: Optional[int] = 5,

View File

@ -23,7 +23,6 @@ from collections.abc import Callable
from typing import Optional, Union
import torch
import torch.nn.functional as F
from torch import nn
from ...activations import ACT2FN
@ -292,77 +291,64 @@ class FlexOlmoAttention(nn.Module):
return attn_output, attn_weights
class FlexOlmoExperts(nn.Module):
"""Collection of expert weights stored as 3D tensors."""
class FlexOlmoExperts(nn.ModuleList):
"""
ModuleList of experts.
"""
def __init__(self, config: FlexOlmoConfig):
super().__init__()
self.num_experts = config.num_local_experts
self.hidden_dim = config.hidden_size
self.intermediate_dim = config.intermediate_size
self.gate_up_proj = nn.Parameter(torch.zeros(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
self.down_proj = nn.Parameter(torch.zeros(self.num_experts, self.hidden_dim, self.intermediate_dim))
self.act_fn = ACT2FN[config.hidden_act]
def forward(
self,
hidden_states: torch.Tensor,
top_k_index: torch.Tensor,
top_k_weights: torch.Tensor,
) -> torch.Tensor:
final_hidden_states = torch.zeros_like(hidden_states)
num_experts = top_k_weights.shape[1]
with torch.no_grad():
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1)
expert_mask = expert_mask.permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
for expert_idx in expert_hit:
expert_idx = expert_idx[0]
if expert_idx == num_experts:
continue
_, token_idx = torch.where(expert_mask[expert_idx])
current_state = hidden_states[token_idx]
gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
current_hidden_states = self.act_fn(gate) * up
current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
current_hidden_states = current_hidden_states * top_k_weights[token_idx, expert_idx, None]
final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))
return final_hidden_states
class FlexOlmoTopKRouter(nn.Module):
def __init__(self, config):
super().__init__()
self.top_k = config.num_experts_per_tok
for _ in range(config.num_experts):
self.append(FlexOlmoMLP(config))
self.num_experts = config.num_experts
self.top_k = config.num_experts_per_tok
self.norm_topk_prob = config.norm_topk_prob
self.hidden_dim = config.hidden_size
self.weight = nn.Parameter(torch.zeros(self.num_experts, self.hidden_dim))
def forward(self, hidden_states):
hidden_states = hidden_states.reshape(-1, self.hidden_dim)
router_logits = F.linear(hidden_states, self.weight) # (seq_len, num_experts)
router_logits = torch.nn.functional.softmax(router_logits, dtype=torch.float, dim=-1)
router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k)
if self.norm_topk_prob:
router_top_value /= router_top_value.sum(dim=-1, keepdim=True)
router_top_value = router_top_value.to(router_logits.dtype)
router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value)
return router_scores, router_indices
def forward(
self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor
) -> torch.Tensor:
"""
Args:
hidden_states: (batch_size * sequence_length, hidden_dim)
selected_experts: (batch_size * sequence_length, top_k)
routing_weights: (batch_size * sequence_length, top_k)
Returns:
(batch_size * sequence_length, hidden_dim)
"""
final_hidden_states = torch.zeros_like(hidden_states)
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
for expert_idx in expert_hit:
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1])
current_hidden_states = self[expert_idx](current_state) * top_k_weights[top_x, idx, None]
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
return final_hidden_states
class FlexOlmoSparseMoeBlock(nn.Module):
def __init__(self, config):
super().__init__()
self.gate = FlexOlmoTopKRouter(config)
self.num_experts = config.num_experts
self.top_k = config.num_experts_per_tok
self.norm_topk_prob = config.norm_topk_prob
self.gate = nn.Linear(config.hidden_size, self.num_experts, bias=False)
self.experts = FlexOlmoExperts(config)
def route_tokens_to_experts(self, hidden_states, router_logits):
routing_weights = torch.nn.functional.softmax(router_logits.float(), dim=-1)
top_k_weights, top_k_index = torch.topk(routing_weights, self.top_k, dim=-1)
if self.norm_topk_prob:
top_k_weights /= top_k_weights.sum(dim=-1, keepdim=True)
top_k_weights = top_k_weights.to(hidden_states.dtype)
return top_k_index, top_k_weights
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
top_k_weights, top_k_index = self.gate(hidden_states)
router_logits = self.gate(hidden_states)
top_k_index, top_k_weights = self.route_tokens_to_experts(hidden_states, router_logits)
final_hidden_states = self.experts(hidden_states, top_k_index, top_k_weights).reshape(
batch_size, sequence_length, hidden_dim
)

View File

@ -152,7 +152,7 @@ class FlexOlmoConfig(OlmoeConfig):
bos_token_id: Optional[int] = None,
eos_token_id: Optional[int] = 100257,
tie_word_embeddings: Optional[bool] = False,
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
attention_bias: Optional[bool] = False,
attention_dropout: Optional[float] = 0.0,
num_experts_per_tok: Optional[int] = 5,

View File

@ -118,7 +118,7 @@ class FuyuConfig(PreTrainedConfig):
layer_norm_eps: Optional[int] = 1e-5,
use_cache: Optional[bool] = True,
tie_word_embeddings: Optional[bool] = False,
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
qk_layernorm: Optional[bool] = True,
hidden_dropout: Optional[float] = 0.0,
attention_dropout: Optional[float] = 0.0,

View File

@ -131,7 +131,7 @@ class GemmaConfig(PreTrainedConfig):
eos_token_id: Optional[int] = 1,
bos_token_id: Optional[int] = 2,
tie_word_embeddings: Optional[bool] = True,
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
attention_bias: Optional[bool] = False,
attention_dropout: Optional[float] = 0.0,
use_bidirectional_attention: Optional[bool] = None,

View File

@ -158,7 +158,7 @@ class GemmaConfig(PreTrainedConfig):
eos_token_id: Optional[int] = 1,
bos_token_id: Optional[int] = 2,
tie_word_embeddings: Optional[bool] = True,
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
attention_bias: Optional[bool] = False,
attention_dropout: Optional[float] = 0.0,
use_bidirectional_attention: Optional[bool] = None,

View File

@ -142,7 +142,7 @@ class Gemma2Config(PreTrainedConfig):
eos_token_id: Optional[int] = 1,
bos_token_id: Optional[int] = 2,
tie_word_embeddings: Optional[bool] = True,
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
attention_bias: Optional[bool] = False,
attention_dropout: Optional[float] = 0.0,
query_pre_attn_scalar: Optional[int] = 256,

View File

@ -171,7 +171,7 @@ class Gemma2Config(PreTrainedConfig):
eos_token_id: Optional[int] = 1,
bos_token_id: Optional[int] = 2,
tie_word_embeddings: Optional[bool] = True,
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
attention_bias: Optional[bool] = False,
attention_dropout: Optional[float] = 0.0,
query_pre_attn_scalar: Optional[int] = 256,

View File

@ -177,7 +177,7 @@ class Gemma3nTextConfig(PreTrainedConfig):
pad_token_id: int = 0,
eos_token_id: int = 1,
bos_token_id: int = 2,
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
attention_bias: bool = False,
attention_dropout: float = 0.0,
sliding_window: int = 512,

View File

@ -187,7 +187,7 @@ class Gemma3nTextConfig(Gemma2Config, PreTrainedConfig):
pad_token_id: int = 0,
eos_token_id: int = 1,
bos_token_id: int = 2,
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
attention_bias: bool = False,
attention_dropout: float = 0.0,
sliding_window: int = 512,

View File

@ -121,7 +121,7 @@ class GlmConfig(PreTrainedConfig):
rms_norm_eps: Optional[float] = 0.00000015625,
use_cache: Optional[bool] = True,
tie_word_embeddings: Optional[bool] = False,
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
pad_token_id: Optional[int] = 151329,
eos_token_id: Optional[list[int]] = [151329, 151336, 151338],
bos_token_id: Optional[int] = None,

View File

@ -122,7 +122,7 @@ class Glm4Config(PreTrainedConfig):
rms_norm_eps: Optional[float] = 0.00000015625,
use_cache: Optional[bool] = True,
tie_word_embeddings: Optional[bool] = False,
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
pad_token_id: Optional[int] = 151329,
eos_token_id: Optional[list[int]] = [151329, 151336, 151338],
bos_token_id: Optional[int] = None,

View File

@ -152,7 +152,7 @@ class Glm4MoeConfig(PreTrainedConfig):
rms_norm_eps: Optional[int] = 1e-5,
use_cache: Optional[bool] = True,
tie_word_embeddings: Optional[bool] = False,
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
attention_bias: Optional[bool] = False,
attention_dropout: Optional[float] = 0.0,
moe_intermediate_size: Optional[int] = 1408,

View File

@ -330,43 +330,37 @@ class Glm4MoeRMSNorm(nn.Module):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
class Glm4MoeNaiveMoe(nn.Module):
"""Collection of expert weights stored as 3D tensors."""
class Glm4MoeNaiveMoe(nn.ModuleList):
"""
ModuleList of experts.
"""
def __init__(self, config):
super().__init__()
self.num_experts = config.num_local_experts
self.hidden_dim = config.hidden_size
self.intermediate_dim = config.intermediate_size
self.gate_up_proj = nn.Parameter(torch.zeros(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
self.down_proj = nn.Parameter(torch.zeros(self.num_experts, self.hidden_dim, self.intermediate_dim))
self.act_fn = ACT2FN[config.hidden_act]
for _ in range(self.num_experts):
self.append(Glm4MoeMLP(config, intermediate_size=config.moe_intermediate_size))
def forward(
self,
hidden_states: torch.Tensor,
top_k_index: torch.Tensor,
top_k_weights: torch.Tensor,
self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor
) -> torch.Tensor:
"""
Args:
hidden_states: (batch_size * sequence_length, hidden_dim)
selected_experts: (batch_size * sequence_length, top_k)
routing_weights: (batch_size * sequence_length, top_k)
Returns:
(batch_size * sequence_length, hidden_dim)
"""
final_hidden_states = torch.zeros_like(hidden_states)
num_experts = top_k_weights.shape[1]
with torch.no_grad():
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1)
expert_mask = expert_mask.permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
for expert_idx in expert_hit:
expert_idx = expert_idx[0]
if expert_idx == num_experts:
continue
_, token_idx = torch.where(expert_mask[expert_idx])
current_state = hidden_states[token_idx]
gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
current_hidden_states = self.act_fn(gate) * up
current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
current_hidden_states = current_hidden_states * top_k_weights[token_idx, expert_idx, None]
final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1])
current_hidden_states = self[expert_idx](current_state) * top_k_weights[top_x, idx, None]
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
return final_hidden_states

View File

@ -166,7 +166,7 @@ class Glm4MoeConfig(PreTrainedConfig):
rms_norm_eps: Optional[int] = 1e-5,
use_cache: Optional[bool] = True,
tie_word_embeddings: Optional[bool] = False,
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
attention_bias: Optional[bool] = False,
attention_dropout: Optional[float] = 0.0,
moe_intermediate_size: Optional[int] = 1408,

View File

@ -220,7 +220,7 @@ class Glm4vTextConfig(PreTrainedConfig):
use_cache: Optional[bool] = True,
tie_word_embeddings: Optional[bool] = False,
attention_dropout: Optional[float] = 0.0,
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
image_token_id: Optional[int] = None,
video_token_id: Optional[int] = None,
**kwargs,

View File

@ -257,7 +257,7 @@ class Glm4vTextConfig(PreTrainedConfig):
use_cache: Optional[bool] = True,
tie_word_embeddings: Optional[bool] = False,
attention_dropout: Optional[float] = 0.0,
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
image_token_id: Optional[int] = None,
video_token_id: Optional[int] = None,
**kwargs,

View File

@ -242,7 +242,7 @@ class Glm4vMoeTextConfig(PreTrainedConfig):
rms_norm_eps: Optional[int] = 1e-5,
use_cache: Optional[bool] = True,
tie_word_embeddings: Optional[bool] = False,
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
attention_bias: Optional[bool] = True,
attention_dropout: Optional[float] = 0.0,
moe_intermediate_size: Optional[int] = 1408,

View File

@ -351,43 +351,37 @@ class Glm4vMoeTextTopkRouter(nn.Module):
return router_logits
class Glm4vMoeTextNaiveMoe(nn.Module):
"""Collection of expert weights stored as 3D tensors."""
class Glm4vMoeTextNaiveMoe(nn.ModuleList):
"""
ModuleList of experts.
"""
def __init__(self, config):
super().__init__()
self.num_experts = config.num_local_experts
self.hidden_dim = config.hidden_size
self.intermediate_dim = config.intermediate_size
self.gate_up_proj = nn.Parameter(torch.zeros(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
self.down_proj = nn.Parameter(torch.zeros(self.num_experts, self.hidden_dim, self.intermediate_dim))
self.act_fn = ACT2FN[config.hidden_act]
for _ in range(self.num_experts):
self.append(Glm4vMoeTextMLP(config, intermediate_size=config.moe_intermediate_size))
def forward(
self,
hidden_states: torch.Tensor,
top_k_index: torch.Tensor,
top_k_weights: torch.Tensor,
self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor
) -> torch.Tensor:
"""
Args:
hidden_states: (batch_size * sequence_length, hidden_dim)
selected_experts: (batch_size * sequence_length, top_k)
routing_weights: (batch_size * sequence_length, top_k)
Returns:
(batch_size * sequence_length, hidden_dim)
"""
final_hidden_states = torch.zeros_like(hidden_states)
num_experts = top_k_weights.shape[1]
with torch.no_grad():
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1)
expert_mask = expert_mask.permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
for expert_idx in expert_hit:
expert_idx = expert_idx[0]
if expert_idx == num_experts:
continue
_, token_idx = torch.where(expert_mask[expert_idx])
current_state = hidden_states[token_idx]
gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
current_hidden_states = self.act_fn(gate) * up
current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
current_hidden_states = current_hidden_states * top_k_weights[token_idx, expert_idx, None]
final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1])
current_hidden_states = self[expert_idx](current_state) * top_k_weights[top_x, idx, None]
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
return final_hidden_states

View File

@ -183,7 +183,7 @@ class Glm4vMoeTextConfig(Glm4MoeConfig):
rms_norm_eps: Optional[int] = 1e-5,
use_cache: Optional[bool] = True,
tie_word_embeddings: Optional[bool] = False,
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
attention_bias: Optional[bool] = True,
attention_dropout: Optional[float] = 0.0,
moe_intermediate_size: Optional[int] = 1408,

View File

@ -503,11 +503,10 @@ class GPT2PreTrainedModel(PreTrainedModel):
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
#
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
if isinstance(module, PreTrainedModel):
for name, p in module.named_parameters():
if name == "c_proj.weight":
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer)))
for name, p in module.named_parameters():
if name == "c_proj.weight":
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer)))
@dataclass

View File

@ -131,7 +131,7 @@ class GPTNeoXConfig(PreTrainedConfig):
eos_token_id: Optional[int] = 2,
tie_word_embeddings: Optional[bool] = False,
use_parallel_residual: Optional[bool] = True,
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
attention_bias: Optional[bool] = True,
**kwargs,
):

View File

@ -100,7 +100,7 @@ class GPTNeoXJapaneseConfig(PreTrainedConfig):
use_cache: Optional[bool] = True,
bos_token_id: Optional[int] = 31996,
eos_token_id: Optional[int] = 31999,
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
attention_dropout: Optional[float] = 0.1,
hidden_dropout: Optional[float] = 0.0,
**kwargs,

View File

@ -71,10 +71,10 @@ class GptOssExperts(nn.Module):
self.num_experts = config.num_local_experts
self.hidden_size = config.hidden_size
self.expert_dim = self.intermediate_size
self.gate_up_proj = nn.Parameter(torch.zeros(self.num_experts, self.hidden_size, 2 * self.expert_dim))
self.gate_up_proj_bias = nn.Parameter(torch.zeros(self.num_experts, 2 * self.expert_dim))
self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, 2 * self.expert_dim))
self.gate_up_proj_bias = nn.Parameter(torch.empty(self.num_experts, 2 * self.expert_dim))
self.down_proj = nn.Parameter(torch.empty((self.num_experts, self.expert_dim, self.hidden_size)))
self.down_proj_bias = nn.Parameter(torch.zeros(self.num_experts, self.hidden_size))
self.down_proj_bias = nn.Parameter(torch.empty(self.num_experts, self.hidden_size))
self.alpha = 1.702
self.limit = 7.0
@ -146,8 +146,8 @@ class GptOssTopKRouter(nn.Module):
self.top_k = config.num_experts_per_tok
self.num_experts = config.num_local_experts
self.hidden_dim = config.hidden_size
self.weight = nn.Parameter(torch.zeros(self.num_experts, self.hidden_dim))
self.bias = nn.Parameter(torch.zeros(self.num_experts))
self.weight = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim))
self.bias = nn.Parameter(torch.empty(self.num_experts))
def forward(self, hidden_states):
hidden_states = hidden_states.reshape(-1, self.hidden_dim)

View File

@ -69,10 +69,10 @@ class GptOssExperts(nn.Module):
self.num_experts = config.num_local_experts
self.hidden_size = config.hidden_size
self.expert_dim = self.intermediate_size
self.gate_up_proj = nn.Parameter(torch.zeros(self.num_experts, self.hidden_size, 2 * self.expert_dim))
self.gate_up_proj_bias = nn.Parameter(torch.zeros(self.num_experts, 2 * self.expert_dim))
self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, 2 * self.expert_dim))
self.gate_up_proj_bias = nn.Parameter(torch.empty(self.num_experts, 2 * self.expert_dim))
self.down_proj = nn.Parameter(torch.empty((self.num_experts, self.expert_dim, self.hidden_size)))
self.down_proj_bias = nn.Parameter(torch.zeros(self.num_experts, self.hidden_size))
self.down_proj_bias = nn.Parameter(torch.empty(self.num_experts, self.hidden_size))
self.alpha = 1.702
self.limit = 7.0
@ -144,8 +144,8 @@ class GptOssTopKRouter(nn.Module):
self.top_k = config.num_experts_per_tok
self.num_experts = config.num_local_experts
self.hidden_dim = config.hidden_size
self.weight = nn.Parameter(torch.zeros(self.num_experts, self.hidden_dim))
self.bias = nn.Parameter(torch.zeros(self.num_experts))
self.weight = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim))
self.bias = nn.Parameter(torch.empty(self.num_experts))
def forward(self, hidden_states):
hidden_states = hidden_states.reshape(-1, self.hidden_dim)

View File

@ -141,7 +141,7 @@ class GraniteConfig(PreTrainedConfig):
bos_token_id: Optional[int] = 1,
eos_token_id: Optional[int] = 2,
tie_word_embeddings: Optional[bool] = False,
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
attention_bias: Optional[bool] = False,
attention_dropout: Optional[float] = 0.0,
mlp_bias: Optional[bool] = False,

Some files were not shown because too many files have changed in this diff Show More