mirror of
https://github.com/huggingface/transformers.git
synced 2025-11-12 17:30:55 +08:00
Compare commits
26 Commits
check-refa
...
cb-prefix-
| Author | SHA1 | Date | |
|---|---|---|---|
| a9c0f9a62b | |||
| 041690ffd0 | |||
| c85ee385b5 | |||
| 25ac7a0c2a | |||
| f6e4cc6230 | |||
| 3824e5d201 | |||
| efde9d5427 | |||
| 53f953cec4 | |||
| dd4e048e75 | |||
| 6ff4fabd9d | |||
| 6d4450e341 | |||
| aee5c2384a | |||
| 5b6c209bc5 | |||
| 258c76e4dc | |||
| 64397a8301 | |||
| cd309610c0 | |||
| dd8f231495 | |||
| 1619a3475f | |||
| ff0f7d6498 | |||
| 80305364e2 | |||
| a623cda427 | |||
| 7d5160bd7a | |||
| 22e39dfb31 | |||
| 63fbd50fb4 | |||
| b433ec8b50 | |||
| 3c16c1ae43 |
2
.github/workflows/benchmark.yml
vendored
2
.github/workflows/benchmark.yml
vendored
@ -52,7 +52,7 @@ jobs:
|
||||
commit_id=$GITHUB_SHA
|
||||
fi
|
||||
commit_msg=$(git show -s --format=%s | cut -c1-70)
|
||||
python3 benchmark_v2/run_benchmarks.py -b 32 -s 128 -n 256 --branch-name "$BRANCH_NAME" --commit-id "$commit_id" --commit-message "$commit_msg" --model-id "$MODEL_ID" --log-level INFO --push-result-to-dataset "$DATASET_ID"
|
||||
python3 benchmark_v2/run_benchmarks.py -b 32 -s 128 -n 256 --level 2 --branch-name "$BRANCH_NAME" --commit-id "$commit_id" --commit-message "$commit_msg" --model-id "$MODEL_ID" --log-level INFO --push-result-to-dataset "$DATASET_ID"
|
||||
env:
|
||||
HF_TOKEN: ${{ secrets.HF_HUB_READ_TOKEN }}
|
||||
PUSH_TO_HUB_TOKEN: ${{ secrets.PUSH_TO_HUB_TOKEN }}
|
||||
|
||||
2
.github/workflows/build-docker-images.yml
vendored
2
.github/workflows/build-docker-images.yml
vendored
@ -97,7 +97,7 @@ jobs:
|
||||
latest-torch-deepspeed-docker:
|
||||
name: "Latest PyTorch + DeepSpeed"
|
||||
runs-on:
|
||||
group: aws-g4dn-2xlarge-cache
|
||||
group: aws-general-8-plus
|
||||
steps:
|
||||
-
|
||||
name: Set up Docker Buildx
|
||||
|
||||
@ -21,7 +21,7 @@ jobs:
|
||||
job: run_models_gpu
|
||||
slack_report_channel: "#amd-hf-ci"
|
||||
runner_group: hfc-amd-mi355
|
||||
docker: huggingface/testing-rocm7.0-preview
|
||||
docker: huggingface/transformers-pytorch-amd-gpu
|
||||
ci_event: Scheduled CI (AMD) - mi355
|
||||
report_repo_id: hf-transformers-bot/transformers-ci-dummy
|
||||
secrets: inherit
|
||||
@ -33,7 +33,7 @@ jobs:
|
||||
job: run_pipelines_torch_gpu
|
||||
slack_report_channel: "#amd-hf-ci"
|
||||
runner_group: hfc-amd-mi355
|
||||
docker: huggingface/testing-rocm7.0-preview
|
||||
docker: huggingface/transformers-pytorch-amd-gpu
|
||||
ci_event: Scheduled CI (AMD) - mi355
|
||||
report_repo_id: hf-transformers-bot/transformers-ci-dummy
|
||||
secrets: inherit
|
||||
@ -45,7 +45,7 @@ jobs:
|
||||
job: run_examples_gpu
|
||||
slack_report_channel: "#amd-hf-ci"
|
||||
runner_group: hfc-amd-mi355
|
||||
docker: huggingface/testing-rocm7.0-preview
|
||||
docker: huggingface/transformers-pytorch-amd-gpu
|
||||
ci_event: Scheduled CI (AMD) - mi355
|
||||
report_repo_id: hf-transformers-bot/transformers-ci-dummy
|
||||
secrets: inherit
|
||||
|
||||
12
.github/workflows/self-scheduled-caller.yml
vendored
12
.github/workflows/self-scheduled-caller.yml
vendored
@ -118,3 +118,15 @@ jobs:
|
||||
report_repo_id: hf-internal-testing/transformers_daily_ci
|
||||
commit_sha: ${{ github.sha }}
|
||||
secrets: inherit
|
||||
|
||||
kernels-ci:
|
||||
name: Kernels CI
|
||||
uses: ./.github/workflows/self-scheduled.yml
|
||||
with:
|
||||
job: run_kernels_gpu
|
||||
slack_report_channel: "#transformers-ci-daily-kernels"
|
||||
docker: huggingface/transformers-all-latest-gpu
|
||||
ci_event: Daily CI
|
||||
report_repo_id: hf-internal-testing/transformers_daily_ci
|
||||
commit_sha: ${{ github.sha }}
|
||||
secrets: inherit
|
||||
75
.github/workflows/self-scheduled.yml
vendored
75
.github/workflows/self-scheduled.yml
vendored
@ -102,8 +102,10 @@ jobs:
|
||||
working-directory: /transformers/tests
|
||||
run: |
|
||||
if [ "${{ inputs.job }}" = "run_models_gpu" ]; then
|
||||
echo "folder_slices=$(python3 ../utils/split_model_tests.py --subdirs '${{ inputs.subdirs }}' --num_splits ${{ env.NUM_SLICES }})" >> $GITHUB_OUTPUT
|
||||
echo "slice_ids=$(python3 -c 'd = list(range(${{ env.NUM_SLICES }})); print(d)')" >> $GITHUB_OUTPUT
|
||||
python3 ../utils/split_model_tests.py --subdirs '${{ inputs.subdirs }}' --num_splits ${{ env.NUM_SLICES }} > folder_slices.txt
|
||||
echo "folder_slices=$(cat folder_slices.txt)" >> $GITHUB_OUTPUT
|
||||
python3 -c "import ast; folder_slices = ast.literal_eval(open('folder_slices.txt').read()); open('slice_ids.txt', 'w').write(str(list(range(len(folder_slices)))))"
|
||||
echo "slice_ids=$(cat slice_ids.txt)" >> $GITHUB_OUTPUT
|
||||
elif [ "${{ inputs.job }}" = "run_trainer_and_fsdp_gpu" ]; then
|
||||
echo "folder_slices=[['trainer'], ['fsdp']]" >> $GITHUB_OUTPUT
|
||||
echo "slice_ids=[0, 1]" >> $GITHUB_OUTPUT
|
||||
@ -336,7 +338,7 @@ jobs:
|
||||
working-directory: ${{ inputs.working-directory-prefix }}/
|
||||
run: |
|
||||
python3 -m pip uninstall -y deepspeed
|
||||
DS_DISABLE_NINJA=1 DS_BUILD_CPU_ADAM=1 DS_BUILD_FUSED_ADAM=1 python3 -m pip install deepspeed --global-option="build_ext" --global-option="-j8" --no-cache -v --disable-pip-version-check
|
||||
DS_DISABLE_NINJA=1 DS_BUILD_CPU_ADAM=1 DS_BUILD_FUSED_ADAM=1 python3 -m pip install deepspeed --no-build-isolation --config-settings="--build-option=build_ext" --config-settings="--build-option=-j8" --no-cache -v --disable-pip-version-check
|
||||
|
||||
# To avoid unknown test failures
|
||||
- name: Pre build DeepSpeed *again* (for nightly & Past CI)
|
||||
@ -346,7 +348,7 @@ jobs:
|
||||
python3 -m pip uninstall -y deepspeed
|
||||
rm -rf DeepSpeed
|
||||
git clone https://github.com/deepspeedai/DeepSpeed && cd DeepSpeed && rm -rf build
|
||||
DS_BUILD_CPU_ADAM=1 DS_BUILD_FUSED_ADAM=1 python3 -m pip install . --global-option="build_ext" --global-option="-j8" --no-cache -v --disable-pip-version-check
|
||||
DS_BUILD_CPU_ADAM=1 DS_BUILD_FUSED_ADAM=1 python3 -m pip install . --no-build-isolation --config-settings="--build-option=build_ext" --config-settings="--build-option=-j8" --no-cache -v --disable-pip-version-check
|
||||
|
||||
- name: NVIDIA-SMI
|
||||
run: |
|
||||
@ -475,6 +477,70 @@ jobs:
|
||||
name: ${{ env.machine_type }}_run_quantization_torch_gpu_${{ env.matrix_folders }}_test_reports
|
||||
path: /transformers/reports/${{ env.machine_type }}_run_quantization_torch_gpu_${{ matrix.folders }}_test_reports
|
||||
|
||||
run_kernels_gpu:
|
||||
if: ${{ inputs.job == 'run_kernels_gpu' }}
|
||||
name: Kernel tests
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
machine_type: [aws-g5-4xlarge-cache]
|
||||
runs-on:
|
||||
group: '${{ matrix.machine_type }}'
|
||||
container:
|
||||
image: ${{ inputs.docker }}
|
||||
options: --gpus all --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/
|
||||
steps:
|
||||
- name: Update clone
|
||||
working-directory: /transformers
|
||||
run: git fetch && git checkout ${{ inputs.commit_sha || github.sha }}
|
||||
|
||||
- name: Reinstall transformers in edit mode
|
||||
working-directory: /transformers
|
||||
run: python3 -m pip uninstall -y transformers && python3 -m pip install -e .[testing]
|
||||
|
||||
- name: Install kernels
|
||||
working-directory: /transformers
|
||||
run: python3 -m pip install -U kernels
|
||||
|
||||
- name: NVIDIA-SMI
|
||||
run: nvidia-smi
|
||||
|
||||
- name: Environment
|
||||
working-directory: /transformers
|
||||
run: python3 utils/print_env.py
|
||||
|
||||
- name: Show installed libraries and their versions
|
||||
working-directory: /transformers
|
||||
run: pip freeze
|
||||
|
||||
- name: Set `machine_type` for report and artifact names
|
||||
working-directory: /transformers
|
||||
shell: bash
|
||||
run: |
|
||||
if [ "${{ matrix.machine_type }}" = "aws-g5-4xlarge-cache" ]; then
|
||||
machine_type=single-gpu
|
||||
else
|
||||
machine_type=${{ matrix.machine_type }}
|
||||
fi
|
||||
echo "machine_type=$machine_type" >> $GITHUB_ENV
|
||||
|
||||
- name: Run kernel tests on GPU
|
||||
working-directory: /transformers
|
||||
run: |
|
||||
python3 -m pytest -v --make-reports=${{ env.machine_type }}_run_kernels_gpu_test_reports tests/kernels/test_kernels.py
|
||||
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
continue-on-error: true
|
||||
run: cat /transformers/reports/${{ env.machine_type }}_run_kernels_gpu_test_reports/failures_short.txt
|
||||
|
||||
- name: "Test suite reports artifacts: ${{ env.machine_type }}_run_kernels_gpu_test_reports"
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: ${{ env.machine_type }}_run_kernels_gpu_test_reports
|
||||
path: /transformers/reports/${{ env.machine_type }}_run_kernels_gpu_test_reports
|
||||
|
||||
run_extract_warnings:
|
||||
# Let's only do this for the job `run_models_gpu` to simplify the (already complex) logic.
|
||||
if: ${{ always() && inputs.job == 'run_models_gpu' }}
|
||||
@ -527,6 +593,7 @@ jobs:
|
||||
run_examples_gpu,
|
||||
run_torch_cuda_extensions_gpu,
|
||||
run_quantization_torch_gpu,
|
||||
run_kernels_gpu,
|
||||
run_extract_warnings
|
||||
]
|
||||
if: always() && !cancelled()
|
||||
|
||||
16
.github/workflows/ssh-runner.yml
vendored
16
.github/workflows/ssh-runner.yml
vendored
@ -4,7 +4,7 @@ on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
runner_type:
|
||||
description: 'Type of runner to test (a10 or t4)'
|
||||
description: 'Type of runner to test (a10)'
|
||||
required: true
|
||||
docker_image:
|
||||
description: 'Name of the Docker image'
|
||||
@ -36,14 +36,10 @@ jobs:
|
||||
NUM_GPUS: ${{ github.event.inputs.num_gpus }}
|
||||
RUNNER_TYPE: ${{ github.event.inputs.runner_type }}
|
||||
run: |
|
||||
if [[ "$NUM_GPUS" == "single" && "$RUNNER_TYPE" == "t4" ]]; then
|
||||
echo "RUNNER=aws-g4dn-4xlarge-cache" >> $GITHUB_ENV
|
||||
elif [[ "$NUM_GPUS" == "multi" && "$RUNNER_TYPE" == "t4" ]]; then
|
||||
echo "RUNNER=aws-g4dn-12xlarge-cache" >> $GITHUB_ENV
|
||||
elif [[ "$NUM_GPUS" == "single" && "$RUNNER_TYPE" == "a10" ]]; then
|
||||
echo "RUNNER=aws-g5-4xlarge-cache" >> $GITHUB_ENV
|
||||
if [[ "$NUM_GPUS" == "single" && "$RUNNER_TYPE" == "a10" ]]; then
|
||||
echo "RUNNER=aws-g5-4xlarge-cache-ssh" >> $GITHUB_ENV
|
||||
elif [[ "$NUM_GPUS" == "multi" && "$RUNNER_TYPE" == "a10" ]]; then
|
||||
echo "RUNNER=aws-g5-12xlarge-cache" >> $GITHUB_ENV
|
||||
echo "RUNNER=aws-g5-12xlarge-cache-ssh" >> $GITHUB_ENV
|
||||
else
|
||||
echo "RUNNER=" >> $GITHUB_ENV
|
||||
fi
|
||||
@ -61,8 +57,6 @@ jobs:
|
||||
group: ${{ needs.get_runner.outputs.RUNNER }}
|
||||
container:
|
||||
image: ${{ github.event.inputs.docker_image }}
|
||||
options: --gpus all --privileged --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/
|
||||
|
||||
steps:
|
||||
- name: Update clone
|
||||
working-directory: /transformers
|
||||
@ -106,7 +100,7 @@ jobs:
|
||||
else
|
||||
echo "SLACKCHANNEL=${{ secrets.SLACK_CIFEEDBACK_CHANNEL }}" >> $GITHUB_ENV
|
||||
fi
|
||||
|
||||
|
||||
- name: Tailscale # In order to be able to SSH when a test fails
|
||||
uses: huggingface/tailscale-action@main
|
||||
with:
|
||||
|
||||
@ -1,8 +1,11 @@
|
||||
import hashlib
|
||||
import itertools
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from transformers.utils.import_utils import is_flash_attn_2_available
|
||||
|
||||
|
||||
KERNELIZATION_AVAILABLE = False
|
||||
try:
|
||||
@ -18,6 +21,16 @@ logger = logging.getLogger(__name__)
|
||||
class BenchmarkConfig:
|
||||
"""Configuration for a single benchmark scenario."""
|
||||
|
||||
all_attn_implementations = [
|
||||
("flash_attention_2", None),
|
||||
("eager", None),
|
||||
("sdpa", "math"),
|
||||
("sdpa", "flash_attention"),
|
||||
("flex_attention", None),
|
||||
]
|
||||
|
||||
all_compiled_modes = [None, "default", "reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
warmup_iterations: int = 5,
|
||||
@ -59,6 +72,13 @@ class BenchmarkConfig:
|
||||
def check_validity(self, skip_validity_check: bool = False) -> None:
|
||||
if skip_validity_check:
|
||||
return
|
||||
# Check FA is installed
|
||||
if self.attn_implementation == "flash_attention_2" and not is_flash_attn_2_available():
|
||||
logger.warning(
|
||||
"Flash attention does not support compile mode. Defaulting to SDPA w/ flash attention backend."
|
||||
)
|
||||
self.attn_implementation = "sdpa"
|
||||
self.sdpa_backend = "flash_attention"
|
||||
# Flash attention does not support compile mode, so we turn it off # FIXME: it would be better to support it
|
||||
is_fa = self.attn_implementation == "flash_attention_2"
|
||||
is_fa |= self.attn_implementation == "sdpa" and self.sdpa_backend == "flash_attention"
|
||||
@ -127,88 +147,68 @@ class BenchmarkConfig:
|
||||
)
|
||||
|
||||
|
||||
def cross_generate_configs(
|
||||
attn_impl_and_sdpa_backend: list[tuple[str, str | None]],
|
||||
compiled_mode: list[str | None],
|
||||
kernelized: list[bool],
|
||||
warmup_iterations: int = 5,
|
||||
measurement_iterations: int = 20,
|
||||
batch_size: int = 1,
|
||||
sequence_length: int = 128,
|
||||
num_tokens_to_generate: int = 128,
|
||||
gpu_monitoring: bool = True,
|
||||
def adapt_configs(
|
||||
configs: list[BenchmarkConfig],
|
||||
warmup_iterations: int | list[int] = 5,
|
||||
measurement_iterations: int | list[int] = 20,
|
||||
batch_size: int | list[int] = 1,
|
||||
sequence_length: int | list[int] = 128,
|
||||
num_tokens_to_generate: int | list[int] = 128,
|
||||
gpu_monitoring: bool | list[bool] = True,
|
||||
) -> list[BenchmarkConfig]:
|
||||
# Create kwargs common to all configs
|
||||
kwargs = {
|
||||
"warmup_iterations": warmup_iterations,
|
||||
"measurement_iterations": measurement_iterations,
|
||||
"batch_size": batch_size,
|
||||
"sequence_length": sequence_length,
|
||||
"num_tokens_to_generate": num_tokens_to_generate,
|
||||
"gpu_monitoring": gpu_monitoring,
|
||||
}
|
||||
# Cross-generate all combinations of attn_implementation, compiled_mode, and kernelized
|
||||
configs = []
|
||||
for attn_implementation, sdpa_backend in list(dict.fromkeys(attn_impl_and_sdpa_backend)):
|
||||
for cm in list(dict.fromkeys(compiled_mode)):
|
||||
for kernelize_on in list(dict.fromkeys(kernelized)):
|
||||
config = BenchmarkConfig(
|
||||
attn_implementation=attn_implementation,
|
||||
sdpa_backend=sdpa_backend,
|
||||
compile_mode=cm,
|
||||
kernelize=kernelize_on,
|
||||
**kwargs,
|
||||
)
|
||||
configs.append(config)
|
||||
return configs
|
||||
|
||||
|
||||
def generate_all_configs(
|
||||
warmup_iterations: int = 5,
|
||||
measurement_iterations: int = 20,
|
||||
batch_size: int = 1,
|
||||
sequence_length: int = 128,
|
||||
num_tokens_to_generate: int = 128,
|
||||
gpu_monitoring: bool = True,
|
||||
) -> list[BenchmarkConfig]:
|
||||
all_attn_implementations = [
|
||||
("flash_attention_2", None),
|
||||
("eager", None),
|
||||
("sdpa", "math"),
|
||||
("sdpa", "flash_attention"),
|
||||
("flex_attention", None),
|
||||
]
|
||||
return cross_generate_configs(
|
||||
attn_impl_and_sdpa_backend=all_attn_implementations,
|
||||
compiled_mode=[None, "default", "reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs"],
|
||||
kernelized=[False, KERNELIZATION_AVAILABLE],
|
||||
warmup_iterations=warmup_iterations,
|
||||
measurement_iterations=measurement_iterations,
|
||||
batch_size=batch_size,
|
||||
sequence_length=sequence_length,
|
||||
num_tokens_to_generate=num_tokens_to_generate,
|
||||
gpu_monitoring=gpu_monitoring,
|
||||
parameters = (
|
||||
x if isinstance(x, list) else [x]
|
||||
for x in [
|
||||
warmup_iterations,
|
||||
measurement_iterations,
|
||||
batch_size,
|
||||
sequence_length,
|
||||
num_tokens_to_generate,
|
||||
gpu_monitoring,
|
||||
]
|
||||
)
|
||||
iterator = itertools.product(*parameters)
|
||||
|
||||
adapted_configs = []
|
||||
for warmup_iters, measurement_iters, bs, seqlen, ntok, monitor in iterator:
|
||||
for config in configs:
|
||||
config = config.to_dict()
|
||||
config["warmup_iterations"] = warmup_iters
|
||||
config["measurement_iterations"] = measurement_iters
|
||||
config["batch_size"] = bs
|
||||
config["sequence_length"] = seqlen
|
||||
config["num_tokens_to_generate"] = ntok
|
||||
config["gpu_monitoring"] = monitor
|
||||
adapted_configs.append(BenchmarkConfig.from_dict(config))
|
||||
return adapted_configs
|
||||
|
||||
|
||||
def generate_main_configs(
|
||||
warmup_iterations: int = 5,
|
||||
measurement_iterations: int = 20,
|
||||
batch_size: int = 1,
|
||||
sequence_length: int = 128,
|
||||
num_tokens_to_generate: int = 128,
|
||||
) -> list[BenchmarkConfig]:
|
||||
# Create kwargs common to all configs
|
||||
kwargs = {
|
||||
"warmup_iterations": warmup_iterations,
|
||||
"measurement_iterations": measurement_iterations,
|
||||
"batch_size": batch_size,
|
||||
"sequence_length": sequence_length,
|
||||
"num_tokens_to_generate": num_tokens_to_generate,
|
||||
}
|
||||
return [ # TODO: test max-autotune instead of default
|
||||
BenchmarkConfig(attn_implementation="flex_attention", compile_mode="default", gpu_monitoring=False, **kwargs),
|
||||
BenchmarkConfig(attn_implementation="flex_attention", compile_mode="default", gpu_monitoring=True, **kwargs),
|
||||
BenchmarkConfig(attn_implementation="eager", compile_mode="default", gpu_monitoring=True, **kwargs),
|
||||
BenchmarkConfig(attn_implementation="flash_attention_2", gpu_monitoring=True, **kwargs),
|
||||
]
|
||||
def get_config_by_level(level: int) -> list[BenchmarkConfig]:
|
||||
configs = []
|
||||
# Early return if level is greater than 3: we generate all combinations of configs, maybe even w/ all compile modes
|
||||
if level >= 3:
|
||||
for attn_implementation, sdpa_backend in BenchmarkConfig.all_attn_implementations:
|
||||
# Usually there is not much to gain by compiling with other modes, but we allow it for level 4
|
||||
compile_modes = BenchmarkConfig.all_compiled_modes if level >= 4 else [None, "default"]
|
||||
for cm in compile_modes:
|
||||
for kernelize_on in [False, KERNELIZATION_AVAILABLE]:
|
||||
configs.append(
|
||||
BenchmarkConfig(
|
||||
attn_implementation=attn_implementation,
|
||||
sdpa_backend=sdpa_backend,
|
||||
compile_mode=cm,
|
||||
kernelize=kernelize_on,
|
||||
)
|
||||
)
|
||||
return configs
|
||||
# Otherwise, we add the configs for the given level
|
||||
if level >= 0:
|
||||
configs.append(BenchmarkConfig(attn_implementation="flex_attention", compile_mode="default"))
|
||||
if level >= 1:
|
||||
configs.append(BenchmarkConfig(attn_implementation="flash_attention_2"))
|
||||
configs.append(BenchmarkConfig(attn_implementation="eager", compile_mode="default"))
|
||||
if level >= 2:
|
||||
configs.append(BenchmarkConfig(attn_implementation="sdpa", compile_mode="default"))
|
||||
configs.append(BenchmarkConfig(attn_implementation="flex_attention", compile_mode="default", kernelize=True))
|
||||
configs.append(BenchmarkConfig(attn_implementation="flash_attention_2", kernelize=True))
|
||||
return configs
|
||||
|
||||
@ -23,7 +23,7 @@ import logging
|
||||
import sys
|
||||
import uuid
|
||||
|
||||
from framework.benchmark_config import BenchmarkConfig, generate_all_configs, generate_main_configs
|
||||
from framework.benchmark_config import adapt_configs, get_config_by_level
|
||||
from framework.benchmark_runner import BenchmarkRunner
|
||||
|
||||
|
||||
@ -40,7 +40,14 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--sequence-length", "-s", type=int, nargs="+", help="Sequence length")
|
||||
parser.add_argument("--num-tokens-to-generate", "-n", type=int, nargs="+", help="Number of tokens to generate")
|
||||
|
||||
parser.add_argument("--cross-generate", action="store_true", help="Cross-generate all combinations of configs")
|
||||
parser.add_argument(
|
||||
"--level",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Level of coverage for the benchmark. 0: only the main config, 1: a few important configs, 2: a config for"
|
||||
" each attn implementation an option, 3: cross-generate all combinations of configs, 4: cross-generate all"
|
||||
" combinations of configs w/ all compile modes",
|
||||
)
|
||||
parser.add_argument("--num-tokens-to-profile", "-p", type=int, default=0, help="Number of tokens to profile")
|
||||
|
||||
parser.add_argument("--branch-name", type=str, help="Git branch name")
|
||||
@ -79,64 +86,24 @@ if __name__ == "__main__":
|
||||
"At least one of the arguments --batch-size, --sequence-length, or --num-tokens-to-generate is required"
|
||||
)
|
||||
|
||||
# If there is only one (batch_size, sequence_length, num_tokens_to_generate), we benchmark across configs
|
||||
elif len(args.batch_size) * len(args.sequence_length) * len(args.num_tokens_to_generate) == 1:
|
||||
if args.cross_generate:
|
||||
benchmark_configs = generate_all_configs(
|
||||
warmup_iterations=args.warmup,
|
||||
measurement_iterations=args.iterations,
|
||||
batch_size=args.batch_size[0],
|
||||
sequence_length=args.sequence_length[0],
|
||||
num_tokens_to_generate=args.num_tokens_to_generate[0],
|
||||
gpu_monitoring=not args.no_gpu_monitoring,
|
||||
)
|
||||
else:
|
||||
benchmark_configs = generate_main_configs(
|
||||
warmup_iterations=args.warmup,
|
||||
measurement_iterations=args.iterations,
|
||||
batch_size=args.batch_size[0],
|
||||
sequence_length=args.sequence_length[0],
|
||||
num_tokens_to_generate=args.num_tokens_to_generate[0],
|
||||
)
|
||||
|
||||
# Otherwise, we benchmark across all combinations of dimensions
|
||||
else:
|
||||
main_config = generate_main_configs(
|
||||
warmup_iterations=args.warmup,
|
||||
measurement_iterations=args.iterations,
|
||||
batch_size=args.batch_size[0],
|
||||
sequence_length=args.sequence_length[0],
|
||||
num_tokens_to_generate=args.num_tokens_to_generate[0],
|
||||
)[0]
|
||||
benchmark_configs = []
|
||||
for num_tokens_to_generate in args.num_tokens_to_generate:
|
||||
for sequence_length in args.sequence_length:
|
||||
for batch_size in args.batch_size:
|
||||
cfg_dict = main_config.to_dict()
|
||||
cfg_dict["batch_size"] = batch_size
|
||||
cfg_dict["sequence_length"] = sequence_length
|
||||
cfg_dict["num_tokens_to_generate"] = num_tokens_to_generate
|
||||
cfg_dict.pop("name")
|
||||
benchmark_configs.append(BenchmarkConfig.from_dict(cfg_dict))
|
||||
|
||||
runner = BenchmarkRunner(
|
||||
logger,
|
||||
args.output_dir,
|
||||
args.branch_name,
|
||||
args.commit_id,
|
||||
args.commit_message,
|
||||
# Get the configs for the given coverage level
|
||||
configs = get_config_by_level(args.level)
|
||||
# Adapt the configs to the given arguments
|
||||
configs = adapt_configs(
|
||||
configs,
|
||||
args.warmup,
|
||||
args.iterations,
|
||||
args.batch_size,
|
||||
args.sequence_length,
|
||||
args.num_tokens_to_generate,
|
||||
not args.no_gpu_monitoring,
|
||||
)
|
||||
|
||||
runner = BenchmarkRunner(logger, args.output_dir, args.branch_name, args.commit_id, args.commit_message)
|
||||
timestamp, results = runner.run_benchmarks(
|
||||
args.model_id,
|
||||
benchmark_configs,
|
||||
args.num_tokens_to_profile,
|
||||
pretty_print_summary=True,
|
||||
args.model_id, configs, args.num_tokens_to_profile, pretty_print_summary=True
|
||||
)
|
||||
|
||||
dataset_id = args.push_result_to_dataset
|
||||
if dataset_id is not None and len(results) > 0:
|
||||
runner.push_results_to_hub(
|
||||
dataset_id,
|
||||
results,
|
||||
timestamp,
|
||||
)
|
||||
runner.push_results_to_hub(dataset_id, results, timestamp)
|
||||
|
||||
@ -39,7 +39,7 @@ RUN python3 -m pip install --no-cache-dir "torchcodec==0.5"
|
||||
# Install flash attention from source. Tested with commit 6387433156558135a998d5568a9d74c1778666d8
|
||||
RUN git clone https://github.com/ROCm/flash-attention/ -b tridao && \
|
||||
cd flash-attention && \
|
||||
GPU_ARCHS="gfx942;gfx950" python setup.py install
|
||||
# GPU_ARCHS builds for MI300, MI325 and MI355
|
||||
GPU_ARCHS="gfx942" python setup.py install
|
||||
# GPU_ARCHS builds for MI300, MI325 but not MI355: we would need to add `;gfx950` but it takes too long to build.
|
||||
|
||||
RUN python3 -m pip install --no-cache-dir einops
|
||||
|
||||
@ -21,7 +21,7 @@ RUN python3 -m pip install --no-cache-dir './transformers[deepspeed-testing]' 'p
|
||||
# Install latest release PyTorch
|
||||
# (PyTorch must be installed before pre-compiling any DeepSpeed c++/cuda ops.)
|
||||
# (https://www.deepspeed.ai/tutorials/advanced-install/#pre-install-deepspeed-ops)
|
||||
RUN python3 -m pip uninstall -y torch torchvision torchaudio && python3 -m pip install --no-cache-dir -U torch==$PYTORCH torchvision torchaudio torchcodec --extra-index-url https://download.pytorch.org/whl/$CUDA
|
||||
RUN python3 -m pip uninstall -y torch torchvision torchaudio torchcodec && python3 -m pip install --no-cache-dir -U torch==$PYTORCH torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/$CUDA
|
||||
|
||||
RUN python3 -m pip install --no-cache-dir git+https://github.com/huggingface/accelerate@main#egg=accelerate
|
||||
|
||||
@ -43,7 +43,7 @@ RUN python3 -m pip uninstall -y deepspeed
|
||||
# This has to be run (again) inside the GPU VMs running the tests.
|
||||
# The installation works here, but some tests fail, if we don't pre-build deepspeed again in the VMs running the tests.
|
||||
# TODO: Find out why test fail.
|
||||
RUN DS_BUILD_CPU_ADAM=1 DS_BUILD_FUSED_ADAM=1 python3 -m pip install deepspeed --global-option="build_ext" --global-option="-j8" --no-cache -v --disable-pip-version-check 2>&1
|
||||
RUN DS_BUILD_CPU_ADAM=1 DS_BUILD_FUSED_ADAM=1 python3 -m pip install deepspeed --no-build-isolation --config-settings="--build-option=build_ext" --config-settings="--build-option=-j8" --no-cache -v --disable-pip-version-check 2>&1
|
||||
|
||||
# `kernels` may give different outputs (within 1e-5 range) even with the same model (weights) and the same inputs
|
||||
RUN python3 -m pip uninstall -y kernels
|
||||
|
||||
@ -24,7 +24,7 @@ RUN [ ${#PYTORCH} -gt 0 ] && VERSION='torch=='$PYTORCH'.*' || VERSION='torch';
|
||||
RUN echo torch=$VERSION
|
||||
# `torchvision` and `torchaudio` should be installed along with `torch`, especially for nightly build.
|
||||
# Currently, let's just use their latest releases (when `torch` is installed with a release version)
|
||||
RUN python3 -m pip install --no-cache-dir -U $VERSION torchvision torchaudio torchcodec --extra-index-url https://download.pytorch.org/whl/$CUDA
|
||||
RUN python3 -m pip install --no-cache-dir -U $VERSION torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/$CUDA
|
||||
|
||||
RUN python3 -m pip install --no-cache-dir git+https://github.com/huggingface/accelerate@main#egg=accelerate
|
||||
|
||||
|
||||
@ -119,6 +119,8 @@
|
||||
title: Tools
|
||||
- local: transformers_as_backend
|
||||
title: Inference server backends
|
||||
- local: continuous_batching
|
||||
title: Continuous Batching
|
||||
title: Inference
|
||||
- isExpanded: false
|
||||
sections:
|
||||
|
||||
194
docs/source/en/continuous_batching.md
Normal file
194
docs/source/en/continuous_batching.md
Normal file
@ -0,0 +1,194 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# Continuous Batching
|
||||
|
||||
Continuous Batching (CB) is an advanced technique to optimize the inference of transformer models by dynamically grouping multiple requests into batches. This approach maximizes GPU utilization and throughput, specifically for workloads with many variable-length inputs.
|
||||
|
||||
We are particularly interested in having Continuous Batching in transformers for the following use cases:
|
||||
- Evaluation of models on large datasets with variable-length inputs
|
||||
- Generating outputs for multiple sequences for GRPO policies
|
||||
|
||||
CB is what makes inference engines like vLLM or SGLang efficient. That being said, transformers does not aim to be a production-ready inference engine, but a complete framework for model development. For this reason, CB is available in `transformers serve`.
|
||||
|
||||
If you are not familiar with some of the core concepts CB is built upon, we invite you to read the associated blog post: [Continuous Batching: Efficient Inference for Large Language Models](https://huggingface.co/blog/continuous-batching). _broken link for now_
|
||||
|
||||
## API Reference
|
||||
|
||||
## Usage Examples
|
||||
|
||||
The main way to use CB in transformers is via the `generate_batch` method.
|
||||
|
||||
Unlike `generate`, CB takes already tokenized inputs, known as input IDs. Each sequence of input IDs is represented as a list of integers, in python: `list[int]`. Since
|
||||
|
||||
For a more detailed example, please refer to: [examples/continuous_batching](./path/to/example)
|
||||
|
||||
### `generate_batch` example
|
||||
|
||||
We have created a `ContinuousMixin` that is inherited by the `GenerationMixin` so that all auto regressive text models support CB.
|
||||
|
||||
This adds the `generate_batch` method to all models that inherit from `GenerationMixin`.
|
||||
|
||||
You can use it as follows:
|
||||
|
||||
```py
|
||||
import datasets
|
||||
import torch
|
||||
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from transformers.generation import GenerationConfig
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"Qwen/Qwen3-4B-Instruct-2507",
|
||||
attn_implementation="spda_paged",
|
||||
device_map="cuda", # if you need cuda
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, padding_side="left")
|
||||
|
||||
# prepare a batch of inputs
|
||||
dataset = datasets.load_dataset("openai/gsm8k", "socratic", split="test")
|
||||
dataset = dataset.select(range(args.samples))
|
||||
tokenized_datasets = dataset.map(lambda x: tokenizer(x["question"]), batched=True)
|
||||
simple_batch_inputs = [item["input_ids"] for item in tokenized_datasets]
|
||||
|
||||
generation_config = GenerationConfig(
|
||||
max_new_tokens=32,
|
||||
use_cuda_graph=False, # Not supported for simple version
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
do_sample=False,
|
||||
max_batch_tokens=512, # max number of tokens in a batch, this is just a default value you should tune based on your hardware
|
||||
)
|
||||
|
||||
batch_outputs = model.generate_batch(
|
||||
inputs=simple_batch_inputs,
|
||||
generation_config=generation_config,
|
||||
)
|
||||
|
||||
for request_id, output in batch_outputs.items():
|
||||
generated_text = tokenizer.decode(output.generated_tokens, skip_special_tokens=True)
|
||||
print(f"Request {request_id} output: {generated_text}")
|
||||
```
|
||||
|
||||
### `ContinuousBatchingManager` example
|
||||
|
||||
If you want more control w.r.t. how you want to schedule requests using CB, you can use the `ContinuousBatchingManager` class directly.
|
||||
|
||||
This is what we use in `transformers serve` because requests arrive asynchronously and we can leverage the asynchronous nature of the CB process to make things more efficient.
|
||||
|
||||
Under the hood, the `ContinuousBatchingManager` creates a background thread that receives inputs from a python `queue.Queue` which it uses to get requests to batch in each forward pass.
|
||||
|
||||
Note that the manager is thread safe!
|
||||
|
||||
```py
|
||||
import datasets
|
||||
import torch
|
||||
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from transformers.generation import GenerationConfig
|
||||
from transformers.generation.continuous_batching import RequestStatus
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"Qwen/Qwen3-4B-Instruct-2507",
|
||||
attn_implementation="spda_paged",
|
||||
device_map="cuda", # if you need cuda
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, padding_side="left")
|
||||
|
||||
# prepare a batch of inputs
|
||||
dataset = datasets.load_dataset("openai/gsm8k", "socratic", split="test")
|
||||
dataset = dataset.select(range(args.samples))
|
||||
tokenized_datasets = dataset.map(lambda x: tokenizer(x["question"]), batched=True)
|
||||
simple_batch_inputs = [item["input_ids"] for item in tokenized_datasets]
|
||||
|
||||
# initialize the manager, available method thanks to the `ContinuousMixin`
|
||||
manager = model.init_continuous_batching(generation_config=generation_config)
|
||||
|
||||
# start the background thread
|
||||
manager.start()
|
||||
|
||||
# this is for demonstration purposes only, in practice this is most useful to do concurrently
|
||||
for i, input in enumerate(simple_batch_inputs):
|
||||
request_id = manager.add_request(input_ids=input, request_id=f"request_{i}") # if you do not specify a request_id, one will be generated for you
|
||||
|
||||
# Can be done in an other thread
|
||||
for id, request in manager.get_result():
|
||||
generated_text = tokenizer.decode(request.generated_tokens, skip_special_tokens=True)
|
||||
print(f"Request {id} output: {generated_text}")
|
||||
|
||||
# you can also get results for a specific request id
|
||||
result = manager.get_result(request_id="request_5") # this is blocking and will wait for the result to be ready
|
||||
|
||||
# or get results for a request that is streaming
|
||||
manager.add_request(
|
||||
input_ids=input,
|
||||
request_id="streaming_request",
|
||||
stream=True,
|
||||
)
|
||||
for chunk in manager.request_id_iter(request_id="streaming_request"):
|
||||
generated_text = tokenizer.decode(chunk.generated_tokens, skip_special_tokens=True)
|
||||
print(generated_text)
|
||||
# FIXME: stop iteration in `request_id_iter` when finished instead of doing it externally
|
||||
if chunk.status == RequestStatus.FINISHED:
|
||||
break
|
||||
|
||||
# stop the background thread before exiting the process
|
||||
manager.stop()
|
||||
```
|
||||
|
||||
## Supported & Unsupported Features
|
||||
|
||||
### Supported Features
|
||||
|
||||
- Dynamic scheduling of variable-length requests
|
||||
- Chunked prefill
|
||||
- Paged Attention Cache
|
||||
- Sliding window attention
|
||||
- Chat templates
|
||||
|
||||
### Unsupported Features
|
||||
|
||||
At the moment, the following features are not supported with CB. We plan to add support to the following:
|
||||
|
||||
- Prefix caching
|
||||
- Beam search
|
||||
- tool calling
|
||||
|
||||
The others are unplanned, but depending on community requests we might consider adding them:
|
||||
|
||||
- MTP (multi token prediction)
|
||||
- Medusa
|
||||
|
||||
## Performance Considerations
|
||||
|
||||
|
||||
## Integration with Serving
|
||||
|
||||
You can use CB in `transformers serve` by passing the `--continuous-batching` flag when starting the server.
|
||||
|
||||
## Monitoring
|
||||
|
||||
We have added `opentelemetry` support to Continuous Batching to help you monitor its performance in production. To enable it, you need to install the `opentelemetry` extra when installing `transformers`:
|
||||
|
||||
```sh
|
||||
# this installs `opentelemetry-api`, `opentelemetry-sdk` and `opentelemetry-exporter-otlp`
|
||||
pip install transformers[open-telemetry]
|
||||
```
|
||||
|
||||
This will enable traces and metrics collection in CB. You will then have to setup the backend to collect and visualize the traces and metrics.
|
||||
|
||||
@ -393,3 +393,9 @@ model = AutoModelForCausalLM.from_pretrained(
|
||||
"mistralai/Mistral-7B-v0.1", quantization_config=quant_config, device_map="auto"
|
||||
)
|
||||
```
|
||||
|
||||
## Continuous Batching
|
||||
|
||||
When serving LLMs for inference, you may have multiple requests arriving at different times. Continuous Batching (CB) is a technique that groups incoming requests into batches to maximize GPU utilization and throughput.
|
||||
|
||||
See the [Continuous Batching](./continuous_batching) guide for more details on how to use CB in transformers.
|
||||
|
||||
@ -158,6 +158,24 @@ print("Retrieval scores (query x image):")
|
||||
print(scores)
|
||||
```
|
||||
|
||||
You can also use checkpoints for `ColQwen2.5` that are **compatible with the ColQwen2 architecture**. This version of the model uses [Qwen2_5_VL](./qwen2_5_vl) as the backbone.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from transformers import ColQwen2ForRetrieval, ColQwen2Processor
|
||||
from transformers.utils.import_utils import is_flash_attn_2_available
|
||||
|
||||
model_name = "Sahil-Kabir/colqwen2.5-v0.2-hf" # An existing compatible checkpoint
|
||||
|
||||
model = ColQwen2ForRetrieval.from_pretrained(
|
||||
model_name,
|
||||
dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
attn_implementation="flash_attention_2" if is_flash_attn_2_available() else "sdpa"
|
||||
)
|
||||
processor = ColQwen2Processor.from_pretrained(model_name)
|
||||
```
|
||||
|
||||
## Notes
|
||||
|
||||
- [`~ColQwen2Processor.score_retrieval`] returns a 2D tensor where the first dimension is the number of queries and the second dimension is the number of images. A higher score indicates more similarity between the query and image.
|
||||
|
||||
@ -149,7 +149,7 @@ The example below packs `up_proj` and `gate_proj` into a single `gate_up_proj` m
|
||||
```python
|
||||
class Llama4TextExperts(nn.Module):
|
||||
...
|
||||
self.gate_up_proj = nn.Parameter(torch.zeros(self.num_experts, self.hidden_size, 2 * self.expert_dim))
|
||||
self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, 2 * self.expert_dim))
|
||||
```
|
||||
|
||||
Batch matrix multiplication can be used in the `forward` pass to compute the output of the `gate_up_proj` module.
|
||||
|
||||
@ -40,7 +40,7 @@ You can choose between MXFP4 and NVFP4 with `FPQuantConfig(forward_dtype="mxfp4"
|
||||
|
||||
A **Blackwell-generation GPU is required** to run the kernels. Runtime support for FP-Quant is implemented through the [QuTLASS](https://github.com/IST-DASLab/qutlass) library and a lightweight PyTorch interface lib [`fp_quant`](https://github.com/IST-DASLab/FP-Quant/tree/master/inference_lib). We recommend installing the former **from source** and the latter with `pip install fp_quant`.
|
||||
|
||||
Users **without a Blackwell-generation GPU** , can use the method with `quantization_config=FPQuantConfig(pseudoquant=True)` without having to install [QuTLASS](https://github.com/IST-DASLab/qutlass). This would provide no speedups but would fully emulate the effect of quantization.
|
||||
Users **without a Blackwell-generation GPU** , can use the method with `quantization_config=FPQuantConfig(pseudoquantization=True)` without having to install [QuTLASS](https://github.com/IST-DASLab/qutlass). This would provide no speedups but would fully emulate the effect of quantization.
|
||||
|
||||
> [!TIP]
|
||||
> Find models pre-quantized with FP-Quant in the official ISTA-DASLab [collection](https://huggingface.co/collections/ISTA-DASLab/fp-quant-6877c186103a21d3a02568ee).
|
||||
|
||||
@ -187,7 +187,7 @@ from torch import nn
|
||||
from transformers import Trainer
|
||||
|
||||
class CustomTrainer(Trainer):
|
||||
def compute_loss(self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], return_outputs: bool = False num_items_in_batch: Optional[torch.Tensor] = None):
|
||||
def compute_loss(self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], return_outputs: bool = False, num_items_in_batch: Optional[torch.Tensor] = None):
|
||||
labels = inputs.pop("labels")
|
||||
# forward pass
|
||||
outputs = model(**inputs)
|
||||
|
||||
@ -152,7 +152,7 @@ class ParallelInterface(MutableMapping):
|
||||
```python
|
||||
class Llama4TextExperts(nn.Module):
|
||||
...
|
||||
self.gate_up_proj = nn.Parameter(torch.zeros(self.num_experts, self.hidden_size, 2 * self.expert_dim))
|
||||
self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, 2 * self.expert_dim))
|
||||
```
|
||||
|
||||
배치 행렬 곱셈을 `forward` 패스에서 사용하여 `gate_up_proj` 모듈의 출력을 계산할 수 있습니다.
|
||||
|
||||
@ -16,6 +16,7 @@ import argparse
|
||||
import contextlib
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
@ -29,7 +30,6 @@ from transformers.generation import GenerationConfig
|
||||
from transformers.generation.continuous_batching.requests import logger
|
||||
|
||||
|
||||
# MODEL_ID = "Qwen/Qwen3-4B-Instruct-2507"
|
||||
SLIDING_WINDOW = 0
|
||||
MODEL_ID = "google/gemma-2-2b-it" if SLIDING_WINDOW > 0 else "meta-llama/Meta-Llama-3-8B"
|
||||
FORCE_MAX_LENGTH = False # should be False unless you are debugging sliding window features
|
||||
@ -193,6 +193,8 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--compile", action="store_true", help="Compile the model using torch.compile")
|
||||
|
||||
parser.add_argument("--samples", type=int, default=500, help="Number of samples to generate")
|
||||
parser.add_argument("--add-prefix", action="store_true", help="Add a prefix to the samples")
|
||||
|
||||
parser.add_argument("--displayed", type=int, default=0, help="Number of samples to display")
|
||||
parser.add_argument("--log-level", type=str, default="INFO")
|
||||
parser.add_argument("--output-file", type=str, default=None)
|
||||
@ -242,7 +244,18 @@ if __name__ == "__main__":
|
||||
dataset = datasets.load_dataset("openai/gsm8k", "socratic", split="test")
|
||||
dataset = dataset.select(range(args.samples))
|
||||
|
||||
simple_batch_inputs = [tokenizer(item["question"])["input_ids"] for item in dataset]
|
||||
def random_prefix() -> str:
|
||||
if not args.add_prefix:
|
||||
return ""
|
||||
prefixes = [
|
||||
"Math and reasonning problems are very important to the world. This is a problem, and then you will find the answer.\n",
|
||||
"We all know that reasonning can be taught by answering questions, often illustrated with examples. Here is one and its solution, hopefully you will enjoy it!\n",
|
||||
"Reasonning a very good metric of intelligence, hence it is regularly trained and tested in both children and AI model like LLMs. This test can look like a math or a logical problem, a riddle or pattern detection task. For instance, this is one of those test. You will find it and the solution associated after. Here it goes:\n",
|
||||
] # fmt: skip
|
||||
return random.choice(prefixes)
|
||||
|
||||
random.seed(0)
|
||||
simple_batch_inputs = [tokenizer(random_prefix() + item["question"])["input_ids"] for item in dataset]
|
||||
|
||||
# Prepare generation config
|
||||
generation_config = GenerationConfig(
|
||||
|
||||
1
setup.py
1
setup.py
@ -392,6 +392,7 @@ extras["torchhub"] = deps_list(
|
||||
extras["benchmark"] = deps_list("optimum-benchmark")
|
||||
|
||||
# OpenTelemetry dependencies for metrics collection in continuous batching
|
||||
# TODO: refactor this to split API and SDK; SDK and exporter should only be needed to run code that collects metrics whereas API is what people will need to instrument their code and handle exporter themselves
|
||||
extras["open-telemetry"] = deps_list("opentelemetry-api") + ["opentelemetry-exporter-otlp", "opentelemetry-sdk"]
|
||||
|
||||
# when modifying the following list, make sure to update src/transformers/dependency_versions_check.py
|
||||
|
||||
@ -876,7 +876,7 @@ class PreTrainedConfig(PushToHubMixin):
|
||||
if hasattr(self, "quantization_config"):
|
||||
serializable_config_dict["quantization_config"] = (
|
||||
self.quantization_config.to_dict()
|
||||
if not isinstance(self.quantization_config, dict) and self.quantization_config is not None
|
||||
if not isinstance(self.quantization_config, dict)
|
||||
else self.quantization_config
|
||||
)
|
||||
self.dict_dtype_to_str(serializable_config_dict)
|
||||
@ -910,7 +910,7 @@ class PreTrainedConfig(PushToHubMixin):
|
||||
if hasattr(self, "quantization_config"):
|
||||
output["quantization_config"] = (
|
||||
self.quantization_config.to_dict()
|
||||
if not isinstance(self.quantization_config, dict) and self.quantization_config is not None
|
||||
if not isinstance(self.quantization_config, dict)
|
||||
else self.quantization_config
|
||||
)
|
||||
self.dict_dtype_to_str(output)
|
||||
|
||||
@ -1,141 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright (C) 2025 the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .core_model_loading import Concatenate, MergeModulelist, WeightConverter
|
||||
from .utils import is_torch_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
def _build_checkpoint_conversion_mapping():
|
||||
mapping = {
|
||||
"mixtral": [
|
||||
WeightConverter(
|
||||
source_keys=[
|
||||
"block_sparse_moe.experts.*.w1.weight",
|
||||
"block_sparse_moe.experts.*.w3.weight",
|
||||
], # you give me a list of 2 keys, I collect a list of a list of tensors
|
||||
target_keys="mlp.experts.gate_up_proj", # target key gets the list of two tensors
|
||||
operations=[
|
||||
MergeModulelist(
|
||||
dim=0
|
||||
), # each process has two lists of tensors, we cat each list. -> we end up with 2 tensors
|
||||
Concatenate(dim=1), # each process has 2 tensors, gate and up, we concat them into gate_up
|
||||
], # we want the loading to add this shard operation here. Though we can't shard after concats and merge, needs to be first
|
||||
),
|
||||
WeightConverter(
|
||||
source_keys=[
|
||||
"block_sparse_moe.experts.*.w2.weight",
|
||||
],
|
||||
target_keys="mlp.experts.down_proj", # target key gets the list of two tensors
|
||||
operations=[
|
||||
MergeModulelist(
|
||||
dim=0
|
||||
), # each process has two lists of tensors, we cat each list. -> we end up with 2 tensors
|
||||
], # we want the loading to add this shard operation here. Though we can't shard after concats and merge, needs to be first
|
||||
),
|
||||
# WeightConverter(
|
||||
# ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"],
|
||||
# "self_attn.qkv_proj",
|
||||
# operations=[Concatenate(dim=0)], # more like stack?
|
||||
# ),
|
||||
WeightConverter("*.block_sparse_moe.", "*.mlp."),
|
||||
],
|
||||
"qwen2_moe": [
|
||||
WeightConverter(
|
||||
source_keys=[
|
||||
"mlp.experts.*.gate_proj.weight",
|
||||
"mlp.experts.*.up_proj.weight",
|
||||
],
|
||||
target_keys="mlp.experts.gate_up_proj",
|
||||
operations=[MergeModulelist(dim=0), Concatenate(dim=1)],
|
||||
),
|
||||
WeightConverter(
|
||||
source_keys=["mlp.experts.*.down_proj.weight"],
|
||||
target_keys="mlp.experts.down_proj",
|
||||
operations=[MergeModulelist(dim=0)],
|
||||
),
|
||||
],
|
||||
"legacy": [
|
||||
WeightConverter(
|
||||
source_keys="LayerNorm.gamma",
|
||||
target_keys="LayerNorm.weight",
|
||||
),
|
||||
WeightConverter(
|
||||
source_keys="LayerNorm.beta",
|
||||
target_keys="LayerNorm.bias",
|
||||
),
|
||||
],
|
||||
}
|
||||
if hasattr(torch.nn.utils.parametrizations, "weight_norm"):
|
||||
mapping["legacy"] += [
|
||||
WeightConverter(
|
||||
source_keys="weight_g",
|
||||
target_keys="parametrizations.weight.original0",
|
||||
),
|
||||
WeightConverter(
|
||||
source_keys="weight_v",
|
||||
target_keys="parametrizations.weight.original1",
|
||||
),
|
||||
]
|
||||
else:
|
||||
mapping["legacy"] += [
|
||||
WeightConverter(
|
||||
source_keys="parametrizations.weight.original0",
|
||||
target_keys="weight_g",
|
||||
),
|
||||
WeightConverter(
|
||||
source_keys="parametrizations.weight.original1",
|
||||
target_keys="weight_v",
|
||||
),
|
||||
]
|
||||
|
||||
mapping["phimoe"] = mapping["mixtral"].copy()
|
||||
mapping["deepseek_v2"] = mapping["qwen2_moe"].copy()
|
||||
mapping["deepseek_v3"] = mapping["qwen2_moe"].copy()
|
||||
mapping["dot1"] = mapping["qwen2_moe"].copy()
|
||||
mapping["ernie_4_5_moe"] = mapping["qwen2_moe"].copy()
|
||||
mapping["glm4_moe"] = mapping["qwen2_moe"].copy()
|
||||
mapping["glm4v_moe"] = mapping["qwen2_moe"].copy()
|
||||
mapping["jamba"] = mapping["qwen2_moe"].copy()
|
||||
mapping["lfm2_moe"] = mapping["mixtral"].copy()
|
||||
mapping["long_cat_flash"] = mapping["qwen2_moe"].copy()
|
||||
mapping["qwen3_moe"] = mapping["qwen2_moe"].copy()
|
||||
mapping["qwen3_omni_moe"] = mapping["qwen2_moe"].copy()
|
||||
mapping["qwen3_next"] = mapping["qwen2_moe"].copy()
|
||||
mapping["qwen3_vl_moe"] = mapping["qwen2_moe"].copy()
|
||||
mapping["hunyuan_v1_moe"] = mapping["qwen2_moe"].copy()
|
||||
mapping["minimax"] = mapping["mixtral"].copy()
|
||||
|
||||
return mapping
|
||||
|
||||
|
||||
_checkpoint_conversion_mapping_cache = None
|
||||
|
||||
|
||||
def get_checkpoint_conversion_mapping():
|
||||
global _checkpoint_conversion_mapping_cache
|
||||
if _checkpoint_conversion_mapping_cache is None:
|
||||
_checkpoint_conversion_mapping_cache = _build_checkpoint_conversion_mapping()
|
||||
globals()["_checkpoint_conversion_mapping"] = _checkpoint_conversion_mapping_cache
|
||||
return _checkpoint_conversion_mapping_cache
|
||||
|
||||
|
||||
def __getattr__(name):
|
||||
if name == "_checkpoint_conversion_mapping":
|
||||
return get_checkpoint_conversion_mapping()
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
@ -1,661 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Core helpers for loading model checkpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import itertools
|
||||
import os
|
||||
import re
|
||||
import threading
|
||||
from abc import abstractmethod
|
||||
from collections import defaultdict
|
||||
from collections.abc import MutableMapping, MutableSet, Sequence
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from functools import partial
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.distributed.tensor import DTensor
|
||||
|
||||
from .integrations.tensor_parallel import ALL_PARALLEL_STYLES, TensorParallelLayer
|
||||
from .utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def _glob_to_regex_src(glob: str, *, digits_only: bool = True) -> str:
|
||||
"""
|
||||
Convert a glob with '*' into a regex *source* string. We don't use `glob.translate`
|
||||
'*' matches (\\d+) if digits_only else (.+). Inner groups are non-capturing.
|
||||
"""
|
||||
star = r"(\d+)" if digits_only else r"(.+)"
|
||||
return re.escape(glob).replace(r"\*", star)
|
||||
|
||||
|
||||
def build_glob_alt(
|
||||
globs: list[str],
|
||||
) -> tuple[re.Pattern, dict[str, str]]:
|
||||
r"""
|
||||
Build one compiled regex alternation with a named group per glob. This allows to run a single
|
||||
re.match and get the correct group name to finally get which pattern matched.
|
||||
Returns (compiled_regex, name->glob map).
|
||||
|
||||
Example:
|
||||
|
||||
```py
|
||||
>>> reg, map_ = build_glob_alt(["mlp.*.w1", "mlp.*.w2"])
|
||||
>>> print(reg)
|
||||
(re.compile(r'(?P<g0>.*mlp\.(\d+)\.w1)|(?P<g1>.*mlp\.(\d+)\.w2)', re.UNICODE),
|
||||
>>> print(map_)
|
||||
{'g0': 'mlp.*.w1', 'g1': 'mlp.*.w2'})
|
||||
>>> match_ = reg.match("model.layers.0.mlp.0.w1.weight")
|
||||
>>> print(match_.lastgroup)
|
||||
'g0'
|
||||
>>> print(map_[match_.lastgroup])
|
||||
mlp.*.w1
|
||||
```
|
||||
"""
|
||||
name_map: dict[str, str] = {}
|
||||
parts: list[str] = []
|
||||
prefix_src = r".*"
|
||||
|
||||
for i, g in enumerate(globs):
|
||||
name = f"g{i}"
|
||||
name_map[name] = g
|
||||
pat_src = _glob_to_regex_src(g)
|
||||
parts.append(f"(?P<{name}>{prefix_src}{pat_src})")
|
||||
|
||||
alt_src = "|".join(parts)
|
||||
return re.compile(alt_src), name_map
|
||||
|
||||
|
||||
def match_glob(key: str, alt: re.Pattern, name_map: dict[str, str]) -> Optional[str]:
|
||||
"""
|
||||
Match the key against the alternation; return the original glob string that matched.
|
||||
"""
|
||||
m = alt.match(key)
|
||||
if not m:
|
||||
return None
|
||||
return name_map.get(m.lastgroup)
|
||||
|
||||
|
||||
class ConversionOps:
|
||||
"""Base class for weight conversion operations."""
|
||||
|
||||
# Reusable staging/scratch buffer to avoid reallocations.
|
||||
_buffer: Optional[torch.Tensor] = None
|
||||
# The inverse operation class, will be used when saving the checkpoint
|
||||
reverse_op: type[ConversionOps]
|
||||
|
||||
def _ensure_buffer(
|
||||
self,
|
||||
required_shape: torch.Size,
|
||||
*,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
growth_factor: float = 1.5,
|
||||
) -> torch.Tensor:
|
||||
"""Ensure a pre-allocated buffer large enough for ``required_shape`` exists."""
|
||||
|
||||
required_elems = 1
|
||||
for dim in required_shape:
|
||||
required_elems *= int(dim)
|
||||
|
||||
need_new = (
|
||||
self._buffer is None
|
||||
or self._buffer.dtype != dtype
|
||||
or self._buffer.device != device
|
||||
or self._buffer.numel() < required_elems
|
||||
)
|
||||
|
||||
if need_new:
|
||||
capacity = max(required_elems, int(required_elems * growth_factor))
|
||||
self._buffer = torch.empty(capacity, dtype=dtype, device=device)
|
||||
|
||||
return self._buffer[:required_elems].view(required_shape)
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
"""Free any cached buffers."""
|
||||
self._buffer = None
|
||||
|
||||
@abstractmethod
|
||||
def convert(
|
||||
self, value: Union[dict[str, torch.Tensor], Sequence[torch.Tensor], torch.Tensor], *args, **kwargs
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class Chunk(ConversionOps):
|
||||
"""Split a tensor along ``dim`` into equally sized chunks or using explicit ``sizes``."""
|
||||
|
||||
reverse_op: type[ConversionOps]
|
||||
|
||||
def __init__(self, dim: int = 0, chunks: Optional[int] = None, sizes: Optional[Sequence[int]] = None):
|
||||
if chunks is None and sizes is None:
|
||||
raise ValueError("`chunks` or `sizes` must be provided for Chunk operations.")
|
||||
if chunks is not None and chunks <= 0:
|
||||
raise ValueError("`chunks` must be a strictly positive integer.")
|
||||
self.dim = dim
|
||||
self.chunks = chunks
|
||||
self.sizes = list(sizes) if sizes is not None else None
|
||||
self.reverse_op = Concatenate
|
||||
|
||||
def convert(self, value: torch.Tensor, *args, **kwargs) -> list[torch.Tensor]:
|
||||
if not isinstance(value, torch.Tensor):
|
||||
raise TypeError("Chunk expects a torch.Tensor as input.")
|
||||
if self.sizes is not None:
|
||||
return list(torch.split(value, self.sizes, dim=self.dim))
|
||||
return list(torch.chunk(value, self.chunks, dim=self.dim))
|
||||
|
||||
|
||||
class Concatenate(ConversionOps):
|
||||
"""Concatenate tensors along `dim` using a reusable buffer."""
|
||||
|
||||
reverse_op: type[ConversionOps]
|
||||
|
||||
def __init__(self, dim: int = 0):
|
||||
self.dim = dim
|
||||
self.reverse_op = Chunk
|
||||
|
||||
@torch.no_grad
|
||||
def convert(self, value: Sequence[torch.Tensor], *args, **kwargs) -> torch.Tensor:
|
||||
if isinstance(value[0], list):
|
||||
value = [v[0] for v in value]
|
||||
tensors = value
|
||||
if not tensors:
|
||||
raise ValueError("Fuse requires at least one tensor to concatenate.")
|
||||
|
||||
out_shape = list(tensors[0].shape)
|
||||
out_shape[self.dim] = sum([t.size(self.dim) for t in tensors])
|
||||
|
||||
with torch.no_grad(): # we use staging buffers
|
||||
out = self._ensure_buffer(torch.Size(out_shape), dtype=tensors[0].dtype, device=tensors[0].device)
|
||||
torch.cat(tuple(tensors), dim=self.dim, out=out)
|
||||
# offset = 0
|
||||
# for tensor in tensors:
|
||||
# index = [slice(None)] * tensor.ndim
|
||||
# index[self.dim] = slice(offset, offset + tensor.shape[self.dim])
|
||||
# out[tuple(index)].copy_(tensor, non_blocking=tensor.is_cuda)
|
||||
# offset += tensor.shape[self.dim]
|
||||
return out.clone() # need to say I can overwrite this storage now
|
||||
|
||||
|
||||
class MergeModulelist(Concatenate):
|
||||
"""
|
||||
Merge a list of tensors into a single tensor along the first dimension.
|
||||
We explicitly define this because for EP or TP you want to make sure you know what you are doing!
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, dim: int = 0):
|
||||
super().__init__(dim=dim)
|
||||
self.reverse_op = SplitModulelist
|
||||
|
||||
def convert(self, value: Sequence[torch.Tensor], *args, **kwargs) -> list[torch.Tensor]:
|
||||
merged = []
|
||||
with torch.no_grad(): # we use staging buffers
|
||||
for group in value:
|
||||
if not isinstance(group, Sequence) or len(group) == 0:
|
||||
raise ValueError("MergeModulelist requires non-empty sub-sequences.")
|
||||
group = [k for k in group if k.ndim]
|
||||
out_shape = list(group[0].shape)
|
||||
out_shape.insert(self.dim, len(group))
|
||||
out = self._ensure_buffer(torch.Size(out_shape), dtype=group[0].dtype, device=group[0].device)
|
||||
torch.stack(tuple(group), dim=self.dim, out=out)
|
||||
# for off, tensor in enumerate(group):
|
||||
# out[off].copy_(tensor, non_blocking=tensor.is_cuda)
|
||||
# torch.as_tensor(numpy.stack(batch))
|
||||
merged.append(out.clone()) # TODO have a single staging tensor here as well!
|
||||
return merged
|
||||
|
||||
|
||||
class SplitModulelist(ConversionOps):
|
||||
"""Inverse of :class:`MergeModulelist` using explicit split sizes per group."""
|
||||
|
||||
def __init__(self, sizes: Sequence[Sequence[int]], dim: int = 0):
|
||||
if not isinstance(sizes, Sequence) or not all(isinstance(sub, Sequence) and sub for sub in sizes):
|
||||
raise ValueError("`sizes` must be a sequence of non-empty sequences of integers.")
|
||||
self.sizes = [list(sub) for sub in sizes]
|
||||
self.dim = dim
|
||||
self.reverse_op = MergeModulelist
|
||||
|
||||
def convert(self, value: Sequence[torch.Tensor], *, context: dict[str, Any]) -> list[list[torch.Tensor]]:
|
||||
if not isinstance(value, Sequence):
|
||||
raise TypeError("SplitModulelist expects a sequence of tensors.")
|
||||
if len(value) != len(self.sizes):
|
||||
raise ValueError("Number of tensors does not match the provided split specifications.")
|
||||
|
||||
result: list[list[torch.Tensor]] = []
|
||||
for tensor, split_sizes in zip(value, self.sizes):
|
||||
if not isinstance(tensor, torch.Tensor):
|
||||
raise TypeError("SplitModulelist can only split torch.Tensor instances.")
|
||||
splits = torch.split(tensor, split_sizes, dim=self.dim)
|
||||
result.append(list(splits))
|
||||
return result
|
||||
|
||||
|
||||
class Cast(ConversionOps):
|
||||
"""
|
||||
Casts the tensor to a given dtype
|
||||
"""
|
||||
|
||||
def __init__(self, dtype):
|
||||
self.dtype = dtype
|
||||
|
||||
def convert(self, value, *args, **kwargs):
|
||||
out = [
|
||||
[x.to(self.dtype) for x in inner] if isinstance(inner, list) else inner.to(self.dtype) for inner in value
|
||||
]
|
||||
return out
|
||||
|
||||
|
||||
class PermuteForRope(ConversionOps):
|
||||
"""
|
||||
Applies the permutation required to convert complex RoPE weights to the split sin/cos format.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def _apply(self, tensor: torch.Tensor) -> torch.Tensor:
|
||||
dim1, dim2 = tensor.shape
|
||||
n_heads = self.config.getattr("num_attention_heads", 1)
|
||||
|
||||
tensor = tensor.view(n_heads, dim1 // n_heads // 2, 2, dim2)
|
||||
tensor = tensor.transpose(1, 2).reshape(dim1, dim2)
|
||||
return tensor
|
||||
|
||||
def convert(
|
||||
self, value: Union[dict[str, torch.Tensor], Sequence[torch.Tensor], torch.Tensor], config
|
||||
) -> Union[dict[str, torch.Tensor], list[torch.Tensor], torch.Tensor]:
|
||||
self.config = config
|
||||
out = [[self._apply(x) for x in inner] if isinstance(inner, list) else self._apply(inner) for inner in value]
|
||||
return out
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class WeightConverter:
|
||||
r"""
|
||||
A weight convert that acts on a pattern of source keys.
|
||||
The keys need to be collected based on the target keys.
|
||||
|
||||
With wild card, glob patterns are matched, so you have to be detailed with what to match. If you match:
|
||||
`model.layers.*.experts.*` -> it will act on all of them
|
||||
{"model.layers.*.experts.*": []}
|
||||
but
|
||||
`experts.*.mlp` will be layer specific.
|
||||
{"model.layers.1.experts.*": [], }
|
||||
- source_keys: str | list[str] (wildcards '*' match digits)
|
||||
- target_keys: str | list[str] | None
|
||||
- distributed_operation / operations / quantization_operations are ALWAYS lists.
|
||||
"""
|
||||
|
||||
source_keys: Union[str, list[str]]
|
||||
target_keys: Optional[Union[str, list[str]]] = None
|
||||
operations: list[ConversionOps] = field(default_factory=list, repr=False)
|
||||
|
||||
distributed_operation: Optional[TensorParallelLayer] = None
|
||||
quantization_operation: Optional[ConversionOps] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if not isinstance(self.source_keys, list):
|
||||
self.source_keys = [self.source_keys]
|
||||
targets_were_none = False
|
||||
if not isinstance(self.target_keys, list):
|
||||
if self.target_keys is None:
|
||||
self.target_keys = list(self.source_keys)
|
||||
targets_were_none = True
|
||||
else:
|
||||
self.target_keys = [self.target_keys]
|
||||
|
||||
if not targets_were_none and bool(len(self.source_keys) - 1) + bool(len(self.target_keys) - 1) >= 2:
|
||||
raise ValueError(
|
||||
f"source keys={self.source_keys}, target_keys={self.target_keys} but you can only have one to many, one to one or many to one."
|
||||
)
|
||||
|
||||
for pattern in self.source_keys:
|
||||
if any(ch in pattern for ch in set("^$+?{}[]|()")):
|
||||
raise AssertionError(f"'{pattern}' is not glob")
|
||||
for pattern in self.target_keys:
|
||||
if any(ch in pattern for ch in set("^$+?{}[]|()")):
|
||||
raise AssertionError(f"'{pattern}' is not glob")
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ConversionEntry:
|
||||
weight_converter: WeightConverter
|
||||
collected_tensors: dict = field(default_factory=lambda: defaultdict(dict))
|
||||
|
||||
|
||||
GLOBAL_WORKERS = min(16, (os.cpu_count() or 8) * 2) # NVMe: 8-16; HDD/NFS: 2-4
|
||||
PER_FILE_LIMIT = 4 # concurrent reads per file
|
||||
|
||||
|
||||
def _materialize_copy(x):
|
||||
# PyTorch: this runs in C and releases the GIL; good for threads.
|
||||
return x[...]
|
||||
|
||||
|
||||
def spawn_materialize(thread_pool, _file_semaphore, file_id, t) -> Future:
|
||||
sem = _file_semaphore[file_id]
|
||||
|
||||
def _job():
|
||||
with sem:
|
||||
return _materialize_copy(t)
|
||||
|
||||
return thread_pool.submit(_job)
|
||||
|
||||
|
||||
def spawn_tp_materialize(thread_pool, _file_semaphore, file_id, t, sharding_method, tensor_idx) -> Future:
|
||||
sem = _file_semaphore[file_id]
|
||||
|
||||
def _job():
|
||||
with sem:
|
||||
return sharding_method.shard_tensor(t, tensor_idx=tensor_idx)[0]
|
||||
|
||||
return thread_pool.submit(_job)
|
||||
|
||||
|
||||
def dot_natural_key(s: str):
|
||||
parts = s.split(".")
|
||||
for i, p in enumerate(parts):
|
||||
# whole-segment digits -> int; otherwise leave as str
|
||||
if p.isdigit():
|
||||
parts[i] = int(p)
|
||||
return parts
|
||||
|
||||
|
||||
@contextmanager
|
||||
def log_to_misc(
|
||||
layer_name: str,
|
||||
misc: MutableMapping[str, str],
|
||||
extras: Any = None,
|
||||
op: Union[list[ConversionOps], ConversionOps, None] = None,
|
||||
):
|
||||
# A simple helper to handle errors with contextual messages.
|
||||
try:
|
||||
yield
|
||||
except Exception as e:
|
||||
|
||||
def _format_op_name(curr_op: Union[list[ConversionOps], ConversionOps, None]) -> Optional[str]:
|
||||
if curr_op is None:
|
||||
return None
|
||||
if isinstance(curr_op, (list, tuple, set)):
|
||||
names = [o.__class__.__name__ for o in curr_op if o is not None]
|
||||
if not names:
|
||||
return None
|
||||
return ", ".join(names)
|
||||
return curr_op.__class__.__name__
|
||||
|
||||
op_name = _format_op_name(op)
|
||||
if isinstance(extras, tuple) and len(extras) == 2:
|
||||
values, target_keys = extras
|
||||
descriptor = f"{op_name} " if op_name else ""
|
||||
misc[layer_name] = (
|
||||
f"{e}\nError: {descriptor}on tensors destined for {target_keys}. Ckpt contains: {len(values[0])}"
|
||||
)
|
||||
elif isinstance(extras, str):
|
||||
suffix = f" via {op_name}" if op_name else ""
|
||||
misc[layer_name] = f"{e}\nError{suffix} when processing parameter {extras}"
|
||||
elif extras is None and op_name:
|
||||
misc[layer_name] = f"{op_name}: {e}"
|
||||
else:
|
||||
misc[layer_name] = f"{extras} |Error: {e}"
|
||||
raise SkipLayer()
|
||||
|
||||
|
||||
def set_param_for_module(
|
||||
model: torch.nn.Module,
|
||||
layer_name: str,
|
||||
param_value: torch.Tensor,
|
||||
meta_model_state_dict: MutableMapping[str, Any],
|
||||
empty_param: torch.Tensor,
|
||||
mismatch_keys: MutableSet[tuple[str, torch.Size, torch.Size]],
|
||||
missing_keys: MutableSet[str],
|
||||
misc: MutableMapping[str, Any],
|
||||
distributed_operation: Optional[TensorParallelLayer],
|
||||
):
|
||||
with log_to_misc(layer_name, misc, layer_name):
|
||||
module_path, _, param_name = layer_name.rpartition(".")
|
||||
module_obj = model.get_submodule(module_path) if module_path else model
|
||||
param_value = param_value[0] if isinstance(param_value, list) else param_value[...]
|
||||
ref = meta_model_state_dict.get(layer_name, empty_param)
|
||||
use_dtensor = hasattr(distributed_operation, "use_dtensor") and distributed_operation.use_dtensor
|
||||
if not isinstance(param_value, torch.nn.Parameter):
|
||||
if distributed_operation is not None and use_dtensor:
|
||||
param_value = DTensor.from_local(
|
||||
param_value,
|
||||
distributed_operation.device_mesh,
|
||||
distributed_operation.shard,
|
||||
run_check=False,
|
||||
shape=ref.size(),
|
||||
stride=ref.stride(),
|
||||
)
|
||||
else:
|
||||
pass # TODO for "local" stuff, it will trigger missmatched no?
|
||||
param_value = torch.nn.Parameter(param_value, requires_grad=param_value.is_floating_point())
|
||||
|
||||
if ref is not None and ref.shape != param_value.shape:
|
||||
mismatch_keys.add((layer_name, param_value.shape, ref.shape))
|
||||
missing_keys.discard(layer_name)
|
||||
setattr(module_obj, param_name, param_value)
|
||||
|
||||
|
||||
class SkipLayer(Exception):
|
||||
"""Control-flow sentinel: abort processing of the current layer only."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def convert_and_load_state_dict_in_model(
|
||||
model,
|
||||
state_dict,
|
||||
weight_mapping,
|
||||
tp_plan,
|
||||
quantizer,
|
||||
dtype=None,
|
||||
device_map=None,
|
||||
dtype_plan=None,
|
||||
device_mesh=None,
|
||||
profile: bool = False,
|
||||
):
|
||||
"""
|
||||
Convert a state dict according to a weight mapping (one WeightConverter per glob pattern),
|
||||
collecting tensors per *layer instance* (the concrete indices captured from '*').
|
||||
"""
|
||||
from .modeling_utils import str_to_torch_dtype
|
||||
|
||||
prefix = model.base_model_prefix
|
||||
tp_plan = tp_plan or {} # {glob_pattern: plan_obj_or_key}
|
||||
device_map = device_map or {} # {exact_target_key: device}
|
||||
dtype_plan = dtype_plan or {} # {glob_pattern: dtype}
|
||||
weight_mapping = weight_mapping or {} # {glob_pattern: WeightConverter}
|
||||
meta_model_state_dict = model.state_dict()
|
||||
missing_keys = set(meta_model_state_dict.keys())
|
||||
|
||||
misc = {}
|
||||
mismatch_keys = set()
|
||||
unexpected_keys = set()
|
||||
# Global thread_poolutor + per-file semaphores: allow lock only upon 4 file access? Should be tensor get_shape dependant?
|
||||
thread_pool = ThreadPoolExecutor(max_workers=GLOBAL_WORKERS)
|
||||
_file_semaphore = defaultdict(lambda: threading.Semaphore(PER_FILE_LIMIT))
|
||||
|
||||
_patterns = list(itertools.chain.from_iterable([k.source_keys for k in weight_mapping]))
|
||||
source_to_target = {sk: k for k in weight_mapping for sk in k.source_keys}
|
||||
weight_pattern_alt, weight_pattern_by_group_name = build_glob_alt(_patterns)
|
||||
tp_plan_alt, tp_plan_by_group_name = build_glob_alt(list(tp_plan.keys()))
|
||||
dtype_policy_alt, dtype_policy_by_group_name = build_glob_alt(list(dtype_plan.keys()))
|
||||
|
||||
state_dict = sorted(state_dict.items(), key=lambda kv: dot_natural_key(kv[0]))
|
||||
# 1. Create the conversion entries
|
||||
by_conversion_pattern: dict[str, ConversionEntry] = {}
|
||||
for original_key, (file_id, tensor) in state_dict:
|
||||
matched_pattern = match_glob(original_key, weight_pattern_alt, weight_pattern_by_group_name)
|
||||
if matched_pattern is not None:
|
||||
converter = source_to_target[matched_pattern] # TODO make sure its the ref
|
||||
sub_with_extractor = partial(re.sub, _glob_to_regex_src(matched_pattern), string=original_key)
|
||||
entry_key = "|".join(converter.target_keys)
|
||||
target_key = "|".join(map(sub_with_extractor, [k.replace("*", "\\1") for k in converter.target_keys]))
|
||||
entry: ConversionEntry = by_conversion_pattern.setdefault(entry_key, ConversionEntry(converter))
|
||||
converter_key = sub_with_extractor(matched_pattern)
|
||||
else:
|
||||
converter = WeightConverter(original_key)
|
||||
converter_key = entry_key = target_key = original_key
|
||||
entry = by_conversion_pattern.setdefault(converter_key, ConversionEntry(converter))
|
||||
|
||||
new_target_key = []
|
||||
for t in target_key.split("|"): # let's correct the keys
|
||||
if t.startswith(prefix) and meta_model_state_dict.get(t.replace(f"{prefix}.", "")) is not None:
|
||||
t = t.replace(f"{prefix}.", "")
|
||||
elif meta_model_state_dict.get(f"{prefix}.{t}") is not None:
|
||||
t = f"{prefix}.{t}"
|
||||
new_target_key.append(t)
|
||||
target_key = "|".join(new_target_key)
|
||||
|
||||
for t in target_key.split("|"):
|
||||
empty_param = meta_model_state_dict.get(t)
|
||||
if empty_param is None:
|
||||
unexpected_keys.add(t)
|
||||
continue
|
||||
|
||||
if quantizer is not None and quantizer.param_needs_quantization(model, t):
|
||||
if quantizer.__class__.__name__ == "FineGrainedFP8HfQuantizer":
|
||||
from .integrations.finegrained_fp8 import Fp8Quantize
|
||||
|
||||
converter.quantization_operation = Fp8Quantize() # TODO support other methods
|
||||
else:
|
||||
raise ValueError("This quantization method is gonna be supported SOOOON")
|
||||
else:
|
||||
matched_dtype_pattern = match_glob(t, dtype_policy_alt, dtype_policy_by_group_name)
|
||||
if matched_dtype_pattern is not None:
|
||||
_dtype = dtype_plan[matched_dtype_pattern]
|
||||
else:
|
||||
_dtype = dtype
|
||||
tensor_dtype = (
|
||||
tensor.dtype if isinstance(tensor, torch.Tensor) else str_to_torch_dtype[tensor.get_dtype()]
|
||||
)
|
||||
if _dtype != tensor_dtype and _dtype is not None:
|
||||
converter.operations.append(Cast(_dtype)) # can this be slow as well?
|
||||
|
||||
first_target_key = target_key.split("|")[0]
|
||||
future = None
|
||||
if device_mesh:
|
||||
if matched_tp_pattern := match_glob(first_target_key, tp_plan_alt, tp_plan_by_group_name):
|
||||
empty_param = meta_model_state_dict.get(first_target_key)
|
||||
if getattr(converter, "distributed_operation", {}) is None:
|
||||
tp_layer = ALL_PARALLEL_STYLES[model.tp_plan[matched_tp_pattern]].__class__
|
||||
converter.distributed_operation = tp_layer(
|
||||
device_mesh=device_mesh, rank=device_map[""].index, empty_param=empty_param.clone()
|
||||
)
|
||||
# VERY IMPORTANT: this tells us wether we collected stuffs or not.
|
||||
shard_index = len(entry.collected_tensors[target_key].get(converter_key, []))
|
||||
future = spawn_tp_materialize(
|
||||
thread_pool,
|
||||
_file_semaphore,
|
||||
file_id,
|
||||
tensor,
|
||||
converter.distributed_operation,
|
||||
shard_index,
|
||||
)
|
||||
|
||||
if future is None: # If not TP, async materialize the tensors. TODO handle disk offload?
|
||||
future = spawn_materialize(thread_pool, _file_semaphore, file_id, tensor)
|
||||
entry.collected_tensors[target_key].setdefault(converter_key, []).append(future)
|
||||
|
||||
# 2. Actually convert the ckpt
|
||||
inverse_converters = {}
|
||||
keys = list(by_conversion_pattern.keys())
|
||||
total_layers = sum(len(by_conversion_pattern[key].collected_tensors) for key in keys)
|
||||
progress_bar = logging.tqdm(total=total_layers, desc="Converting weights", leave=False) if total_layers else None
|
||||
|
||||
for key in keys[::-1]: # revert to process simple keys first
|
||||
group = by_conversion_pattern.pop(key)
|
||||
converter = group.weight_converter
|
||||
operations = converter.operations if isinstance(converter.operations, list) else [converter.operations]
|
||||
for layer_name, tensors_for_this_layer in group.collected_tensors.items():
|
||||
concrete_target_keys = layer_name.split("|")
|
||||
try:
|
||||
if bool(set(concrete_target_keys) - unexpected_keys):
|
||||
with log_to_misc(layer_name, misc):
|
||||
values = [[k.result() for k in inner] for inner in tensors_for_this_layer.values()]
|
||||
|
||||
for op in operations:
|
||||
with log_to_misc(layer_name, misc, (values, concrete_target_keys), operations):
|
||||
values = op.convert(values, model.config)
|
||||
|
||||
values = [values] if not isinstance(values, list) else values
|
||||
with log_to_misc(layer_name, misc, (values, concrete_target_keys), operations):
|
||||
realized_value = {
|
||||
k: t for k, t in zip(concrete_target_keys, values) if k not in unexpected_keys
|
||||
}
|
||||
|
||||
for k in list(realized_value.keys()).copy():
|
||||
if op := converter.quantization_operation:
|
||||
with log_to_misc(layer_name, misc, op=op):
|
||||
realized_value.update(
|
||||
op.convert({k: realized_value.pop(k)}, quant_config=quantizer.quantization_config)
|
||||
)
|
||||
|
||||
if progress_bar is not None:
|
||||
progress_bar.set_postfix_str(layer_name, refresh=False)
|
||||
progress_bar.update()
|
||||
|
||||
for k, output_value in realized_value.items():
|
||||
for src in converter.source_keys: # what should happen to k when we meet k at saving
|
||||
inverse_converters[k] = {src: converter}
|
||||
set_param_for_module(
|
||||
model,
|
||||
k,
|
||||
output_value,
|
||||
meta_model_state_dict,
|
||||
empty_param,
|
||||
mismatch_keys,
|
||||
missing_keys,
|
||||
misc,
|
||||
converter.distributed_operation,
|
||||
)
|
||||
except SkipLayer:
|
||||
continue
|
||||
del group
|
||||
for op in operations:
|
||||
op.clear_cache()
|
||||
if progress_bar is not None:
|
||||
progress_bar.close()
|
||||
model.inverse_converters = inverse_converters
|
||||
thread_pool.shutdown(wait=True)
|
||||
return missing_keys, unexpected_keys, mismatch_keys, misc
|
||||
|
||||
|
||||
# TODO this is not done yet!
|
||||
def revert_weight_conversion(model, state_dict):
|
||||
mapping = getattr(model, "", {}) # IDK why but setting this will fail all llava.
|
||||
reverse_key_mapping = [(v, k) for k, v in mapping.items()]
|
||||
original_state_dict = {}
|
||||
for key, value in state_dict.items():
|
||||
for pattern, inverse_converter in reverse_key_mapping:
|
||||
# TODO FIXME you name it
|
||||
replacement = inverse_converter.lstrip("^") # strip off un-needed chars and patterns
|
||||
replacement = re.sub(r"\(.*\)", "", replacement)
|
||||
key, n_replace = re.subn(pattern, replacement, key)
|
||||
# Early exit of the loop
|
||||
if n_replace > 0:
|
||||
break
|
||||
original_state_dict[key] = value
|
||||
state_dict = original_state_dict
|
||||
return state_dict
|
||||
@ -12,7 +12,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from collections import deque
|
||||
from math import floor, gcd, sqrt
|
||||
from typing import Optional
|
||||
|
||||
@ -21,8 +20,8 @@ import torch
|
||||
from ...configuration_utils import PreTrainedConfig
|
||||
from ...generation.configuration_utils import GenerationConfig
|
||||
from ...utils.metrics import attach_tracer, traced
|
||||
from .cache_manager import CacheAllocator, FullAttentionCacheAllocator, SlidingAttentionCacheAllocator
|
||||
from .requests import get_device_and_memory_breakdown, logger
|
||||
from .cache_manager import BlockManager, CacheAllocator, FullAttentionCacheAllocator, SlidingAttentionCacheAllocator
|
||||
from .requests import RequestState, get_device_and_memory_breakdown, logger
|
||||
|
||||
|
||||
def group_layers_by_attn_type(config: PreTrainedConfig) -> tuple[list[list[int]], list[str]]:
|
||||
@ -32,7 +31,7 @@ def group_layers_by_attn_type(config: PreTrainedConfig) -> tuple[list[list[int]]
|
||||
- All groups have the same number of layers
|
||||
|
||||
For a model with the following layer types: ["sliding", "full", "full", "sliding", "full", "full", "full", "full"]
|
||||
We would get two groups: [0, 3] and [1, 2], [4,5], [6,7].
|
||||
We would get four groups: [0, 3], [1, 2], [4,5] and [6,7].
|
||||
"""
|
||||
# If the config has no layer_type attribute, it means all layers are the same attention type
|
||||
layer_types = getattr(config, "layer_types", None)
|
||||
@ -173,10 +172,12 @@ class PagedAttentionCache:
|
||||
page_size = self.head_dim * self.num_key_value_heads
|
||||
|
||||
if "flash" in self.config._attn_implementation:
|
||||
num_attention_masks = 1 # only used to compute the default meme args
|
||||
else:
|
||||
num_attention_masks = 0 # only used to compute the default memory footprint args
|
||||
elif "sliding_attention" in group_types:
|
||||
# TODO: when we generalize to allow for block-attn, we can use `num_attention_masks=sum(set(group_types))`
|
||||
num_attention_masks = 2 if "sliding_attention" in group_types else 1
|
||||
num_attention_masks = 2
|
||||
else:
|
||||
num_attention_masks = 1
|
||||
|
||||
memory_handler = PagedAttentionMemoryHandler(
|
||||
block_size=self.block_size,
|
||||
@ -216,7 +217,6 @@ class PagedAttentionCache:
|
||||
logger.info(f"{self.cache_shape = } {self.key_cache[0].shape = } {self.key_cache[0].numel() = }")
|
||||
|
||||
# Block management data structures
|
||||
self._free_blocks = deque(range(num_blocks))
|
||||
self.group_cache_managers: list[CacheAllocator] = []
|
||||
for i, group_type in enumerate(group_types):
|
||||
if group_type == "full_attention":
|
||||
@ -227,13 +227,18 @@ class PagedAttentionCache:
|
||||
raise ValueError(f"Invalid group type: {group_type}")
|
||||
self.group_cache_managers.append(cm)
|
||||
|
||||
# We only use prefix sharing if the whole model has only full attention layers
|
||||
self.use_prefix_sharing = (group_types == ["full_attention"])
|
||||
self._block_manager = BlockManager(num_blocks, self.block_size, self.use_prefix_sharing)
|
||||
self.blocks_to_complete: dict[str, int] = {}
|
||||
|
||||
@traced
|
||||
def allocate_blocks(self, n_blocks: int, request_id: str) -> int:
|
||||
def allocate_blocks(self, n_blocks: int, state: RequestState) -> int:
|
||||
"""Allocate cache blocks across all layer groups for a given request. Actual allocation is done by the cache
|
||||
managers, and this method only returns the maximum number of blocks actually allocated across all managers."""
|
||||
max_allocated = 0
|
||||
for cm in self.group_cache_managers:
|
||||
allocated = cm.allocate_blocks(n_blocks, request_id, self._free_blocks)
|
||||
allocated = cm.allocate_blocks(n_blocks, state.request_id, self._block_manager)
|
||||
if allocated is None:
|
||||
return None
|
||||
max_allocated = max(max_allocated, allocated)
|
||||
@ -244,11 +249,11 @@ class PagedAttentionCache:
|
||||
"""Free all allocated cache blocks for a given request across all layer groups. Actual deallocation is done
|
||||
by the cache managers."""
|
||||
for cm in self.group_cache_managers:
|
||||
cm.free_blocks(request_id, self._free_blocks)
|
||||
cm.free_blocks(request_id, self._block_manager)
|
||||
|
||||
def get_num_free_blocks(self) -> int:
|
||||
"""Get the current number of unallocated blocks available for new requests."""
|
||||
return len(self._free_blocks)
|
||||
return self._block_manager.num_free_blocks
|
||||
|
||||
@traced
|
||||
def extend_read_indices(
|
||||
@ -335,6 +340,38 @@ class PagedAttentionCache:
|
||||
# Return the new KV values
|
||||
return key_states_with_cache, value_states_with_cache
|
||||
|
||||
def search_prefix_match(self, request_id: str, prompt_ids: list[int]) -> int:
|
||||
current_hash = None
|
||||
allocated_blocks = []
|
||||
for b in range(len(prompt_ids) // self.block_size):
|
||||
tokens = prompt_ids[b * self.block_size : (b + 1) * self.block_size]
|
||||
current_hash = self._block_manager.compute_hash(current_hash, tokens)
|
||||
block_id = self._block_manager._hash_to_id.get(current_hash)
|
||||
if block_id is not None:
|
||||
allocated_blocks.append(block_id)
|
||||
self._block_manager.increase_ref_count(block_id)
|
||||
else:
|
||||
break
|
||||
# If we found a matching prefix, we reference the blocks in the request
|
||||
if allocated_blocks:
|
||||
logger.debug(f"Found prefix match for request {request_id} with {len(allocated_blocks)} blocks")
|
||||
cm = self.group_cache_managers[0]
|
||||
cm.block_table[request_id] = allocated_blocks
|
||||
return len(allocated_blocks) * self.block_size
|
||||
|
||||
def mark_blocks_as_completed(self, state: RequestState) -> None:
|
||||
"""Marks the blocks that have been computed in the forward pass as such. If prefix sharing is off, this is a
|
||||
no-op."""
|
||||
num_completed_blocks = 0 if not self.use_prefix_sharing else self.blocks_to_complete.pop(state.request_id)
|
||||
if num_completed_blocks == 0:
|
||||
return None
|
||||
cm = self.group_cache_managers[0] # if prefix sharing is on, there is only one group
|
||||
self._block_manager.mark_blocks_as_computed(
|
||||
num_completed_blocks=num_completed_blocks,
|
||||
allocated_blocks=cm.block_table[state.request_id],
|
||||
prompt_ids=(state.full_prompt_ids + state.static_outputs),
|
||||
)
|
||||
|
||||
|
||||
# TODO: rework computation with the groups and their sizes
|
||||
class PagedAttentionMemoryHandler:
|
||||
@ -469,6 +506,8 @@ class PagedAttentionMemoryHandler:
|
||||
2N * (layer_group_size * page_size * cache_dtype + 2 * num_group),
|
||||
m * N * (peak_activation_per_token * activation_dtype + 28 + 4 * num_group),
|
||||
])
|
||||
|
||||
If num_attention_masks is 0, the equation simplifies to a 1st degree polynomial.
|
||||
"""
|
||||
cache_memory = self.get_available_memory(max_memory_percent)
|
||||
logger.info(f"Cache memory: {cache_memory}")
|
||||
@ -480,11 +519,16 @@ class PagedAttentionMemoryHandler:
|
||||
c = -cache_memory
|
||||
logger.debug(f"Coefficients of 2nd degree polynomial: {a = }, {b = }, {c = }")
|
||||
|
||||
# Compute discriminant and greatest solution
|
||||
discriminant = b**2 - 4 * a * c
|
||||
if discriminant < 0:
|
||||
raise ValueError(f"Discriminant is negative: {discriminant = }")
|
||||
greatest_solution = (-b + sqrt(discriminant)) / (2 * a)
|
||||
# If num_attention_masks is 0, the equation simplifies to a 1st degree polynomial
|
||||
if self.num_attention_masks == 0:
|
||||
greatest_solution = -c / b
|
||||
# Otherwise, we solve the quadratic equation
|
||||
else:
|
||||
discriminant = b**2 - 4 * a * c
|
||||
if discriminant < 0:
|
||||
raise ValueError(f"Discriminant is negative: {discriminant = }")
|
||||
greatest_solution = (-b + sqrt(discriminant)) / (2 * a)
|
||||
|
||||
if greatest_solution < 0:
|
||||
raise ValueError(f"Greatest solution is negative: {greatest_solution = }")
|
||||
|
||||
|
||||
@ -14,29 +14,191 @@
|
||||
# limitations under the License.
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import deque
|
||||
from collections.abc import Iterator
|
||||
from math import ceil
|
||||
from typing import Optional
|
||||
from typing import Optional, TypeVar
|
||||
|
||||
from .requests import logger
|
||||
|
||||
|
||||
T = TypeVar('T')
|
||||
def reverse_enumerate(xs: list[T]) -> Iterator[tuple[int, T]]:
|
||||
index = len(xs) - 1
|
||||
for x in xs[::-1]:
|
||||
yield index, x
|
||||
index -= 1
|
||||
|
||||
|
||||
class Block:
|
||||
"""A class to represent a block in the hash table of the block manager. We say that a block is completed when the KV
|
||||
cache it points to is fully computed, otherwise it is partial. A block can have a parent, which is the block that
|
||||
came before in the sequence. Once a block is computed, it is given a hash, which takes into account the tokens ids
|
||||
of the block and its parent's hash."""
|
||||
|
||||
def __init__(self, id_: int, parent_id: int | None) -> None:
|
||||
self.id: int = id_
|
||||
self.parent_id: int | None = parent_id
|
||||
self.hash: int | None = None
|
||||
self.ref_count: int = 1
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"Block(id={self.id}, parent_id={self.parent_id}, hash={self.hash}, ref_count={self.ref_count})"
|
||||
|
||||
@property
|
||||
def is_complete(self) -> bool:
|
||||
return self.hash is not None
|
||||
|
||||
|
||||
class BlockManager:
|
||||
"""A class to manage the number of free blocks and block re-use."""
|
||||
|
||||
def __init__(self, num_blocks: int, block_size: int, use_prefix_sharing: bool) -> None:
|
||||
"""Initializes the block manager with a given number of blocks (num_blocks)"""
|
||||
self.num_blocks = num_blocks
|
||||
self.block_size = block_size
|
||||
self._uninit_block_ids = deque(range(num_blocks))
|
||||
self._init_block_ids: dict[int, None] = {} # effectively act as an ordered set
|
||||
self._use_prefix_sharing = use_prefix_sharing
|
||||
# TODO: handle de-allocation for those strutures
|
||||
self._hash_to_id: dict[int, int] = {}
|
||||
self._id_to_block: dict[int, Block] = {}
|
||||
# NOTE: one of those may be redundant
|
||||
# TODO: handle case where the last block of a finshed request is not complete
|
||||
|
||||
@property
|
||||
def num_free_blocks(self) -> int:
|
||||
"""Returns the number of free blocks left."""
|
||||
return len(self._uninit_block_ids) + len(self._init_block_ids)
|
||||
|
||||
def is_enough_free_blocks(self, n_blocks: int) -> bool:
|
||||
# Exit early if there are enough uninitialized blocks
|
||||
if len(self._uninit_block_ids) >= n_blocks:
|
||||
return True
|
||||
# Exit early if even after uninitializing all initialized blocks, there are not enough free blocks
|
||||
block_to_unintialize = n_blocks - len(self._uninit_block_ids)
|
||||
if len(self._init_block_ids) < block_to_unintialize:
|
||||
return False
|
||||
# Uninitialize the required amount of blocks
|
||||
for _ in range(block_to_unintialize):
|
||||
id_to_unintialize = self._init_block_ids.popitem()[0]
|
||||
block = self._id_to_block[id_to_unintialize]
|
||||
self._hash_to_id.pop(block.hash)
|
||||
self._uninit_block_ids.append(id_to_unintialize)
|
||||
return True
|
||||
|
||||
def get_free_blocks(self, n_blocks: int, last_block_id: int | None) -> list[int] | None:
|
||||
"""Returns a free block and mark it as used by removing it from the free blocks queue."""
|
||||
if not self.is_enough_free_blocks(n_blocks):
|
||||
return None
|
||||
allocated_block_ids = [self._uninit_block_ids.popleft() for _ in range(n_blocks)]
|
||||
# If we use prefix caching, we keep track of the allocated blocks as partial blocks
|
||||
if self._use_prefix_sharing:
|
||||
for block_id in allocated_block_ids:
|
||||
block = Block(block_id, last_block_id)
|
||||
self._id_to_block[block_id] = block # TODO: we can only store partial block here, and keep the parent referenced as a hash once the plck is complete
|
||||
last_block_id = block_id
|
||||
# In both cases, we return the allocated block ids
|
||||
return allocated_block_ids
|
||||
|
||||
def increase_ref_count(self, block_id: int) -> None:
|
||||
"""Increases the reference count of a block."""
|
||||
block = self._id_to_block[block_id]
|
||||
block.ref_count += 1
|
||||
if block.ref_count == 1:
|
||||
self._init_block_ids.pop(block_id)
|
||||
|
||||
def decrease_ref_count(self, block_id: int) -> None:
|
||||
"""Decreases the reference count of a block."""
|
||||
block = self._id_to_block[block_id]
|
||||
block.ref_count -= 1
|
||||
if block.ref_count == 0:
|
||||
if block.is_complete:
|
||||
self._init_block_ids[block_id] = None
|
||||
else:
|
||||
self._id_to_block.pop(block_id)
|
||||
self._uninit_block_ids.append(block_id)
|
||||
|
||||
def free_blocks(self, blocks: list[int]) -> None:
|
||||
"""Marks a list of blocks as free. If there is no prefix sharing, we simply add them to the uninitialized blocks
|
||||
queue. Otherwise, we mark them as initalized but they can be freed in no uninitialized blocks are lefts."""
|
||||
if self._use_prefix_sharing:
|
||||
for block_id in blocks:
|
||||
self.decrease_ref_count(block_id)
|
||||
else:
|
||||
self._uninit_block_ids.extend(blocks)
|
||||
|
||||
|
||||
def mark_blocks_as_computed(
|
||||
self,
|
||||
num_completed_blocks: int,
|
||||
allocated_blocks: list[int],
|
||||
prompt_ids: list[int]
|
||||
) -> None:
|
||||
# Look for the first complete block, starting from the last block
|
||||
parent_hash = None
|
||||
incomplete_blocks: list[Block] = []
|
||||
for i, block_id in reverse_enumerate(allocated_blocks):
|
||||
block = self._id_to_block[block_id]
|
||||
if block.is_complete:
|
||||
parent_hash = block.hash
|
||||
break
|
||||
incomplete_blocks.append((i, block))
|
||||
|
||||
# Now go through the incomplete blocks and updated them
|
||||
new_parent_id = None
|
||||
while incomplete_blocks:
|
||||
i, block = incomplete_blocks.pop()
|
||||
|
||||
# If the parent id has been updated, we apply the change
|
||||
if new_parent_id is not None:
|
||||
block.parent_id = new_parent_id
|
||||
new_parent_id = None
|
||||
|
||||
# If we have set the hash for all complete blocks, we can stop
|
||||
if num_completed_blocks == 0:
|
||||
break
|
||||
|
||||
# Otherwise, we compute the hash
|
||||
num_completed_blocks -= 1
|
||||
tokens = prompt_ids[i * self.block_size : (i + 1) * self.block_size]
|
||||
block.hash = self.compute_hash(parent_hash, tokens)
|
||||
|
||||
existing_block_id = self._hash_to_id.get(block.hash)
|
||||
# If the block hash is already in the hash to id mapping, we reference the existing block instead
|
||||
if existing_block_id is not None:
|
||||
logger.debug(f"Found existing block {existing_block_id} for block {block.id}")
|
||||
allocated_blocks[i] = existing_block_id
|
||||
self._id_to_block[existing_block_id].ref_count += 1
|
||||
new_parent_id = existing_block_id
|
||||
self.free_blocks([block.id])
|
||||
|
||||
# Otherwise, we add the completed block to the hash table
|
||||
else:
|
||||
self._hash_to_id[block.hash] = block.id
|
||||
|
||||
# Update loop variables
|
||||
parent_hash = block.hash
|
||||
|
||||
def compute_hash(self, parent_hash: int | None, tokens: list[int]) -> int:
|
||||
return hash((parent_hash, tuple(tokens)))
|
||||
|
||||
class CacheAllocator(ABC):
|
||||
"""Abstract base class for cache managers. Cache managers keep track of per-request cache allocations, determine
|
||||
when a new physical block needs to be allocated and compute physical indices for reading or writing to the cache."""
|
||||
|
||||
_index: int
|
||||
_block_table: dict[str, list[int]] # request_id -> list of block_ids allocated to the request
|
||||
block_table: dict[str, list[int]] # request_id -> list of block_ids allocated to the request
|
||||
|
||||
@abstractmethod
|
||||
def allocate_blocks(self, n_blocks: int, request_id: str, free_blocks: deque[int]) -> Optional[int]:
|
||||
def allocate_blocks(self, n_blocks: int, request_id: str, block_manager: BlockManager) -> Optional[int]:
|
||||
"""Allocates n_blocks for a given request_id. Returns the num of blocks allocated if successful and None
|
||||
otherwise."""
|
||||
|
||||
def free_blocks(self, request_id: str, free_blocks: deque[int]) -> None:
|
||||
def free_blocks(self, request_id: str, block_manager: BlockManager) -> None:
|
||||
"""Frees all blocks associated with a request_id."""
|
||||
if request_id in self._block_table:
|
||||
blocks_to_free = self._block_table.pop(request_id)
|
||||
free_blocks.extend(blocks_to_free)
|
||||
if request_id in self.block_table:
|
||||
blocks_to_free = self.block_table.pop(request_id)
|
||||
block_manager.free_blocks(blocks_to_free)
|
||||
else:
|
||||
logger.warning(
|
||||
f"CacheAllocator {self._index} attempted to free blocks for non-existent request_id: {request_id}"
|
||||
@ -54,7 +216,6 @@ class CacheAllocator(ABC):
|
||||
def get_seqlens_k(self, request_id: str, past_length: int, query_length: int) -> tuple[str, int]:
|
||||
"""Returns the attention type of the cache allocator and the key sequence length for the given request_id."""
|
||||
|
||||
|
||||
class FullAttentionCacheAllocator(CacheAllocator):
|
||||
"""Cache manager for a group of full attention layers."""
|
||||
|
||||
@ -66,23 +227,29 @@ class FullAttentionCacheAllocator(CacheAllocator):
|
||||
"""
|
||||
self._index = index
|
||||
self.block_size = block_size
|
||||
self._block_table = {}
|
||||
self.block_table = {}
|
||||
|
||||
def allocate_blocks(self, n_blocks: int, request_id: str, free_blocks: deque[int]) -> Optional[int]:
|
||||
def allocate_blocks(self, n_blocks: int, request_id: str, block_manager: BlockManager) -> Optional[int]:
|
||||
"""Allocate blocks for a given request_id. Returns the number of blocks allocated if successful and None
|
||||
otherwise. For group of full attention layers, we always allocate the number of requested blocks."""
|
||||
if len(free_blocks) < n_blocks:
|
||||
# Make sure the request_id is in the block table and get the first block id
|
||||
if request_id not in self.block_table:
|
||||
self.block_table[request_id] = [] # TODO: check the impact of making this a deque
|
||||
last_block_id = None
|
||||
else:
|
||||
last_block_id = self.block_table[request_id][-1]
|
||||
# Actual allocation, return early if failed
|
||||
allocated_blocks = block_manager.get_free_blocks(n_blocks, last_block_id)
|
||||
if allocated_blocks is None:
|
||||
return None
|
||||
if request_id not in self._block_table:
|
||||
self._block_table[request_id] = []
|
||||
self._block_table[request_id].extend(free_blocks.popleft() for _ in range(n_blocks))
|
||||
self.block_table[request_id].extend(allocated_blocks)
|
||||
return n_blocks
|
||||
|
||||
def get_read_indices(self, request_id: str, past_length: int, query_length: int) -> list[int]:
|
||||
"""Returns the physical indices of where to read request_id's cache. For a group of full attention layers, we
|
||||
first write the new cache to the cache tensor and then read the entire cache from the beginning to the end."""
|
||||
# Retrieve the block table for the request and raise an error if it doesn't exist
|
||||
block_table = self._block_table.get(request_id)
|
||||
block_table = self.block_table.get(request_id)
|
||||
if block_table is None:
|
||||
raise ValueError(f"No block table found for request {request_id}")
|
||||
# Compute the physical indices
|
||||
@ -97,7 +264,7 @@ class FullAttentionCacheAllocator(CacheAllocator):
|
||||
def get_write_indices(self, request_id: str, past_length: int, query_length: int) -> list[int]:
|
||||
"""Returns the physical indices for writing to the cache. For a group of full attention layers, we write the new
|
||||
cache as a continuation of the existing cache for the same request."""
|
||||
block_table = self._block_table.get(request_id)
|
||||
block_table = self.block_table.get(request_id)
|
||||
if block_table is None:
|
||||
raise ValueError(f"No block table found for request {request_id}")
|
||||
# Compute the physical indices
|
||||
@ -129,25 +296,26 @@ class SlidingAttentionCacheAllocator(CacheAllocator):
|
||||
self.block_size = block_size
|
||||
self.sliding_window = sliding_window
|
||||
self._max_blocks_per_request = ceil(self.sliding_window / self.block_size)
|
||||
self._block_table = {}
|
||||
self.block_table = {}
|
||||
|
||||
def allocate_blocks(self, n_blocks: int, request_id: str, free_blocks: deque[int]) -> Optional[int]:
|
||||
def allocate_blocks(self, n_blocks: int, request_id: str, block_manager: BlockManager) -> Optional[int]:
|
||||
"""Allocate blocks for a given request_id. Returns the number of blocks allocated if successful and None
|
||||
otherwise. For group of sliding window attention layers, we only allocate up to the point where we can fit an
|
||||
entire sliding window in the cache tensor."""
|
||||
if request_id not in self._block_table:
|
||||
self._block_table[request_id] = []
|
||||
if request_id not in self.block_table:
|
||||
self.block_table[request_id] = []
|
||||
# Early return if we are already at the max number of blocks per request
|
||||
already_allocated = len(self._block_table[request_id])
|
||||
already_allocated = len(self.block_table[request_id])
|
||||
if already_allocated == self._max_blocks_per_request:
|
||||
return 0
|
||||
# Compute actual number of blocks to allocate
|
||||
after_allocation = min(already_allocated + n_blocks, self._max_blocks_per_request)
|
||||
actual_n_blocks = after_allocation - already_allocated
|
||||
# Classic allocation
|
||||
if len(free_blocks) < actual_n_blocks:
|
||||
allocated_blocks = block_manager.get_free_blocks(actual_n_blocks, None) # no prefix caching w/ sliding window
|
||||
if allocated_blocks is None:
|
||||
return None
|
||||
self._block_table[request_id].extend(free_blocks.popleft() for _ in range(actual_n_blocks))
|
||||
self.block_table[request_id].extend(allocated_blocks)
|
||||
return actual_n_blocks
|
||||
|
||||
def get_read_indices(self, request_id: str, past_length: int, query_length: int) -> list[int]:
|
||||
@ -157,7 +325,7 @@ class SlidingAttentionCacheAllocator(CacheAllocator):
|
||||
sliding_window - 1 cache page and then manually add the new key / values states after. Hence the -1 indices
|
||||
which indicate where to store the new key or values indices."""
|
||||
# Retrieve the block table for the request and raise an error if it doesn't exist
|
||||
block_table = self._block_table.get(request_id)
|
||||
block_table = self.block_table.get(request_id)
|
||||
if block_table is None:
|
||||
raise ValueError(f"No block table found for request {request_id}")
|
||||
# Apply sliding window
|
||||
@ -178,7 +346,7 @@ class SlidingAttentionCacheAllocator(CacheAllocator):
|
||||
sliding window attention layers, we write the new cache in rolling-buffer kind of way: if we reach the end of
|
||||
the allocated physical cache, we start writing from the beginning of the physical cache again."""
|
||||
# Retrieve the block table for the request and raise an error if it doesn't exist
|
||||
block_table = self._block_table.get(request_id)
|
||||
block_table = self.block_table.get(request_id)
|
||||
if block_table is None:
|
||||
raise ValueError(f"No block table found for request {request_id}")
|
||||
# Apply sliding window
|
||||
|
||||
@ -21,7 +21,7 @@ from functools import partial
|
||||
from itertools import count
|
||||
from math import ceil
|
||||
from time import perf_counter
|
||||
from typing import Optional, Union
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -446,10 +446,7 @@ class ContinuousBatchProcessor:
|
||||
cumulative_seqlens_q = [0]
|
||||
logits_indices = []
|
||||
|
||||
if isinstance(self.cumulative_seqlens_k, dict):
|
||||
cumulative_seqlens_k = {layer_type: [0] for layer_type in self.cumulative_seqlens_k}
|
||||
else:
|
||||
cumulative_seqlens_k = [0]
|
||||
cumulative_seqlens_k = {layer_type: [0] for layer_type in self.cumulative_seqlens_k}
|
||||
|
||||
read_index = [[] for _ in range(self.cache.num_groups)]
|
||||
write_index = [[] for _ in range(self.cache.num_groups)]
|
||||
@ -498,10 +495,7 @@ class ContinuousBatchProcessor:
|
||||
self.metrics.record_kv_cache_memory_metrics(self.cache)
|
||||
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
if isinstance(self.cumulative_seqlens_k, dict):
|
||||
ck = max(cumulative_seqlens_k[layer_type][-1] for layer_type in self.cumulative_seqlens_k)
|
||||
else:
|
||||
ck = cumulative_seqlens_k[-1]
|
||||
ck = max(cumulative_seqlens_k[layer_type][-1] for layer_type in self.cumulative_seqlens_k)
|
||||
logger.debug(
|
||||
f"Scheduled: {len(self.requests_in_batch)}, Waiting: {len(self.scheduler.waiting_requests)}, "
|
||||
f"Active: {len(self.scheduler.active_requests)}. cum Q: {cumulative_seqlens_q[-1]}. "
|
||||
@ -517,7 +511,7 @@ class ContinuousBatchProcessor:
|
||||
read_index: list[list[int]],
|
||||
write_index: list[list[int]],
|
||||
cumulative_seqlens_q: list[int],
|
||||
cumulative_seqlens_k: Union[list[int], dict[str, list[int]]],
|
||||
cumulative_seqlens_k: dict[str, list[int]],
|
||||
logits_indices: list[int],
|
||||
) -> None:
|
||||
"""Builds the actual tensors for the current batch, by modifying the already allocated tensors in place."""
|
||||
@ -561,9 +555,7 @@ class ContinuousBatchProcessor:
|
||||
@traced
|
||||
def _maybe_send_output(self, state: RequestState) -> None:
|
||||
"""Send output to the queue based on streaming mode and request state."""
|
||||
if state.streaming:
|
||||
self.output_queue.put(state.to_generation_output())
|
||||
elif state.status == RequestStatus.FINISHED:
|
||||
if state.streaming or state.status == RequestStatus.FINISHED:
|
||||
self.output_queue.put(state.to_generation_output())
|
||||
|
||||
@traced
|
||||
@ -571,17 +563,30 @@ class ContinuousBatchProcessor:
|
||||
"""Update request states based on generated tokens."""
|
||||
out_tokens = self._sync()
|
||||
for i, state in enumerate(self.requests_in_batch):
|
||||
|
||||
# If the request has no remaining prompt ids, it means prefill has already or just finished
|
||||
if len(state.remaining_prompt_ids) == 0:
|
||||
self.metrics.record_ttft_metric(state.created_time, state.request_id)
|
||||
state.status = RequestStatus.DECODING
|
||||
token = out_tokens[self.logits_indices[i]]
|
||||
state.prompt_ids = [token]
|
||||
if state.update_with_token(token):
|
||||
# Update the request and stop if it is complete
|
||||
is_finished = state.update_and_check_completion(token)
|
||||
# We mark the completed blocks as such
|
||||
self.cache.mark_blocks_as_completed(state)
|
||||
if is_finished:
|
||||
self.metrics.record_request_completion(state.created_time, state.request_id)
|
||||
self.scheduler.finish_request(state.request_id, evict_from_cache=(not self.manual_eviction))
|
||||
self._maybe_send_output(state)
|
||||
# Otherwise, the request is still prefilling, but the prefill has been split
|
||||
elif state.status == RequestStatus.PREFILLING_SPLIT:
|
||||
state.status = RequestStatus.SPLIT_PENDING_REMAINDER
|
||||
# DEBUG: there is a dangling if, but idk if it ever happens. Adding an error to catch it.
|
||||
else:
|
||||
raise ValueError(f"Request {state.request_id} is in an unexpected state: {state.status}")
|
||||
|
||||
|
||||
# We error out if the cache is full
|
||||
if self.cache.get_num_free_blocks() == 0:
|
||||
raise ValueError("No more free blocks")
|
||||
|
||||
@ -799,7 +804,6 @@ class ContinuousBatchingManager:
|
||||
logger.warning("Manager thread is already running.")
|
||||
return
|
||||
|
||||
self._result_queue = queue.Queue()
|
||||
self._generation_thread = threading.Thread(target=self._run_generation_loop)
|
||||
self._generation_thread.start()
|
||||
|
||||
@ -919,6 +923,7 @@ class ContinuousBatchingManager:
|
||||
if result is not None:
|
||||
yield result
|
||||
|
||||
# FIXME: stop iteration when request status is finished?
|
||||
def request_id_iter(self, request_id: str) -> Generator[GenerationOutput]:
|
||||
"""Iterate over results matching a specific request id as they become available."""
|
||||
request_cancelled = False
|
||||
@ -930,20 +935,6 @@ class ContinuousBatchingManager:
|
||||
request_cancelled = self.batch_processor.scheduler.request_is_cancelled(request_id)
|
||||
|
||||
@traced
|
||||
def warmup(self, batch_processor: ContinuousBatchProcessor) -> None:
|
||||
stream = torch.cuda.Stream(device=self.model.device)
|
||||
stream.wait_stream(torch.cuda.current_stream())
|
||||
with torch.cuda.stream(stream):
|
||||
# Warmup the model with a dummy forward pass
|
||||
self._generation_step(batch_processor)
|
||||
torch.cuda.current_stream().wait_stream(stream)
|
||||
|
||||
self.graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(self.graph, stream=stream):
|
||||
self._generation_step(batch_processor)
|
||||
|
||||
@traced
|
||||
# @torch.compile
|
||||
def _generation_step(self) -> None:
|
||||
"""Perform a single generation step. This is cuda graphed"""
|
||||
self.batch_processor._generation_step(self.model, self.logit_processor, self.do_sample)
|
||||
|
||||
@ -105,10 +105,10 @@ class RequestState:
|
||||
error (Optional[str]): Any error message associated with the request. When None, has had no error yet.
|
||||
"""
|
||||
|
||||
# Required fields
|
||||
# Required fields # TODO: come up with better names / not sure prompt_ids and such are not redundant
|
||||
request_id: str
|
||||
full_prompt_ids: Optional[list[int]] = None # Full initial prompt
|
||||
prompt_ids: Optional[list[int]] = None # Tokens IDs currently being processed (initial + generated)
|
||||
prompt_ids: Optional[list[int]] = None # Tokens IDs currently being processed
|
||||
remaining_prompt_ids: list[int] = field(default_factory=list) # For split requests, prefill left to process
|
||||
static_outputs: list[int] = field(default_factory=list) # Generated tokens
|
||||
allocated_blocks: int = 0 # Number of blocks allocated to the request
|
||||
@ -153,7 +153,7 @@ class RequestState:
|
||||
|
||||
# TODO: this logic seems one token off, check it out
|
||||
@traced
|
||||
def update_with_token(self, token_id: int) -> bool:
|
||||
def update_and_check_completion(self, token_id: int) -> bool:
|
||||
"""Update the request with a newly generated token and check for completion.
|
||||
|
||||
Args:
|
||||
|
||||
@ -104,7 +104,7 @@ class Scheduler(ABC):
|
||||
)
|
||||
|
||||
@traced
|
||||
def _allocate_blocks_if_needed(self, state: RequestState, len_next_tokens: int) -> bool:
|
||||
def _allocate_blocks_if_needed(self, state: RequestState) -> bool:
|
||||
"""Allocate additional cache blocks for a request if the currently allocated blocks are insufficient to
|
||||
accommodate the next tokens. It calculates how many blocks are needed based on the request's current
|
||||
cache occupancy and the number of tokens to be processed. The allocation itself is done by the CacheAllocator
|
||||
@ -113,10 +113,11 @@ class Scheduler(ABC):
|
||||
# 1. we check that the occupancy is less than the requested length
|
||||
# 2. we allocate enough blocks to cover the requested length
|
||||
current_len = state.current_len()
|
||||
len_next_tokens = len(state.prompt_ids)
|
||||
occupancy = state.allocated_blocks * self.cache.block_size - current_len
|
||||
if occupancy < len_next_tokens or state.allocated_blocks == 0:
|
||||
blocks_needed = ((len_next_tokens - occupancy + 1) // self.cache.block_size) + 1
|
||||
allocated = self.cache.allocate_blocks(blocks_needed, state.request_id)
|
||||
allocated = self.cache.allocate_blocks(blocks_needed, state)
|
||||
if allocated is None:
|
||||
return False
|
||||
state.allocated_blocks += allocated
|
||||
@ -125,11 +126,25 @@ class Scheduler(ABC):
|
||||
@traced(span_name="prepare_request")
|
||||
def _prepare_request_for_processing(
|
||||
self, state: RequestState, token_budget: int, request_ids_to_remove_from_waiting: set[str]
|
||||
):
|
||||
) -> None:
|
||||
"""Prepares a request for processing in the current batch."""
|
||||
request_tokens = (
|
||||
state.remaining_prompt_ids if state.status == RequestStatus.SPLIT_PENDING_REMAINDER else state.prompt_ids
|
||||
)
|
||||
# If prefix sharing is enabled, we look for a prefix match and split the request if found
|
||||
if self.cache.use_prefix_sharing and state.status == RequestStatus.PENDING:
|
||||
prefill_length = self.cache.search_prefix_match(state.request_id, state.prompt_ids)
|
||||
if prefill_length > 0:
|
||||
self.active_requests[state.request_id] = state
|
||||
state.remaining_prompt_ids = state.prompt_ids[prefill_length:]
|
||||
state.prompt_ids = state.prompt_ids[prefill_length:]
|
||||
request_ids_to_remove_from_waiting.add(state.request_id)
|
||||
state.status = RequestStatus.SPLIT_PENDING_REMAINDER
|
||||
|
||||
# If the request has a split prefill, the tokens to process are the remaining prompt ids
|
||||
if state.status == RequestStatus.SPLIT_PENDING_REMAINDER:
|
||||
request_tokens = state.remaining_prompt_ids
|
||||
# Otherwise, the tokens to process are the prompt ids, which are the full prompt or the last predicted tokens
|
||||
else:
|
||||
request_tokens = state.prompt_ids
|
||||
|
||||
if len(request_tokens) < token_budget:
|
||||
# Can process the entire prompt/remainder
|
||||
if state.status == RequestStatus.PENDING:
|
||||
@ -152,6 +167,7 @@ class Scheduler(ABC):
|
||||
state.prompt_ids = request_tokens[:token_budget]
|
||||
|
||||
|
||||
# TODO: further common-ize the two classes
|
||||
@attach_tracer()
|
||||
class FIFOScheduler(Scheduler):
|
||||
"""This scheduler processes requests in the order they arrive, meaning decoding requests has priority over
|
||||
@ -195,30 +211,31 @@ class FIFOScheduler(Scheduler):
|
||||
|
||||
self._prepare_request_for_processing(state, token_budget, request_ids_to_remove_from_waiting)
|
||||
request_len = len(state.prompt_ids)
|
||||
if not self._allocate_blocks_if_needed(
|
||||
state, len(state.prompt_ids)
|
||||
): # don't schedule if we can't allocate blocks
|
||||
if len(self.cache._free_blocks) == 0:
|
||||
# If we can't allocate blocks, do not schedule the request and break if the cache is full
|
||||
if not self._allocate_blocks_if_needed(state):
|
||||
if self.cache.get_num_free_blocks() == 0:
|
||||
break
|
||||
continue
|
||||
|
||||
@traced
|
||||
def _add_to_scheduled_requests(state: RequestState):
|
||||
scheduled_requests.append(state)
|
||||
|
||||
_add_to_scheduled_requests(state)
|
||||
# Add the request to the scheduled requests
|
||||
scheduled_requests.append(state)
|
||||
|
||||
# Update the token budget
|
||||
token_budget -= request_len
|
||||
# If using prefix sharing, we make note of the blocks that will be computed in the forward pass
|
||||
if self.cache.use_prefix_sharing:
|
||||
tokens_in_current_block = state.current_len() % self.cache.block_size
|
||||
tokens_after_forward = tokens_in_current_block + request_len
|
||||
computed_blocks = tokens_after_forward // self.cache.block_size
|
||||
self.cache.blocks_to_complete[state.request_id] = computed_blocks
|
||||
|
||||
@traced
|
||||
def _remove_from_waiting_requests(state: RequestState):
|
||||
req_id = state.request_id
|
||||
if req_id in self.waiting_requests:
|
||||
del self.waiting_requests[req_id]
|
||||
request_ids_to_remove_from_waiting.add(req_id)
|
||||
|
||||
_remove_from_waiting_requests(state)
|
||||
# Remove the request from the waiting queue and mark it as removed
|
||||
req_id = state.request_id
|
||||
was_waiting = self.waiting_requests.pop(req_id, None) is not None
|
||||
if was_waiting:
|
||||
request_ids_to_remove_from_waiting.add(req_id)
|
||||
|
||||
# Early exit of the loop if we have no token budget left
|
||||
if token_budget == 0:
|
||||
break
|
||||
|
||||
@ -249,6 +266,7 @@ class PrefillFirstScheduler(Scheduler):
|
||||
elif state.status == RequestStatus.DECODING:
|
||||
second_priority_states.append(state)
|
||||
|
||||
# Add waiting requests to second priority
|
||||
for req_id in self.waiting_requests_order:
|
||||
second_priority_states.append(self.waiting_requests[req_id])
|
||||
|
||||
@ -259,30 +277,31 @@ class PrefillFirstScheduler(Scheduler):
|
||||
for state in candidates:
|
||||
self._prepare_request_for_processing(state, token_budget, request_ids_to_remove_from_waiting)
|
||||
request_len = len(state.prompt_ids)
|
||||
if not self._allocate_blocks_if_needed(
|
||||
state, len(state.prompt_ids)
|
||||
): # don't schedule if we can't allocate blocks
|
||||
if len(self.cache._free_blocks) == 0:
|
||||
# If we can't allocate blocks, do not schedule the request and break if the cache is full
|
||||
if not self._allocate_blocks_if_needed(state):
|
||||
if self.cache.get_num_free_blocks() == 0:
|
||||
break
|
||||
continue
|
||||
|
||||
@traced
|
||||
def _add_to_scheduled_requests(state: RequestState):
|
||||
scheduled_requests.append(state)
|
||||
|
||||
_add_to_scheduled_requests(state)
|
||||
# Add the request to the scheduled requests
|
||||
scheduled_requests.append(state)
|
||||
|
||||
# Update the token budget
|
||||
token_budget -= request_len
|
||||
# If using prefix sharing, we make note of the blocks that will be computed in the forward pass
|
||||
if self.cache.use_prefix_sharing:
|
||||
tokens_in_current_block = state.current_len() % self.cache.block_size
|
||||
tokens_after_forward = tokens_in_current_block + request_len
|
||||
computed_blocks = tokens_after_forward // self.cache.block_size
|
||||
self.cache.blocks_to_complete[state.request_id] = computed_blocks
|
||||
|
||||
@traced
|
||||
def _remove_from_waiting_requests(state: RequestState):
|
||||
req_id = state.request_id
|
||||
if req_id in self.waiting_requests:
|
||||
del self.waiting_requests[req_id]
|
||||
request_ids_to_remove_from_waiting.add(req_id)
|
||||
|
||||
_remove_from_waiting_requests(state)
|
||||
# Remove the request from the waiting queue and mark it as removed
|
||||
req_id = state.request_id
|
||||
if req_id in self.waiting_requests:
|
||||
del self.waiting_requests[req_id]
|
||||
request_ids_to_remove_from_waiting.add(req_id)
|
||||
|
||||
# Early exit of the loop if we have no token budget left
|
||||
if token_budget == 0:
|
||||
break
|
||||
|
||||
|
||||
@ -1635,12 +1635,7 @@ class GenerationMixin(ContinuousMixin):
|
||||
|
||||
# TransformersKwargs are model-agnostic attention and generation arguments such as 'output_attentions'
|
||||
for key, value in model_kwargs.items():
|
||||
if (
|
||||
value is not None
|
||||
and key not in model_args
|
||||
and key not in TransformersKwargs.__optional_keys__
|
||||
and key != "debug_io"
|
||||
):
|
||||
if value is not None and key not in model_args and key not in TransformersKwargs.__optional_keys__:
|
||||
unused_model_args.append(key)
|
||||
|
||||
if unused_model_args:
|
||||
|
||||
@ -512,8 +512,10 @@ def accelerate_disk_offload(
|
||||
checkpoint_files,
|
||||
device_map,
|
||||
checkpoint_keys,
|
||||
key_renaming_mapping,
|
||||
sharded_metadata,
|
||||
dtype,
|
||||
reverse_key_renaming_mapping,
|
||||
):
|
||||
disk_only_shard_files = []
|
||||
if disk_offload_folder is not None:
|
||||
@ -532,13 +534,19 @@ def accelerate_disk_offload(
|
||||
weight_map = dict.fromkeys(checkpoint_keys, checkpoint_files[0])
|
||||
else:
|
||||
folder = os.path.sep.join(checkpoint_files[0].split(os.path.sep)[:-1])
|
||||
# Fix the weight map keys according to the key mapping
|
||||
weight_map = {
|
||||
key_renaming_mapping[k]: v
|
||||
for k, v in sharded_metadata["weight_map"].items()
|
||||
if k in key_renaming_mapping
|
||||
}
|
||||
weight_map = {k: os.path.join(folder, v) for k, v in weight_map.items()}
|
||||
# Find potential checkpoints containing only offloaded weights
|
||||
disk_only_shard_files = get_disk_only_shard_files(device_map, weight_map)
|
||||
disk_offload_index = {
|
||||
name: {
|
||||
"safetensors_file": file,
|
||||
"weight_name": name,
|
||||
"weight_name": reverse_key_renaming_mapping[name],
|
||||
"dtype": str_dtype,
|
||||
}
|
||||
for name, file in weight_map.items()
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import inspect
|
||||
from copy import deepcopy
|
||||
from inspect import signature
|
||||
|
||||
from ..utils import (
|
||||
@ -23,6 +24,7 @@ if is_accelerate_available():
|
||||
import accelerate
|
||||
from accelerate import init_empty_weights
|
||||
from accelerate.hooks import add_hook_to_module, remove_hook_from_module
|
||||
from accelerate.utils import find_tied_parameters
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
@ -149,6 +151,52 @@ def replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name
|
||||
return model
|
||||
|
||||
|
||||
def get_keys_to_not_convert(model):
|
||||
r"""
|
||||
An utility function to get the key of the module to keep in full precision if any For example for CausalLM modules
|
||||
we may want to keep the lm_head in full precision for numerical stability reasons. For other architectures, we want
|
||||
to keep the tied weights of the model. The function will return a list of the keys of the modules to not convert in
|
||||
int8.
|
||||
|
||||
Parameters:
|
||||
model (`torch.nn.Module`):
|
||||
Input model
|
||||
"""
|
||||
# Create a copy of the model and tie the weights, then
|
||||
# check if it contains tied weights
|
||||
tied_model = deepcopy(model) # this has 0 cost since it is done inside `init_empty_weights` context manager`
|
||||
tied_model.tie_weights()
|
||||
|
||||
tied_params = find_tied_parameters(tied_model)
|
||||
tied_keys = sum(tied_params, [])
|
||||
has_tied_params = len(tied_keys) > 0
|
||||
|
||||
# If there is not tied weights, we want to keep the lm_head(output_embedding) in full precision
|
||||
if not has_tied_params:
|
||||
output_emb = model.get_output_embeddings()
|
||||
if output_emb is not None:
|
||||
list_last_module = [name for name, module in model.named_modules() if id(module) == id(output_emb)]
|
||||
return list_last_module
|
||||
|
||||
# otherwise, no tied weights, no output embedding defined, simply keep the last module in full precision
|
||||
list_modules = list(model.named_parameters())
|
||||
list_last_module = [list_modules[-1][0]]
|
||||
# add last module together with tied weights
|
||||
intersection = set(list_last_module) - set(tied_keys)
|
||||
list_untouched = list(set(tied_keys)) + list(intersection)
|
||||
|
||||
# remove ".weight" from the keys
|
||||
names_to_remove = [".weight", ".bias"]
|
||||
filtered_module_names = []
|
||||
for name in list_untouched:
|
||||
for name_to_remove in names_to_remove:
|
||||
if name_to_remove in name:
|
||||
name = name.replace(name_to_remove, "")
|
||||
filtered_module_names.append(name)
|
||||
|
||||
return filtered_module_names
|
||||
|
||||
|
||||
# Copied from PEFT: https://github.com/huggingface/peft/blob/47b3712898539569c02ec5b3ed4a6c36811331a1/src/peft/utils/integrations.py#L41
|
||||
def dequantize_bnb_weight(weight: "torch.nn.Parameter", dtype: "torch.dtype", state=None):
|
||||
"""
|
||||
|
||||
@ -13,11 +13,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import re
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Optional
|
||||
|
||||
from ..core_model_loading import ConversionOps
|
||||
from ..utils import is_accelerate_available, is_torch_accelerator_available, is_torch_available, logging
|
||||
|
||||
|
||||
@ -33,18 +30,6 @@ if is_accelerate_available():
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
try:
|
||||
_FP8_DTYPE = torch.float8_e4m3fn
|
||||
_FP8_MIN = torch.finfo(_FP8_DTYPE).min
|
||||
_FP8_MAX = torch.finfo(_FP8_DTYPE).max
|
||||
_FP8_IS_INT = False
|
||||
except AttributeError:
|
||||
_FP8_DTYPE = torch.int8
|
||||
_FP8_MIN, _FP8_MAX = -127, 127
|
||||
_FP8_IS_INT = True
|
||||
logger.warning_once(
|
||||
"torch.float8_e4m3fn not available; falling back to int8 emulation for Fp8Quantize operations."
|
||||
)
|
||||
|
||||
|
||||
# Copied from https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/inference/kernel.py
|
||||
@ -347,12 +332,6 @@ class FP8Linear(nn.Linear):
|
||||
if self.weight.element_size() > 1:
|
||||
return F.linear(input, self.weight, self.bias)
|
||||
else:
|
||||
if isinstance(self.weight, torch.distributed.tensor.DTensor):
|
||||
weight = self.weight._local_tensor.contiguous()
|
||||
scale_inv = self.weight_scale_inv._local_tensor.contiguous()
|
||||
else:
|
||||
weight = self.weight.contiguous()
|
||||
scale_inv = self.weight_scale_inv.contiguous()
|
||||
# Context manager used to switch among the available accelerators
|
||||
device_type = torch.accelerator.current_accelerator().type if is_torch_accelerator_available() else "cuda"
|
||||
torch_accelerator_module = getattr(torch, device_type, torch.cuda)
|
||||
@ -360,9 +339,9 @@ class FP8Linear(nn.Linear):
|
||||
qinput, scale = act_quant(input, self.block_size[1])
|
||||
output = w8a8_block_fp8_matmul_triton(
|
||||
qinput,
|
||||
weight,
|
||||
self.weight,
|
||||
scale,
|
||||
scale_inv,
|
||||
self.weight_scale_inv,
|
||||
self.block_size,
|
||||
output_dtype=input.dtype,
|
||||
)
|
||||
@ -371,124 +350,9 @@ class FP8Linear(nn.Linear):
|
||||
torch_accelerator_module.synchronize()
|
||||
if self.bias is not None:
|
||||
output = output + self.bias
|
||||
output = torch.nan_to_num(output, nan=0.0)
|
||||
return output.to(dtype=input.dtype)
|
||||
|
||||
|
||||
def _ceil_div(a, b):
|
||||
return (a + b - 1) // b
|
||||
|
||||
|
||||
class FP8Expert(nn.Module):
|
||||
dtype = torch.float8_e4m3fn
|
||||
|
||||
def __init__(self, config, block_size, device):
|
||||
super().__init__()
|
||||
|
||||
from ..activations import ACT2FN
|
||||
|
||||
self.block_size = block_size
|
||||
self.num_experts = config.num_local_experts
|
||||
self.hidden_dim = config.hidden_size
|
||||
self.intermediate_dim = config.intermediate_size
|
||||
|
||||
Wg_out, Wg_in = 2 * self.intermediate_dim, self.hidden_dim
|
||||
Wd_out, Wd_in = self.hidden_dim, self.intermediate_dim
|
||||
|
||||
self.gate_up_proj = nn.Parameter(
|
||||
torch.zeros(self.num_experts, Wg_out, Wg_in, dtype=FP8Expert.dtype, device=device)
|
||||
)
|
||||
self.down_proj = nn.Parameter(
|
||||
torch.zeros(self.num_experts, Wd_out, Wd_in, dtype=FP8Expert.dtype, device=device)
|
||||
)
|
||||
|
||||
# Create inverse scale tiles only when using 1-byte types (fp8)
|
||||
if self.gate_up_proj.element_size() == 1:
|
||||
bo, bi = self.block_size
|
||||
|
||||
# gate_up tiles: ceil(Wg_out/bo) x ceil(Wg_in/bi)
|
||||
gu_scale_o = _ceil_div(Wg_out, bo)
|
||||
gu_scale_i = _ceil_div(Wg_in, bi)
|
||||
self.gate_up_proj_scales_inv = nn.Parameter(
|
||||
torch.zeros(self.num_experts, gu_scale_o, gu_scale_i, dtype=torch.float32, device=device)
|
||||
)
|
||||
|
||||
# down tiles: ceil(Wd_out/bo) x ceil(Wd_in/bi)
|
||||
dp_scale_o = _ceil_div(Wd_out, bo)
|
||||
dp_scale_i = _ceil_div(Wd_in, bi)
|
||||
self.down_proj_scales_inv = nn.Parameter(
|
||||
torch.zeros(self.num_experts, dp_scale_o, dp_scale_i, dtype=torch.float32, device=device)
|
||||
)
|
||||
else:
|
||||
# Match FP8Linear behavior when not using 1-byte weights
|
||||
self.register_parameter("gate_up_proj_scale_inv", None)
|
||||
self.register_parameter("down_proj_scale_inv", None)
|
||||
|
||||
# (Optional) bias per projection — many MoEs omit bias; keep None to match your FP8Linear default
|
||||
self.register_parameter("gate_up_bias", None)
|
||||
self.register_parameter("down_bias", None)
|
||||
|
||||
# Activation used in the MLP (same as your config / ACT2FN)
|
||||
# Keep a handle here; actual usage happens in forward of your MoE block
|
||||
self.act_fn = ACT2FN[config.hidden_act]
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
top_k_index: torch.Tensor,
|
||||
top_k_weights: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
final_hidden_states = torch.zeros_like(hidden_states)
|
||||
num_experts = top_k_weights.shape[1]
|
||||
with torch.no_grad():
|
||||
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1)
|
||||
expert_mask = expert_mask.permute(2, 1, 0)
|
||||
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
|
||||
|
||||
for expert_idx in expert_hit:
|
||||
expert_idx = expert_idx[0]
|
||||
if expert_idx == num_experts:
|
||||
continue
|
||||
_, token_idx = torch.where(expert_mask[expert_idx])
|
||||
current_state = hidden_states.index_select(0, token_idx)
|
||||
gate, up = self.linear(
|
||||
current_state, self.gate_up_proj[expert_idx], self.gate_up_proj_scales_inv[expert_idx]
|
||||
).chunk(2, dim=-1)
|
||||
current_hidden_states = self.act_fn(gate) * up
|
||||
current_hidden_states = self.linear(
|
||||
current_hidden_states, self.down_proj[expert_idx], self.down_proj_scales_inv[expert_idx]
|
||||
)
|
||||
|
||||
routing_weights = top_k_weights[token_idx, expert_idx].unsqueeze(-1)
|
||||
current_hidden_states = current_hidden_states * routing_weights.to(current_hidden_states.dtype)
|
||||
final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))
|
||||
|
||||
return final_hidden_states
|
||||
|
||||
def linear(self, input: torch.Tensor, weight: torch.Tensor, weight_scale_inv: torch.Tensor) -> torch.Tensor:
|
||||
if weight.element_size() > 1:
|
||||
return F.linear(input, weight, None)
|
||||
else:
|
||||
# Context manager used to switch among the available accelerators
|
||||
device_type = torch.accelerator.current_accelerator().type if is_torch_accelerator_available() else "cuda"
|
||||
torch_accelerator_module = getattr(torch, device_type, torch.cuda)
|
||||
with torch_accelerator_module.device(input.device):
|
||||
qinput, scale = act_quant(input, self.block_size[1])
|
||||
output = w8a8_block_fp8_matmul_triton(
|
||||
qinput,
|
||||
weight,
|
||||
scale,
|
||||
weight_scale_inv,
|
||||
self.block_size,
|
||||
output_dtype=input.dtype,
|
||||
)
|
||||
# Blocks the CPU until all accelerator operations on the specified device are complete. It is used to ensure that the results of the
|
||||
# preceding operations are ready before proceeding
|
||||
torch_accelerator_module.synchronize()
|
||||
return output.to(dtype=input.dtype)
|
||||
|
||||
|
||||
# TODO: we do need this.... but not recursive...
|
||||
def _replace_with_fp8_linear(
|
||||
model,
|
||||
tp_plan=None,
|
||||
@ -497,48 +361,40 @@ def _replace_with_fp8_linear(
|
||||
quantization_config=None,
|
||||
has_been_replaced=False,
|
||||
):
|
||||
iterator = list(model.named_parameters()).copy()
|
||||
for name, empty_tensor in iterator:
|
||||
current_key_name = name
|
||||
name = name.rsplit(".", 1)[0] if "." in name else name
|
||||
module = model.get_submodule(name)
|
||||
"""Replace Linear layers with FP8Linear."""
|
||||
if current_key_name is None:
|
||||
current_key_name = []
|
||||
|
||||
current_key_name_str = re.sub(r"\d+", "*", current_key_name)
|
||||
if not any(key in current_key_name_str for key in (modules_to_not_convert or [])):
|
||||
with init_empty_weights():
|
||||
if (
|
||||
"gate_up_proj" in current_key_name
|
||||
or "down_proj" in current_key_name
|
||||
and "experts" in current_key_name
|
||||
): # Experts!
|
||||
in_features = empty_tensor.size(-2)
|
||||
out_features = empty_tensor.size(-1)
|
||||
model.set_submodule(
|
||||
name,
|
||||
FP8Expert(
|
||||
config=model.config,
|
||||
block_size=quantization_config.weight_block_size,
|
||||
device=empty_tensor.device,
|
||||
),
|
||||
)
|
||||
for name, module in model.named_children():
|
||||
current_key_name.append(name)
|
||||
|
||||
elif isinstance(module, nn.Linear):
|
||||
in_features = module.in_features
|
||||
out_features = module.out_features
|
||||
model.set_submodule(
|
||||
name,
|
||||
FP8Linear(
|
||||
in_features=in_features,
|
||||
out_features=out_features,
|
||||
bias=module.bias is not None,
|
||||
device=module.weight.device,
|
||||
dtype=module.weight.dtype,
|
||||
activation_scheme=quantization_config.activation_scheme,
|
||||
block_size=quantization_config.weight_block_size,
|
||||
),
|
||||
if isinstance(module, nn.Linear) and name not in (modules_to_not_convert or []):
|
||||
current_key_name_str = ".".join(current_key_name)
|
||||
if not any(key in current_key_name_str for key in (modules_to_not_convert or [])):
|
||||
with init_empty_weights():
|
||||
model._modules[name] = FP8Linear(
|
||||
in_features=module.in_features,
|
||||
out_features=module.out_features,
|
||||
bias=module.bias is not None,
|
||||
device=module.weight.device,
|
||||
dtype=module.weight.dtype,
|
||||
activation_scheme=quantization_config.activation_scheme,
|
||||
block_size=quantization_config.weight_block_size,
|
||||
)
|
||||
has_been_replaced = True
|
||||
# when changing a layer the TP PLAN for that layer should be updated. TODO
|
||||
has_been_replaced = True
|
||||
# when changing a layer the TP PLAN for that layer should be updated. TODO
|
||||
|
||||
if len(list(module.children())) > 0:
|
||||
_, has_been_replaced = _replace_with_fp8_linear(
|
||||
module,
|
||||
tp_plan,
|
||||
modules_to_not_convert,
|
||||
current_key_name,
|
||||
quantization_config,
|
||||
has_been_replaced=has_been_replaced,
|
||||
)
|
||||
|
||||
current_key_name.pop(-1)
|
||||
|
||||
return model, has_been_replaced
|
||||
|
||||
@ -549,7 +405,7 @@ def replace_with_fp8_linear(
|
||||
quantization_config=None,
|
||||
):
|
||||
"""Helper function to replace model layers with FP8 versions."""
|
||||
modules_to_not_convert += ["lm_head"]
|
||||
modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert
|
||||
|
||||
if quantization_config.modules_to_not_convert is not None:
|
||||
modules_to_not_convert.extend(quantization_config.modules_to_not_convert)
|
||||
@ -568,133 +424,3 @@ def replace_with_fp8_linear(
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
class QuantizationOp(ConversionOps):
|
||||
"""Base class for quantization operations."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class Fp8Quantize(QuantizationOp):
|
||||
"""
|
||||
A quantization operation that creates two tensors, weight and scale out of a weight.
|
||||
"""
|
||||
|
||||
reverse_op: type[ConversionOps]
|
||||
|
||||
def __init__(self, block_size: Optional[tuple[int, int]] = None):
|
||||
self.block_size = block_size
|
||||
self.reverse_op = Fp8Dequantize
|
||||
|
||||
def convert(self, input_dict: torch.Tensor, *, quant_config: dict[str, Any]) -> dict[str, torch.Tensor]:
|
||||
# Unpack single key/value (value may be wrapped in a list)
|
||||
target_keys, value = tuple(input_dict.items())[0]
|
||||
value = value[0] if isinstance(value, list) else value
|
||||
|
||||
# Resolve block size (support dict-like or attr-like quant_config)
|
||||
block_size = None
|
||||
if quant_config is not None:
|
||||
if isinstance(quant_config, dict):
|
||||
block_size = quant_config.get("weight_block_size")
|
||||
else:
|
||||
block_size = getattr(quant_config, "weight_block_size", None)
|
||||
if block_size is None:
|
||||
block_size = (value.shape[-2], value.shape[-1])
|
||||
|
||||
block_m, block_n = block_size
|
||||
rows, cols = value.shape[-2], value.shape[-1]
|
||||
|
||||
# Enforce exact tiling like your original
|
||||
if rows % block_m != 0 or cols % block_n != 0:
|
||||
raise ValueError(
|
||||
f"Matrix dimensions ({rows}, {cols}) must be divisible by block sizes ({block_m}, {block_n}). for {target_keys}"
|
||||
)
|
||||
|
||||
# Leading dims can be empty (2D) or include num_experts/... (3D+)
|
||||
leading_shape = value.shape[:-2]
|
||||
rows_tiles = rows // block_m
|
||||
cols_tiles = cols // block_n
|
||||
|
||||
original_shape = value.shape
|
||||
value_fp32 = value.to(torch.float32)
|
||||
|
||||
# Reshape to (..., rows_tiles, block_m, cols_tiles, block_n)
|
||||
reshaped = value_fp32.reshape(*leading_shape, rows_tiles, block_m, cols_tiles, block_n)
|
||||
|
||||
# Per-tile max-abs over the block dims
|
||||
# dims: block_m is at -3, block_n is at -1 after the reshape
|
||||
max_abs = reshaped.abs().amax(dim=(-3, -1))
|
||||
safe_max_abs = torch.where(max_abs > 0, max_abs, torch.ones_like(max_abs))
|
||||
|
||||
# Tile scale (we store inverse scale like your Linear: weight_scale_inv)
|
||||
scales = _FP8_MAX / safe_max_abs
|
||||
scales = torch.where(max_abs > 0, scales, torch.ones_like(scales)) # keep zeros stable
|
||||
|
||||
# Broadcast scales back over the block dims and quantize
|
||||
# max_abs/scales shape: (..., rows_tiles, cols_tiles)
|
||||
scales_broadcast = scales.unsqueeze(-1).unsqueeze(-3) # -> (..., rows_tiles, 1, cols_tiles, 1)
|
||||
scaled = reshaped * scales_broadcast
|
||||
|
||||
if _FP8_IS_INT:
|
||||
quantized = torch.clamp(scaled.round(), min=_FP8_MIN, max=_FP8_MAX).to(_FP8_DTYPE)
|
||||
else:
|
||||
quantized = torch.clamp(scaled, min=_FP8_MIN, max=_FP8_MAX).to(_FP8_DTYPE)
|
||||
|
||||
quantized = quantized.reshape(original_shape)
|
||||
|
||||
inv_scales = (1.0 / scales).to(torch.float32) # shape: (*leading, rows_tiles, cols_tiles)
|
||||
if target_keys.endswith("weight"):
|
||||
scale_key = target_keys.rsplit(".", 1)[0] + ".weight_scale_inv"
|
||||
else:
|
||||
scale_key = target_keys + "_scales_inv"
|
||||
|
||||
# Return both quantized weights and per-tile inverse scales (keeps leading dims, e.g., num_experts)
|
||||
return {
|
||||
target_keys: quantized,
|
||||
scale_key: inv_scales,
|
||||
}
|
||||
|
||||
|
||||
class Fp8Dequantize(QuantizationOp):
|
||||
"""Inverse operation of :class:`Fp8Quantize`. Takes a pair (weight, scale) and reconstructs the fp32 tensor."""
|
||||
|
||||
def __init__(self, block_size: Optional[tuple[int, int]] = None):
|
||||
self.block_size = block_size
|
||||
self.reverse_op = Fp8Quantize
|
||||
|
||||
def convert(
|
||||
self,
|
||||
value: Union[Sequence[torch.Tensor], dict[str, torch.Tensor]],
|
||||
*,
|
||||
context: dict[str, Any],
|
||||
) -> torch.Tensor:
|
||||
if isinstance(value, dict):
|
||||
tensors = list(value.values())
|
||||
else:
|
||||
tensors = list(value) if isinstance(value, Sequence) else [value]
|
||||
if len(tensors) != 2:
|
||||
raise ValueError("Fp8Dequantize expects exactly two tensors: quantized weights and scales.")
|
||||
quantized, scales = tensors
|
||||
if not isinstance(quantized, torch.Tensor) or not isinstance(scales, torch.Tensor):
|
||||
raise TypeError("Fp8Dequantize expects tensors as inputs.")
|
||||
|
||||
quantized_fp32 = quantized.to(torch.float32)
|
||||
rows, cols = quantized_fp32.shape[-2:]
|
||||
block_size = self.block_size
|
||||
if block_size is None:
|
||||
quant_config = context.get("quantization_config")
|
||||
block_size = getattr(quant_config, "weight_block_size", None)
|
||||
if block_size is None:
|
||||
block_size = (rows, cols)
|
||||
block_m, block_n = block_size
|
||||
if rows % block_m != 0 or cols % block_n != 0:
|
||||
raise ValueError(
|
||||
f"Matrix dimensions ({rows}, {cols}) must be divisible by block sizes ({block_m}, {block_n})."
|
||||
)
|
||||
|
||||
reshaped = quantized_fp32.reshape(-1, rows // block_m, block_m, cols // block_n, block_n)
|
||||
expanded_scales = scales.to(torch.float32).reshape(-1, rows // block_m, cols // block_n)
|
||||
expanded_scales = expanded_scales.unsqueeze(-1).unsqueeze(2)
|
||||
dequantized = reshaped * expanded_scales
|
||||
return dequantized.reshape(quantized_fp32.shape)
|
||||
|
||||
@ -51,10 +51,13 @@ try:
|
||||
)
|
||||
},
|
||||
"RMSNorm": {
|
||||
"cuda": LayerRepository(
|
||||
repo_id="kernels-community/liger_kernels",
|
||||
layer_name="LigerRMSNorm",
|
||||
),
|
||||
"cuda": {
|
||||
Mode.INFERENCE: LayerRepository(
|
||||
repo_id="kernels-community/liger_kernels",
|
||||
layer_name="LigerRMSNorm",
|
||||
# revision="pure-layer-test",
|
||||
),
|
||||
},
|
||||
"rocm": {
|
||||
Mode.INFERENCE: LayerRepository(
|
||||
repo_id="kernels-community/liger_kernels",
|
||||
|
||||
@ -236,7 +236,7 @@ class PeftAdapterMixin:
|
||||
**adapter_kwargs,
|
||||
)
|
||||
peft_config.inference_mode = not is_trainable
|
||||
# TODO: WE NEED TOO APPLY OUR DYNAMIC WEIGHT CONVERSION AT SOME POINT HERE!
|
||||
|
||||
# Create and add fresh new adapters into the model.
|
||||
inject_adapter_in_model(peft_config, self, adapter_name, **peft_load_kwargs)
|
||||
|
||||
|
||||
@ -18,7 +18,6 @@ import operator
|
||||
import os
|
||||
import re
|
||||
from functools import partial, reduce
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@ -307,7 +306,7 @@ def repack_weights(
|
||||
return final_ordered_tensor
|
||||
|
||||
|
||||
def get_tensor_shard(param, empty_param, device_mesh, rank, dim, tensor_idx: Optional[int] = None):
|
||||
def get_tensor_shard(param, empty_param, device_mesh, rank, dim):
|
||||
"""
|
||||
Generalized tensor sharding across a multi-dimensional device mesh.
|
||||
Extract only the fraction of the parameter owned by the given `rank` when the parameter would have gone sharding at provided `dim`.
|
||||
@ -359,57 +358,32 @@ def get_tensor_shard(param, empty_param, device_mesh, rank, dim, tensor_idx: Opt
|
||||
rank (int): Global rank of the current process/device.
|
||||
dim (int): Dimension along which to shard the tensor.
|
||||
"""
|
||||
param_dim = empty_param.ndim
|
||||
param_dim = empty_param.dim()
|
||||
|
||||
if dim < 0:
|
||||
dim = param_dim + dim
|
||||
if dim >= param_dim:
|
||||
raise ValueError(f"dim {dim} is out of bounds for tensor of dimension {param_dim}")
|
||||
|
||||
# Flatten the mesh to get the total number of devices
|
||||
mesh_shape = device_mesh.shape
|
||||
world_size = reduce(operator.mul, mesh_shape)
|
||||
if dim < 0:
|
||||
dim = param_dim + dim
|
||||
if empty_param.dim() == 3 and dim == 1 and len(param.get_shape()) == 2:
|
||||
dim = 0
|
||||
elif empty_param.dim() == 3 and dim == 2 and len(param.get_shape()) == 2:
|
||||
dim = 0
|
||||
|
||||
shard_size = math.ceil(empty_param.size(dim) / world_size)
|
||||
start = rank * shard_size
|
||||
end = min(start + shard_size, empty_param.size(dim))
|
||||
|
||||
if dim >= param_dim:
|
||||
raise ValueError(f"dim {dim} is out of bounds for tensor of dimension {param_dim}")
|
||||
|
||||
if rank >= world_size:
|
||||
raise ValueError(f"Rank {rank} is out of bounds for mesh size {world_size}")
|
||||
|
||||
# we have the full tensor not 1 part of it.
|
||||
# in that case, we just assume that the weight was properly saved
|
||||
# and thus because we TP if the layer is colwise it should not use this. Layer should be packed_colwise
|
||||
# to inform that it needs to read form a packed tensor. It will also take care of the module list thingy.
|
||||
# here we take care of potential chunking / layer split / layer chunking.
|
||||
# The only "hard" case is? if we collect q,k,v -> merge it into qkv. In that case
|
||||
# actually we still shard dim=0 does not change
|
||||
# so only case is if the dim of the empty param is 3 and the shard dim is 0 -> we put the
|
||||
# tensor on a certain device (with the input tensor_index)
|
||||
dimensions = param.get_shape()
|
||||
shard_size = math.ceil(empty_param.shape[dim] / world_size)
|
||||
start = rank * shard_size
|
||||
|
||||
if empty_param.dim() == 3 and dim == 0 and len(param.get_shape()) == 2:
|
||||
# special case we don't "shard" just send this entire tensor to the correct rank.
|
||||
if start <= tensor_idx < end:
|
||||
# this tensor does need to be materialized on this device:
|
||||
return param[:]
|
||||
else:
|
||||
return torch.empty([], dtype=torch.int64, device=rank)
|
||||
|
||||
slice_indices = [slice(None)] * len(param.get_shape())
|
||||
|
||||
if start < param.get_shape()[dim]:
|
||||
# Construct slicing index dynamically
|
||||
end = min(start + shard_size, empty_param.shape[dim])
|
||||
slice_indices = [slice(None)] * param_dim
|
||||
if start < empty_param.shape[dim]:
|
||||
slice_indices[dim] = slice(start, end)
|
||||
param = param[tuple(slice_indices)]
|
||||
if isinstance(param, list): # TODO handle the modulelist case!
|
||||
param = [p[:] for p in param]
|
||||
return param
|
||||
|
||||
return param[tuple(slice_indices)]
|
||||
dimensions = list(param.shape)
|
||||
dimensions[dim] = 0
|
||||
return torch.empty(tuple(dimensions), dtype=torch.int64) # empty allocates memory....
|
||||
return torch.empty(tuple(dimensions), dtype=torch.int64)
|
||||
|
||||
|
||||
def distribute_module(
|
||||
@ -436,19 +410,6 @@ class TensorParallelLayer:
|
||||
"""
|
||||
|
||||
use_dtensor = True
|
||||
device_mesh = None
|
||||
rank = None
|
||||
|
||||
# Used to compare the shape of the original tensor
|
||||
empty_param = None
|
||||
|
||||
# Used to init the corresponding DTensor
|
||||
shard = None
|
||||
|
||||
def __init__(self, device_mesh=None, rank=None, empty_param=None):
|
||||
self.rank = rank
|
||||
self.device_mesh = device_mesh
|
||||
self.empty_param = empty_param
|
||||
|
||||
@staticmethod
|
||||
def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh): ...
|
||||
@ -478,12 +439,12 @@ class GatherParallel(TensorParallelLayer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
input_layouts: Placement | None = None,
|
||||
output_layouts: Placement | None = None,
|
||||
use_local_output: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
super().__init__()
|
||||
self.input_layouts = (input_layouts or Replicate(),)
|
||||
self.output_layouts = output_layouts
|
||||
self.desired_input_layouts = (Replicate(),)
|
||||
@ -504,21 +465,6 @@ class GatherParallel(TensorParallelLayer):
|
||||
dist.all_reduce(outputs[0], op=dist.ReduceOp.SUM, async_op=False)
|
||||
return outputs
|
||||
|
||||
def shard_tensor(
|
||||
self,
|
||||
param,
|
||||
param_type=None,
|
||||
param_casting_dtype=None,
|
||||
to_contiguous=None,
|
||||
rank=None,
|
||||
device_mesh=None,
|
||||
tensor_idx=None,
|
||||
):
|
||||
shard = [Replicate()]
|
||||
parameter = param[...]
|
||||
self.shard = shard
|
||||
return parameter, shard
|
||||
|
||||
def prepare_module_tp(self, module: nn.Module, device_mesh) -> nn.Module:
|
||||
distribute_module(
|
||||
module,
|
||||
@ -547,23 +493,6 @@ class IsolatedParallel(TensorParallelLayer):
|
||||
# TODO: figure out dynamo support for instance method and switch this to instance method
|
||||
return outputs
|
||||
|
||||
def shard_tensor(
|
||||
self,
|
||||
param,
|
||||
param_type=None,
|
||||
param_casting_dtype=None,
|
||||
to_contiguous=None,
|
||||
rank=None,
|
||||
device_mesh=None,
|
||||
tensor_idx=None,
|
||||
):
|
||||
mesh = device_mesh or self.device_mesh
|
||||
parameter = param[...]
|
||||
if mesh is not None:
|
||||
parameter = parameter / mesh.size()
|
||||
self.shard = None
|
||||
return parameter, None
|
||||
|
||||
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
|
||||
param = param[...].to(param_casting_dtype)
|
||||
if to_contiguous:
|
||||
@ -586,8 +515,8 @@ class ReplicateParallel(TensorParallelLayer):
|
||||
This class is used to replicate computation in a TP layer (used in SP regions when we don't use sequence parallelism for example)
|
||||
"""
|
||||
|
||||
def __init__(self, use_dtensor=True, use_local_output=True, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
def __init__(self, *, use_dtensor=True, use_local_output=True):
|
||||
super().__init__()
|
||||
self.input_layouts = (Replicate(),)
|
||||
self.output_layouts = (Replicate(),)
|
||||
self.desired_input_layouts = (Replicate(),)
|
||||
@ -608,33 +537,12 @@ class ReplicateParallel(TensorParallelLayer):
|
||||
def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
|
||||
return outputs.to_local() if use_local_output and isinstance(outputs, DTensor) else outputs
|
||||
|
||||
def shard_tensor(
|
||||
self,
|
||||
param,
|
||||
param_type=None,
|
||||
param_casting_dtype=None,
|
||||
to_contiguous=None,
|
||||
rank=None,
|
||||
device_mesh=None,
|
||||
tensor_idx=None,
|
||||
):
|
||||
parameter = param[...]
|
||||
shard = [Replicate()]
|
||||
self.shard = shard
|
||||
return parameter, shard
|
||||
|
||||
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
|
||||
parameter, shard = self.shard_tensor(
|
||||
param,
|
||||
param_type=param_type,
|
||||
param_casting_dtype=param_casting_dtype,
|
||||
to_contiguous=to_contiguous,
|
||||
rank=rank,
|
||||
device_mesh=device_mesh,
|
||||
)
|
||||
if self.use_dtensor:
|
||||
parameter = DTensor.from_local(parameter, device_mesh, shard, run_check=False)
|
||||
return parameter
|
||||
param = param[...].to(param_casting_dtype)
|
||||
if to_contiguous:
|
||||
param = param.contiguous()
|
||||
param = DTensor.from_local(param, device_mesh, [Replicate()], run_check=False)
|
||||
return param
|
||||
|
||||
|
||||
class ColwiseParallel(TensorParallelLayer):
|
||||
@ -644,13 +552,13 @@ class ColwiseParallel(TensorParallelLayer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
input_layouts: Placement | None = None,
|
||||
output_layouts: Placement | None = None,
|
||||
use_local_output: bool = True,
|
||||
use_dtensor=True,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
super().__init__()
|
||||
self.input_layouts = (input_layouts or Replicate(),)
|
||||
self.output_layouts = (output_layouts or Shard(-1),)
|
||||
self.desired_input_layouts = (Replicate(),)
|
||||
@ -670,24 +578,17 @@ class ColwiseParallel(TensorParallelLayer):
|
||||
input_tensor = input_tensor.redistribute(placements=desired_input_layouts, async_op=False)
|
||||
return input_tensor
|
||||
|
||||
def shard_tensor(self, param, param_type=None, tensor_idx=None):
|
||||
device_mesh = self.device_mesh
|
||||
empty_param = self.empty_param
|
||||
rank = self.rank
|
||||
if param_type == "bias":
|
||||
parameter = get_tensor_shard(param, empty_param, device_mesh, rank, -1, tensor_idx)
|
||||
shard = [Shard(-1)]
|
||||
else:
|
||||
shard = [Shard(-2)]
|
||||
parameter = get_tensor_shard(param, empty_param, device_mesh, rank, -2, tensor_idx)
|
||||
self.shard = shard
|
||||
return parameter, shard
|
||||
|
||||
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
|
||||
# colwise shard weight/bias to Shard(0), weight be Shard(-2) (0 if you have 1 dim only)
|
||||
# means Colwise as Linear is input * weight^T + bias, where
|
||||
# weight would become Shard(1)
|
||||
parameter, shard = self.shard_tensor(param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh)
|
||||
if param_type == "bias":
|
||||
parameter = get_tensor_shard(param, empty_param, device_mesh, rank, -1)
|
||||
shard = [Shard(-1)]
|
||||
else:
|
||||
shard = [Shard(-2)]
|
||||
parameter = get_tensor_shard(param, empty_param, device_mesh, rank, -2)
|
||||
|
||||
parameter = parameter.to(param_casting_dtype)
|
||||
if to_contiguous:
|
||||
parameter = parameter.contiguous()
|
||||
@ -707,26 +608,6 @@ class ColwiseParallel(TensorParallelLayer):
|
||||
|
||||
|
||||
class PackedColwiseParallel(ColwiseParallel):
|
||||
def shard_tensor(
|
||||
self,
|
||||
param,
|
||||
param_type=None,
|
||||
param_casting_dtype=None,
|
||||
to_contiguous=None,
|
||||
rank=None,
|
||||
device_mesh=None,
|
||||
tensor_idx=None,
|
||||
):
|
||||
device_mesh = device_mesh or self.device_mesh
|
||||
empty_param = self.empty_param
|
||||
rank = rank if rank is not None else self.rank
|
||||
return get_packed_weights(param, empty_param, device_mesh, rank, -2), [Shard(-2)]
|
||||
|
||||
def create_nn_parameter(
|
||||
self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh
|
||||
):
|
||||
return nn.Parameter(param, requires_grad=param.is_floating_point())
|
||||
|
||||
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
|
||||
# colwise shard weight/bias to Shard(0), weight be Shard(-2) (0 if you have 1 dim only)
|
||||
# means Colwise as Linear is input * weight^T + bias, where
|
||||
@ -761,40 +642,18 @@ class RowwiseParallel(TensorParallelLayer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
input_layouts: Placement | None = None,
|
||||
output_layouts: Placement | None = None,
|
||||
use_local_output: bool = True,
|
||||
use_dtensor=True,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
super().__init__()
|
||||
self.input_layouts = (input_layouts or Shard(-1),)
|
||||
self.output_layouts = (output_layouts or Replicate(),)
|
||||
self.use_local_output = use_local_output
|
||||
self.use_dtensor = use_dtensor
|
||||
|
||||
def shard_tensor(
|
||||
self,
|
||||
param,
|
||||
param_type=None,
|
||||
param_casting_dtype=None,
|
||||
to_contiguous=None,
|
||||
rank=None,
|
||||
device_mesh=None,
|
||||
tensor_idx=None,
|
||||
):
|
||||
device_mesh = device_mesh or self.device_mesh
|
||||
empty_param = self.empty_param
|
||||
rank = rank if rank is not None else self.rank
|
||||
if param_type == "bias":
|
||||
shard = [Replicate()]
|
||||
parameter = param[:]
|
||||
else:
|
||||
parameter = get_tensor_shard(param, empty_param, device_mesh, rank, -1, tensor_idx=tensor_idx)
|
||||
shard = [Shard(-1)]
|
||||
self.shard = shard
|
||||
return parameter, shard
|
||||
|
||||
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
|
||||
# Rowwise shard weight to Shard(1), bias to Replicate(), weight be Shard(1)
|
||||
# means Rowwise as nn.Linear is input * weight^T + bias, where
|
||||
@ -866,21 +725,6 @@ class RowwiseParallel(TensorParallelLayer):
|
||||
|
||||
|
||||
class PackedRowwiseParallel(RowwiseParallel):
|
||||
def shard_tensor(
|
||||
self,
|
||||
param,
|
||||
param_type=None,
|
||||
param_casting_dtype=None,
|
||||
to_contiguous=None,
|
||||
rank=None,
|
||||
device_mesh=None,
|
||||
tensor_idx=None,
|
||||
):
|
||||
device_mesh = device_mesh or self.device_mesh
|
||||
empty_param = self.empty_param
|
||||
rank = rank if rank is not None else self.rank
|
||||
return get_packed_weights(param, empty_param, device_mesh, rank, -1), [Shard(-1)]
|
||||
|
||||
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
|
||||
# colwise shard weight/bias to Shard(0), weight be Shard(-2) (0 if you have 1 dim only)
|
||||
# means Colwise as Linear is input * weight^T + bias, where
|
||||
@ -939,8 +783,8 @@ class SequenceParallel(TensorParallelLayer):
|
||||
to ensure that they are replicated.
|
||||
"""
|
||||
|
||||
def __init__(self, sequence_dim: int = 1, use_local_output: bool = False, use_dtensor=False, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
def __init__(self, *, sequence_dim: int = 1, use_local_output: bool = False, use_dtensor=False):
|
||||
super().__init__()
|
||||
self.input_layouts = (Replicate(),)
|
||||
self.desired_input_layouts = (Shard(1),)
|
||||
self.output_layouts = (Replicate(),)
|
||||
@ -949,21 +793,6 @@ class SequenceParallel(TensorParallelLayer):
|
||||
self.sequence_sharding = (Shard(sequence_dim),)
|
||||
self.use_local_output = use_local_output
|
||||
|
||||
def shard_tensor(
|
||||
self,
|
||||
param,
|
||||
param_type=None,
|
||||
param_casting_dtype=None,
|
||||
to_contiguous=None,
|
||||
rank=None,
|
||||
device_mesh=None,
|
||||
tensor_idx=None,
|
||||
):
|
||||
parameter = param[...]
|
||||
shard = [Replicate()]
|
||||
self.shard = shard
|
||||
return parameter, shard
|
||||
|
||||
@staticmethod
|
||||
def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh):
|
||||
input_tensor = inputs[0]
|
||||
@ -998,34 +827,10 @@ class GroupedGemmParallel(TensorParallelLayer):
|
||||
Applies Expert Parallelism to MoE experts by loading the correct experts on each device.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.use_dtensor = False
|
||||
|
||||
def shard_tensor(
|
||||
self,
|
||||
param,
|
||||
param_type=None,
|
||||
param_casting_dtype=None,
|
||||
to_contiguous=None,
|
||||
rank=None,
|
||||
device_mesh=None,
|
||||
tensor_idx=None,
|
||||
):
|
||||
empty_param = self.empty_param
|
||||
ep_rank = self.rank
|
||||
device_mesh = self.device_mesh
|
||||
|
||||
global_num_experts = empty_param.shape[0]
|
||||
if global_num_experts % device_mesh.size() != 0:
|
||||
raise ValueError(
|
||||
f"Global number of experts must be divisible by number of devices: {global_num_experts} % {device_mesh.size()} != 0"
|
||||
)
|
||||
local_num_experts = global_num_experts // device_mesh.size()
|
||||
parameter = param[ep_rank * local_num_experts : (ep_rank + 1) * local_num_experts]
|
||||
self.shard = None
|
||||
return parameter, None
|
||||
|
||||
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
|
||||
ep_rank = rank
|
||||
global_num_experts = empty_param.shape[0]
|
||||
@ -1046,8 +851,8 @@ class RouterParallel(TensorParallelLayer):
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
self.use_dtensor = False
|
||||
|
||||
@staticmethod
|
||||
@ -1112,20 +917,6 @@ class RouterParallel(TensorParallelLayer):
|
||||
) # masking class for one hot
|
||||
return router_scores, router_indices
|
||||
|
||||
def shard_tensor(
|
||||
self,
|
||||
param,
|
||||
param_type=None,
|
||||
param_casting_dtype=None,
|
||||
to_contiguous=None,
|
||||
rank=None,
|
||||
device_mesh=None,
|
||||
tensor_idx=None,
|
||||
):
|
||||
parameter = param[...]
|
||||
self.shard = None
|
||||
return parameter, None
|
||||
|
||||
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
|
||||
# TODO: i'd like for this to be the default
|
||||
param = param[...].to(param_casting_dtype)
|
||||
@ -1268,9 +1059,6 @@ def shard_and_distribute_module(
|
||||
if current_shard_plan is not None:
|
||||
try:
|
||||
tp_layer = ALL_PARALLEL_STYLES[current_shard_plan]
|
||||
tp_layer.empty_param = empty_param
|
||||
tp_layer.device_mesh = device_mesh
|
||||
tp_layer.rank = rank
|
||||
param = tp_layer.partition_tensor(
|
||||
param, empty_param, param_type, param_casting_dtype, is_contiguous, rank, device_mesh
|
||||
)
|
||||
|
||||
@ -23,11 +23,10 @@ import json
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
import warnings
|
||||
from abc import abstractmethod
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable, Sequence
|
||||
from collections.abc import Callable
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from contextlib import contextmanager
|
||||
from enum import Enum
|
||||
@ -46,17 +45,17 @@ from torch.distributions import constraints
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from .configuration_utils import PreTrainedConfig
|
||||
from .conversion_mapping import get_checkpoint_conversion_mapping
|
||||
from .core_model_loading import WeightConverter, convert_and_load_state_dict_in_model, revert_weight_conversion
|
||||
from .distributed import DistributedConfig
|
||||
from .dynamic_module_utils import custom_object_save
|
||||
from .generation import CompileConfig, GenerationConfig
|
||||
from .integrations import PeftAdapterMixin, deepspeed_config, is_deepspeed_zero3_enabled, is_fsdp_enabled
|
||||
from .integrations.accelerate import (
|
||||
_get_device_map,
|
||||
accelerate_disk_offload,
|
||||
accelerate_dispatch,
|
||||
check_and_set_device_map,
|
||||
expand_device_map,
|
||||
find_tied_parameters,
|
||||
init_empty_weights,
|
||||
)
|
||||
from .integrations.deepspeed import _load_state_dict_into_zero3_model
|
||||
@ -123,7 +122,6 @@ from .utils.import_utils import (
|
||||
is_sagemaker_mp_enabled,
|
||||
is_tracing,
|
||||
)
|
||||
from .utils.loading_report import log_state_dict_report
|
||||
from .utils.quantization_config import QuantizationMethod
|
||||
|
||||
|
||||
@ -132,6 +130,7 @@ if is_accelerate_available():
|
||||
from accelerate.utils import (
|
||||
extract_model_from_parallel,
|
||||
offload_weight,
|
||||
save_offload_index,
|
||||
)
|
||||
from accelerate.utils.modeling import get_state_dict_from_offload
|
||||
|
||||
@ -697,6 +696,82 @@ def _load_state_dict_into_meta_model(
|
||||
return disk_offload_index
|
||||
|
||||
|
||||
def load_shard_file(args):
|
||||
(
|
||||
shard_file,
|
||||
state_dict,
|
||||
disk_only_shard_files,
|
||||
is_quantized,
|
||||
device_map,
|
||||
hf_quantizer,
|
||||
key_renaming_mapping,
|
||||
weights_only,
|
||||
model,
|
||||
reverse_key_renaming_mapping,
|
||||
disk_offload_folder,
|
||||
disk_offload_index,
|
||||
device_mesh,
|
||||
) = args
|
||||
|
||||
# Skip the load for shards that only contain disk-offloaded weights
|
||||
if shard_file in disk_only_shard_files:
|
||||
return [], disk_offload_index
|
||||
|
||||
map_location = "cpu"
|
||||
if shard_file.endswith(".safetensors") and not (is_deepspeed_zero3_enabled() and not is_quantized):
|
||||
map_location = "meta"
|
||||
|
||||
# If shard_file is "", we use the existing state_dict instead of loading it
|
||||
if shard_file != "":
|
||||
state_dict = load_state_dict(
|
||||
shard_file, is_quantized=is_quantized, map_location=map_location, weights_only=weights_only
|
||||
)
|
||||
|
||||
# Fix the key names
|
||||
state_dict = {key_renaming_mapping[k]: v for k, v in state_dict.items() if k in key_renaming_mapping}
|
||||
|
||||
error_msgs = []
|
||||
if is_deepspeed_zero3_enabled() and not is_quantized:
|
||||
error_msgs += _load_state_dict_into_zero3_model(model, state_dict)
|
||||
# Skip it with fsdp on ranks other than 0
|
||||
elif not (is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized):
|
||||
disk_offload_index = _load_state_dict_into_meta_model(
|
||||
model,
|
||||
state_dict,
|
||||
shard_file,
|
||||
reverse_key_renaming_mapping,
|
||||
device_map=device_map,
|
||||
disk_offload_folder=disk_offload_folder,
|
||||
disk_offload_index=disk_offload_index,
|
||||
hf_quantizer=hf_quantizer,
|
||||
device_mesh=device_mesh,
|
||||
)
|
||||
|
||||
return error_msgs, disk_offload_index
|
||||
|
||||
|
||||
def load_shard_files_with_threadpool(args_list):
|
||||
num_workers = int(os.environ.get("HF_PARALLEL_LOADING_WORKERS", "8"))
|
||||
|
||||
# Do not spawn anymore workers than you need
|
||||
num_workers = min(len(args_list), num_workers)
|
||||
|
||||
logger.info(f"Loading model weights in parallel with {num_workers} workers...")
|
||||
|
||||
error_msgs = []
|
||||
|
||||
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
||||
with logging.tqdm(total=len(args_list), desc="Loading checkpoint shards") as pbar:
|
||||
futures = [executor.submit(load_shard_file, arg) for arg in args_list]
|
||||
for future in as_completed(futures):
|
||||
_error_msgs, disk_offload_index = future.result()
|
||||
|
||||
error_msgs += _error_msgs
|
||||
|
||||
pbar.update(1)
|
||||
|
||||
return error_msgs, disk_offload_index
|
||||
|
||||
|
||||
def _add_variant(weights_name: str, variant: Optional[str] = None) -> str:
|
||||
if variant is not None:
|
||||
@ -1099,6 +1174,104 @@ def _get_dtype(
|
||||
return config, dtype, dtype_orig
|
||||
|
||||
|
||||
def _find_missing_and_unexpected_keys(
|
||||
model: "PreTrainedModel",
|
||||
original_checkpoint_keys: list[str],
|
||||
checkpoint_keys: list[str],
|
||||
loading_base_model_from_task_state_dict: bool,
|
||||
hf_quantizer: Optional[HfQuantizer],
|
||||
) -> tuple[list[str], list[str]]:
|
||||
"""Find missing keys (keys that are part of the model parameters but were NOT found in the loaded state dict keys) and unexpected keys
|
||||
(keys found in the loaded state dict keys, but that are NOT part of the model parameters)
|
||||
"""
|
||||
prefix = model.base_model_prefix
|
||||
|
||||
# Compute expected keys, i.e. keys that the full model expects
|
||||
expected_keys = list(model.state_dict().keys())
|
||||
if hf_quantizer is not None:
|
||||
expected_keys = hf_quantizer.update_expected_keys(model, expected_keys, checkpoint_keys)
|
||||
|
||||
# Adjust prefix of the keys to make them match loaded keys before removing them
|
||||
missing_keys = sorted(set(expected_keys) - set(checkpoint_keys))
|
||||
unexpected_keys = set(checkpoint_keys) - set(expected_keys)
|
||||
# If a module has the same name under the base and task specific model, we have to re-add it to unexpected keys
|
||||
if loading_base_model_from_task_state_dict:
|
||||
task_specific_keys = [k for k in original_checkpoint_keys if not k.startswith(f"{prefix}.")]
|
||||
unexpected_keys.update(task_specific_keys)
|
||||
|
||||
# Remove nonpersistent buffers from unexpected keys: they are not in the expected keys (model state dict), but
|
||||
# may be in the loaded keys. Note that removing all buffers does the job, as they were part of the expected keys anyway
|
||||
model_buffers = {n for n, _ in model.named_buffers()}
|
||||
unexpected_keys = sorted(unexpected_keys - model_buffers)
|
||||
|
||||
tied_params = find_tied_parameters(model)
|
||||
for group in tied_params:
|
||||
missing_in_group = [k for k in missing_keys if k in group]
|
||||
if len(missing_in_group) > 0 and len(missing_in_group) < len(group):
|
||||
missing_keys = [k for k in missing_keys if k not in missing_in_group]
|
||||
|
||||
if hf_quantizer is not None:
|
||||
missing_keys = hf_quantizer.update_missing_keys(model, missing_keys, prefix)
|
||||
unexpected_keys = hf_quantizer.update_unexpected_keys(model, unexpected_keys)
|
||||
|
||||
return missing_keys, unexpected_keys
|
||||
|
||||
|
||||
def _find_mismatched_keys(
|
||||
model: "PreTrainedModel",
|
||||
state_dict: Optional[dict],
|
||||
checkpoint_files: Optional[list[str]],
|
||||
ignore_mismatched_sizes: bool,
|
||||
keys_to_rename_mapping: dict[str, str],
|
||||
is_quantized: bool,
|
||||
weights_only: bool,
|
||||
) -> tuple[list[str], list[tuple[int, int]]]:
|
||||
"""
|
||||
Find potential shape mismatch between the different state dicts and the model parameters, but only if `ignore_mismatched_sizes`
|
||||
is True. Otherwise, return immediately and any shape mismatch that may exist will be raised later on. This avoids checking
|
||||
every parameter in advance, as shape mismatch are extremely rare in practice. If we want to ignore them however, we do
|
||||
need to check in advance as we need to know which parameters we need to move back from meta to cpu, and initialize
|
||||
correctly. Indeed, as our model initialization takes place at the module level, and not the weight level, in the
|
||||
case of a sharded checkpoint we cannot correctly initialize the weights according to `model._init_weights()` if we perform
|
||||
this check on each state dict at loading time (after the first loaded checkpoint, there are no way to initialize only the
|
||||
mismatched weights if any, without overwriting the previously loaded weights as well because all the module will be
|
||||
initialized, not only the weights that are mismatched).
|
||||
"""
|
||||
|
||||
# An error will be raised later on anyway if there is a mismatch - this avoids running the rest of this function
|
||||
# if there are no mismatch (which is almost always the case)
|
||||
if not ignore_mismatched_sizes:
|
||||
return [], []
|
||||
|
||||
if state_dict is not None:
|
||||
checkpoint_files = [""]
|
||||
|
||||
model_state_dict = model.state_dict()
|
||||
mismatched_keys = []
|
||||
mismatched_shapes = []
|
||||
for shard_file in checkpoint_files:
|
||||
# If shard_file is "", we use the existing state_dict instead of loading it
|
||||
if shard_file != "":
|
||||
state_dict = load_state_dict(
|
||||
shard_file, is_quantized=is_quantized, map_location="meta", weights_only=weights_only
|
||||
)
|
||||
|
||||
# Fix the key names
|
||||
new_state_dict = {keys_to_rename_mapping[k]: v for k, v in state_dict.items() if k in keys_to_rename_mapping}
|
||||
|
||||
for key, tensor in new_state_dict.items():
|
||||
if key in model_state_dict and tensor.shape != model_state_dict[key].shape:
|
||||
# This skips size mismatches for 4-bit weights. Two 4-bit values share an 8-bit container, causing size differences.
|
||||
# Without matching with module type or parameter type it seems like a practical way to detect valid 4bit weights.
|
||||
if not (
|
||||
is_quantized and tensor.shape[-1] == 1 and tensor.numel() * 2 == model_state_dict[key].numel()
|
||||
):
|
||||
mismatched_keys.append(key)
|
||||
mismatched_shapes.append((tensor.shape, model_state_dict[key].shape))
|
||||
|
||||
return mismatched_keys, mismatched_shapes
|
||||
|
||||
|
||||
class PipelineParallel(Enum):
|
||||
inputs = 0
|
||||
outputs = 1
|
||||
@ -1504,8 +1677,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
# to also prevent bfloat16 casting, use the _keep_in_fp32_modules_strict flag
|
||||
_keep_in_fp32_modules_strict = None
|
||||
|
||||
dtype_plan: Optional[dict[str, torch.dtype]] = None
|
||||
|
||||
# a list of `re` patterns of `state_dict` keys that should be removed from the list of missing
|
||||
# keys we find (keys inside the model but not in the checkpoint) and avoid unnecessary warnings.
|
||||
_keys_to_ignore_on_load_missing = None
|
||||
@ -1670,18 +1841,11 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
self.name_or_path = config.name_or_path
|
||||
self.warnings_issued = {}
|
||||
self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None
|
||||
|
||||
# Overwrite the class attribute to make it an instance attribute, so models like
|
||||
# `InstructBlipForConditionalGeneration` can dynamically update it without modifying the class attribute
|
||||
# when a different component (e.g. language_model) is used.
|
||||
self._keep_in_fp32_modules = copy.copy(self.__class__._keep_in_fp32_modules)
|
||||
self._keep_in_fp32_modules_strict = copy.copy(self.__class__._keep_in_fp32_modules_strict)
|
||||
self.dtype_plan = {}
|
||||
|
||||
if isinstance(self._keep_in_fp32_modules, list):
|
||||
self.dtype_plan.update(dict.fromkeys(self._keep_in_fp32_modules, torch.float32))
|
||||
if isinstance(self._keep_in_fp32_modules_strict, list):
|
||||
self.dtype_plan.update(dict.fromkeys(self._keep_in_fp32_modules_strict, torch.float32))
|
||||
|
||||
self._no_split_modules = self._no_split_modules or []
|
||||
_CAN_RECORD_REGISTRY[str(self.__class__)] = self._can_record_outputs # added for executorch support only
|
||||
@ -1697,6 +1861,31 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
self.init_weights()
|
||||
self._backward_compatibility_gradient_checkpointing()
|
||||
|
||||
# Make sure the modules correctly exist if the flag is active
|
||||
if self._keep_in_fp32_modules is not None or self._keep_in_fp32_modules_strict is not None:
|
||||
all_parameters = {name for name, _ in self.named_parameters() if len(name) > 0}
|
||||
unique_module_names = set()
|
||||
# Get all unique module names in the module graph, without the prefixes
|
||||
for param in all_parameters:
|
||||
unique_module_names.update(
|
||||
[name for name in param.split(".") if not name.isnumeric() and name not in ["weight", "bias"]]
|
||||
)
|
||||
# Check that every module in the keep_in_fp32 list is part of the module graph
|
||||
if self._keep_in_fp32_modules is not None:
|
||||
for module in self._keep_in_fp32_modules:
|
||||
if module not in unique_module_names:
|
||||
raise ValueError(
|
||||
f"{module} was specified in the `_keep_in_fp32_modules` list, but is not part of the modules in"
|
||||
f" {self.__class__.__name__}"
|
||||
)
|
||||
|
||||
if self._keep_in_fp32_modules_strict is not None:
|
||||
for module in self._keep_in_fp32_modules_strict:
|
||||
if module not in unique_module_names:
|
||||
raise ValueError(
|
||||
f"{module} was specified in the `_keep_in_fp32_modules_strict` list, but is not part of the modules in"
|
||||
f" {self.__class__.__name__}"
|
||||
)
|
||||
|
||||
self._tp_plan, self._ep_plan, self._pp_plan = {}, {}, {}
|
||||
# If current model is a base model, attach `base_model_tp_plan` and `base_model_pp_plan` from config
|
||||
@ -2443,41 +2632,34 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
`nn.Parameter`, this method should also be overridden in order to initialize it correctly.
|
||||
"""
|
||||
if hasattr(self.config, "initializer_range"):
|
||||
std = self.config.initializer_range or 0.02
|
||||
std = self.config.initializer_range
|
||||
else:
|
||||
# 0.02 is the standard default value across the library
|
||||
std = getattr(self.config.get_text_config(), "initializer_range", 0.02)
|
||||
|
||||
try:
|
||||
if isinstance(
|
||||
module, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.ConvTranspose1d, nn.ConvTranspose2d)
|
||||
):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.Parameter):
|
||||
module.data.normal_(mean=0.0, std=std)
|
||||
elif isinstance(module, nn.MultiheadAttention):
|
||||
# This uses torch's original init
|
||||
module._reset_parameters()
|
||||
# We cannot use `isinstance` on the RMSNorms or LayerNorms, as they usually are custom modules which change names
|
||||
# between modelings (because they are prefixed with the model name)
|
||||
elif (
|
||||
isinstance(module, (nn.GroupNorm, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d))
|
||||
or "LayerNorm" in module.__class__.__name__
|
||||
or "RMSNorm" in module.__class__.__name__
|
||||
):
|
||||
# Norms can exist without weights (in which case they are None from torch primitives)
|
||||
if hasattr(module, "weight") and module.weight is not None:
|
||||
module.weight.data.fill_(1.0)
|
||||
if hasattr(module, "bias") and module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to init: {str(e)}")
|
||||
if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.ConvTranspose1d, nn.ConvTranspose2d)):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.MultiheadAttention):
|
||||
# This uses torch's original init
|
||||
module._reset_parameters()
|
||||
# We cannot use `isinstance` on the RMSNorms or LayerNorms, as they usually are custom modules which change names
|
||||
# between modelings (because they are prefixed with the model name)
|
||||
elif (
|
||||
isinstance(module, (nn.GroupNorm, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d))
|
||||
or "LayerNorm" in module.__class__.__name__
|
||||
or "RMSNorm" in module.__class__.__name__
|
||||
):
|
||||
# Norms can exist without weights (in which case they are None from torch primitives)
|
||||
if hasattr(module, "weight") and module.weight is not None:
|
||||
module.weight.data.fill_(1.0)
|
||||
if hasattr(module, "bias") and module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
def _initialize_weights(self, module):
|
||||
"""
|
||||
@ -2512,12 +2694,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
else:
|
||||
module.smart_apply(fn)
|
||||
fn(self)
|
||||
if not isinstance(self, nn.Parameter):
|
||||
for name, param in self.named_parameters(recurse=False):
|
||||
if param is None:
|
||||
continue
|
||||
fn(param)
|
||||
|
||||
return self
|
||||
|
||||
torch.nn.Module.smart_apply = smart_apply
|
||||
@ -3281,7 +3457,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
variant: Optional[str] = None,
|
||||
token: Optional[Union[str, bool]] = None,
|
||||
save_peft_format: bool = True,
|
||||
save_original_format: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@ -3330,10 +3505,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
For backward compatibility with PEFT library, in case adapter weights are attached to the model, all
|
||||
keys of the state dict of adapters needs to be prepended with `base_model.model`. Advanced users can
|
||||
disable this behaviours by setting `save_peft_format` to `False`.
|
||||
save_original_format (`bool`, *optional*, defaults to `True`):
|
||||
For backward compatibility with the previous versions of `transfomers` you can save the checkpoint with
|
||||
its reverse mapping. The reverse mapping needs to exists even if the model was loaded from a None legacy
|
||||
checkpoint.
|
||||
kwargs (`dict[str, Any]`, *optional*):
|
||||
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
||||
"""
|
||||
@ -3473,18 +3644,24 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
module_map[name + f".{key}"] = module
|
||||
state_dict = model_to_save.state_dict()
|
||||
|
||||
if (
|
||||
any(
|
||||
allowed_name in class_name.__name__.lower()
|
||||
for class_name in self.__class__.__mro__[:-1]
|
||||
for allowed_name in VLMS
|
||||
)
|
||||
or save_original_format
|
||||
if any(
|
||||
allowed_name in class_name.__name__.lower()
|
||||
for class_name in self.__class__.__mro__[:-1]
|
||||
for allowed_name in VLMS
|
||||
):
|
||||
# MEGA BIG TODO HERE: self._conversion_ops needs to be used to save the final ckpt
|
||||
# using what was loaded. Actually self._conversion_ops wont work because we need it
|
||||
# even if the files are not legacy -> thus no conversion happened
|
||||
state_dict = revert_weight_conversion(self, state_dict)
|
||||
reverse_key_mapping = {v: k for k, v in self._checkpoint_conversion_mapping.items()}
|
||||
|
||||
original_state_dict = {}
|
||||
for key, value in state_dict.items():
|
||||
for pattern, replacement in reverse_key_mapping.items():
|
||||
replacement = replacement.lstrip("^") # strip off un-needed chars and patterns
|
||||
replacement = re.sub(r"\(.*\)", "", replacement)
|
||||
key, n_replace = re.subn(pattern, replacement, key)
|
||||
# Early exit of the loop
|
||||
if n_replace > 0:
|
||||
break
|
||||
original_state_dict[key] = value
|
||||
state_dict = original_state_dict
|
||||
|
||||
# Translate state_dict from smp to hf if saving with smp >= 1.10
|
||||
if IS_SAGEMAKER_MP_POST_1_10:
|
||||
@ -3652,8 +3829,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
|
||||
if safe_serialization:
|
||||
# At some point we will need to deal better with save_function (used for TPU and other distributed
|
||||
# joyfulness), but for now this enough. # TODO: we should def parallelize this we are otherwise just waiting
|
||||
# too much before scheduling the next write when its in a different file
|
||||
# joyfulness), but for now this enough.
|
||||
safe_save_file(shard, os.path.join(save_directory, shard_file), metadata=metadata)
|
||||
else:
|
||||
save_function(shard, os.path.join(save_directory, shard_file))
|
||||
@ -4219,13 +4395,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
config, quantization_config, dtype, device_map, weights_only, user_agent
|
||||
)
|
||||
|
||||
weight_conversions: Optional[list[WeightConverter]] = None
|
||||
model_type = getattr(config, "model_type", None)
|
||||
if model_type is not None:
|
||||
weight_conversions = get_checkpoint_conversion_mapping().get(model_type)
|
||||
if weight_conversions is None:
|
||||
weight_conversions = get_checkpoint_conversion_mapping()["legacy"]
|
||||
|
||||
if gguf_file:
|
||||
if hf_quantizer is not None:
|
||||
raise ValueError(
|
||||
@ -4281,6 +4450,11 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
# Let's make sure we don't run the init function of buffer modules
|
||||
model = cls(config, *model_args, **model_kwargs)
|
||||
|
||||
# Potentially upcast some modules to avoid loosing precision
|
||||
model.upcast_modules_in_fp32(hf_quantizer, dtype)
|
||||
# Make sure to tie the weights correctly
|
||||
model.tie_weights()
|
||||
|
||||
# make sure we use the model's config since the __init__ call might have copied it
|
||||
config = model.config
|
||||
|
||||
@ -4288,7 +4462,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
hf_quantizer.preprocess_model(
|
||||
model=model,
|
||||
device_map=device_map,
|
||||
keep_in_fp32_modules=model._keep_in_fp32_modules, # TODO prob no longer needed?
|
||||
keep_in_fp32_modules=model._keep_in_fp32_modules,
|
||||
config=config,
|
||||
checkpoint_files=checkpoint_files,
|
||||
use_kernels=use_kernels,
|
||||
@ -4320,15 +4494,15 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
device_mesh=device_mesh,
|
||||
key_mapping=key_mapping,
|
||||
weights_only=weights_only,
|
||||
weight_mapping=weight_conversions,
|
||||
)
|
||||
|
||||
model.tie_weights() # make sure token embedding weights are still tied if needed
|
||||
model.eval() # Set model in evaluation mode to deactivate DropOut modules by default
|
||||
model.set_use_kernels(use_kernels, kernel_config)
|
||||
|
||||
# If it is a model with generation capabilities, attempt to load generation files (generation config,
|
||||
# custom generate function)
|
||||
if model.can_generate() and hasattr(model, "adjust_generation_fn") and trust_remote_code:
|
||||
if model.can_generate() and hasattr(model, "adjust_generation_fn"):
|
||||
model.adjust_generation_fn(
|
||||
generation_config,
|
||||
from_auto_class,
|
||||
@ -4339,16 +4513,17 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# for device_map="auto" : dispatch model with hooks on all devices if necessary
|
||||
# for device_map="auto" : dispatch model with hooks on all devices if necessary (not needed with a tp_plan, so we skip it as it slightly
|
||||
# harm performances).
|
||||
if device_map is not None and device_mesh is None:
|
||||
accelerate_dispatch(model, hf_quantizer, device_map, offload_folder, offload_index, offload_buffers)
|
||||
|
||||
if hf_quantizer is not None:
|
||||
model.hf_quantizer = hf_quantizer
|
||||
hf_quantizer.postprocess_model(model, config=config) # usually a no-op but sometimes needed
|
||||
hf_quantizer.postprocess_model(model, config=config) # usually a no-op
|
||||
|
||||
if _adapter_model_path is not None:
|
||||
adapter_kwargs["key_mapping"] = weight_conversions # TODO: Dynamic weight loader for adapters
|
||||
adapter_kwargs["key_mapping"] = key_mapping # TODO: Dynamic weight loader for adapters
|
||||
model.load_adapter(
|
||||
_adapter_model_path,
|
||||
adapter_name=adapter_name,
|
||||
@ -4366,6 +4541,107 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
return model, loading_info
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def _fix_state_dict_key_on_load(key: str) -> tuple[str, bool]:
|
||||
"""Replace legacy parameter names with their modern equivalents. E.g. beta -> bias, gamma -> weight."""
|
||||
# Rename LayerNorm beta & gamma params for some early models ported from Tensorflow (e.g. Bert)
|
||||
# This rename is logged.
|
||||
if key.endswith("LayerNorm.beta"):
|
||||
return key.replace("LayerNorm.beta", "LayerNorm.bias"), True
|
||||
if key.endswith("LayerNorm.gamma"):
|
||||
return key.replace("LayerNorm.gamma", "LayerNorm.weight"), True
|
||||
|
||||
# Rename weight norm parametrizations to match changes across torch versions.
|
||||
# Impacts a number of speech/wav2vec models. e.g. Hubert, Wav2Vec2, and others.
|
||||
# This rename is not logged.
|
||||
if hasattr(nn.utils.parametrizations, "weight_norm"):
|
||||
if key.endswith("weight_g"):
|
||||
return key.replace("weight_g", "parametrizations.weight.original0"), True
|
||||
if key.endswith("weight_v"):
|
||||
return key.replace("weight_v", "parametrizations.weight.original1"), True
|
||||
else:
|
||||
if key.endswith("parametrizations.weight.original0"):
|
||||
return key.replace("parametrizations.weight.original0", "weight_g"), True
|
||||
if key.endswith("parametrizations.weight.original1"):
|
||||
return key.replace("parametrizations.weight.original1", "weight_v"), True
|
||||
|
||||
return key, False
|
||||
|
||||
def _get_key_renaming_mapping(
|
||||
self,
|
||||
checkpoint_keys: list[str],
|
||||
key_mapping: Optional[dict[str, str]] = None,
|
||||
loading_base_model_from_task_state_dict: bool = False,
|
||||
loading_task_model_from_base_state_dict: bool = False,
|
||||
):
|
||||
"""
|
||||
Compute a mapping between the serialized keys on disk `checkpoint_keys`, and the keys that the model
|
||||
that we are loading expects. This is the single entry point for key renaming that will be used during
|
||||
loading.
|
||||
Log if any parameters have been renamed.
|
||||
"""
|
||||
prefix = self.base_model_prefix
|
||||
_prefix = f"{prefix}."
|
||||
|
||||
if loading_task_model_from_base_state_dict:
|
||||
task_specific_expected_keys, base_model_keys = [], []
|
||||
for key in self.state_dict():
|
||||
if key.startswith(_prefix):
|
||||
base_model_keys.append(key[len(_prefix) :])
|
||||
else:
|
||||
task_specific_expected_keys.append(key)
|
||||
|
||||
renamed_keys = {}
|
||||
key_renaming_mapping = {}
|
||||
for key in checkpoint_keys:
|
||||
# Class specific rename
|
||||
new_key, has_changed = self._fix_state_dict_key_on_load(key)
|
||||
|
||||
# Optionally map the key according to `key_mapping`
|
||||
if key_mapping is not None:
|
||||
for pattern, replacement in key_mapping.items():
|
||||
new_key, n_replace = re.subn(pattern, replacement, new_key)
|
||||
# Early exit of the loop
|
||||
if n_replace > 0:
|
||||
has_changed = True
|
||||
break
|
||||
|
||||
# In this case, we need to add the prefix to the keys, to match them to the expected keys
|
||||
if loading_task_model_from_base_state_dict:
|
||||
# small sanity check: if we find a key that is only part of the task-specific keys, we raise
|
||||
# (if it's also part of the base model, we do not raise and assume it comes from there)
|
||||
if new_key in task_specific_expected_keys and new_key not in base_model_keys:
|
||||
raise ValueError(
|
||||
"The state dictionary of the model you are trying to load is corrupted. Are you sure it was "
|
||||
"properly saved?"
|
||||
)
|
||||
new_key = ".".join([prefix, new_key])
|
||||
# In this case we need to remove the prefix from the key to match them to the expected keys, and use
|
||||
# only the keys starting with the prefix
|
||||
elif loading_base_model_from_task_state_dict:
|
||||
if not new_key.startswith(_prefix):
|
||||
continue
|
||||
new_key = new_key[len(_prefix) :]
|
||||
|
||||
key_renaming_mapping[key] = new_key
|
||||
|
||||
# track gamma/beta rename for logging
|
||||
if has_changed:
|
||||
if key.endswith("LayerNorm.gamma"):
|
||||
renamed_keys["LayerNorm.gamma"] = (key, new_key)
|
||||
elif key.endswith("LayerNorm.beta"):
|
||||
renamed_keys["LayerNorm.beta"] = (key, new_key)
|
||||
|
||||
if renamed_keys:
|
||||
warning_msg = f"A pretrained model of type `{self.__class__.__name__}` "
|
||||
warning_msg += "contains parameters that have been renamed internally (a few are listed below but more are present in the model):\n"
|
||||
for old_key, new_key in renamed_keys.values():
|
||||
warning_msg += f"* `{old_key}` -> `{new_key}`\n"
|
||||
warning_msg += "If you are using a model from the Hub, consider submitting a PR to adjust these weights and help future users."
|
||||
logger.info_once(warning_msg)
|
||||
|
||||
return key_renaming_mapping
|
||||
|
||||
@staticmethod
|
||||
def _fix_state_dict_key_on_save(key) -> tuple[str, bool]:
|
||||
"""
|
||||
@ -4397,16 +4673,97 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
device_mesh: Optional["torch.distributed.device_mesh.DeviceMesh"] = None,
|
||||
key_mapping: Optional[dict[str, str]] = None,
|
||||
weights_only: bool = True,
|
||||
weight_mapping: Optional[Sequence[WeightConverter]] = None,
|
||||
):
|
||||
# TODO: we should only be calling hf_quantizer.skip_placement or something like that
|
||||
is_quantized = hf_quantizer is not None
|
||||
is_hqq_or_quark = is_quantized and hf_quantizer.quantization_config.quant_method in {
|
||||
QuantizationMethod.HQQ,
|
||||
QuantizationMethod.QUARK,
|
||||
}
|
||||
|
||||
# Model's definition arriving here is final (TP hooks added, quantized layers replaces)
|
||||
# Get all the keys of the state dicts that we have to initialize the model with
|
||||
if sharded_metadata is not None:
|
||||
original_checkpoint_keys = sharded_metadata["all_checkpoint_keys"]
|
||||
elif state_dict is not None:
|
||||
original_checkpoint_keys = list(state_dict.keys())
|
||||
else:
|
||||
original_checkpoint_keys = list(
|
||||
load_state_dict(checkpoint_files[0], map_location="meta", weights_only=weights_only).keys()
|
||||
)
|
||||
|
||||
# Check if we are in a special state, i.e. loading from a state dict coming from a different architecture
|
||||
prefix = model.base_model_prefix
|
||||
has_prefix_module = any(s.startswith(prefix) for s in original_checkpoint_keys) if len(prefix) > 0 else False
|
||||
expects_prefix_module = hasattr(model, prefix) if len(prefix) > 0 else False
|
||||
loading_task_model_from_base_state_dict = not has_prefix_module and expects_prefix_module
|
||||
loading_base_model_from_task_state_dict = has_prefix_module and not expects_prefix_module
|
||||
|
||||
# Find the key names that the model expects from the serialized keys
|
||||
key_renaming_mapping = model._get_key_renaming_mapping(
|
||||
original_checkpoint_keys,
|
||||
key_mapping,
|
||||
loading_base_model_from_task_state_dict,
|
||||
loading_task_model_from_base_state_dict,
|
||||
)
|
||||
checkpoint_keys = list(key_renaming_mapping.values())
|
||||
|
||||
# Find missing and unexpected keys from the state dict
|
||||
missing_keys, unexpected_keys = _find_missing_and_unexpected_keys(
|
||||
model, original_checkpoint_keys, checkpoint_keys, loading_base_model_from_task_state_dict, hf_quantizer
|
||||
)
|
||||
# Find all the keys with shape mismatch (if we ignore the mismatch, the weights need to be newly initialized the
|
||||
# same way as missing keys)
|
||||
mismatched_keys, mismatched_shapes = _find_mismatched_keys(
|
||||
model,
|
||||
state_dict,
|
||||
checkpoint_files,
|
||||
ignore_mismatched_sizes,
|
||||
key_renaming_mapping,
|
||||
is_quantized,
|
||||
weights_only,
|
||||
)
|
||||
|
||||
# We need to update both the mapping and the list of checkpoint keys to remove the mismatched and unexpected ones
|
||||
key_renaming_mapping = {
|
||||
k: v for k, v in key_renaming_mapping.items() if v not in mismatched_keys and v not in unexpected_keys
|
||||
}
|
||||
checkpoint_keys = list(key_renaming_mapping.values())
|
||||
|
||||
# Move missing (and potentially mismatched) keys back to cpu from meta device (because they won't be moved when
|
||||
# loading the weights as they are not in the loaded state dict)
|
||||
model._move_missing_keys_from_meta_to_cpu(missing_keys + mismatched_keys, dtype, hf_quantizer)
|
||||
|
||||
# correctly initialize the missing (and potentially mismatched) keys
|
||||
model._initialize_missing_keys(missing_keys + mismatched_keys, is_quantized)
|
||||
|
||||
# Get reverse key mapping
|
||||
reverse_key_renaming_mapping = {v: k for k, v in key_renaming_mapping.items()}
|
||||
|
||||
is_offloaded_safetensors = False
|
||||
# This offload index if for params explicitly on the "disk" in the device_map
|
||||
disk_offload_index = None
|
||||
disk_only_shard_files = []
|
||||
# Prepare parameters offloading if needed
|
||||
if device_map is not None and "disk" in device_map.values():
|
||||
disk_offload_index, disk_only_shard_files, is_offloaded_safetensors = accelerate_disk_offload(
|
||||
disk_offload_folder,
|
||||
checkpoint_files,
|
||||
device_map,
|
||||
checkpoint_keys,
|
||||
key_renaming_mapping,
|
||||
sharded_metadata,
|
||||
dtype,
|
||||
reverse_key_renaming_mapping,
|
||||
)
|
||||
# To be able to iterate, even if we don't use it if the state_dict is already provided
|
||||
elif state_dict is not None:
|
||||
checkpoint_files = [""]
|
||||
|
||||
# Compute expected model keys
|
||||
expected_keys = list(model.state_dict().keys())
|
||||
if hf_quantizer is not None:
|
||||
expected_keys = hf_quantizer.update_expected_keys(model, expected_keys, checkpoint_keys)
|
||||
|
||||
if logger.level >= logging.WARNING:
|
||||
verify_tp_plan(expected_keys, getattr(model, "_tp_plan", None))
|
||||
|
||||
@ -4415,84 +4772,46 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
expanded_device_map = expand_device_map(device_map, expected_keys)
|
||||
caching_allocator_warmup(model, expanded_device_map, hf_quantizer)
|
||||
|
||||
# Now we read all the files to get a pointer on each physical weights
|
||||
merged_state_dict = {}
|
||||
all_pointer = set()
|
||||
|
||||
if device_map is None:
|
||||
device_map = {"": "cpu"}
|
||||
keys = sorted(device_map.keys(), key=len, reverse=True)
|
||||
tp_plan = getattr(model, "_tp_plan", None)
|
||||
error_msgs = []
|
||||
misc = {}
|
||||
|
||||
if is_deepspeed_zero3_enabled() and not is_quantized:
|
||||
error_msgs += _load_state_dict_into_zero3_model(model, state_dict)
|
||||
else:
|
||||
if checkpoint_files is not None:
|
||||
pattern = re.compile(r"(" + "|".join(map(re.escape, keys)) + r")")
|
||||
if sharded_metadata is None:
|
||||
k_v_iterator = dict.fromkeys(
|
||||
safe_open(checkpoint_files[0], framework="pt").keys(), "model.safetensors"
|
||||
).items()
|
||||
else:
|
||||
k_v_iterator = sharded_metadata["weight_map"].items()
|
||||
|
||||
for k, v in k_v_iterator:
|
||||
match = pattern.match(k)
|
||||
if match and match.group(1) != "":
|
||||
device = device_map[match.group(1)]
|
||||
else:
|
||||
device = device_map.get("", "cpu")
|
||||
if isinstance(device, torch.device):
|
||||
device = device.index # safetensors only
|
||||
if device == "disk":
|
||||
device = "cpu" # we read to cpu to then write to disk
|
||||
file_pointer = safe_open(
|
||||
os.path.join(checkpoint_files[0].rsplit("/", 1)[0], v), framework="pt", device=device
|
||||
)
|
||||
all_pointer.add(file_pointer)
|
||||
merged_state_dict[k] = (v, file_pointer.get_slice(k)) # don't materialize yet
|
||||
elif state_dict is not None:
|
||||
merged_state_dict = {k: ("", v) for k, v in state_dict.items()}
|
||||
else:
|
||||
raise ValueError("Neither a state dict nor checkpoint files were found.")
|
||||
start = time.perf_counter()
|
||||
missing_keys, unexpected_keys, mismatched_keys, misc = convert_and_load_state_dict_in_model(
|
||||
model,
|
||||
merged_state_dict,
|
||||
weight_mapping,
|
||||
tp_plan,
|
||||
hf_quantizer,
|
||||
dtype,
|
||||
# Prepare and compatabilize arguments for serial and parallel shard loading
|
||||
args_list = [
|
||||
(
|
||||
shard_file,
|
||||
state_dict,
|
||||
disk_only_shard_files,
|
||||
is_quantized,
|
||||
device_map,
|
||||
model.dtype_plan,
|
||||
device_mesh=device_mesh,
|
||||
hf_quantizer,
|
||||
key_renaming_mapping,
|
||||
weights_only,
|
||||
model,
|
||||
reverse_key_renaming_mapping,
|
||||
disk_offload_folder,
|
||||
disk_offload_index,
|
||||
device_mesh,
|
||||
)
|
||||
end = time.perf_counter()
|
||||
for shard_file in checkpoint_files
|
||||
]
|
||||
|
||||
for k in all_pointer: # finally close all opened file pointeres
|
||||
k.__exit__(None, None, None)
|
||||
error_msgs = []
|
||||
|
||||
new_state_dict = model.state_dict()
|
||||
if (
|
||||
os.environ.get("HF_ENABLE_PARALLEL_LOADING", "").upper() in ENV_VARS_TRUE_VALUES
|
||||
and not is_deepspeed_zero3_enabled()
|
||||
):
|
||||
_error_msgs, disk_offload_index = load_shard_files_with_threadpool(args_list)
|
||||
error_msgs += _error_msgs
|
||||
else:
|
||||
if len(args_list) > 1:
|
||||
args_list = logging.tqdm(args_list, desc="Loading checkpoint shards")
|
||||
|
||||
#!!!!!!!!!!!!!!!!!!!!!!! POST PROCESS!!!!!!!!!!!!!!!!!!
|
||||
# Check if we are in a special state, i.e. loading from a state dict coming from a different architecture
|
||||
prefix = model.base_model_prefix
|
||||
has_prefix_module = any(s.startswith(prefix) for s in new_state_dict.keys()) if len(prefix) > 0 else False
|
||||
expects_prefix_module = hasattr(model, prefix) if len(prefix) > 0 else False
|
||||
loading_task_model_from_base_state_dict = not has_prefix_module and expects_prefix_module
|
||||
for args in args_list:
|
||||
_error_msgs, disk_offload_index = load_shard_file(args)
|
||||
error_msgs += _error_msgs
|
||||
|
||||
# TODO last TODO here is to tie the weights once and only. If they are missing and False, and if true
|
||||
|
||||
# TODO TODO TODO
|
||||
# Move missing (and potentially mismatched) keys back to cpu from meta device (because they won't be moved when
|
||||
# loading the weights as they are not in the loaded state dict)
|
||||
miss_and_mismatched = missing_keys | {k[0] for k in mismatched_keys}
|
||||
model._move_missing_keys_from_meta_to_cpu(miss_and_mismatched, dtype, hf_quantizer)
|
||||
|
||||
# correctly initialize the missing (and potentially mismatched) keys
|
||||
model._initialize_missing_keys(miss_and_mismatched, is_quantized)
|
||||
# Save offloaded index if needed
|
||||
if disk_offload_index is not None and len(disk_offload_index) > 0 and not is_offloaded_safetensors:
|
||||
save_offload_index(disk_offload_index, disk_offload_folder)
|
||||
disk_offload_index = None
|
||||
|
||||
# Post-processing for tensor parallelism
|
||||
if device_mesh is not None:
|
||||
@ -4500,7 +4819,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
tp_device = list(device_map.values())[0]
|
||||
# This is needed for the RotaryEmbedding, which was not initialized on the correct device as it is
|
||||
# not part of the state_dict (persistent=False)
|
||||
for buffer in model.buffers(): # TODO to avaoid this buffer could be added to the ckpt
|
||||
for buffer in model.buffers():
|
||||
if buffer.device != tp_device:
|
||||
buffer.data = buffer.to(tp_device)
|
||||
|
||||
@ -4527,24 +4846,52 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
device_mesh,
|
||||
)
|
||||
|
||||
# Remove tied weights keys and etc
|
||||
# Remove potential model-specific exceptions from the warnings
|
||||
missing_keys, unexpected_keys = model._adjust_missing_and_unexpected_keys(
|
||||
missing_keys, unexpected_keys, loading_task_model_from_base_state_dict, model
|
||||
missing_keys, unexpected_keys, loading_task_model_from_base_state_dict
|
||||
)
|
||||
logger.warning(f"Loading the checkpoint files into the model took {end - start}")
|
||||
log_state_dict_report(
|
||||
model=model,
|
||||
pretrained_model_name_or_path=pretrained_model_name_or_path,
|
||||
logger=logger,
|
||||
error_msgs=error_msgs,
|
||||
unexpected_keys=unexpected_keys,
|
||||
missing_keys=missing_keys,
|
||||
mismatched_keys=mismatched_keys,
|
||||
mismatched_shapes=mismatched_keys,
|
||||
misc=misc,
|
||||
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
||||
)
|
||||
disk_offload_index = None
|
||||
|
||||
# TODO: separate this in another function: it's not core....
|
||||
# All potential warnings/infos
|
||||
if len(error_msgs) > 0:
|
||||
error_msg = "\n\t".join(error_msgs)
|
||||
if "size mismatch" in error_msg:
|
||||
error_msg += (
|
||||
"\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
|
||||
)
|
||||
raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
|
||||
if len(unexpected_keys) > 0:
|
||||
archs = [] if model.config.architectures is None else model.config.architectures
|
||||
warner = logger.warning if model.__class__.__name__ in archs else logger.info
|
||||
warner(
|
||||
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
|
||||
f" initializing {model.__class__.__name__}: {update_key_name(unexpected_keys)}\n- This IS expected if you are"
|
||||
f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or"
|
||||
" with another architecture (e.g. initializing a BertForSequenceClassification model from a"
|
||||
" BertForPreTraining model).\n- This IS NOT expected if you are initializing"
|
||||
f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly identical"
|
||||
" (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
|
||||
)
|
||||
if len(missing_keys) > 0:
|
||||
logger.warning(
|
||||
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
|
||||
f" {pretrained_model_name_or_path} and are newly initialized: {update_key_name(missing_keys)}\nYou should probably"
|
||||
" TRAIN this model on a down-stream task to be able to use it for predictions and inference."
|
||||
)
|
||||
if len(mismatched_keys) > 0:
|
||||
mismatched_warning = "\n".join(
|
||||
[
|
||||
f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
|
||||
for key, (shape1, shape2) in zip(mismatched_keys, mismatched_shapes)
|
||||
]
|
||||
)
|
||||
logger.warning(
|
||||
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
|
||||
f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
|
||||
f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able"
|
||||
" to use it for predictions and inference."
|
||||
)
|
||||
|
||||
return model, missing_keys, unexpected_keys, mismatched_keys, disk_offload_index, error_msgs
|
||||
|
||||
def retrieve_modules_from_names(self, names, add_prefix=False, remove_prefix=False):
|
||||
@ -4753,6 +5100,8 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
value = torch.empty_like(param, dtype=dtype, device="cpu")
|
||||
if not is_quantized or not hf_quantizer.param_needs_quantization(self, key):
|
||||
_load_parameter_into_model(self, key, value)
|
||||
else:
|
||||
hf_quantizer.create_quantized_param(self, value, key, "cpu")
|
||||
|
||||
def _initialize_missing_keys(self, missing_keys: list[str], is_quantized: bool) -> None:
|
||||
"""Initialize the missing keys (keys that are part of the model parameters, but were NOT found in the loaded state dicts), according to
|
||||
@ -4802,23 +5151,16 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
self.initialize_weights()
|
||||
|
||||
def _adjust_missing_and_unexpected_keys(
|
||||
self, missing_keys: set[str], unexpected_keys: set[str], loading_task_model_from_base_state_dict: bool, model
|
||||
) -> tuple[set[str], set[str]]:
|
||||
self, missing_keys: list[str], unexpected_keys: list[str], loading_task_model_from_base_state_dict: bool
|
||||
) -> tuple[list[str], list[str]]:
|
||||
"""Adjust the `missing_keys` and `unexpected_keys` based on current model's exception rules, to avoid
|
||||
raising unneeded warnings/errors.
|
||||
"""
|
||||
# Old checkpoints may have keys for rotary_emb.inv_freq forach layer, however we moved this buffer to the main model
|
||||
# Old checkpoints may have keys for rotary_emb.inv_freq for each layer, however we moved this buffer to the main model
|
||||
# (so the buffer name has changed). Remove them in such a case. This is another exception that was not added to
|
||||
# `_keys_to_ignore_on_load_unexpected` as it touches many models -> we add it manually to the existing patterns
|
||||
has_inv_freq_buffers = any(buffer.endswith("rotary_emb.inv_freq") for buffer, _ in self.named_buffers())
|
||||
additional_unexpected_patterns = [r"rotary_emb\.inv_freq"] if has_inv_freq_buffers else []
|
||||
tied_param_names = "|".join(model._tied_weights_keys or [])
|
||||
if tied_param_names:
|
||||
model.tie_weights()
|
||||
if model.config.tie_word_embeddings:
|
||||
for k in missing_keys.copy():
|
||||
if re.match(tied_param_names, k):
|
||||
missing_keys.discard(k)
|
||||
|
||||
missing_patterns = self._keys_to_ignore_on_load_missing or []
|
||||
unexpected_patterns = (self._keys_to_ignore_on_load_unexpected or []) + additional_unexpected_patterns
|
||||
@ -4830,17 +5172,17 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
|
||||
# Clean-up missing keys
|
||||
if ignore_missing_regex is not None:
|
||||
missing_keys = {key for key in missing_keys if ignore_missing_regex.search(key) is None}
|
||||
missing_keys = [key for key in missing_keys if ignore_missing_regex.search(key) is None]
|
||||
|
||||
# Clean-up unexpected keys
|
||||
if ignore_unexpected_regex is not None:
|
||||
unexpected_keys = {key for key in unexpected_keys if ignore_unexpected_regex.search(key) is None}
|
||||
unexpected_keys = [key for key in unexpected_keys if ignore_unexpected_regex.search(key) is None]
|
||||
|
||||
# Note: only the unexpected keys should remove the added prefix here, to correctly display the original name
|
||||
# in the warnings. For missing keys, we should show the prefix in the warning as it's part of the final model
|
||||
if loading_task_model_from_base_state_dict:
|
||||
_prefix = f"{self.base_model_prefix}."
|
||||
unexpected_keys = {k.removeprefix(_prefix) for k in unexpected_keys}
|
||||
unexpected_keys = [k.removeprefix(_prefix) for k in unexpected_keys]
|
||||
|
||||
return missing_keys, unexpected_keys
|
||||
|
||||
@ -4877,6 +5219,35 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
def eval(self):
|
||||
return self.train(False)
|
||||
|
||||
def upcast_modules_in_fp32(self, hf_quantizer: HfQuantizer | None, dtype: torch.dtype) -> None:
|
||||
"""
|
||||
Upcast modules defined in `_keep_in_fp32_modules` and `_keep_in_fp32_modules_strict` in fp32, if
|
||||
`dtype` is different than fp32.
|
||||
"""
|
||||
# If the dtype is already fp32, we can skip
|
||||
if dtype == torch.float32:
|
||||
return
|
||||
|
||||
keep_in_fp32_modules = []
|
||||
# The _keep_in_fp32_modules flag is only used to avoid bf16 -> fp16 casting precision issues. It was introduced
|
||||
# in case of force loading a model that should stay bf16 in fp16 (which includes a few quantizers as this is a pre-processing
|
||||
# step for e.g. bitsandbytes). See https://github.com/huggingface/transformers/issues/20287 for details.
|
||||
if self._keep_in_fp32_modules is not None and (
|
||||
dtype == torch.float16 or getattr(hf_quantizer, "use_keep_in_fp32_modules", False)
|
||||
):
|
||||
keep_in_fp32_modules.extend(self._keep_in_fp32_modules)
|
||||
|
||||
if self._keep_in_fp32_modules_strict is not None and (dtype == torch.float16 or dtype == torch.bfloat16):
|
||||
keep_in_fp32_modules.extend(self._keep_in_fp32_modules_strict)
|
||||
|
||||
if len(keep_in_fp32_modules) > 0:
|
||||
# We need to match exact layers, so we add either `.` on each side, or start/end of string
|
||||
keep_in_fp32_regex = re.compile("|".join([rf"((^|\.){module}($|\.))" for module in keep_in_fp32_modules]))
|
||||
for name, param in self.named_parameters():
|
||||
if keep_in_fp32_regex.search(name):
|
||||
# param = param.to(torch.float32) does not work here as only in the local scope.
|
||||
param.data = param.data.to(torch.float32)
|
||||
|
||||
|
||||
PreTrainedModel.push_to_hub = copy_func(PreTrainedModel.push_to_hub)
|
||||
if PreTrainedModel.push_to_hub.__doc__ is not None:
|
||||
|
||||
@ -136,7 +136,7 @@ class ArceeConfig(PreTrainedConfig):
|
||||
bos_token_id: Optional[int] = 128000,
|
||||
eos_token_id: Optional[int] = 128001,
|
||||
tie_word_embeddings: Optional[bool] = False,
|
||||
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
|
||||
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
|
||||
attention_bias: Optional[bool] = False,
|
||||
attention_dropout: Optional[float] = 0.0,
|
||||
mlp_bias: Optional[bool] = False,
|
||||
|
||||
@ -137,7 +137,7 @@ class ArceeConfig(LlamaConfig):
|
||||
bos_token_id: Optional[int] = 128000,
|
||||
eos_token_id: Optional[int] = 128001,
|
||||
tie_word_embeddings: Optional[bool] = False,
|
||||
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
|
||||
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
|
||||
attention_bias: Optional[bool] = False,
|
||||
attention_dropout: Optional[float] = 0.0,
|
||||
mlp_bias: Optional[bool] = False,
|
||||
|
||||
@ -134,7 +134,7 @@ class AriaTextConfig(PreTrainedConfig):
|
||||
eos_token_id: Optional[int] = 2,
|
||||
pretraining_tp: Optional[int] = 1,
|
||||
tie_word_embeddings: Optional[bool] = False,
|
||||
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
|
||||
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
|
||||
attention_bias: Optional[bool] = False,
|
||||
attention_dropout: Optional[float] = 0.0,
|
||||
mlp_bias: Optional[bool] = False,
|
||||
|
||||
@ -117,7 +117,7 @@ class BitNetConfig(PreTrainedConfig):
|
||||
tie_word_embeddings: Optional[bool] = False,
|
||||
attention_bias: Optional[bool] = False,
|
||||
attention_dropout: Optional[str] = 0.0,
|
||||
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
|
||||
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
|
||||
@ -44,7 +44,7 @@ class BltLocalEncoderConfig(PreTrainedConfig):
|
||||
rms_norm_eps: Optional[float] = 1e-5,
|
||||
dropout: Optional[float] = 0.0,
|
||||
max_position_embeddings: Optional[int] = 24576,
|
||||
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
|
||||
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
|
||||
hidden_act: Optional[str] = "silu",
|
||||
intermediate_size: Optional[int] = 2816,
|
||||
initializer_range: Optional[float] = 0.02,
|
||||
@ -99,7 +99,7 @@ class BltLocalDecoderConfig(PreTrainedConfig):
|
||||
rms_norm_eps: Optional[float] = 1e-5,
|
||||
dropout: Optional[float] = 0.0,
|
||||
max_position_embeddings: Optional[int] = 24576,
|
||||
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
|
||||
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
|
||||
hidden_act: Optional[str] = "silu",
|
||||
intermediate_size: Optional[int] = 2816,
|
||||
initializer_range: Optional[float] = 0.02,
|
||||
@ -150,7 +150,7 @@ class BltGlobalTransformerConfig(PreTrainedConfig):
|
||||
rms_norm_eps: Optional[float] = 1e-5,
|
||||
dropout: Optional[float] = 0.0,
|
||||
max_position_embeddings: Optional[int] = 4096,
|
||||
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
|
||||
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
|
||||
hidden_act: Optional[str] = "silu",
|
||||
intermediate_size: Optional[int] = 5632,
|
||||
initializer_range: Optional[float] = 0.02,
|
||||
@ -231,7 +231,7 @@ class BltPatcherConfig(PreTrainedConfig):
|
||||
rms_norm_eps: Optional[float] = 1e-5,
|
||||
dropout: Optional[float] = 0.0,
|
||||
intermediate_size: Optional[int] = 2048,
|
||||
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
|
||||
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
|
||||
initializer_range: Optional[float] = 0.02,
|
||||
**kwargs,
|
||||
):
|
||||
@ -356,7 +356,7 @@ class BltConfig(PreTrainedConfig):
|
||||
global_config: Optional[dict] = None,
|
||||
tie_word_embeddings: Optional[bool] = False,
|
||||
initializer_range: Optional[float] = 0.02,
|
||||
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
|
||||
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
# Basic model configuration
|
||||
|
||||
@ -203,7 +203,7 @@ class ChameleonConfig(PreTrainedConfig):
|
||||
bos_token_id: Optional[int] = 1,
|
||||
eos_token_id: Optional[int] = 2,
|
||||
tie_word_embeddings: Optional[bool] = False,
|
||||
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
|
||||
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
|
||||
attention_bias: Optional[int] = False,
|
||||
attention_dropout: Optional[float] = 0.0,
|
||||
model_parallel_size: Optional[int] = 1,
|
||||
|
||||
@ -139,7 +139,7 @@ class CohereConfig(PreTrainedConfig):
|
||||
bos_token_id: Optional[int] = 5,
|
||||
eos_token_id: Optional[int] = 255001,
|
||||
tie_word_embeddings: Optional[bool] = True,
|
||||
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
|
||||
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
|
||||
attention_bias: Optional[bool] = False,
|
||||
attention_dropout: Optional[float] = 0.0,
|
||||
use_qk_norm: Optional[bool] = False,
|
||||
|
||||
@ -138,7 +138,7 @@ class Cohere2Config(PreTrainedConfig):
|
||||
bos_token_id: Optional[int] = 5,
|
||||
eos_token_id: Optional[int] = 255001,
|
||||
tie_word_embeddings: Optional[bool] = True,
|
||||
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
|
||||
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
|
||||
attention_bias: Optional[bool] = False,
|
||||
attention_dropout: Optional[float] = 0.0,
|
||||
sliding_window: Optional[int] = 4096,
|
||||
|
||||
@ -162,7 +162,7 @@ class Cohere2Config(PreTrainedConfig):
|
||||
bos_token_id: Optional[int] = 5,
|
||||
eos_token_id: Optional[int] = 255001,
|
||||
tie_word_embeddings: Optional[bool] = True,
|
||||
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
|
||||
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
|
||||
attention_bias: Optional[bool] = False,
|
||||
attention_dropout: Optional[float] = 0.0,
|
||||
sliding_window: Optional[int] = 4096,
|
||||
|
||||
@ -39,9 +39,10 @@ from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
from peft import PeftModel
|
||||
from safetensors import safe_open
|
||||
|
||||
from transformers import AutoConfig
|
||||
from transformers import AutoConfig, AutoModel
|
||||
from transformers.models.colqwen2 import ColQwen2ForRetrieval
|
||||
from transformers.models.colqwen2.configuration_colqwen2 import ColQwen2Config
|
||||
from transformers.utils import logging
|
||||
@ -69,7 +70,7 @@ def load_original_state_dict(model_id: str, revision: Optional[str] = None) -> d
|
||||
original_state_dict[key] = f.get_tensor(key)
|
||||
|
||||
# Some weights are tied, so `lm.head`` is not saved. Let's clone to load state dict.
|
||||
if "lm_head.weight" not in original_state_dict:
|
||||
if "lm_head.weight" not in original_state_dict and "model.embed_tokens.weight" in original_state_dict:
|
||||
original_state_dict["lm_head.weight"] = original_state_dict["model.embed_tokens.weight"].clone()
|
||||
|
||||
return original_state_dict
|
||||
@ -124,7 +125,21 @@ def convert_colqwen2_weights_to_hf(
|
||||
config.is_composition = False
|
||||
|
||||
# Load the untrained model
|
||||
model = ColQwen2ForRetrieval(config=config).to("cpu").eval()
|
||||
vlm_name_or_path = getattr(config.vlm_config, "_name_or_path", None)
|
||||
if vlm_name_or_path and "2.5" in str(vlm_name_or_path):
|
||||
print(
|
||||
"Detected colqwen2.5 adapters in vlm_config; loading base model %s and merging PEFT weights."
|
||||
% vlm_name_or_path
|
||||
)
|
||||
base_model = AutoModel.from_pretrained(
|
||||
vlm_name_or_path,
|
||||
device_map="cpu",
|
||||
trust_remote_code=True,
|
||||
)
|
||||
peft_model = PeftModel.from_pretrained(base_model, model_id)
|
||||
model = peft_model.merge_and_unload()
|
||||
else:
|
||||
model = ColQwen2ForRetrieval(config=config).to("cpu").eval()
|
||||
print("Created model with new config and randomly initialized weights")
|
||||
|
||||
# NOTE: The new model was initialized with float32 weights. We need to convert it to the desired precision.
|
||||
@ -201,6 +216,7 @@ if __name__ == "__main__":
|
||||
help="Name or path of the original VLM backbone model",
|
||||
default=None,
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
convert_colqwen2_weights_to_hf(
|
||||
|
||||
@ -172,7 +172,6 @@ class ColQwen2ForRetrieval(ColQwen2PreTrainedModel):
|
||||
inputs_embeds = self.vlm.language_model.embed_tokens(input_ids)
|
||||
|
||||
if pixel_values is not None:
|
||||
pixel_values = pixel_values.type(self.vlm.visual.get_dtype())
|
||||
image_embeds = self.vlm.visual(pixel_values, grid_thw=image_grid_thw)
|
||||
image_mask = (
|
||||
(input_ids == self.config.vlm_config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds)
|
||||
|
||||
@ -359,7 +359,6 @@ class ColQwen2ForRetrieval(ColPaliForRetrieval):
|
||||
inputs_embeds = self.vlm.language_model.embed_tokens(input_ids)
|
||||
|
||||
if pixel_values is not None:
|
||||
pixel_values = pixel_values.type(self.vlm.visual.get_dtype())
|
||||
image_embeds = self.vlm.visual(pixel_values, grid_thw=image_grid_thw)
|
||||
image_mask = (
|
||||
(input_ids == self.config.vlm_config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds)
|
||||
|
||||
@ -122,7 +122,7 @@ class CsmDepthDecoderConfig(PreTrainedConfig):
|
||||
pad_token_id: Optional[int] = None,
|
||||
bos_token_id: Optional[int] = None,
|
||||
eos_token_id: Optional[int] = None,
|
||||
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
|
||||
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
|
||||
attention_bias: Optional[bool] = False,
|
||||
attention_dropout: Optional[float] = 0.0,
|
||||
mlp_bias: Optional[bool] = False,
|
||||
@ -291,7 +291,7 @@ class CsmConfig(PreTrainedConfig):
|
||||
eos_token_id: Optional[int] = None,
|
||||
audio_token_id: Optional[int] = 128002,
|
||||
audio_eos_token_id: Optional[int] = 128003,
|
||||
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
|
||||
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
|
||||
attention_bias: Optional[bool] = False,
|
||||
attention_dropout: Optional[float] = 0.0,
|
||||
mlp_bias: Optional[bool] = False,
|
||||
|
||||
@ -189,7 +189,7 @@ class DbrxConfig(PreTrainedConfig):
|
||||
use_cache: Optional[bool] = True,
|
||||
initializer_range: Optional[float] = 0.02,
|
||||
output_router_logits: Optional[bool] = False,
|
||||
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
|
||||
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
if attn_config is None:
|
||||
|
||||
@ -394,11 +394,10 @@ class DecisionTransformerGPT2PreTrainedModel(PreTrainedModel):
|
||||
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
|
||||
#
|
||||
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
|
||||
if isinstance(module, PreTrainedModel):
|
||||
for name, p in module.named_parameters():
|
||||
if "c_proj" in name and "weight" in name:
|
||||
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
|
||||
p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer)))
|
||||
for name, p in module.named_parameters():
|
||||
if "c_proj" in name and "weight" in name:
|
||||
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
|
||||
p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer)))
|
||||
|
||||
|
||||
class DecisionTransformerGPT2Model(DecisionTransformerGPT2PreTrainedModel):
|
||||
|
||||
@ -127,7 +127,8 @@ class DeepseekV2Config(PreTrainedConfig):
|
||||
"layers.*.self_attn.q_b_proj": "colwise",
|
||||
"layers.*.self_attn.kv_b_proj": "colwise",
|
||||
"layers.*.self_attn.o_proj": "rowwise",
|
||||
"layers.*.mlp.gate_up_proj": "colwise",
|
||||
"layers.*.mlp.gate_proj": "colwise",
|
||||
"layers.*.mlp.up_proj": "colwise",
|
||||
"layers.*.mlp.down_proj": "rowwise",
|
||||
}
|
||||
base_model_pp_plan = {
|
||||
@ -153,7 +154,7 @@ class DeepseekV2Config(PreTrainedConfig):
|
||||
bos_token_id: Optional[int] = 1,
|
||||
eos_token_id: Optional[int] = 2,
|
||||
tie_word_embeddings: Optional[bool] = False,
|
||||
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
|
||||
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
|
||||
attention_bias: Optional[bool] = False,
|
||||
attention_dropout: Optional[float] = 0.0,
|
||||
mlp_bias: Optional[bool] = False,
|
||||
|
||||
@ -42,43 +42,37 @@ from ...utils.generic import check_model_inputs
|
||||
from .configuration_deepseek_v2 import DeepseekV2Config
|
||||
|
||||
|
||||
class DeepseekV2Experts(nn.Module):
|
||||
"""Collection of expert weights stored as 3D tensors."""
|
||||
class DeepseekV2Experts(nn.ModuleList):
|
||||
"""
|
||||
ModuleList of experts.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.num_experts = config.n_routed_experts
|
||||
self.hidden_dim = config.hidden_size
|
||||
self.intermediate_dim = config.moe_intermediate_size
|
||||
self.gate_up_proj = nn.Parameter(torch.zeros(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
|
||||
self.down_proj = nn.Parameter(torch.zeros(self.num_experts, self.hidden_dim, self.intermediate_dim))
|
||||
self.act_fn = ACT2FN[config.hidden_act]
|
||||
for _ in range(config.n_routed_experts):
|
||||
self.append(DeepseekV2MLP(config, intermediate_size=config.moe_intermediate_size))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
top_k_index: torch.Tensor,
|
||||
top_k_weights: torch.Tensor,
|
||||
self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
hidden_states: (batch_size * sequence_length, hidden_dim)
|
||||
selected_experts: (batch_size * sequence_length, top_k)
|
||||
routing_weights: (batch_size * sequence_length, top_k)
|
||||
Returns:
|
||||
(batch_size * sequence_length, hidden_dim)
|
||||
"""
|
||||
final_hidden_states = torch.zeros_like(hidden_states)
|
||||
num_experts = top_k_weights.shape[1]
|
||||
with torch.no_grad():
|
||||
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1)
|
||||
expert_mask = expert_mask.permute(2, 1, 0)
|
||||
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
|
||||
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0)
|
||||
|
||||
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
|
||||
for expert_idx in expert_hit:
|
||||
expert_idx = expert_idx[0]
|
||||
if expert_idx == num_experts:
|
||||
continue
|
||||
_, token_idx = torch.where(expert_mask[expert_idx])
|
||||
current_state = hidden_states[token_idx]
|
||||
gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
|
||||
current_hidden_states = self.act_fn(gate) * up
|
||||
current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
|
||||
current_hidden_states = current_hidden_states * top_k_weights[token_idx, expert_idx, None]
|
||||
final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))
|
||||
|
||||
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
|
||||
current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1])
|
||||
current_hidden_states = self[expert_idx](current_state) * top_k_weights[top_x, idx, None]
|
||||
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
@ -117,7 +111,6 @@ class DeepseekV2Moe(nn.Module):
|
||||
topk_weight, topk_idx = torch.topk(tmp_scores, k=self.top_k, dim=-1, sorted=False)
|
||||
|
||||
topk_weight = topk_weight * self.routed_scaling_factor
|
||||
topk_weight = torch.zeros_like(router_logits).scatter_(1, topk_idx, topk_weight)
|
||||
return topk_idx, topk_weight
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
@ -142,7 +142,8 @@ class DeepseekV2Config(LlamaConfig):
|
||||
"layers.*.self_attn.q_b_proj": "colwise",
|
||||
"layers.*.self_attn.kv_b_proj": "colwise",
|
||||
"layers.*.self_attn.o_proj": "rowwise",
|
||||
"layers.*.mlp.gate_up_proj": "colwise",
|
||||
"layers.*.mlp.gate_proj": "colwise",
|
||||
"layers.*.mlp.up_proj": "colwise",
|
||||
"layers.*.mlp.down_proj": "rowwise",
|
||||
}
|
||||
|
||||
@ -166,7 +167,7 @@ class DeepseekV2Config(LlamaConfig):
|
||||
bos_token_id: Optional[int] = 1,
|
||||
eos_token_id: Optional[int] = 2,
|
||||
tie_word_embeddings: Optional[bool] = False,
|
||||
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
|
||||
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
|
||||
attention_bias: Optional[bool] = False,
|
||||
attention_dropout: Optional[float] = 0.0,
|
||||
mlp_bias: Optional[bool] = False,
|
||||
@ -223,10 +224,12 @@ def apply_rotary_emb(
|
||||
return xq_out, xk_out
|
||||
|
||||
|
||||
class DeepseekV2Experts(Qwen2MoeExperts):
|
||||
class DeepseekV2Experts(Qwen2MoeExperts, nn.ModuleList):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
nn.ModuleList.__init__(self)
|
||||
self.num_experts = config.n_routed_experts
|
||||
for _ in range(config.n_routed_experts):
|
||||
self.append(DeepseekV2MLP(config, intermediate_size=config.moe_intermediate_size))
|
||||
|
||||
|
||||
class DeepseekV2Moe(nn.Module):
|
||||
@ -264,7 +267,6 @@ class DeepseekV2Moe(nn.Module):
|
||||
topk_weight, topk_idx = torch.topk(tmp_scores, k=self.top_k, dim=-1, sorted=False)
|
||||
|
||||
topk_weight = topk_weight * self.routed_scaling_factor
|
||||
topk_weight = torch.zeros_like(router_logits).scatter_(1, topk_idx, topk_weight)
|
||||
return topk_idx, topk_weight
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
@ -186,7 +186,7 @@ class DeepseekV3Config(PreTrainedConfig):
|
||||
eos_token_id: Optional[int] = 1,
|
||||
pretraining_tp: Optional[int] = 1,
|
||||
tie_word_embeddings: Optional[bool] = False,
|
||||
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
|
||||
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
|
||||
rope_interleave: Optional[bool] = True,
|
||||
attention_bias: Optional[bool] = False,
|
||||
attention_dropout: Optional[float] = 0.0,
|
||||
|
||||
@ -149,43 +149,37 @@ class DeepseekV3TopkRouter(nn.Module):
|
||||
return router_logits
|
||||
|
||||
|
||||
class DeepseekV3NaiveMoe(nn.Module):
|
||||
"""Collection of expert weights stored as 3D tensors."""
|
||||
class DeepseekV3NaiveMoe(nn.ModuleList):
|
||||
"""
|
||||
ModuleList of experts.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.num_experts = config.num_local_experts
|
||||
self.hidden_dim = config.hidden_size
|
||||
self.intermediate_dim = config.intermediate_size
|
||||
self.gate_up_proj = nn.Parameter(torch.zeros(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
|
||||
self.down_proj = nn.Parameter(torch.zeros(self.num_experts, self.hidden_dim, self.intermediate_dim))
|
||||
self.act_fn = ACT2FN[config.hidden_act]
|
||||
for _ in range(self.num_experts):
|
||||
self.append(DeepseekV3MLP(config, intermediate_size=config.moe_intermediate_size))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
top_k_index: torch.Tensor,
|
||||
top_k_weights: torch.Tensor,
|
||||
self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
hidden_states: (batch_size * sequence_length, hidden_dim)
|
||||
selected_experts: (batch_size * sequence_length, top_k)
|
||||
routing_weights: (batch_size * sequence_length, top_k)
|
||||
Returns:
|
||||
(batch_size * sequence_length, hidden_dim)
|
||||
"""
|
||||
final_hidden_states = torch.zeros_like(hidden_states)
|
||||
num_experts = top_k_weights.shape[1]
|
||||
with torch.no_grad():
|
||||
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1)
|
||||
expert_mask = expert_mask.permute(2, 1, 0)
|
||||
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
|
||||
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0)
|
||||
|
||||
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
|
||||
for expert_idx in expert_hit:
|
||||
expert_idx = expert_idx[0]
|
||||
if expert_idx == num_experts:
|
||||
continue
|
||||
_, token_idx = torch.where(expert_mask[expert_idx])
|
||||
current_state = hidden_states[token_idx]
|
||||
gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
|
||||
current_hidden_states = self.act_fn(gate) * up
|
||||
current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
|
||||
current_hidden_states = current_hidden_states * top_k_weights[token_idx, expert_idx, None]
|
||||
final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))
|
||||
|
||||
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
|
||||
current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1])
|
||||
current_hidden_states = self[expert_idx](current_state) * top_k_weights[top_x, idx, None]
|
||||
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
|
||||
@ -102,10 +102,12 @@ class DeepseekV3TopkRouter(nn.Module):
|
||||
return router_logits
|
||||
|
||||
|
||||
class DeepseekV3NaiveMoe(MixtralExperts):
|
||||
class DeepseekV3NaiveMoe(MixtralExperts, nn.ModuleList):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
nn.ModuleList.__init__(self)
|
||||
self.num_experts = config.num_local_experts
|
||||
for _ in range(self.num_experts):
|
||||
self.append(DeepseekV3MLP(config, intermediate_size=config.moe_intermediate_size))
|
||||
|
||||
|
||||
class DeepseekV3MoE(nn.Module):
|
||||
|
||||
@ -118,7 +118,7 @@ class DiffLlamaConfig(PreTrainedConfig):
|
||||
bos_token_id: Optional[int] = 1,
|
||||
eos_token_id: Optional[int] = 2,
|
||||
tie_word_embeddings: Optional[bool] = False,
|
||||
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
|
||||
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
|
||||
attention_bias: Optional[bool] = False,
|
||||
attention_dropout: Optional[float] = 0.0,
|
||||
lambda_std_dev: Optional[float] = 0.1,
|
||||
|
||||
@ -148,7 +148,7 @@ class DogeConfig(PreTrainedConfig):
|
||||
use_cache: Optional[bool] = True,
|
||||
tie_word_embeddings: Optional[bool] = False,
|
||||
max_position_embeddings: Optional[int] = 2048,
|
||||
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
|
||||
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
|
||||
num_attention_heads: Optional[int] = 8,
|
||||
num_key_value_heads: Optional[int] = None,
|
||||
attention_bias: Optional[bool] = False,
|
||||
|
||||
@ -176,7 +176,7 @@ class DogeConfig(PreTrainedConfig):
|
||||
use_cache: Optional[bool] = True,
|
||||
tie_word_embeddings: Optional[bool] = False,
|
||||
max_position_embeddings: Optional[int] = 2048,
|
||||
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
|
||||
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
|
||||
num_attention_heads: Optional[int] = 8,
|
||||
num_key_value_heads: Optional[int] = None,
|
||||
attention_bias: Optional[bool] = False,
|
||||
|
||||
@ -159,7 +159,7 @@ class Dots1Config(PreTrainedConfig):
|
||||
rms_norm_eps: Optional[int] = 1e-6,
|
||||
use_cache: Optional[bool] = True,
|
||||
tie_word_embeddings: Optional[bool] = False,
|
||||
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
|
||||
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
|
||||
attention_bias: Optional[bool] = False,
|
||||
attention_dropout: Optional[float] = 0.0,
|
||||
routed_scaling_factor: Optional[float] = 1.0,
|
||||
|
||||
@ -305,43 +305,37 @@ class Dots1TopkRouter(nn.Module):
|
||||
return router_logits
|
||||
|
||||
|
||||
class Dots1NaiveMoe(nn.Module):
|
||||
"""Collection of expert weights stored as 3D tensors."""
|
||||
class Dots1NaiveMoe(nn.ModuleList):
|
||||
"""
|
||||
ModuleList of experts.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.num_experts = config.num_local_experts
|
||||
self.hidden_dim = config.hidden_size
|
||||
self.intermediate_dim = config.intermediate_size
|
||||
self.gate_up_proj = nn.Parameter(torch.zeros(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
|
||||
self.down_proj = nn.Parameter(torch.zeros(self.num_experts, self.hidden_dim, self.intermediate_dim))
|
||||
self.act_fn = ACT2FN[config.hidden_act]
|
||||
for _ in range(self.num_experts):
|
||||
self.append(Dots1MLP(config, intermediate_size=config.moe_intermediate_size))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
top_k_index: torch.Tensor,
|
||||
top_k_weights: torch.Tensor,
|
||||
self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
hidden_states: (batch_size * sequence_length, hidden_dim)
|
||||
selected_experts: (batch_size * sequence_length, top_k)
|
||||
routing_weights: (batch_size * sequence_length, top_k)
|
||||
Returns:
|
||||
(batch_size * sequence_length, hidden_dim)
|
||||
"""
|
||||
final_hidden_states = torch.zeros_like(hidden_states)
|
||||
num_experts = top_k_weights.shape[1]
|
||||
with torch.no_grad():
|
||||
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1)
|
||||
expert_mask = expert_mask.permute(2, 1, 0)
|
||||
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
|
||||
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0)
|
||||
|
||||
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
|
||||
for expert_idx in expert_hit:
|
||||
expert_idx = expert_idx[0]
|
||||
if expert_idx == num_experts:
|
||||
continue
|
||||
_, token_idx = torch.where(expert_mask[expert_idx])
|
||||
current_state = hidden_states[token_idx]
|
||||
gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
|
||||
current_hidden_states = self.act_fn(gate) * up
|
||||
current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
|
||||
current_hidden_states = current_hidden_states * top_k_weights[token_idx, expert_idx, None]
|
||||
final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))
|
||||
|
||||
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
|
||||
current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1])
|
||||
current_hidden_states = self[expert_idx](current_state) * top_k_weights[top_x, idx, None]
|
||||
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
|
||||
@ -125,7 +125,7 @@ class Ernie4_5Config(PreTrainedConfig):
|
||||
bos_token_id: Optional[int] = 1,
|
||||
eos_token_id: Optional[int] = 2,
|
||||
tie_word_embeddings: Optional[bool] = True,
|
||||
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
|
||||
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
|
||||
use_bias: Optional[bool] = False,
|
||||
head_dim: Optional[int] = 128,
|
||||
**kwargs,
|
||||
|
||||
@ -161,7 +161,7 @@ class Ernie4_5_MoeConfig(PreTrainedConfig):
|
||||
rms_norm_eps: Optional[int] = 1e-5,
|
||||
use_cache: Optional[bool] = True,
|
||||
tie_word_embeddings: Optional[bool] = True,
|
||||
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
|
||||
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
|
||||
use_bias: Optional[int] = False,
|
||||
moe_intermediate_size: Optional[int] = 1536,
|
||||
moe_k: Optional[int] = 6,
|
||||
|
||||
@ -315,95 +315,62 @@ class Ernie4_5_MoeStatics(nn.Module):
|
||||
return hidden_states + self.e_score_correction_bias.squeeze()
|
||||
|
||||
|
||||
class Ernie4_5_MoeExperts(nn.Module):
|
||||
class Ernie4_5_MoeExperts(nn.ModuleList):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.num_experts = config.moe_num_experts
|
||||
self.hidden_dim = config.hidden_size
|
||||
self.intermediate_dim = config.moe_intermediate_size
|
||||
self.use_bias = config.use_bias
|
||||
self.act_fn = ACT2FN[config.hidden_act]
|
||||
|
||||
self.gate_up_proj = nn.Parameter(torch.zeros(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
|
||||
self.down_proj = nn.Parameter(torch.zeros(self.num_experts, self.hidden_dim, self.intermediate_dim))
|
||||
if self.use_bias:
|
||||
self.gate_up_proj_bias = nn.Parameter(torch.zeros(self.num_experts, 2 * self.intermediate_dim))
|
||||
self.down_proj_bias = nn.Parameter(torch.zeros(self.num_experts, self.hidden_dim))
|
||||
else:
|
||||
self.gate_up_proj_bias = None
|
||||
self.down_proj_bias = None
|
||||
for _ in range(self.num_experts):
|
||||
self.append(Ernie4_5_MoeMLP(config, config.moe_intermediate_size))
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.Tensor, selected_experts: torch.Tensor, routing_weights: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
final_hidden_states = torch.zeros_like(hidden_states)
|
||||
if selected_experts.numel() == 0:
|
||||
return final_hidden_states
|
||||
|
||||
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
|
||||
|
||||
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
|
||||
for expert_idx in expert_hit:
|
||||
expert_idx = int(expert_idx.item())
|
||||
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
|
||||
current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1])
|
||||
gate_inputs = F.linear(
|
||||
current_state,
|
||||
self.gate_up_proj[expert_idx],
|
||||
None if self.gate_up_proj_bias is None else self.gate_up_proj_bias[expert_idx],
|
||||
)
|
||||
gate, up = gate_inputs.chunk(2, dim=-1)
|
||||
current_hidden_states = self.act_fn(gate) * up
|
||||
current_hidden_states = F.linear(
|
||||
current_hidden_states,
|
||||
self.down_proj[expert_idx],
|
||||
None if self.down_proj_bias is None else self.down_proj_bias[expert_idx],
|
||||
)
|
||||
current_hidden_states = current_hidden_states * routing_weights[top_x, idx, None]
|
||||
current_hidden_states = self[expert_idx](current_state) * routing_weights[top_x, idx, None]
|
||||
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
class Ernie4_5_MoeTopKRouter(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.zeros(config.moe_num_experts, config.hidden_size, dtype=torch.float32))
|
||||
self.moe_statics = Ernie4_5_MoeStatics(config)
|
||||
self.top_k = config.moe_k
|
||||
self.norm_min = config.moe_norm_min
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
device_type = (
|
||||
hidden_states.device.type
|
||||
if isinstance(hidden_states.device.type, str) and hidden_states.device.type != "mps"
|
||||
else "cpu"
|
||||
)
|
||||
|
||||
with torch.autocast(device_type=device_type, enabled=False): # Force float32
|
||||
router_logits = F.linear(hidden_states.float(), self.weight)
|
||||
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
||||
_, selected_experts = torch.topk(self.moe_statics(routing_weights), self.top_k, dim=-1)
|
||||
routing_weights = torch.gather(routing_weights, dim=-1, index=selected_experts)
|
||||
routing_weights = routing_weights / torch.clamp(
|
||||
routing_weights.sum(dim=-1, keepdim=True), min=self.norm_min
|
||||
)
|
||||
routing_weights = routing_weights.to(router_logits.dtype)
|
||||
return routing_weights, selected_experts
|
||||
|
||||
|
||||
class Ernie4_5_MoeSparseMoeBlock(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.hidden_dim = config.hidden_size
|
||||
self.num_experts = config.moe_num_experts
|
||||
self.top_k = config.moe_k
|
||||
self.gate = Ernie4_5_MoeTopKRouter(config)
|
||||
self.norm_min = config.moe_norm_min
|
||||
|
||||
self.gate = nn.Linear(config.hidden_size, config.moe_num_experts, bias=False, dtype=torch.float32)
|
||||
self.moe_statics = Ernie4_5_MoeStatics(config)
|
||||
self.experts = Ernie4_5_MoeExperts(config)
|
||||
|
||||
self.shared_experts = None
|
||||
if config.moe_num_shared_experts > 0:
|
||||
self.shared_experts = Ernie4_5_MoeMLP(config, config.moe_intermediate_size * config.moe_num_shared_experts)
|
||||
|
||||
def route_tokens_to_experts(self, hidden_states):
|
||||
device_type = (
|
||||
hidden_states.device.type
|
||||
if isinstance(hidden_states.device.type, str) and hidden_states.device.type != "mps"
|
||||
else "cpu"
|
||||
)
|
||||
|
||||
with torch.autocast(device_type=device_type, enabled=False): # Force float32
|
||||
router_logits = self.gate(hidden_states.float())
|
||||
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
||||
_, selected_experts = torch.topk(self.moe_statics(routing_weights), self.top_k, dim=-1)
|
||||
routing_weights = torch.gather(routing_weights, dim=-1, index=selected_experts)
|
||||
routing_weights = routing_weights / torch.clamp(
|
||||
routing_weights.sum(dim=-1, keepdim=True), min=self.norm_min
|
||||
)
|
||||
routing_weights = routing_weights.to(router_logits.dtype)
|
||||
return selected_experts, routing_weights
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, sequence_length, _ = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, self.hidden_dim)
|
||||
@ -411,14 +378,14 @@ class Ernie4_5_MoeSparseMoeBlock(nn.Module):
|
||||
if self.shared_experts is not None:
|
||||
shared_output = self.shared_experts(hidden_states)
|
||||
|
||||
routing_weights, selected_experts = self.gate(hidden_states)
|
||||
selected_experts, routing_weights = self.route_tokens_to_experts(hidden_states)
|
||||
final_hidden_states = self.experts(hidden_states, selected_experts, routing_weights)
|
||||
|
||||
if self.shared_experts is not None:
|
||||
final_hidden_states = final_hidden_states + shared_output
|
||||
|
||||
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, self.hidden_dim)
|
||||
return final_hidden_states.to(hidden_states.dtype)
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
class Ernie4_5_MoeDecoderLayer(GradientCheckpointingLayer):
|
||||
@ -487,11 +454,11 @@ class Ernie4_5_MoePreTrainedModel(PreTrainedModel):
|
||||
_can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
|
||||
_supports_attention_backend = True
|
||||
_can_record_outputs = {
|
||||
"router_logits": OutputRecorder(Ernie4_5_MoeTopKRouter, layer_name="mlp.gate", index=0),
|
||||
"router_logits": OutputRecorder(nn.Linear, layer_name="mlp.gate", index=0),
|
||||
"hidden_states": Ernie4_5_MoeDecoderLayer,
|
||||
"attentions": Ernie4_5_MoeAttention,
|
||||
}
|
||||
_keep_in_fp32_modules_strict = ["gate.weight", "moe_statics"]
|
||||
_keep_in_fp32_modules_strict = ["gate", "moe_statics"]
|
||||
# Not supporting multi-token prediction (MTP) atm
|
||||
_keys_to_ignore_on_load_unexpected = ["mtp"]
|
||||
|
||||
|
||||
@ -19,7 +19,6 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache
|
||||
from ...masking_utils import create_causal_mask
|
||||
from ...modeling_outputs import MoeModelOutputWithPast
|
||||
@ -97,95 +96,62 @@ class Ernie4_5_MoeStatics(nn.Module):
|
||||
return hidden_states + self.e_score_correction_bias.squeeze()
|
||||
|
||||
|
||||
class Ernie4_5_MoeExperts(nn.Module):
|
||||
class Ernie4_5_MoeExperts(nn.ModuleList):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.num_experts = config.moe_num_experts
|
||||
self.hidden_dim = config.hidden_size
|
||||
self.intermediate_dim = config.moe_intermediate_size
|
||||
self.use_bias = config.use_bias
|
||||
self.act_fn = ACT2FN[config.hidden_act]
|
||||
|
||||
self.gate_up_proj = nn.Parameter(torch.zeros(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
|
||||
self.down_proj = nn.Parameter(torch.zeros(self.num_experts, self.hidden_dim, self.intermediate_dim))
|
||||
if self.use_bias:
|
||||
self.gate_up_proj_bias = nn.Parameter(torch.zeros(self.num_experts, 2 * self.intermediate_dim))
|
||||
self.down_proj_bias = nn.Parameter(torch.zeros(self.num_experts, self.hidden_dim))
|
||||
else:
|
||||
self.gate_up_proj_bias = None
|
||||
self.down_proj_bias = None
|
||||
for _ in range(self.num_experts):
|
||||
self.append(Ernie4_5_MoeMLP(config, config.moe_intermediate_size))
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.Tensor, selected_experts: torch.Tensor, routing_weights: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
final_hidden_states = torch.zeros_like(hidden_states)
|
||||
if selected_experts.numel() == 0:
|
||||
return final_hidden_states
|
||||
|
||||
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
|
||||
|
||||
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
|
||||
for expert_idx in expert_hit:
|
||||
expert_idx = int(expert_idx.item())
|
||||
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
|
||||
current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1])
|
||||
gate_inputs = F.linear(
|
||||
current_state,
|
||||
self.gate_up_proj[expert_idx],
|
||||
None if self.gate_up_proj_bias is None else self.gate_up_proj_bias[expert_idx],
|
||||
)
|
||||
gate, up = gate_inputs.chunk(2, dim=-1)
|
||||
current_hidden_states = self.act_fn(gate) * up
|
||||
current_hidden_states = F.linear(
|
||||
current_hidden_states,
|
||||
self.down_proj[expert_idx],
|
||||
None if self.down_proj_bias is None else self.down_proj_bias[expert_idx],
|
||||
)
|
||||
current_hidden_states = current_hidden_states * routing_weights[top_x, idx, None]
|
||||
current_hidden_states = self[expert_idx](current_state) * routing_weights[top_x, idx, None]
|
||||
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
class Ernie4_5_MoeTopKRouter(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.zeros(config.hidden_size, config.moe_num_experts, dtype=torch.float32))
|
||||
self.moe_statics = Ernie4_5_MoeStatics(config)
|
||||
self.top_k = config.moe_k
|
||||
self.norm_min = config.moe_norm_min
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
device_type = (
|
||||
hidden_states.device.type
|
||||
if isinstance(hidden_states.device.type, str) and hidden_states.device.type != "mps"
|
||||
else "cpu"
|
||||
)
|
||||
|
||||
with torch.autocast(device_type=device_type, enabled=False): # Force float32
|
||||
router_logits = F.linear(hidden_states.float(), self.weight)
|
||||
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
||||
_, selected_experts = torch.topk(self.moe_statics(routing_weights), self.top_k, dim=-1)
|
||||
routing_weights = torch.gather(routing_weights, dim=-1, index=selected_experts)
|
||||
routing_weights = routing_weights / torch.clamp(
|
||||
routing_weights.sum(dim=-1, keepdim=True), min=self.norm_min
|
||||
)
|
||||
routing_weights = routing_weights.to(router_logits.dtype)
|
||||
return routing_weights, selected_experts
|
||||
|
||||
|
||||
class Ernie4_5_MoeSparseMoeBlock(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.hidden_dim = config.hidden_size
|
||||
self.num_experts = config.moe_num_experts
|
||||
self.top_k = config.moe_k
|
||||
self.gate = Ernie4_5_MoeTopKRouter(config)
|
||||
self.norm_min = config.moe_norm_min
|
||||
|
||||
self.gate = nn.Linear(config.hidden_size, config.moe_num_experts, bias=False, dtype=torch.float32)
|
||||
self.moe_statics = Ernie4_5_MoeStatics(config)
|
||||
self.experts = Ernie4_5_MoeExperts(config)
|
||||
|
||||
self.shared_experts = None
|
||||
if config.moe_num_shared_experts > 0:
|
||||
self.shared_experts = Ernie4_5_MoeMLP(config, config.moe_intermediate_size * config.moe_num_shared_experts)
|
||||
|
||||
def route_tokens_to_experts(self, hidden_states):
|
||||
device_type = (
|
||||
hidden_states.device.type
|
||||
if isinstance(hidden_states.device.type, str) and hidden_states.device.type != "mps"
|
||||
else "cpu"
|
||||
)
|
||||
|
||||
with torch.autocast(device_type=device_type, enabled=False): # Force float32
|
||||
router_logits = self.gate(hidden_states.float())
|
||||
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
||||
_, selected_experts = torch.topk(self.moe_statics(routing_weights), self.top_k, dim=-1)
|
||||
routing_weights = torch.gather(routing_weights, dim=-1, index=selected_experts)
|
||||
routing_weights = routing_weights / torch.clamp(
|
||||
routing_weights.sum(dim=-1, keepdim=True), min=self.norm_min
|
||||
)
|
||||
routing_weights = routing_weights.to(router_logits.dtype)
|
||||
return selected_experts, routing_weights
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, sequence_length, _ = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, self.hidden_dim)
|
||||
@ -193,7 +159,7 @@ class Ernie4_5_MoeSparseMoeBlock(nn.Module):
|
||||
if self.shared_experts is not None:
|
||||
shared_output = self.shared_experts(hidden_states)
|
||||
|
||||
routing_weights, selected_experts = self.gate(hidden_states)
|
||||
selected_experts, routing_weights = self.route_tokens_to_experts(hidden_states)
|
||||
final_hidden_states = self.experts(hidden_states, selected_experts, routing_weights)
|
||||
|
||||
if self.shared_experts is not None:
|
||||
@ -227,11 +193,11 @@ class Ernie4_5_MoeDecoderLayer(Qwen3MoeDecoderLayer):
|
||||
class Ernie4_5_MoePreTrainedModel(MixtralPreTrainedModel):
|
||||
config: Ernie4_5_MoeConfig
|
||||
_no_split_modules = ["Ernie4_5_MoeDecoderLayer"]
|
||||
_keep_in_fp32_modules_strict = ["router"]
|
||||
_keep_in_fp32_modules_strict = ["gate", "moe_statics"]
|
||||
# Not supporting multi-token prediction (MTP) atm
|
||||
_keys_to_ignore_on_load_unexpected = ["mtp"]
|
||||
_can_record_outputs = {
|
||||
"router_logits": OutputRecorder(Ernie4_5_MoeTopKRouter, layer_name="mlp.router", index=0),
|
||||
"router_logits": OutputRecorder(nn.Linear, layer_name="mlp.gate", index=0),
|
||||
"hidden_states": Ernie4_5_MoeDecoderLayer,
|
||||
"attentions": Ernie4_5_MoeAttention,
|
||||
}
|
||||
|
||||
@ -203,7 +203,7 @@ class EvollaConfig(PreTrainedConfig):
|
||||
hidden_act: Optional[str] = "silu", # llama activation function
|
||||
max_position_embeddings: Optional[int] = 8192, # llama rope max length
|
||||
rms_norm_eps: Optional[int] = 1e-05,
|
||||
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
|
||||
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
|
||||
attention_bias: Optional[bool] = False,
|
||||
attention_dropout: Optional[float] = 0.0,
|
||||
mlp_bias: Optional[bool] = False,
|
||||
|
||||
@ -143,7 +143,7 @@ class Exaone4Config(PreTrainedConfig):
|
||||
bos_token_id: Optional[int] = 0,
|
||||
eos_token_id: Optional[int] = 2,
|
||||
tie_word_embeddings: Optional[bool] = False,
|
||||
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
|
||||
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
|
||||
attention_dropout: Optional[float] = 0.0,
|
||||
sliding_window: Optional[int] = 4096,
|
||||
sliding_window_pattern: Optional[int] = 4,
|
||||
|
||||
@ -176,7 +176,7 @@ class Exaone4Config(PreTrainedConfig):
|
||||
bos_token_id: Optional[int] = 0,
|
||||
eos_token_id: Optional[int] = 2,
|
||||
tie_word_embeddings: Optional[bool] = False,
|
||||
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
|
||||
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
|
||||
attention_dropout: Optional[float] = 0.0,
|
||||
sliding_window: Optional[int] = 4096,
|
||||
sliding_window_pattern: Optional[int] = 4,
|
||||
|
||||
@ -128,7 +128,7 @@ class FalconConfig(PreTrainedConfig):
|
||||
parallel_attn: Optional[bool] = True,
|
||||
bias: Optional[bool] = False,
|
||||
max_position_embeddings: Optional[int] = 2048,
|
||||
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
|
||||
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
|
||||
bos_token_id: Optional[int] = 11,
|
||||
eos_token_id: Optional[int] = 11,
|
||||
ffn_hidden_size: Optional[int] = None,
|
||||
|
||||
@ -164,7 +164,7 @@ class FalconH1Config(PreTrainedConfig):
|
||||
mamba_norm_before_gate: Optional[bool] = True,
|
||||
mamba_rms_norm: Optional[bool] = False,
|
||||
projectors_bias: Optional[bool] = False,
|
||||
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
|
||||
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
|
||||
lm_head_multiplier: Optional[float] = 1.0,
|
||||
embedding_multiplier: Optional[float] = 1.0,
|
||||
mlp_multipliers: Optional[int] = None,
|
||||
|
||||
@ -1196,20 +1196,19 @@ class FalconH1PreTrainedModel(PreTrainedModel):
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.initializer_range
|
||||
if isinstance(module, nn.Module):
|
||||
for name, param in module.named_parameters(recurse=True):
|
||||
if not param.requires_grad:
|
||||
continue
|
||||
if "layernorm" in name.lower() and "weight" in name:
|
||||
# LayerNorm weights usually initialized to 1
|
||||
param.data.fill_(1.0)
|
||||
elif "bias" in name:
|
||||
param.data.zero_()
|
||||
else:
|
||||
try:
|
||||
param.data.normal_(mean=0.0, std=std)
|
||||
except Exception as e:
|
||||
print(f"Skipping init for {name} due to error: {e}")
|
||||
for name, param in module.named_parameters(recurse=True):
|
||||
if not param.requires_grad:
|
||||
continue
|
||||
if "layernorm" in name.lower() and "weight" in name:
|
||||
# LayerNorm weights usually initialized to 1
|
||||
param.data.fill_(1.0)
|
||||
elif "bias" in name:
|
||||
param.data.zero_()
|
||||
else:
|
||||
try:
|
||||
param.data.normal_(mean=0.0, std=std)
|
||||
except Exception as e:
|
||||
print(f"Skipping init for {name} due to error: {e}")
|
||||
|
||||
|
||||
def compute_mup_vector(config):
|
||||
|
||||
@ -922,20 +922,19 @@ class FalconH1PreTrainedModel(PreTrainedModel):
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.initializer_range
|
||||
if isinstance(module, nn.Module):
|
||||
for name, param in module.named_parameters(recurse=True):
|
||||
if not param.requires_grad:
|
||||
continue
|
||||
if "layernorm" in name.lower() and "weight" in name:
|
||||
# LayerNorm weights usually initialized to 1
|
||||
param.data.fill_(1.0)
|
||||
elif "bias" in name:
|
||||
param.data.zero_()
|
||||
else:
|
||||
try:
|
||||
param.data.normal_(mean=0.0, std=std)
|
||||
except Exception as e:
|
||||
print(f"Skipping init for {name} due to error: {e}")
|
||||
for name, param in module.named_parameters(recurse=True):
|
||||
if not param.requires_grad:
|
||||
continue
|
||||
if "layernorm" in name.lower() and "weight" in name:
|
||||
# LayerNorm weights usually initialized to 1
|
||||
param.data.fill_(1.0)
|
||||
elif "bias" in name:
|
||||
param.data.zero_()
|
||||
else:
|
||||
try:
|
||||
param.data.normal_(mean=0.0, std=std)
|
||||
except Exception as e:
|
||||
print(f"Skipping init for {name} due to error: {e}")
|
||||
|
||||
|
||||
def compute_mup_vector(config):
|
||||
|
||||
@ -109,7 +109,6 @@ class FlexOlmoConfig(PreTrainedConfig):
|
||||
|
||||
model_type = "flex_olmo"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
attribute_map = {"num_local_experts": "num_experts"}
|
||||
base_model_tp_plan = {
|
||||
"layers.*.self_attn.q_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
|
||||
"layers.*.self_attn.k_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
|
||||
@ -142,7 +141,7 @@ class FlexOlmoConfig(PreTrainedConfig):
|
||||
bos_token_id: Optional[int] = None,
|
||||
eos_token_id: Optional[int] = 100257,
|
||||
tie_word_embeddings: Optional[bool] = False,
|
||||
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
|
||||
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
|
||||
attention_bias: Optional[bool] = False,
|
||||
attention_dropout: Optional[float] = 0.0,
|
||||
num_experts_per_tok: Optional[int] = 5,
|
||||
|
||||
@ -23,7 +23,6 @@ from collections.abc import Callable
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from ...activations import ACT2FN
|
||||
@ -292,77 +291,64 @@ class FlexOlmoAttention(nn.Module):
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class FlexOlmoExperts(nn.Module):
|
||||
"""Collection of expert weights stored as 3D tensors."""
|
||||
class FlexOlmoExperts(nn.ModuleList):
|
||||
"""
|
||||
ModuleList of experts.
|
||||
"""
|
||||
|
||||
def __init__(self, config: FlexOlmoConfig):
|
||||
super().__init__()
|
||||
self.num_experts = config.num_local_experts
|
||||
self.hidden_dim = config.hidden_size
|
||||
self.intermediate_dim = config.intermediate_size
|
||||
self.gate_up_proj = nn.Parameter(torch.zeros(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
|
||||
self.down_proj = nn.Parameter(torch.zeros(self.num_experts, self.hidden_dim, self.intermediate_dim))
|
||||
self.act_fn = ACT2FN[config.hidden_act]
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
top_k_index: torch.Tensor,
|
||||
top_k_weights: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
final_hidden_states = torch.zeros_like(hidden_states)
|
||||
num_experts = top_k_weights.shape[1]
|
||||
with torch.no_grad():
|
||||
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1)
|
||||
expert_mask = expert_mask.permute(2, 1, 0)
|
||||
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
|
||||
|
||||
for expert_idx in expert_hit:
|
||||
expert_idx = expert_idx[0]
|
||||
if expert_idx == num_experts:
|
||||
continue
|
||||
_, token_idx = torch.where(expert_mask[expert_idx])
|
||||
current_state = hidden_states[token_idx]
|
||||
gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
|
||||
current_hidden_states = self.act_fn(gate) * up
|
||||
current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
|
||||
current_hidden_states = current_hidden_states * top_k_weights[token_idx, expert_idx, None]
|
||||
final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))
|
||||
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
class FlexOlmoTopKRouter(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.top_k = config.num_experts_per_tok
|
||||
for _ in range(config.num_experts):
|
||||
self.append(FlexOlmoMLP(config))
|
||||
self.num_experts = config.num_experts
|
||||
self.top_k = config.num_experts_per_tok
|
||||
self.norm_topk_prob = config.norm_topk_prob
|
||||
self.hidden_dim = config.hidden_size
|
||||
self.weight = nn.Parameter(torch.zeros(self.num_experts, self.hidden_dim))
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = hidden_states.reshape(-1, self.hidden_dim)
|
||||
router_logits = F.linear(hidden_states, self.weight) # (seq_len, num_experts)
|
||||
router_logits = torch.nn.functional.softmax(router_logits, dtype=torch.float, dim=-1)
|
||||
router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k)
|
||||
if self.norm_topk_prob:
|
||||
router_top_value /= router_top_value.sum(dim=-1, keepdim=True)
|
||||
router_top_value = router_top_value.to(router_logits.dtype)
|
||||
router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value)
|
||||
return router_scores, router_indices
|
||||
def forward(
|
||||
self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
hidden_states: (batch_size * sequence_length, hidden_dim)
|
||||
selected_experts: (batch_size * sequence_length, top_k)
|
||||
routing_weights: (batch_size * sequence_length, top_k)
|
||||
Returns:
|
||||
(batch_size * sequence_length, hidden_dim)
|
||||
"""
|
||||
final_hidden_states = torch.zeros_like(hidden_states)
|
||||
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0)
|
||||
|
||||
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
|
||||
for expert_idx in expert_hit:
|
||||
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
|
||||
current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1])
|
||||
current_hidden_states = self[expert_idx](current_state) * top_k_weights[top_x, idx, None]
|
||||
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
class FlexOlmoSparseMoeBlock(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.gate = FlexOlmoTopKRouter(config)
|
||||
self.num_experts = config.num_experts
|
||||
self.top_k = config.num_experts_per_tok
|
||||
self.norm_topk_prob = config.norm_topk_prob
|
||||
self.gate = nn.Linear(config.hidden_size, self.num_experts, bias=False)
|
||||
self.experts = FlexOlmoExperts(config)
|
||||
|
||||
def route_tokens_to_experts(self, hidden_states, router_logits):
|
||||
routing_weights = torch.nn.functional.softmax(router_logits.float(), dim=-1)
|
||||
top_k_weights, top_k_index = torch.topk(routing_weights, self.top_k, dim=-1)
|
||||
if self.norm_topk_prob:
|
||||
top_k_weights /= top_k_weights.sum(dim=-1, keepdim=True)
|
||||
top_k_weights = top_k_weights.to(hidden_states.dtype)
|
||||
return top_k_index, top_k_weights
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||
top_k_weights, top_k_index = self.gate(hidden_states)
|
||||
router_logits = self.gate(hidden_states)
|
||||
top_k_index, top_k_weights = self.route_tokens_to_experts(hidden_states, router_logits)
|
||||
final_hidden_states = self.experts(hidden_states, top_k_index, top_k_weights).reshape(
|
||||
batch_size, sequence_length, hidden_dim
|
||||
)
|
||||
|
||||
@ -152,7 +152,7 @@ class FlexOlmoConfig(OlmoeConfig):
|
||||
bos_token_id: Optional[int] = None,
|
||||
eos_token_id: Optional[int] = 100257,
|
||||
tie_word_embeddings: Optional[bool] = False,
|
||||
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
|
||||
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
|
||||
attention_bias: Optional[bool] = False,
|
||||
attention_dropout: Optional[float] = 0.0,
|
||||
num_experts_per_tok: Optional[int] = 5,
|
||||
|
||||
@ -118,7 +118,7 @@ class FuyuConfig(PreTrainedConfig):
|
||||
layer_norm_eps: Optional[int] = 1e-5,
|
||||
use_cache: Optional[bool] = True,
|
||||
tie_word_embeddings: Optional[bool] = False,
|
||||
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
|
||||
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
|
||||
qk_layernorm: Optional[bool] = True,
|
||||
hidden_dropout: Optional[float] = 0.0,
|
||||
attention_dropout: Optional[float] = 0.0,
|
||||
|
||||
@ -131,7 +131,7 @@ class GemmaConfig(PreTrainedConfig):
|
||||
eos_token_id: Optional[int] = 1,
|
||||
bos_token_id: Optional[int] = 2,
|
||||
tie_word_embeddings: Optional[bool] = True,
|
||||
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
|
||||
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
|
||||
attention_bias: Optional[bool] = False,
|
||||
attention_dropout: Optional[float] = 0.0,
|
||||
use_bidirectional_attention: Optional[bool] = None,
|
||||
|
||||
@ -158,7 +158,7 @@ class GemmaConfig(PreTrainedConfig):
|
||||
eos_token_id: Optional[int] = 1,
|
||||
bos_token_id: Optional[int] = 2,
|
||||
tie_word_embeddings: Optional[bool] = True,
|
||||
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
|
||||
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
|
||||
attention_bias: Optional[bool] = False,
|
||||
attention_dropout: Optional[float] = 0.0,
|
||||
use_bidirectional_attention: Optional[bool] = None,
|
||||
|
||||
@ -142,7 +142,7 @@ class Gemma2Config(PreTrainedConfig):
|
||||
eos_token_id: Optional[int] = 1,
|
||||
bos_token_id: Optional[int] = 2,
|
||||
tie_word_embeddings: Optional[bool] = True,
|
||||
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
|
||||
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
|
||||
attention_bias: Optional[bool] = False,
|
||||
attention_dropout: Optional[float] = 0.0,
|
||||
query_pre_attn_scalar: Optional[int] = 256,
|
||||
|
||||
@ -171,7 +171,7 @@ class Gemma2Config(PreTrainedConfig):
|
||||
eos_token_id: Optional[int] = 1,
|
||||
bos_token_id: Optional[int] = 2,
|
||||
tie_word_embeddings: Optional[bool] = True,
|
||||
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
|
||||
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
|
||||
attention_bias: Optional[bool] = False,
|
||||
attention_dropout: Optional[float] = 0.0,
|
||||
query_pre_attn_scalar: Optional[int] = 256,
|
||||
|
||||
@ -177,7 +177,7 @@ class Gemma3nTextConfig(PreTrainedConfig):
|
||||
pad_token_id: int = 0,
|
||||
eos_token_id: int = 1,
|
||||
bos_token_id: int = 2,
|
||||
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
|
||||
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
|
||||
attention_bias: bool = False,
|
||||
attention_dropout: float = 0.0,
|
||||
sliding_window: int = 512,
|
||||
|
||||
@ -187,7 +187,7 @@ class Gemma3nTextConfig(Gemma2Config, PreTrainedConfig):
|
||||
pad_token_id: int = 0,
|
||||
eos_token_id: int = 1,
|
||||
bos_token_id: int = 2,
|
||||
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
|
||||
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
|
||||
attention_bias: bool = False,
|
||||
attention_dropout: float = 0.0,
|
||||
sliding_window: int = 512,
|
||||
|
||||
@ -121,7 +121,7 @@ class GlmConfig(PreTrainedConfig):
|
||||
rms_norm_eps: Optional[float] = 0.00000015625,
|
||||
use_cache: Optional[bool] = True,
|
||||
tie_word_embeddings: Optional[bool] = False,
|
||||
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
|
||||
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
|
||||
pad_token_id: Optional[int] = 151329,
|
||||
eos_token_id: Optional[list[int]] = [151329, 151336, 151338],
|
||||
bos_token_id: Optional[int] = None,
|
||||
|
||||
@ -122,7 +122,7 @@ class Glm4Config(PreTrainedConfig):
|
||||
rms_norm_eps: Optional[float] = 0.00000015625,
|
||||
use_cache: Optional[bool] = True,
|
||||
tie_word_embeddings: Optional[bool] = False,
|
||||
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
|
||||
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
|
||||
pad_token_id: Optional[int] = 151329,
|
||||
eos_token_id: Optional[list[int]] = [151329, 151336, 151338],
|
||||
bos_token_id: Optional[int] = None,
|
||||
|
||||
@ -152,7 +152,7 @@ class Glm4MoeConfig(PreTrainedConfig):
|
||||
rms_norm_eps: Optional[int] = 1e-5,
|
||||
use_cache: Optional[bool] = True,
|
||||
tie_word_embeddings: Optional[bool] = False,
|
||||
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
|
||||
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
|
||||
attention_bias: Optional[bool] = False,
|
||||
attention_dropout: Optional[float] = 0.0,
|
||||
moe_intermediate_size: Optional[int] = 1408,
|
||||
|
||||
@ -330,43 +330,37 @@ class Glm4MoeRMSNorm(nn.Module):
|
||||
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
||||
|
||||
|
||||
class Glm4MoeNaiveMoe(nn.Module):
|
||||
"""Collection of expert weights stored as 3D tensors."""
|
||||
class Glm4MoeNaiveMoe(nn.ModuleList):
|
||||
"""
|
||||
ModuleList of experts.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.num_experts = config.num_local_experts
|
||||
self.hidden_dim = config.hidden_size
|
||||
self.intermediate_dim = config.intermediate_size
|
||||
self.gate_up_proj = nn.Parameter(torch.zeros(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
|
||||
self.down_proj = nn.Parameter(torch.zeros(self.num_experts, self.hidden_dim, self.intermediate_dim))
|
||||
self.act_fn = ACT2FN[config.hidden_act]
|
||||
for _ in range(self.num_experts):
|
||||
self.append(Glm4MoeMLP(config, intermediate_size=config.moe_intermediate_size))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
top_k_index: torch.Tensor,
|
||||
top_k_weights: torch.Tensor,
|
||||
self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
hidden_states: (batch_size * sequence_length, hidden_dim)
|
||||
selected_experts: (batch_size * sequence_length, top_k)
|
||||
routing_weights: (batch_size * sequence_length, top_k)
|
||||
Returns:
|
||||
(batch_size * sequence_length, hidden_dim)
|
||||
"""
|
||||
final_hidden_states = torch.zeros_like(hidden_states)
|
||||
num_experts = top_k_weights.shape[1]
|
||||
with torch.no_grad():
|
||||
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1)
|
||||
expert_mask = expert_mask.permute(2, 1, 0)
|
||||
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
|
||||
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0)
|
||||
|
||||
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
|
||||
for expert_idx in expert_hit:
|
||||
expert_idx = expert_idx[0]
|
||||
if expert_idx == num_experts:
|
||||
continue
|
||||
_, token_idx = torch.where(expert_mask[expert_idx])
|
||||
current_state = hidden_states[token_idx]
|
||||
gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
|
||||
current_hidden_states = self.act_fn(gate) * up
|
||||
current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
|
||||
current_hidden_states = current_hidden_states * top_k_weights[token_idx, expert_idx, None]
|
||||
final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))
|
||||
|
||||
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
|
||||
current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1])
|
||||
current_hidden_states = self[expert_idx](current_state) * top_k_weights[top_x, idx, None]
|
||||
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
|
||||
@ -166,7 +166,7 @@ class Glm4MoeConfig(PreTrainedConfig):
|
||||
rms_norm_eps: Optional[int] = 1e-5,
|
||||
use_cache: Optional[bool] = True,
|
||||
tie_word_embeddings: Optional[bool] = False,
|
||||
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
|
||||
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
|
||||
attention_bias: Optional[bool] = False,
|
||||
attention_dropout: Optional[float] = 0.0,
|
||||
moe_intermediate_size: Optional[int] = 1408,
|
||||
|
||||
@ -220,7 +220,7 @@ class Glm4vTextConfig(PreTrainedConfig):
|
||||
use_cache: Optional[bool] = True,
|
||||
tie_word_embeddings: Optional[bool] = False,
|
||||
attention_dropout: Optional[float] = 0.0,
|
||||
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
|
||||
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
|
||||
image_token_id: Optional[int] = None,
|
||||
video_token_id: Optional[int] = None,
|
||||
**kwargs,
|
||||
|
||||
@ -257,7 +257,7 @@ class Glm4vTextConfig(PreTrainedConfig):
|
||||
use_cache: Optional[bool] = True,
|
||||
tie_word_embeddings: Optional[bool] = False,
|
||||
attention_dropout: Optional[float] = 0.0,
|
||||
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
|
||||
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
|
||||
image_token_id: Optional[int] = None,
|
||||
video_token_id: Optional[int] = None,
|
||||
**kwargs,
|
||||
|
||||
@ -242,7 +242,7 @@ class Glm4vMoeTextConfig(PreTrainedConfig):
|
||||
rms_norm_eps: Optional[int] = 1e-5,
|
||||
use_cache: Optional[bool] = True,
|
||||
tie_word_embeddings: Optional[bool] = False,
|
||||
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
|
||||
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
|
||||
attention_bias: Optional[bool] = True,
|
||||
attention_dropout: Optional[float] = 0.0,
|
||||
moe_intermediate_size: Optional[int] = 1408,
|
||||
|
||||
@ -351,43 +351,37 @@ class Glm4vMoeTextTopkRouter(nn.Module):
|
||||
return router_logits
|
||||
|
||||
|
||||
class Glm4vMoeTextNaiveMoe(nn.Module):
|
||||
"""Collection of expert weights stored as 3D tensors."""
|
||||
class Glm4vMoeTextNaiveMoe(nn.ModuleList):
|
||||
"""
|
||||
ModuleList of experts.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.num_experts = config.num_local_experts
|
||||
self.hidden_dim = config.hidden_size
|
||||
self.intermediate_dim = config.intermediate_size
|
||||
self.gate_up_proj = nn.Parameter(torch.zeros(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
|
||||
self.down_proj = nn.Parameter(torch.zeros(self.num_experts, self.hidden_dim, self.intermediate_dim))
|
||||
self.act_fn = ACT2FN[config.hidden_act]
|
||||
for _ in range(self.num_experts):
|
||||
self.append(Glm4vMoeTextMLP(config, intermediate_size=config.moe_intermediate_size))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
top_k_index: torch.Tensor,
|
||||
top_k_weights: torch.Tensor,
|
||||
self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
hidden_states: (batch_size * sequence_length, hidden_dim)
|
||||
selected_experts: (batch_size * sequence_length, top_k)
|
||||
routing_weights: (batch_size * sequence_length, top_k)
|
||||
Returns:
|
||||
(batch_size * sequence_length, hidden_dim)
|
||||
"""
|
||||
final_hidden_states = torch.zeros_like(hidden_states)
|
||||
num_experts = top_k_weights.shape[1]
|
||||
with torch.no_grad():
|
||||
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1)
|
||||
expert_mask = expert_mask.permute(2, 1, 0)
|
||||
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
|
||||
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0)
|
||||
|
||||
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
|
||||
for expert_idx in expert_hit:
|
||||
expert_idx = expert_idx[0]
|
||||
if expert_idx == num_experts:
|
||||
continue
|
||||
_, token_idx = torch.where(expert_mask[expert_idx])
|
||||
current_state = hidden_states[token_idx]
|
||||
gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
|
||||
current_hidden_states = self.act_fn(gate) * up
|
||||
current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
|
||||
current_hidden_states = current_hidden_states * top_k_weights[token_idx, expert_idx, None]
|
||||
final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))
|
||||
|
||||
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
|
||||
current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1])
|
||||
current_hidden_states = self[expert_idx](current_state) * top_k_weights[top_x, idx, None]
|
||||
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
|
||||
@ -183,7 +183,7 @@ class Glm4vMoeTextConfig(Glm4MoeConfig):
|
||||
rms_norm_eps: Optional[int] = 1e-5,
|
||||
use_cache: Optional[bool] = True,
|
||||
tie_word_embeddings: Optional[bool] = False,
|
||||
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
|
||||
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
|
||||
attention_bias: Optional[bool] = True,
|
||||
attention_dropout: Optional[float] = 0.0,
|
||||
moe_intermediate_size: Optional[int] = 1408,
|
||||
|
||||
@ -503,11 +503,10 @@ class GPT2PreTrainedModel(PreTrainedModel):
|
||||
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
|
||||
#
|
||||
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
|
||||
if isinstance(module, PreTrainedModel):
|
||||
for name, p in module.named_parameters():
|
||||
if name == "c_proj.weight":
|
||||
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
|
||||
p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer)))
|
||||
for name, p in module.named_parameters():
|
||||
if name == "c_proj.weight":
|
||||
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
|
||||
p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer)))
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@ -131,7 +131,7 @@ class GPTNeoXConfig(PreTrainedConfig):
|
||||
eos_token_id: Optional[int] = 2,
|
||||
tie_word_embeddings: Optional[bool] = False,
|
||||
use_parallel_residual: Optional[bool] = True,
|
||||
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
|
||||
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
|
||||
attention_bias: Optional[bool] = True,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
@ -100,7 +100,7 @@ class GPTNeoXJapaneseConfig(PreTrainedConfig):
|
||||
use_cache: Optional[bool] = True,
|
||||
bos_token_id: Optional[int] = 31996,
|
||||
eos_token_id: Optional[int] = 31999,
|
||||
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
|
||||
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
|
||||
attention_dropout: Optional[float] = 0.1,
|
||||
hidden_dropout: Optional[float] = 0.0,
|
||||
**kwargs,
|
||||
|
||||
@ -71,10 +71,10 @@ class GptOssExperts(nn.Module):
|
||||
self.num_experts = config.num_local_experts
|
||||
self.hidden_size = config.hidden_size
|
||||
self.expert_dim = self.intermediate_size
|
||||
self.gate_up_proj = nn.Parameter(torch.zeros(self.num_experts, self.hidden_size, 2 * self.expert_dim))
|
||||
self.gate_up_proj_bias = nn.Parameter(torch.zeros(self.num_experts, 2 * self.expert_dim))
|
||||
self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, 2 * self.expert_dim))
|
||||
self.gate_up_proj_bias = nn.Parameter(torch.empty(self.num_experts, 2 * self.expert_dim))
|
||||
self.down_proj = nn.Parameter(torch.empty((self.num_experts, self.expert_dim, self.hidden_size)))
|
||||
self.down_proj_bias = nn.Parameter(torch.zeros(self.num_experts, self.hidden_size))
|
||||
self.down_proj_bias = nn.Parameter(torch.empty(self.num_experts, self.hidden_size))
|
||||
self.alpha = 1.702
|
||||
self.limit = 7.0
|
||||
|
||||
@ -146,8 +146,8 @@ class GptOssTopKRouter(nn.Module):
|
||||
self.top_k = config.num_experts_per_tok
|
||||
self.num_experts = config.num_local_experts
|
||||
self.hidden_dim = config.hidden_size
|
||||
self.weight = nn.Parameter(torch.zeros(self.num_experts, self.hidden_dim))
|
||||
self.bias = nn.Parameter(torch.zeros(self.num_experts))
|
||||
self.weight = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim))
|
||||
self.bias = nn.Parameter(torch.empty(self.num_experts))
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = hidden_states.reshape(-1, self.hidden_dim)
|
||||
|
||||
@ -69,10 +69,10 @@ class GptOssExperts(nn.Module):
|
||||
self.num_experts = config.num_local_experts
|
||||
self.hidden_size = config.hidden_size
|
||||
self.expert_dim = self.intermediate_size
|
||||
self.gate_up_proj = nn.Parameter(torch.zeros(self.num_experts, self.hidden_size, 2 * self.expert_dim))
|
||||
self.gate_up_proj_bias = nn.Parameter(torch.zeros(self.num_experts, 2 * self.expert_dim))
|
||||
self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, 2 * self.expert_dim))
|
||||
self.gate_up_proj_bias = nn.Parameter(torch.empty(self.num_experts, 2 * self.expert_dim))
|
||||
self.down_proj = nn.Parameter(torch.empty((self.num_experts, self.expert_dim, self.hidden_size)))
|
||||
self.down_proj_bias = nn.Parameter(torch.zeros(self.num_experts, self.hidden_size))
|
||||
self.down_proj_bias = nn.Parameter(torch.empty(self.num_experts, self.hidden_size))
|
||||
self.alpha = 1.702
|
||||
self.limit = 7.0
|
||||
|
||||
@ -144,8 +144,8 @@ class GptOssTopKRouter(nn.Module):
|
||||
self.top_k = config.num_experts_per_tok
|
||||
self.num_experts = config.num_local_experts
|
||||
self.hidden_dim = config.hidden_size
|
||||
self.weight = nn.Parameter(torch.zeros(self.num_experts, self.hidden_dim))
|
||||
self.bias = nn.Parameter(torch.zeros(self.num_experts))
|
||||
self.weight = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim))
|
||||
self.bias = nn.Parameter(torch.empty(self.num_experts))
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = hidden_states.reshape(-1, self.hidden_dim)
|
||||
|
||||
@ -141,7 +141,7 @@ class GraniteConfig(PreTrainedConfig):
|
||||
bos_token_id: Optional[int] = 1,
|
||||
eos_token_id: Optional[int] = 2,
|
||||
tie_word_embeddings: Optional[bool] = False,
|
||||
rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
|
||||
rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None,
|
||||
attention_bias: Optional[bool] = False,
|
||||
attention_dropout: Optional[float] = 0.0,
|
||||
mlp_bias: Optional[bool] = False,
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user